Coverage for python/felis/tap.py: 12%
262 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-20 02:40 -0700
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-20 02:40 -0700
1# This file is part of felis.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["Tap11Base", "TapLoadingVisitor", "init_tables"]
26import logging
27from collections.abc import Iterable, Mapping, MutableMapping
28from typing import Any
30from sqlalchemy import Column, Integer, String
31from sqlalchemy.engine import Engine
32from sqlalchemy.engine.mock import MockConnection
33from sqlalchemy.orm import Session, declarative_base, sessionmaker
34from sqlalchemy.schema import MetaData
35from sqlalchemy.sql.expression import Insert, insert
37from .check import FelisValidator
38from .types import FelisType
39from .visitor import Visitor
41_Mapping = Mapping[str, Any]
43Tap11Base: Any = declarative_base() # Any to avoid mypy mess with SA 2
44logger = logging.getLogger("felis")
46IDENTIFIER_LENGTH = 128
47SMALL_FIELD_LENGTH = 32
48SIMPLE_FIELD_LENGTH = 128
49TEXT_FIELD_LENGTH = 2048
50QUALIFIED_TABLE_LENGTH = 3 * IDENTIFIER_LENGTH + 2
52_init_table_once = False
55def init_tables(
56 tap_schema_name: str | None = None,
57 tap_tables_postfix: str | None = None,
58 tap_schemas_table: str | None = None,
59 tap_tables_table: str | None = None,
60 tap_columns_table: str | None = None,
61 tap_keys_table: str | None = None,
62 tap_key_columns_table: str | None = None,
63) -> MutableMapping[str, Any]:
64 """Generate definitions for TAP tables."""
65 postfix = tap_tables_postfix or ""
67 # Dirty hack to enable this method to be called more than once, replaces
68 # MetaData instance with a fresh copy if called more than once.
69 # TODO: probably replace ORM stuff with core sqlalchemy functions.
70 global _init_table_once
71 if not _init_table_once:
72 _init_table_once = True
73 else:
74 Tap11Base.metadata = MetaData()
76 if tap_schema_name:
77 Tap11Base.metadata.schema = tap_schema_name
79 class Tap11Schemas(Tap11Base):
80 __tablename__ = (tap_schemas_table or "schemas") + postfix
81 schema_name = Column(String(IDENTIFIER_LENGTH), primary_key=True, nullable=False)
82 utype = Column(String(SIMPLE_FIELD_LENGTH))
83 description = Column(String(TEXT_FIELD_LENGTH))
84 schema_index = Column(Integer)
86 class Tap11Tables(Tap11Base):
87 __tablename__ = (tap_tables_table or "tables") + postfix
88 schema_name = Column(String(IDENTIFIER_LENGTH), nullable=False)
89 table_name = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False, primary_key=True)
90 table_type = Column(String(SMALL_FIELD_LENGTH), nullable=False)
91 utype = Column(String(SIMPLE_FIELD_LENGTH))
92 description = Column(String(TEXT_FIELD_LENGTH))
93 table_index = Column(Integer)
95 class Tap11Columns(Tap11Base):
96 __tablename__ = (tap_columns_table or "columns") + postfix
97 table_name = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False, primary_key=True)
98 column_name = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True)
99 datatype = Column(String(SIMPLE_FIELD_LENGTH), nullable=False)
100 arraysize = Column(String(10))
101 xtype = Column(String(SIMPLE_FIELD_LENGTH))
102 # Size is deprecated
103 # size = Column(Integer(), quote=True)
104 description = Column(String(TEXT_FIELD_LENGTH))
105 utype = Column(String(SIMPLE_FIELD_LENGTH))
106 unit = Column(String(SIMPLE_FIELD_LENGTH))
107 ucd = Column(String(SIMPLE_FIELD_LENGTH))
108 indexed = Column(Integer, nullable=False)
109 principal = Column(Integer, nullable=False)
110 std = Column(Integer, nullable=False)
111 column_index = Column(Integer)
113 class Tap11Keys(Tap11Base):
114 __tablename__ = (tap_keys_table or "keys") + postfix
115 key_id = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True)
116 from_table = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False)
117 target_table = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False)
118 description = Column(String(TEXT_FIELD_LENGTH))
119 utype = Column(String(SIMPLE_FIELD_LENGTH))
121 class Tap11KeyColumns(Tap11Base):
122 __tablename__ = (tap_key_columns_table or "key_columns") + postfix
123 key_id = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True)
124 from_column = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True)
125 target_column = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True)
127 return dict(
128 schemas=Tap11Schemas,
129 tables=Tap11Tables,
130 columns=Tap11Columns,
131 keys=Tap11Keys,
132 key_columns=Tap11KeyColumns,
133 )
136class TapLoadingVisitor(Visitor[None, tuple, Tap11Base, None, tuple, None, None]):
137 """Felis schema visitor for generating TAP schema.
139 Parameters
140 ----------
141 engine : `sqlalchemy.engine.Engine` or `None`
142 SQLAlchemy engine instance.
143 catalog_name : `str` or `None`
144 Name of the database catalog.
145 schema_name : `str` or `None`
146 Name of the database schema.
147 tap_tables : `~collections.abc.Mapping`
148 Optional mapping of table name to its declarative base class.
149 """
151 def __init__(
152 self,
153 engine: Engine | None,
154 catalog_name: str | None = None,
155 schema_name: str | None = None,
156 tap_tables: MutableMapping[str, Any] | None = None,
157 ):
158 self.graph_index: MutableMapping[str, Any] = {}
159 self.catalog_name = catalog_name
160 self.schema_name = schema_name
161 self.engine = engine
162 self._mock_connection: MockConnection | None = None
163 self.tables = tap_tables or init_tables()
164 self.checker = FelisValidator()
166 @classmethod
167 def from_mock_connection(
168 cls,
169 mock_connection: MockConnection,
170 catalog_name: str | None = None,
171 schema_name: str | None = None,
172 tap_tables: MutableMapping[str, Any] | None = None,
173 ) -> TapLoadingVisitor:
174 visitor = cls(engine=None, catalog_name=catalog_name, schema_name=schema_name, tap_tables=tap_tables)
175 visitor._mock_connection = mock_connection
176 return visitor
178 def visit_schema(self, schema_obj: _Mapping) -> None:
179 self.checker.check_schema(schema_obj)
180 if (version_obj := schema_obj.get("version")) is not None:
181 self.visit_schema_version(version_obj, schema_obj)
182 schema = self.tables["schemas"]()
183 # Override with default
184 self.schema_name = self.schema_name or schema_obj["name"]
186 schema.schema_name = self._schema_name()
187 schema.description = schema_obj.get("description")
188 schema.utype = schema_obj.get("votable:utype")
189 schema.schema_index = int(schema_obj.get("tap:schema_index", 0))
191 if self.engine is not None:
192 session: Session = sessionmaker(self.engine)()
194 session.add(schema)
196 for table_obj in schema_obj["tables"]:
197 table, columns = self.visit_table(table_obj, schema_obj)
198 session.add(table)
199 session.add_all(columns)
201 keys, key_columns = self.visit_constraints(schema_obj)
202 session.add_all(keys)
203 session.add_all(key_columns)
205 session.commit()
206 else:
207 logger.info("Dry run, not inserting into database")
209 # Only if we are mocking (dry run)
210 assert self._mock_connection is not None, "Mock connection must not be None"
211 conn = self._mock_connection
212 conn.execute(_insert(self.tables["schemas"], schema))
214 for table_obj in schema_obj["tables"]:
215 table, columns = self.visit_table(table_obj, schema_obj)
216 conn.execute(_insert(self.tables["tables"], table))
217 for column in columns:
218 conn.execute(_insert(self.tables["columns"], column))
220 keys, key_columns = self.visit_constraints(schema_obj)
221 for key in keys:
222 conn.execute(_insert(self.tables["keys"], key))
223 for key_column in key_columns:
224 conn.execute(_insert(self.tables["key_columns"], key_column))
226 def visit_constraints(self, schema_obj: _Mapping) -> tuple:
227 all_keys = []
228 all_key_columns = []
229 for table_obj in schema_obj["tables"]:
230 for c in table_obj.get("constraints", []):
231 key, key_columns = self.visit_constraint(c, table_obj)
232 if not key:
233 continue
234 all_keys.append(key)
235 all_key_columns += key_columns
236 return all_keys, all_key_columns
238 def visit_schema_version(
239 self, version_obj: str | Mapping[str, Any], schema_obj: Mapping[str, Any]
240 ) -> None:
241 # Docstring is inherited.
243 # For now we ignore schema versioning completely, still do some checks.
244 self.checker.check_schema_version(version_obj, schema_obj)
246 def visit_table(self, table_obj: _Mapping, schema_obj: _Mapping) -> tuple:
247 self.checker.check_table(table_obj, schema_obj)
248 table_id = table_obj["@id"]
249 table = self.tables["tables"]()
250 table.schema_name = self._schema_name()
251 table.table_name = self._table_name(table_obj["name"])
252 table.table_type = "table"
253 table.utype = table_obj.get("votable:utype")
254 table.description = table_obj.get("description")
255 table.table_index = int(table_obj.get("tap:table_index", 0))
257 columns = [self.visit_column(c, table_obj) for c in table_obj["columns"]]
258 self.visit_primary_key(table_obj.get("primaryKey", []), table_obj)
260 for i in table_obj.get("indexes", []):
261 self.visit_index(i, table)
263 self.graph_index[table_id] = table
264 return table, columns
266 def check_column(self, column_obj: _Mapping, table_obj: _Mapping) -> None:
267 self.checker.check_column(column_obj, table_obj)
268 _id = column_obj["@id"]
269 # Guaranteed to exist at this point, for mypy use "" as default
270 datatype_name = column_obj.get("datatype", "")
271 felis_type = FelisType.felis_type(datatype_name)
272 if felis_type.is_sized:
273 # It is expected that both arraysize and length are fine for
274 # length types.
275 arraysize = column_obj.get("votable:arraysize", column_obj.get("length"))
276 if arraysize is None:
277 logger.warning(
278 f"votable:arraysize and length for {_id} are None for type {datatype_name}. "
279 'Using length "*". '
280 "Consider setting `votable:arraysize` or `length`."
281 )
282 if felis_type.is_timestamp:
283 # datetime types really should have a votable:arraysize, because
284 # they are converted to strings and the `length` is loosely to the
285 # string size
286 if "votable:arraysize" not in column_obj:
287 logger.warning(
288 f"votable:arraysize for {_id} is None for type {datatype_name}. "
289 f'Using length "*". '
290 "Consider setting `votable:arraysize` to an appropriate size for "
291 "materialized datetime/timestamp strings."
292 )
294 def visit_column(self, column_obj: _Mapping, table_obj: _Mapping) -> Tap11Base:
295 self.check_column(column_obj, table_obj)
296 column_id = column_obj["@id"]
297 table_name = self._table_name(table_obj["name"])
299 column = self.tables["columns"]()
300 column.table_name = table_name
301 column.column_name = column_obj["name"]
303 felis_datatype = column_obj["datatype"]
304 felis_type = FelisType.felis_type(felis_datatype)
305 column.datatype = column_obj.get("votable:datatype", felis_type.votable_name)
307 arraysize = None
308 if felis_type.is_sized:
309 # prefer votable:arraysize to length, fall back to `*`
310 arraysize = column_obj.get("votable:arraysize", column_obj.get("length", "*"))
311 if felis_type.is_timestamp:
312 arraysize = column_obj.get("votable:arraysize", "*")
313 column.arraysize = arraysize
315 column.xtype = column_obj.get("votable:xtype")
316 column.description = column_obj.get("description")
317 column.utype = column_obj.get("votable:utype")
319 unit = column_obj.get("ivoa:unit") or column_obj.get("fits:tunit")
320 column.unit = unit
321 column.ucd = column_obj.get("ivoa:ucd")
323 # We modify this after we process columns
324 column.indexed = 0
326 column.principal = column_obj.get("tap:principal", 0)
327 column.std = column_obj.get("tap:std", 0)
328 column.column_index = column_obj.get("tap:column_index")
330 self.graph_index[column_id] = column
331 return column
333 def visit_primary_key(self, primary_key_obj: str | Iterable[str], table_obj: _Mapping) -> None:
334 self.checker.check_primary_key(primary_key_obj, table_obj)
335 if primary_key_obj:
336 if isinstance(primary_key_obj, str):
337 primary_key_obj = [primary_key_obj]
338 columns = [self.graph_index[c_id] for c_id in primary_key_obj]
339 # if just one column and it's indexed, update the object
340 if len(columns) == 1:
341 columns[0].indexed = 1
342 return None
344 def visit_constraint(self, constraint_obj: _Mapping, table_obj: _Mapping) -> tuple:
345 self.checker.check_constraint(constraint_obj, table_obj)
346 constraint_type = constraint_obj["@type"]
347 key = None
348 key_columns = []
349 if constraint_type == "ForeignKey":
350 constraint_name = constraint_obj["name"]
351 description = constraint_obj.get("description")
352 utype = constraint_obj.get("votable:utype")
354 columns = [self.graph_index[col["@id"]] for col in constraint_obj.get("columns", [])]
355 refcolumns = [
356 self.graph_index[refcol["@id"]] for refcol in constraint_obj.get("referencedColumns", [])
357 ]
359 table_name = None
360 for column in columns:
361 if not table_name:
362 table_name = column.table_name
363 if table_name != column.table_name:
364 raise ValueError("Inconsisent use of table names")
366 table_name = None
367 for column in refcolumns:
368 if not table_name:
369 table_name = column.table_name
370 if table_name != column.table_name:
371 raise ValueError("Inconsisent use of table names")
372 first_column = columns[0]
373 first_refcolumn = refcolumns[0]
375 key = self.tables["keys"]()
376 key.key_id = constraint_name
377 key.from_table = first_column.table_name
378 key.target_table = first_refcolumn.table_name
379 key.description = description
380 key.utype = utype
381 for column, refcolumn in zip(columns, refcolumns):
382 key_column = self.tables["key_columns"]()
383 key_column.key_id = constraint_name
384 key_column.from_column = column.column_name
385 key_column.target_column = refcolumn.column_name
386 key_columns.append(key_column)
387 return key, key_columns
389 def visit_index(self, index_obj: _Mapping, table_obj: _Mapping) -> None:
390 self.checker.check_index(index_obj, table_obj)
391 columns = [self.graph_index[col["@id"]] for col in index_obj.get("columns", [])]
392 # if just one column and it's indexed, update the object
393 if len(columns) == 1:
394 columns[0].indexed = 1
395 return None
397 def _schema_name(self, schema_name: str | None = None) -> str | None:
398 # If _schema_name is None, SQLAlchemy will catch it
399 _schema_name = schema_name or self.schema_name
400 if self.catalog_name and _schema_name:
401 return ".".join([self.catalog_name, _schema_name])
402 return _schema_name
404 def _table_name(self, table_name: str) -> str:
405 schema_name = self._schema_name()
406 if schema_name:
407 return ".".join([schema_name, table_name])
408 return table_name
411def _insert(table: Tap11Base, value: Any) -> Insert:
412 """Return a SQLAlchemy insert statement.
414 Parameters
415 ----------
416 table : `Tap11Base`
417 The table we are inserting into.
418 value : `Any`
419 An object representing the object we are inserting to the table.
421 Returns
422 -------
423 statement
424 A SQLAlchemy insert statement
425 """
426 values_dict = {}
427 for i in table.__table__.columns:
428 name = i.name
429 column_value = getattr(value, i.name)
430 if isinstance(column_value, str):
431 column_value = column_value.replace("'", "''")
432 values_dict[name] = column_value
433 return insert(table).values(values_dict)