Coverage for python/felis/tap.py: 12%
263 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-23 10:44 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-23 10:44 +0000
1# This file is part of felis.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["Tap11Base", "TapLoadingVisitor", "init_tables"]
26import logging
27from collections.abc import Iterable, Mapping, MutableMapping
28from typing import Any
30from sqlalchemy import Column, Integer, String
31from sqlalchemy.engine import Engine
32from sqlalchemy.engine.mock import MockConnection
33from sqlalchemy.ext.declarative import declarative_base
34from sqlalchemy.orm import Session, sessionmaker
35from sqlalchemy.schema import MetaData
36from sqlalchemy.sql.expression import Insert, insert
38from .check import FelisValidator
39from .types import FelisType
40from .visitor import Visitor
42_Mapping = Mapping[str, Any]
44Tap11Base: Any = declarative_base() # Any to avoid mypy mess with SA 2
45logger = logging.getLogger("felis")
47IDENTIFIER_LENGTH = 128
48SMALL_FIELD_LENGTH = 32
49SIMPLE_FIELD_LENGTH = 128
50TEXT_FIELD_LENGTH = 2048
51QUALIFIED_TABLE_LENGTH = 3 * IDENTIFIER_LENGTH + 2
53_init_table_once = False
56def init_tables(
57 tap_schema_name: str | None = None,
58 tap_tables_postfix: str | None = None,
59 tap_schemas_table: str | None = None,
60 tap_tables_table: str | None = None,
61 tap_columns_table: str | None = None,
62 tap_keys_table: str | None = None,
63 tap_key_columns_table: str | None = None,
64) -> MutableMapping[str, Any]:
65 """Generate definitions for TAP tables."""
66 postfix = tap_tables_postfix or ""
68 # Dirty hack to enable this method to be called more than once, replaces
69 # MetaData instance with a fresh copy if called more than once.
70 # TODO: probably replace ORM stuff with core sqlalchemy functions.
71 global _init_table_once
72 if not _init_table_once:
73 _init_table_once = True
74 else:
75 Tap11Base.metadata = MetaData()
77 if tap_schema_name:
78 Tap11Base.metadata.schema = tap_schema_name
80 class Tap11Schemas(Tap11Base):
81 __tablename__ = (tap_schemas_table or "schemas") + postfix
82 schema_name = Column(String(IDENTIFIER_LENGTH), primary_key=True, nullable=False)
83 utype = Column(String(SIMPLE_FIELD_LENGTH))
84 description = Column(String(TEXT_FIELD_LENGTH))
85 schema_index = Column(Integer)
87 class Tap11Tables(Tap11Base):
88 __tablename__ = (tap_tables_table or "tables") + postfix
89 schema_name = Column(String(IDENTIFIER_LENGTH), nullable=False)
90 table_name = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False, primary_key=True)
91 table_type = Column(String(SMALL_FIELD_LENGTH), nullable=False)
92 utype = Column(String(SIMPLE_FIELD_LENGTH))
93 description = Column(String(TEXT_FIELD_LENGTH))
94 table_index = Column(Integer)
96 class Tap11Columns(Tap11Base):
97 __tablename__ = (tap_columns_table or "columns") + postfix
98 table_name = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False, primary_key=True)
99 column_name = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True)
100 datatype = Column(String(SIMPLE_FIELD_LENGTH), nullable=False)
101 arraysize = Column(String(10))
102 xtype = Column(String(SIMPLE_FIELD_LENGTH))
103 # Size is deprecated
104 # size = Column(Integer(), quote=True)
105 description = Column(String(TEXT_FIELD_LENGTH))
106 utype = Column(String(SIMPLE_FIELD_LENGTH))
107 unit = Column(String(SIMPLE_FIELD_LENGTH))
108 ucd = Column(String(SIMPLE_FIELD_LENGTH))
109 indexed = Column(Integer, nullable=False)
110 principal = Column(Integer, nullable=False)
111 std = Column(Integer, nullable=False)
112 column_index = Column(Integer)
114 class Tap11Keys(Tap11Base):
115 __tablename__ = (tap_keys_table or "keys") + postfix
116 key_id = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True)
117 from_table = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False)
118 target_table = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False)
119 description = Column(String(TEXT_FIELD_LENGTH))
120 utype = Column(String(SIMPLE_FIELD_LENGTH))
122 class Tap11KeyColumns(Tap11Base):
123 __tablename__ = (tap_key_columns_table or "key_columns") + postfix
124 key_id = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True)
125 from_column = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True)
126 target_column = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True)
128 return dict(
129 schemas=Tap11Schemas,
130 tables=Tap11Tables,
131 columns=Tap11Columns,
132 keys=Tap11Keys,
133 key_columns=Tap11KeyColumns,
134 )
137class TapLoadingVisitor(Visitor[None, tuple, Tap11Base, None, tuple, None, None]):
138 """Felis schema visitor for generating TAP schema.
140 Parameters
141 ----------
142 engine : `sqlalchemy.engine.Engine` or `None`
143 SQLAlchemy engine instance.
144 catalog_name : `str` or `None`
145 Name of the database catalog.
146 schema_name : `str` or `None`
147 Name of the database schema.
148 tap_tables : `~collections.abc.Mapping`
149 Optional mapping of table name to its declarative base class.
150 """
152 def __init__(
153 self,
154 engine: Engine | None,
155 catalog_name: str | None = None,
156 schema_name: str | None = None,
157 tap_tables: MutableMapping[str, Any] | None = None,
158 ):
159 self.graph_index: MutableMapping[str, Any] = {}
160 self.catalog_name = catalog_name
161 self.schema_name = schema_name
162 self.engine = engine
163 self._mock_connection: MockConnection | None = None
164 self.tables = tap_tables or init_tables()
165 self.checker = FelisValidator()
167 @classmethod
168 def from_mock_connection(
169 cls,
170 mock_connection: MockConnection,
171 catalog_name: str | None = None,
172 schema_name: str | None = None,
173 tap_tables: MutableMapping[str, Any] | None = None,
174 ) -> TapLoadingVisitor:
175 visitor = cls(engine=None, catalog_name=catalog_name, schema_name=schema_name, tap_tables=tap_tables)
176 visitor._mock_connection = mock_connection
177 return visitor
179 def visit_schema(self, schema_obj: _Mapping) -> None:
180 self.checker.check_schema(schema_obj)
181 if (version_obj := schema_obj.get("version")) is not None:
182 self.visit_schema_version(version_obj, schema_obj)
183 schema = self.tables["schemas"]()
184 # Override with default
185 self.schema_name = self.schema_name or schema_obj["name"]
187 schema.schema_name = self._schema_name()
188 schema.description = schema_obj.get("description")
189 schema.utype = schema_obj.get("votable:utype")
190 schema.schema_index = int(schema_obj.get("tap:schema_index", 0))
192 if self.engine is not None:
193 session: Session = sessionmaker(self.engine)()
195 session.add(schema)
197 for table_obj in schema_obj["tables"]:
198 table, columns = self.visit_table(table_obj, schema_obj)
199 session.add(table)
200 session.add_all(columns)
202 keys, key_columns = self.visit_constraints(schema_obj)
203 session.add_all(keys)
204 session.add_all(key_columns)
206 session.commit()
207 else:
208 logger.info("Dry run, not inserting into database")
210 # Only if we are mocking (dry run)
211 assert self._mock_connection is not None, "Mock connection must not be None"
212 conn = self._mock_connection
213 conn.execute(_insert(self.tables["schemas"], schema))
215 for table_obj in schema_obj["tables"]:
216 table, columns = self.visit_table(table_obj, schema_obj)
217 conn.execute(_insert(self.tables["tables"], table))
218 for column in columns:
219 conn.execute(_insert(self.tables["columns"], column))
221 keys, key_columns = self.visit_constraints(schema_obj)
222 for key in keys:
223 conn.execute(_insert(self.tables["keys"], key))
224 for key_column in key_columns:
225 conn.execute(_insert(self.tables["key_columns"], key_column))
227 def visit_constraints(self, schema_obj: _Mapping) -> tuple:
228 all_keys = []
229 all_key_columns = []
230 for table_obj in schema_obj["tables"]:
231 for c in table_obj.get("constraints", []):
232 key, key_columns = self.visit_constraint(c, table_obj)
233 if not key:
234 continue
235 all_keys.append(key)
236 all_key_columns += key_columns
237 return all_keys, all_key_columns
239 def visit_schema_version(
240 self, version_obj: str | Mapping[str, Any], schema_obj: Mapping[str, Any]
241 ) -> None:
242 # Docstring is inherited.
244 # For now we ignore schema versioning completely, still do some checks.
245 self.checker.check_schema_version(version_obj, schema_obj)
247 def visit_table(self, table_obj: _Mapping, schema_obj: _Mapping) -> tuple:
248 self.checker.check_table(table_obj, schema_obj)
249 table_id = table_obj["@id"]
250 table = self.tables["tables"]()
251 table.schema_name = self._schema_name()
252 table.table_name = self._table_name(table_obj["name"])
253 table.table_type = "table"
254 table.utype = table_obj.get("votable:utype")
255 table.description = table_obj.get("description")
256 table.table_index = int(table_obj.get("tap:table_index", 0))
258 columns = [self.visit_column(c, table_obj) for c in table_obj["columns"]]
259 self.visit_primary_key(table_obj.get("primaryKey", []), table_obj)
261 for i in table_obj.get("indexes", []):
262 self.visit_index(i, table)
264 self.graph_index[table_id] = table
265 return table, columns
267 def check_column(self, column_obj: _Mapping, table_obj: _Mapping) -> None:
268 self.checker.check_column(column_obj, table_obj)
269 _id = column_obj["@id"]
270 # Guaranteed to exist at this point, for mypy use "" as default
271 datatype_name = column_obj.get("datatype", "")
272 felis_type = FelisType.felis_type(datatype_name)
273 if felis_type.is_sized:
274 # It is expected that both arraysize and length are fine for
275 # length types.
276 arraysize = column_obj.get("votable:arraysize", column_obj.get("length"))
277 if arraysize is None:
278 logger.warning(
279 f"votable:arraysize and length for {_id} are None for type {datatype_name}. "
280 'Using length "*". '
281 "Consider setting `votable:arraysize` or `length`."
282 )
283 if felis_type.is_timestamp:
284 # datetime types really should have a votable:arraysize, because
285 # they are converted to strings and the `length` is loosely to the
286 # string size
287 if "votable:arraysize" not in column_obj:
288 logger.warning(
289 f"votable:arraysize for {_id} is None for type {datatype_name}. "
290 f'Using length "*". '
291 "Consider setting `votable:arraysize` to an appropriate size for "
292 "materialized datetime/timestamp strings."
293 )
295 def visit_column(self, column_obj: _Mapping, table_obj: _Mapping) -> Tap11Base:
296 self.check_column(column_obj, table_obj)
297 column_id = column_obj["@id"]
298 table_name = self._table_name(table_obj["name"])
300 column = self.tables["columns"]()
301 column.table_name = table_name
302 column.column_name = column_obj["name"]
304 felis_datatype = column_obj["datatype"]
305 felis_type = FelisType.felis_type(felis_datatype)
306 column.datatype = column_obj.get("votable:datatype", felis_type.votable_name)
308 arraysize = None
309 if felis_type.is_sized:
310 # prefer votable:arraysize to length, fall back to `*`
311 arraysize = column_obj.get("votable:arraysize", column_obj.get("length", "*"))
312 if felis_type.is_timestamp:
313 arraysize = column_obj.get("votable:arraysize", "*")
314 column.arraysize = arraysize
316 column.xtype = column_obj.get("votable:xtype")
317 column.description = column_obj.get("description")
318 column.utype = column_obj.get("votable:utype")
320 unit = column_obj.get("ivoa:unit") or column_obj.get("fits:tunit")
321 column.unit = unit
322 column.ucd = column_obj.get("ivoa:ucd")
324 # We modify this after we process columns
325 column.indexed = 0
327 column.principal = column_obj.get("tap:principal", 0)
328 column.std = column_obj.get("tap:std", 0)
329 column.column_index = column_obj.get("tap:column_index")
331 self.graph_index[column_id] = column
332 return column
334 def visit_primary_key(self, primary_key_obj: str | Iterable[str], table_obj: _Mapping) -> None:
335 self.checker.check_primary_key(primary_key_obj, table_obj)
336 if primary_key_obj:
337 if isinstance(primary_key_obj, str):
338 primary_key_obj = [primary_key_obj]
339 columns = [self.graph_index[c_id] for c_id in primary_key_obj]
340 # if just one column and it's indexed, update the object
341 if len(columns) == 1:
342 columns[0].indexed = 1
343 return None
345 def visit_constraint(self, constraint_obj: _Mapping, table_obj: _Mapping) -> tuple:
346 self.checker.check_constraint(constraint_obj, table_obj)
347 constraint_type = constraint_obj["@type"]
348 key = None
349 key_columns = []
350 if constraint_type == "ForeignKey":
351 constraint_name = constraint_obj["name"]
352 description = constraint_obj.get("description")
353 utype = constraint_obj.get("votable:utype")
355 columns = [self.graph_index[col["@id"]] for col in constraint_obj.get("columns", [])]
356 refcolumns = [
357 self.graph_index[refcol["@id"]] for refcol in constraint_obj.get("referencedColumns", [])
358 ]
360 table_name = None
361 for column in columns:
362 if not table_name:
363 table_name = column.table_name
364 if table_name != column.table_name:
365 raise ValueError("Inconsisent use of table names")
367 table_name = None
368 for column in refcolumns:
369 if not table_name:
370 table_name = column.table_name
371 if table_name != column.table_name:
372 raise ValueError("Inconsisent use of table names")
373 first_column = columns[0]
374 first_refcolumn = refcolumns[0]
376 key = self.tables["keys"]()
377 key.key_id = constraint_name
378 key.from_table = first_column.table_name
379 key.target_table = first_refcolumn.table_name
380 key.description = description
381 key.utype = utype
382 for column, refcolumn in zip(columns, refcolumns):
383 key_column = self.tables["key_columns"]()
384 key_column.key_id = constraint_name
385 key_column.from_column = column.column_name
386 key_column.target_column = refcolumn.column_name
387 key_columns.append(key_column)
388 return key, key_columns
390 def visit_index(self, index_obj: _Mapping, table_obj: _Mapping) -> None:
391 self.checker.check_index(index_obj, table_obj)
392 columns = [self.graph_index[col["@id"]] for col in index_obj.get("columns", [])]
393 # if just one column and it's indexed, update the object
394 if len(columns) == 1:
395 columns[0].indexed = 1
396 return None
398 def _schema_name(self, schema_name: str | None = None) -> str | None:
399 # If _schema_name is None, SQLAlchemy will catch it
400 _schema_name = schema_name or self.schema_name
401 if self.catalog_name and _schema_name:
402 return ".".join([self.catalog_name, _schema_name])
403 return _schema_name
405 def _table_name(self, table_name: str) -> str:
406 schema_name = self._schema_name()
407 if schema_name:
408 return ".".join([schema_name, table_name])
409 return table_name
412def _insert(table: Tap11Base, value: Any) -> Insert:
413 """Return a SQLAlchemy insert statement.
415 Parameters
416 ----------
417 table : `Tap11Base`
418 The table we are inserting into.
419 value : `Any`
420 An object representing the object we are inserting to the table.
422 Returns
423 -------
424 statement
425 A SQLAlchemy insert statement
426 """
427 values_dict = {}
428 for i in table.__table__.columns:
429 name = i.name
430 column_value = getattr(value, i.name)
431 if isinstance(column_value, str):
432 column_value = column_value.replace("'", "''")
433 values_dict[name] = column_value
434 return insert(table).values(values_dict)