Coverage for python/felis/tap.py: 12%
255 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-14 09:10 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-14 09:10 +0000
1# This file is part of felis.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["Tap11Base", "TapLoadingVisitor", "init_tables"]
26import logging
27from collections.abc import Iterable, MutableMapping
28from typing import Any
30from sqlalchemy import Column, Integer, String
31from sqlalchemy.engine import Engine
32from sqlalchemy.engine.mock import MockConnection
33from sqlalchemy.orm import Session, declarative_base, sessionmaker
34from sqlalchemy.schema import MetaData
35from sqlalchemy.sql.expression import Insert, insert
37from felis import datamodel
39from .datamodel import Constraint, Index, Schema, Table
40from .types import FelisType
42Tap11Base: Any = declarative_base() # Any to avoid mypy mess with SA 2
43logger = logging.getLogger(__name__)
45IDENTIFIER_LENGTH = 128
46SMALL_FIELD_LENGTH = 32
47SIMPLE_FIELD_LENGTH = 128
48TEXT_FIELD_LENGTH = 2048
49QUALIFIED_TABLE_LENGTH = 3 * IDENTIFIER_LENGTH + 2
51_init_table_once = False
54def init_tables(
55 tap_schema_name: str | None = None,
56 tap_tables_postfix: str | None = None,
57 tap_schemas_table: str | None = None,
58 tap_tables_table: str | None = None,
59 tap_columns_table: str | None = None,
60 tap_keys_table: str | None = None,
61 tap_key_columns_table: str | None = None,
62) -> MutableMapping[str, Any]:
63 """Generate definitions for TAP tables."""
64 postfix = tap_tables_postfix or ""
66 # Dirty hack to enable this method to be called more than once, replaces
67 # MetaData instance with a fresh copy if called more than once.
68 # TODO: probably replace ORM stuff with core sqlalchemy functions.
69 global _init_table_once
70 if not _init_table_once:
71 _init_table_once = True
72 else:
73 Tap11Base.metadata = MetaData()
75 if tap_schema_name:
76 Tap11Base.metadata.schema = tap_schema_name
78 class Tap11Schemas(Tap11Base):
79 __tablename__ = (tap_schemas_table or "schemas") + postfix
80 schema_name = Column(String(IDENTIFIER_LENGTH), primary_key=True, nullable=False)
81 utype = Column(String(SIMPLE_FIELD_LENGTH))
82 description = Column(String(TEXT_FIELD_LENGTH))
83 schema_index = Column(Integer)
85 class Tap11Tables(Tap11Base):
86 __tablename__ = (tap_tables_table or "tables") + postfix
87 schema_name = Column(String(IDENTIFIER_LENGTH), nullable=False)
88 table_name = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False, primary_key=True)
89 table_type = Column(String(SMALL_FIELD_LENGTH), nullable=False)
90 utype = Column(String(SIMPLE_FIELD_LENGTH))
91 description = Column(String(TEXT_FIELD_LENGTH))
92 table_index = Column(Integer)
94 class Tap11Columns(Tap11Base):
95 __tablename__ = (tap_columns_table or "columns") + postfix
96 table_name = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False, primary_key=True)
97 column_name = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True)
98 datatype = Column(String(SIMPLE_FIELD_LENGTH), nullable=False)
99 arraysize = Column(String(10))
100 xtype = Column(String(SIMPLE_FIELD_LENGTH))
101 # Size is deprecated
102 # size = Column(Integer(), quote=True)
103 description = Column(String(TEXT_FIELD_LENGTH))
104 utype = Column(String(SIMPLE_FIELD_LENGTH))
105 unit = Column(String(SIMPLE_FIELD_LENGTH))
106 ucd = Column(String(SIMPLE_FIELD_LENGTH))
107 indexed = Column(Integer, nullable=False)
108 principal = Column(Integer, nullable=False)
109 std = Column(Integer, nullable=False)
110 column_index = Column(Integer)
112 class Tap11Keys(Tap11Base):
113 __tablename__ = (tap_keys_table or "keys") + postfix
114 key_id = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True)
115 from_table = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False)
116 target_table = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False)
117 description = Column(String(TEXT_FIELD_LENGTH))
118 utype = Column(String(SIMPLE_FIELD_LENGTH))
120 class Tap11KeyColumns(Tap11Base):
121 __tablename__ = (tap_key_columns_table or "key_columns") + postfix
122 key_id = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True)
123 from_column = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True)
124 target_column = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True)
126 return dict(
127 schemas=Tap11Schemas,
128 tables=Tap11Tables,
129 columns=Tap11Columns,
130 keys=Tap11Keys,
131 key_columns=Tap11KeyColumns,
132 )
135class TapLoadingVisitor:
136 """Felis schema visitor for generating TAP schema.
138 Parameters
139 ----------
140 engine : `sqlalchemy.engine.Engine` or `None`
141 SQLAlchemy engine instance.
142 catalog_name : `str` or `None`
143 Name of the database catalog.
144 schema_name : `str` or `None`
145 Name of the database schema.
146 tap_tables : `~collections.abc.Mapping`
147 Optional mapping of table name to its declarative base class.
148 """
150 def __init__(
151 self,
152 engine: Engine | None,
153 catalog_name: str | None = None,
154 schema_name: str | None = None,
155 tap_tables: MutableMapping[str, Any] | None = None,
156 tap_schema_index: int | None = None,
157 ):
158 self.graph_index: MutableMapping[str, Any] = {}
159 self.catalog_name = catalog_name
160 self.schema_name = schema_name
161 self.engine = engine
162 self._mock_connection: MockConnection | None = None
163 self.tables = tap_tables or init_tables()
164 self.tap_schema_index = tap_schema_index
166 @classmethod
167 def from_mock_connection(
168 cls,
169 mock_connection: MockConnection,
170 catalog_name: str | None = None,
171 schema_name: str | None = None,
172 tap_tables: MutableMapping[str, Any] | None = None,
173 tap_schema_index: int | None = None,
174 ) -> TapLoadingVisitor:
175 visitor = cls(engine=None, catalog_name=catalog_name, schema_name=schema_name, tap_tables=tap_tables)
176 visitor._mock_connection = mock_connection
177 visitor.tap_schema_index = tap_schema_index
178 return visitor
180 def visit_schema(self, schema_obj: Schema) -> None:
181 schema = self.tables["schemas"]()
182 # Override with default
183 self.schema_name = self.schema_name or schema_obj.name
185 schema.schema_name = self._schema_name()
186 schema.description = schema_obj.description
187 schema.utype = schema_obj.votable_utype
188 schema.schema_index = self.tap_schema_index
189 logger.debug(f"Set TAP_SCHEMA index: {self.tap_schema_index}")
191 if self.engine is not None:
192 session: Session = sessionmaker(self.engine)()
194 session.add(schema)
196 for table_obj in schema_obj.tables:
197 table, columns = self.visit_table(table_obj, schema_obj)
198 session.add(table)
199 session.add_all(columns)
201 keys, key_columns = self.visit_constraints(schema_obj)
202 session.add_all(keys)
203 session.add_all(key_columns)
205 logger.debug("Committing TAP schema: %s", schema_obj.name)
206 logger.debug("TAP tables: %s", len(self.tables))
207 session.commit()
208 else:
209 logger.info("Dry run, not inserting into database")
211 # Only if we are mocking (dry run)
212 assert self._mock_connection is not None, "Mock connection must not be None"
213 conn = self._mock_connection
214 conn.execute(_insert(self.tables["schemas"], schema))
216 for table_obj in schema_obj.tables:
217 table, columns = self.visit_table(table_obj, schema_obj)
218 conn.execute(_insert(self.tables["tables"], table))
219 for column in columns:
220 conn.execute(_insert(self.tables["columns"], column))
222 keys, key_columns = self.visit_constraints(schema_obj)
223 for key in keys:
224 conn.execute(_insert(self.tables["keys"], key))
225 for key_column in key_columns:
226 conn.execute(_insert(self.tables["key_columns"], key_column))
228 def visit_constraints(self, schema_obj: Schema) -> tuple:
229 all_keys = []
230 all_key_columns = []
231 for table_obj in schema_obj.tables:
232 for c in table_obj.constraints:
233 key, key_columns = self.visit_constraint(c)
234 if not key:
235 continue
236 all_keys.append(key)
237 all_key_columns += key_columns
238 return all_keys, all_key_columns
240 def visit_table(self, table_obj: Table, schema_obj: Schema) -> tuple:
241 table_id = table_obj.id
242 table = self.tables["tables"]()
243 table.schema_name = self._schema_name()
244 table.table_name = self._table_name(table_obj.name)
245 table.table_type = "table"
246 table.utype = table_obj.votable_utype
247 table.description = table_obj.description
248 table.table_index = 0 if table_obj.tap_table_index is None else table_obj.tap_table_index
250 columns = [self.visit_column(c, table_obj) for c in table_obj.columns]
251 self.visit_primary_key(table_obj.primary_key, table_obj)
253 for i in table_obj.indexes:
254 self.visit_index(i, table)
256 self.graph_index[table_id] = table
257 return table, columns
259 def check_column(self, column_obj: datamodel.Column) -> None:
260 _id = column_obj.id
261 datatype_name = column_obj.datatype
262 felis_type = FelisType.felis_type(datatype_name.value)
263 if felis_type.is_sized:
264 # It is expected that both arraysize and length are fine for
265 # length types.
266 arraysize = column_obj.votable_arraysize or column_obj.length
267 if arraysize is None:
268 logger.warning(
269 f"votable:arraysize and length for {_id} are None for type {datatype_name}. "
270 'Using length "*". '
271 "Consider setting `votable:arraysize` or `length`."
272 )
273 if felis_type.is_timestamp:
274 # datetime types really should have a votable:arraysize, because
275 # they are converted to strings and the `length` is loosely to the
276 # string size
277 if not column_obj.votable_arraysize:
278 logger.warning(
279 f"votable:arraysize for {_id} is None for type {datatype_name}. "
280 f'Using length "*". '
281 "Consider setting `votable:arraysize` to an appropriate size for "
282 "materialized datetime/timestamp strings."
283 )
285 def visit_column(self, column_obj: datamodel.Column, table_obj: Table) -> Tap11Base:
286 self.check_column(column_obj)
287 column_id = column_obj.id
288 table_name = self._table_name(table_obj.name)
290 column = self.tables["columns"]()
291 column.table_name = table_name
292 column.column_name = column_obj.name
294 felis_datatype = column_obj.datatype
295 felis_type = FelisType.felis_type(felis_datatype.value)
296 column.datatype = column_obj.votable_datatype or felis_type.votable_name
298 arraysize = None
299 if felis_type.is_sized:
300 arraysize = column_obj.votable_arraysize or column_obj.length or "*"
301 if felis_type.is_timestamp:
302 arraysize = column_obj.votable_arraysize or "*"
303 column.arraysize = arraysize
305 column.xtype = column_obj.votable_xtype
306 column.description = column_obj.description
307 column.utype = column_obj.votable_utype
309 unit = column_obj.ivoa_unit or column_obj.fits_tunit
310 column.unit = unit
311 column.ucd = column_obj.ivoa_ucd
313 # We modify this after we process columns
314 column.indexed = 0
316 column.principal = column_obj.tap_principal
317 column.std = column_obj.tap_std
318 column.column_index = column_obj.tap_column_index
320 self.graph_index[column_id] = column
321 return column
323 def visit_primary_key(self, primary_key_obj: str | Iterable[str] | None, table_obj: Table) -> None:
324 if primary_key_obj:
325 if isinstance(primary_key_obj, str):
326 primary_key_obj = [primary_key_obj]
327 columns = [self.graph_index[c_id] for c_id in primary_key_obj]
328 # if just one column and it's indexed, update the object
329 if len(columns) == 1:
330 columns[0].indexed = 1
331 return None
333 def visit_constraint(self, constraint_obj: Constraint) -> tuple:
334 constraint_type = constraint_obj.type
335 key = None
336 key_columns = []
337 if constraint_type == "ForeignKey":
338 constraint_name = constraint_obj.name
339 description = constraint_obj.description
340 utype = constraint_obj.votable_utype
342 columns = [self.graph_index[col_id] for col_id in getattr(constraint_obj, "columns", [])]
343 refcolumns = [
344 self.graph_index[refcol_id] for refcol_id in getattr(constraint_obj, "referenced_columns", [])
345 ]
347 table_name = None
348 for column in columns:
349 if not table_name:
350 table_name = column.table_name
351 if table_name != column.table_name:
352 raise ValueError("Inconsisent use of table names")
354 table_name = None
355 for column in refcolumns:
356 if not table_name:
357 table_name = column.table_name
358 if table_name != column.table_name:
359 raise ValueError("Inconsisent use of table names")
360 first_column = columns[0]
361 first_refcolumn = refcolumns[0]
363 key = self.tables["keys"]()
364 key.key_id = constraint_name
365 key.from_table = first_column.table_name
366 key.target_table = first_refcolumn.table_name
367 key.description = description
368 key.utype = utype
369 for column, refcolumn in zip(columns, refcolumns):
370 key_column = self.tables["key_columns"]()
371 key_column.key_id = constraint_name
372 key_column.from_column = column.column_name
373 key_column.target_column = refcolumn.column_name
374 key_columns.append(key_column)
375 return key, key_columns
377 def visit_index(self, index_obj: Index, table_obj: Table) -> None:
378 columns = [self.graph_index[col_id] for col_id in getattr(index_obj, "columns", [])]
379 # if just one column and it's indexed, update the object
380 if len(columns) == 1:
381 columns[0].indexed = 1
382 return None
384 def _schema_name(self, schema_name: str | None = None) -> str | None:
385 # If _schema_name is None, SQLAlchemy will catch it
386 _schema_name = schema_name or self.schema_name
387 if self.catalog_name and _schema_name:
388 return ".".join([self.catalog_name, _schema_name])
389 return _schema_name
391 def _table_name(self, table_name: str) -> str:
392 schema_name = self._schema_name()
393 if schema_name:
394 return ".".join([schema_name, table_name])
395 return table_name
398def _insert(table: Tap11Base, value: Any) -> Insert:
399 """Return a SQLAlchemy insert statement.
401 Parameters
402 ----------
403 table : `Tap11Base`
404 The table we are inserting into.
405 value : `Any`
406 An object representing the object we are inserting to the table.
408 Returns
409 -------
410 statement
411 A SQLAlchemy insert statement
412 """
413 values_dict = {}
414 for i in table.__table__.columns:
415 name = i.name
416 column_value = getattr(value, i.name)
417 if isinstance(column_value, str):
418 column_value = column_value.replace("'", "''")
419 values_dict[name] = column_value
420 return insert(table).values(values_dict)