Coverage for python/felis/tap.py: 12%
253 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-12 10:48 -0700
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-12 10:48 -0700
1# This file is part of felis.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["Tap11Base", "TapLoadingVisitor", "init_tables"]
26import logging
27from collections.abc import Iterable, Mapping, MutableMapping
28from typing import Any, Optional, Union
30from sqlalchemy import Column, Integer, String
31from sqlalchemy.engine import Engine
32from sqlalchemy.engine.mock import MockConnection
33from sqlalchemy.ext.declarative import declarative_base
34from sqlalchemy.orm import Session, sessionmaker
35from sqlalchemy.schema import MetaData
36from sqlalchemy.sql.expression import Insert, insert
38from .check import FelisValidator
39from .types import FelisType
40from .visitor import Visitor
42_Mapping = Mapping[str, Any]
44Tap11Base: Any = declarative_base() # Any to avoid mypy mess with SA 2
45logger = logging.getLogger("felis")
47IDENTIFIER_LENGTH = 128
48SMALL_FIELD_LENGTH = 32
49SIMPLE_FIELD_LENGTH = 128
50TEXT_FIELD_LENGTH = 2048
51QUALIFIED_TABLE_LENGTH = 3 * IDENTIFIER_LENGTH + 2
53_init_table_once = False
56def init_tables(
57 tap_schema_name: Optional[str] = None,
58 tap_tables_postfix: Optional[str] = None,
59 tap_schemas_table: Optional[str] = None,
60 tap_tables_table: Optional[str] = None,
61 tap_columns_table: Optional[str] = None,
62 tap_keys_table: Optional[str] = None,
63 tap_key_columns_table: Optional[str] = None,
64) -> MutableMapping[str, Any]:
65 postfix = tap_tables_postfix or ""
67 # Dirty hack to enable this method to be called more than once, replaces
68 # MetaData instance with a fresh copy if called more than once.
69 # TODO: probably replace ORM stuff with core sqlalchemy functions.
70 global _init_table_once
71 if not _init_table_once:
72 _init_table_once = True
73 else:
74 Tap11Base.metadata = MetaData()
76 if tap_schema_name:
77 Tap11Base.metadata.schema = tap_schema_name
79 class Tap11Schemas(Tap11Base):
80 __tablename__ = (tap_schemas_table or "schemas") + postfix
81 schema_name = Column(String(IDENTIFIER_LENGTH), primary_key=True, nullable=False)
82 utype = Column(String(SIMPLE_FIELD_LENGTH))
83 description = Column(String(TEXT_FIELD_LENGTH))
84 schema_index = Column(Integer)
86 class Tap11Tables(Tap11Base):
87 __tablename__ = (tap_tables_table or "tables") + postfix
88 schema_name = Column(String(IDENTIFIER_LENGTH), nullable=False)
89 table_name = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False, primary_key=True)
90 table_type = Column(String(SMALL_FIELD_LENGTH), nullable=False)
91 utype = Column(String(SIMPLE_FIELD_LENGTH))
92 description = Column(String(TEXT_FIELD_LENGTH))
93 table_index = Column(Integer)
95 class Tap11Columns(Tap11Base):
96 __tablename__ = (tap_columns_table or "columns") + postfix
97 table_name = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False, primary_key=True)
98 column_name = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True)
99 datatype = Column(String(SIMPLE_FIELD_LENGTH), nullable=False)
100 arraysize = Column(String(10))
101 xtype = Column(String(SIMPLE_FIELD_LENGTH))
102 # Size is deprecated
103 # size = Column(Integer(), quote=True)
104 description = Column(String(TEXT_FIELD_LENGTH))
105 utype = Column(String(SIMPLE_FIELD_LENGTH))
106 unit = Column(String(SIMPLE_FIELD_LENGTH))
107 ucd = Column(String(SIMPLE_FIELD_LENGTH))
108 indexed = Column(Integer, nullable=False)
109 principal = Column(Integer, nullable=False)
110 std = Column(Integer, nullable=False)
111 column_index = Column(Integer)
113 class Tap11Keys(Tap11Base):
114 __tablename__ = (tap_keys_table or "keys") + postfix
115 key_id = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True)
116 from_table = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False)
117 target_table = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False)
118 description = Column(String(TEXT_FIELD_LENGTH))
119 utype = Column(String(SIMPLE_FIELD_LENGTH))
121 class Tap11KeyColumns(Tap11Base):
122 __tablename__ = (tap_key_columns_table or "key_columns") + postfix
123 key_id = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True)
124 from_column = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True)
125 target_column = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True)
127 return dict(
128 schemas=Tap11Schemas,
129 tables=Tap11Tables,
130 columns=Tap11Columns,
131 keys=Tap11Keys,
132 key_columns=Tap11KeyColumns,
133 )
136class TapLoadingVisitor(Visitor[None, tuple, Tap11Base, None, tuple, None]):
137 def __init__(
138 self,
139 engine: Engine | None,
140 catalog_name: Optional[str] = None,
141 schema_name: Optional[str] = None,
142 tap_tables: Optional[MutableMapping[str, Any]] = None,
143 ):
144 self.graph_index: MutableMapping[str, Any] = {}
145 self.catalog_name = catalog_name
146 self.schema_name = schema_name
147 self.engine = engine
148 self._mock_connection: MockConnection | None = None
149 self.tables = tap_tables or init_tables()
150 self.checker = FelisValidator()
152 @classmethod
153 def from_mock_connection(
154 cls,
155 mock_connection: MockConnection,
156 catalog_name: Optional[str] = None,
157 schema_name: Optional[str] = None,
158 tap_tables: Optional[MutableMapping[str, Any]] = None,
159 ) -> TapLoadingVisitor:
160 visitor = cls(engine=None, catalog_name=catalog_name, schema_name=schema_name, tap_tables=tap_tables)
161 visitor._mock_connection = mock_connection
162 return visitor
164 def visit_schema(self, schema_obj: _Mapping) -> None:
165 self.checker.check_schema(schema_obj)
166 schema = self.tables["schemas"]()
167 # Override with default
168 self.schema_name = self.schema_name or schema_obj["name"]
170 schema.schema_name = self._schema_name()
171 schema.description = schema_obj.get("description")
172 schema.utype = schema_obj.get("votable:utype")
173 schema.schema_index = int(schema_obj.get("tap:schema_index", 0))
175 if self.engine is not None:
176 session: Session = sessionmaker(self.engine)()
177 session.add(schema)
178 for table_obj in schema_obj["tables"]:
179 table, columns, keys, key_columns = self.visit_table(table_obj, schema_obj)
180 session.add(table)
181 session.add_all(columns)
182 session.add_all(keys)
183 session.add_all(key_columns)
184 session.commit()
185 else:
186 # Only if we are mocking (dry run)
187 assert self._mock_connection is not None, "Mock connection must not be None"
188 conn = self._mock_connection
189 conn.execute(_insert(self.tables["schemas"], schema))
190 for table_obj in schema_obj["tables"]:
191 table, columns, keys, key_columns = self.visit_table(table_obj, schema_obj)
192 conn.execute(_insert(self.tables["tables"], table))
193 for column in columns:
194 conn.execute(_insert(self.tables["columns"], column))
195 for key in keys:
196 conn.execute(_insert(self.tables["keys"], key))
197 for key_column in key_columns:
198 conn.execute(_insert(self.tables["key_columns"], key_column))
200 def visit_table(self, table_obj: _Mapping, schema_obj: _Mapping) -> tuple:
201 self.checker.check_table(table_obj, schema_obj)
202 table_id = table_obj["@id"]
203 table = self.tables["tables"]()
204 table.schema_name = self._schema_name()
205 table.table_name = self._table_name(table_obj["name"])
206 table.table_type = "table"
207 table.utype = table_obj.get("votable:utype")
208 table.description = table_obj.get("description")
209 table.table_index = int(table_obj.get("tap:table_index", 0))
211 columns = [self.visit_column(c, table_obj) for c in table_obj["columns"]]
212 self.visit_primary_key(table_obj.get("primaryKey", []), table_obj)
213 all_keys = []
214 all_key_columns = []
215 for c in table_obj.get("constraints", []):
216 key, key_columns = self.visit_constraint(c, table)
217 if not key:
218 continue
219 all_keys.append(key)
220 all_key_columns += key_columns
222 for i in table_obj.get("indexes", []):
223 self.visit_index(i, table)
225 self.graph_index[table_id] = table
226 return table, columns, all_keys, all_key_columns
228 def check_column(self, column_obj: _Mapping, table_obj: _Mapping) -> None:
229 self.checker.check_column(column_obj, table_obj)
230 _id = column_obj["@id"]
231 # Guaranteed to exist at this point, for mypy use "" as default
232 datatype_name = column_obj.get("datatype", "")
233 felis_type = FelisType.felis_type(datatype_name)
234 if felis_type.is_sized:
235 # It is expected that both arraysize and length are fine for
236 # length types.
237 arraysize = column_obj.get("votable:arraysize", column_obj.get("length"))
238 if arraysize is None:
239 logger.warning(
240 f"votable:arraysize and length for {_id} are None for type {datatype_name}. "
241 'Using length "*". '
242 "Consider setting `votable:arraysize` or `length`."
243 )
244 if felis_type.is_timestamp:
245 # datetime types really should have a votable:arraysize, because
246 # they are converted to strings and the `length` is loosely to the
247 # string size
248 if "votable:arraysize" not in column_obj:
249 logger.warning(
250 f"votable:arraysize for {_id} is None for type {datatype_name}. "
251 f'Using length "*". '
252 "Consider setting `votable:arraysize` to an appropriate size for "
253 "materialized datetime/timestamp strings."
254 )
256 def visit_column(self, column_obj: _Mapping, table_obj: _Mapping) -> Tap11Base:
257 self.check_column(column_obj, table_obj)
258 column_id = column_obj["@id"]
259 table_name = self._table_name(table_obj["name"])
261 column = self.tables["columns"]()
262 column.table_name = table_name
263 column.column_name = column_obj["name"]
265 felis_datatype = column_obj["datatype"]
266 felis_type = FelisType.felis_type(felis_datatype)
267 column.datatype = column_obj.get("votable:datatype", felis_type.votable_name)
269 arraysize = None
270 if felis_type.is_sized:
271 # prefer votable:arraysize to length, fall back to `*`
272 arraysize = column_obj.get("votable:arraysize", column_obj.get("length", "*"))
273 if felis_type.is_timestamp:
274 arraysize = column_obj.get("votable:arraysize", "*")
275 column.arraysize = arraysize
277 column.xtype = column_obj.get("votable:xtype")
278 column.description = column_obj.get("description")
279 column.utype = column_obj.get("votable:utype")
281 unit = column_obj.get("ivoa:unit") or column_obj.get("fits:tunit")
282 column.unit = unit
283 column.ucd = column_obj.get("ivoa:ucd")
285 # We modify this after we process columns
286 column.indexed = 0
288 column.principal = column_obj.get("tap:principal", 0)
289 column.std = column_obj.get("tap:std", 0)
290 column.column_index = column_obj.get("tap:column_index")
292 self.graph_index[column_id] = column
293 return column
295 def visit_primary_key(self, primary_key_obj: Union[str, Iterable[str]], table_obj: _Mapping) -> None:
296 self.checker.check_primary_key(primary_key_obj, table_obj)
297 if primary_key_obj:
298 if isinstance(primary_key_obj, str):
299 primary_key_obj = [primary_key_obj]
300 columns = [self.graph_index[c_id] for c_id in primary_key_obj]
301 # if just one column and it's indexed, update the object
302 if len(columns) == 1:
303 columns[0].indexed = 1
304 return None
306 def visit_constraint(self, constraint_obj: _Mapping, table_obj: _Mapping) -> tuple:
307 self.checker.check_constraint(constraint_obj, table_obj)
308 constraint_type = constraint_obj["@type"]
309 key = None
310 key_columns = []
311 if constraint_type == "ForeignKey":
312 constraint_name = constraint_obj["name"]
313 description = constraint_obj.get("description")
314 utype = constraint_obj.get("votable:utype")
316 columns = [self.graph_index[c_id] for c_id in constraint_obj.get("columns", [])]
317 refcolumns = [self.graph_index[c_id] for c_id in constraint_obj.get("referencedColumns", [])]
319 table_name = None
320 for column in columns:
321 if not table_name:
322 table_name = column.table_name
323 if table_name != column.table_name:
324 raise ValueError("Inconsisent use of table names")
326 table_name = None
327 for column in refcolumns:
328 if not table_name:
329 table_name = column.table_name
330 if table_name != column.table_name:
331 raise ValueError("Inconsisent use of table names")
332 first_column = columns[0]
333 first_refcolumn = refcolumns[0]
335 key = self.tables["keys"]()
336 key.key_id = constraint_name
337 key.from_table = first_column.table_name
338 key.target_table = first_refcolumn.table_name
339 key.description = description
340 key.utype = utype
341 for column, refcolumn in zip(columns, refcolumns):
342 key_column = self.tables["key_columns"]()
343 key_column.key_id = constraint_name
344 key_column.from_column = column.column_name
345 key_column.target_column = refcolumn.column_name
346 key_columns.append(key_column)
347 return key, key_columns
349 def visit_index(self, index_obj: _Mapping, table_obj: _Mapping) -> None:
350 self.checker.check_index(index_obj, table_obj)
351 columns = [self.graph_index[c_id] for c_id in index_obj.get("columns", [])]
352 # if just one column and it's indexed, update the object
353 if len(columns) == 1:
354 columns[0].indexed = 1
355 return None
357 def _schema_name(self, schema_name: Optional[str] = None) -> Optional[str]:
358 # If _schema_name is None, SQLAlchemy will catch it
359 _schema_name = schema_name or self.schema_name
360 if self.catalog_name and _schema_name:
361 return ".".join([self.catalog_name, _schema_name])
362 return _schema_name
364 def _table_name(self, table_name: str) -> str:
365 schema_name = self._schema_name()
366 if schema_name:
367 return ".".join([schema_name, table_name])
368 return table_name
371def _insert(table: Tap11Base, value: Any) -> Insert:
372 """
373 Return a SQLAlchemy insert statement based on
374 :param table: The table we are inserting to
375 :param value: An object representing the object we are inserting
376 to the table
377 :return: A SQLAlchemy insert statement
378 """
379 values_dict = {}
380 for i in table.__table__.columns:
381 name = i.name
382 column_value = getattr(value, i.name)
383 if type(column_value) == str:
384 column_value = column_value.replace("'", "''")
385 values_dict[name] = column_value
386 return insert(table).values(values_dict)