Coverage for python/felis/tap.py: 11%
246 statements
« prev ^ index » next coverage.py v7.2.1, created at 2023-03-12 20:29 -0700
« prev ^ index » next coverage.py v7.2.1, created at 2023-03-12 20:29 -0700
1# This file is part of felis.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["Tap11Base", "TapLoadingVisitor", "init_tables"]
26import logging
27from collections.abc import Iterable, Mapping, MutableMapping
28from typing import Any, Optional, Union
30from sqlalchemy import Column, Integer, String
31from sqlalchemy.engine import Engine
32from sqlalchemy.ext.declarative import declarative_base
33from sqlalchemy.orm import DeclarativeMeta, Session, sessionmaker
34from sqlalchemy.schema import MetaData
35from sqlalchemy.sql.expression import Insert, insert
37from .check import FelisValidator
38from .types import FelisType
39from .visitor import Visitor
41_Mapping = Mapping[str, Any]
43Tap11Base: DeclarativeMeta = declarative_base()
44logger = logging.getLogger("felis")
46IDENTIFIER_LENGTH = 128
47SMALL_FIELD_LENGTH = 32
48SIMPLE_FIELD_LENGTH = 128
49TEXT_FIELD_LENGTH = 2048
50QUALIFIED_TABLE_LENGTH = 3 * IDENTIFIER_LENGTH + 2
52_init_table_once = False
55def init_tables(
56 tap_schema_name: Optional[str] = None,
57 tap_tables_postfix: Optional[str] = None,
58 tap_schemas_table: Optional[str] = None,
59 tap_tables_table: Optional[str] = None,
60 tap_columns_table: Optional[str] = None,
61 tap_keys_table: Optional[str] = None,
62 tap_key_columns_table: Optional[str] = None,
63) -> MutableMapping[str, Any]:
64 postfix = tap_tables_postfix or ""
66 # Dirty hack to enable this method to be called more than once, replaces
67 # MetaData instance with a fresh copy if called more than once.
68 # TODO: probably replace ORM stuff with core sqlalchemy functions.
69 global _init_table_once
70 if not _init_table_once:
71 _init_table_once = True
72 else:
73 Tap11Base.metadata = MetaData()
75 if tap_schema_name:
76 Tap11Base.metadata.schema = tap_schema_name
78 class Tap11Schemas(Tap11Base):
79 __tablename__ = (tap_schemas_table or "schemas") + postfix
80 schema_name = Column(String(IDENTIFIER_LENGTH), primary_key=True, nullable=False)
81 utype = Column(String(SIMPLE_FIELD_LENGTH))
82 description = Column(String(TEXT_FIELD_LENGTH))
83 schema_index = Column(Integer)
85 class Tap11Tables(Tap11Base):
86 __tablename__ = (tap_tables_table or "tables") + postfix
87 schema_name = Column(String(IDENTIFIER_LENGTH), nullable=False)
88 table_name = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False, primary_key=True)
89 table_type = Column(String(SMALL_FIELD_LENGTH), nullable=False)
90 utype = Column(String(SIMPLE_FIELD_LENGTH))
91 description = Column(String(TEXT_FIELD_LENGTH))
92 table_index = Column(Integer)
94 class Tap11Columns(Tap11Base):
95 __tablename__ = (tap_columns_table or "columns") + postfix
96 table_name = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False, primary_key=True)
97 column_name = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True)
98 datatype = Column(String(SIMPLE_FIELD_LENGTH), nullable=False)
99 arraysize = Column(String(10))
100 xtype = Column(String(SIMPLE_FIELD_LENGTH))
101 # Size is deprecated
102 # size = Column(Integer(), quote=True)
103 description = Column(String(TEXT_FIELD_LENGTH))
104 utype = Column(String(SIMPLE_FIELD_LENGTH))
105 unit = Column(String(SIMPLE_FIELD_LENGTH))
106 ucd = Column(String(SIMPLE_FIELD_LENGTH))
107 indexed = Column(Integer, nullable=False)
108 principal = Column(Integer, nullable=False)
109 std = Column(Integer, nullable=False)
110 column_index = Column(Integer)
112 class Tap11Keys(Tap11Base):
113 __tablename__ = (tap_keys_table or "keys") + postfix
114 key_id = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True)
115 from_table = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False)
116 target_table = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False)
117 description = Column(String(TEXT_FIELD_LENGTH))
118 utype = Column(String(SIMPLE_FIELD_LENGTH))
120 class Tap11KeyColumns(Tap11Base):
121 __tablename__ = (tap_key_columns_table or "key_columns") + postfix
122 key_id = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True)
123 from_column = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True)
124 target_column = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True)
126 return dict(
127 schemas=Tap11Schemas,
128 tables=Tap11Tables,
129 columns=Tap11Columns,
130 keys=Tap11Keys,
131 key_columns=Tap11KeyColumns,
132 )
135class TapLoadingVisitor(Visitor[None, tuple, Tap11Base, None, tuple, None]):
136 def __init__(
137 self,
138 engine: Engine,
139 catalog_name: Optional[str] = None,
140 schema_name: Optional[str] = None,
141 mock: bool = False,
142 tap_tables: Optional[MutableMapping[str, Any]] = None,
143 ):
144 self.graph_index: MutableMapping[str, Any] = {}
145 self.catalog_name = catalog_name
146 self.schema_name = schema_name
147 self.engine = engine
148 self.mock = mock
149 self.tables = tap_tables or init_tables()
150 self.checker = FelisValidator()
152 def visit_schema(self, schema_obj: _Mapping) -> None:
153 self.checker.check_schema(schema_obj)
154 schema = self.tables["schemas"]()
155 # Override with default
156 self.schema_name = self.schema_name or schema_obj["name"]
158 schema.schema_name = self._schema_name()
159 schema.description = schema_obj.get("description")
160 schema.utype = schema_obj.get("votable:utype")
161 schema.schema_index = int(schema_obj.get("tap:schema_index", 0))
163 if not self.mock:
164 session: Session = sessionmaker(self.engine)()
165 session.add(schema)
166 for table_obj in schema_obj["tables"]:
167 table, columns, keys, key_columns = self.visit_table(table_obj, schema_obj)
168 session.add(table)
169 session.add_all(columns)
170 session.add_all(keys)
171 session.add_all(key_columns)
172 session.commit()
173 else:
174 # Only if we are mocking (dry run)
175 conn = self.engine
176 conn.execute(_insert(self.tables["schemas"], schema))
177 for table_obj in schema_obj["tables"]:
178 table, columns, keys, key_columns = self.visit_table(table_obj, schema_obj)
179 conn.execute(_insert(self.tables["tables"], table))
180 for column in columns:
181 conn.execute(_insert(self.tables["columns"], column))
182 for key in keys:
183 conn.execute(_insert(self.tables["keys"], key))
184 for key_column in key_columns:
185 conn.execute(_insert(self.tables["key_columns"], key_column))
187 def visit_table(self, table_obj: _Mapping, schema_obj: _Mapping) -> tuple:
188 self.checker.check_table(table_obj, schema_obj)
189 table_id = table_obj["@id"]
190 table = self.tables["tables"]()
191 table.schema_name = self._schema_name()
192 table.table_name = self._table_name(table_obj["name"])
193 table.table_type = "table"
194 table.utype = table_obj.get("votable:utype")
195 table.description = table_obj.get("description")
196 table.table_index = int(table_obj.get("tap:table_index", 0))
198 columns = [self.visit_column(c, table_obj) for c in table_obj["columns"]]
199 self.visit_primary_key(table_obj.get("primaryKey", []), table_obj)
200 all_keys = []
201 all_key_columns = []
202 for c in table_obj.get("constraints", []):
203 key, key_columns = self.visit_constraint(c, table)
204 if not key:
205 continue
206 all_keys.append(key)
207 all_key_columns += key_columns
209 for i in table_obj.get("indexes", []):
210 self.visit_index(i, table)
212 self.graph_index[table_id] = table
213 return table, columns, all_keys, all_key_columns
215 def check_column(self, column_obj: _Mapping, table_obj: _Mapping) -> None:
216 self.checker.check_column(column_obj, table_obj)
217 _id = column_obj["@id"]
218 # Guaranteed to exist at this point, for mypy use "" as default
219 datatype_name = column_obj.get("datatype", "")
220 felis_type = FelisType.felis_type(datatype_name)
221 if felis_type.is_sized:
222 # It is expected that both arraysize and length are fine for
223 # length types.
224 arraysize = column_obj.get("votable:arraysize", column_obj.get("length"))
225 if arraysize is None:
226 logger.warning(
227 f"votable:arraysize and length for {_id} are None for type {datatype_name}. "
228 'Using length "*". '
229 "Consider setting `votable:arraysize` or `length`."
230 )
231 if felis_type.is_timestamp:
232 # datetime types really should have a votable:arraysize, because
233 # they are converted to strings and the `length` is loosely to the
234 # string size
235 if "votable:arraysize" not in column_obj:
236 logger.warning(
237 f"votable:arraysize for {_id} is None for type {datatype_name}. "
238 f'Using length "*". '
239 "Consider setting `votable:arraysize` to an appropriate size for "
240 "materialized datetime/timestamp strings."
241 )
243 def visit_column(self, column_obj: _Mapping, table_obj: _Mapping) -> Tap11Base:
244 self.check_column(column_obj, table_obj)
245 column_id = column_obj["@id"]
246 table_name = self._table_name(table_obj["name"])
248 column = self.tables["columns"]()
249 column.table_name = table_name
250 column.column_name = column_obj["name"]
252 felis_datatype = column_obj["datatype"]
253 felis_type = FelisType.felis_type(felis_datatype)
254 column.datatype = column_obj.get("votable:datatype", felis_type.votable_name)
256 arraysize = None
257 if felis_type.is_sized:
258 # prefer votable:arraysize to length, fall back to `*`
259 arraysize = column_obj.get("votable:arraysize", column_obj.get("length", "*"))
260 if felis_type.is_timestamp:
261 arraysize = column_obj.get("votable:arraysize", "*")
262 column.arraysize = arraysize
264 column.xtype = column_obj.get("votable:xtype")
265 column.description = column_obj.get("description")
266 column.utype = column_obj.get("votable:utype")
268 unit = column_obj.get("ivoa:unit") or column_obj.get("fits:tunit")
269 column.unit = unit
270 column.ucd = column_obj.get("ivoa:ucd")
272 # We modify this after we process columns
273 column.indexed = 0
275 column.principal = column_obj.get("tap:principal", 0)
276 column.std = column_obj.get("tap:std", 0)
277 column.column_index = column_obj.get("tap:column_index")
279 self.graph_index[column_id] = column
280 return column
282 def visit_primary_key(self, primary_key_obj: Union[str, Iterable[str]], table_obj: _Mapping) -> None:
283 self.checker.check_primary_key(primary_key_obj, table_obj)
284 if primary_key_obj:
285 if isinstance(primary_key_obj, str):
286 primary_key_obj = [primary_key_obj]
287 columns = [self.graph_index[c_id] for c_id in primary_key_obj]
288 # if just one column and it's indexed, update the object
289 if len(columns) == 1:
290 columns[0].indexed = 1
291 return None
293 def visit_constraint(self, constraint_obj: _Mapping, table_obj: _Mapping) -> tuple:
294 self.checker.check_constraint(constraint_obj, table_obj)
295 constraint_type = constraint_obj["@type"]
296 key = None
297 key_columns = []
298 if constraint_type == "ForeignKey":
299 constraint_name = constraint_obj["name"]
300 description = constraint_obj.get("description")
301 utype = constraint_obj.get("votable:utype")
303 columns = [self.graph_index[c_id] for c_id in constraint_obj.get("columns", [])]
304 refcolumns = [self.graph_index[c_id] for c_id in constraint_obj.get("referencedColumns", [])]
306 table_name = None
307 for column in columns:
308 if not table_name:
309 table_name = column.table_name
310 if table_name != column.table_name:
311 raise ValueError("Inconsisent use of table names")
313 table_name = None
314 for column in refcolumns:
315 if not table_name:
316 table_name = column.table_name
317 if table_name != column.table_name:
318 raise ValueError("Inconsisent use of table names")
319 first_column = columns[0]
320 first_refcolumn = refcolumns[0]
322 key = self.tables["keys"]()
323 key.key_id = constraint_name
324 key.from_table = first_column.table_name
325 key.target_table = first_refcolumn.table_name
326 key.description = description
327 key.utype = utype
328 for column, refcolumn in zip(columns, refcolumns):
329 key_column = self.tables["key_columns"]()
330 key_column.key_id = constraint_name
331 key_column.from_column = column.column_name
332 key_column.target_column = refcolumn.column_name
333 key_columns.append(key_column)
334 return key, key_columns
336 def visit_index(self, index_obj: _Mapping, table_obj: _Mapping) -> None:
337 self.checker.check_index(index_obj, table_obj)
338 columns = [self.graph_index[c_id] for c_id in index_obj.get("columns", [])]
339 # if just one column and it's indexed, update the object
340 if len(columns) == 1:
341 columns[0].indexed = 1
342 return None
344 def _schema_name(self, schema_name: Optional[str] = None) -> Optional[str]:
345 # If _schema_name is None, SQLAlchemy will catch it
346 _schema_name = schema_name or self.schema_name
347 if self.catalog_name and _schema_name:
348 return ".".join([self.catalog_name, _schema_name])
349 return _schema_name
351 def _table_name(self, table_name: str) -> str:
352 schema_name = self._schema_name()
353 if schema_name:
354 return ".".join([schema_name, table_name])
355 return table_name
358def _insert(table: Tap11Base, value: Any) -> Insert:
359 """
360 Return a SQLAlchemy insert statement based on
361 :param table: The table we are inserting to
362 :param value: An object representing the object we are inserting
363 to the table
364 :return: A SQLAlchemy insert statement
365 """
366 values_dict = {}
367 for i in table.__table__.columns:
368 name = i.name
369 column_value = getattr(value, i.name)
370 if type(column_value) == str:
371 column_value = column_value.replace("'", "''")
372 values_dict[name] = column_value
373 return insert(table, values=values_dict)