Coverage for python/felis/metadata.py: 17%
179 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-25 10:20 -0700
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-25 10:20 -0700
1# This file is part of felis.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24import logging
25from typing import IO, Any, Literal
27import sqlalchemy.schema as sqa_schema
28from lsst.utils.iteration import ensure_iterable
29from sqlalchemy import (
30 CheckConstraint,
31 Column,
32 Constraint,
33 Engine,
34 ForeignKeyConstraint,
35 Index,
36 MetaData,
37 PrimaryKeyConstraint,
38 ResultProxy,
39 Table,
40 UniqueConstraint,
41 create_mock_engine,
42 make_url,
43 text,
44)
45from sqlalchemy.engine.interfaces import Dialect
46from sqlalchemy.engine.mock import MockConnection
47from sqlalchemy.engine.url import URL
48from sqlalchemy.exc import SQLAlchemyError
49from sqlalchemy.types import TypeEngine
51from felis.datamodel import Schema
52from felis.db._variants import make_variant_dict
54from . import datamodel
55from .db import sqltypes
56from .types import FelisType
58logger = logging.getLogger(__name__)
61class InsertDump:
62 """An Insert Dumper for SQL statements which supports writing messages
63 to stdout or a file.
64 """
66 def __init__(self, file: IO[str] | None = None) -> None:
67 """Initialize the insert dumper.
69 Parameters
70 ----------
71 file : `io.TextIOBase` or `None`, optional
72 The file to write the SQL statements to. If None, the statements
73 will be written to stdout.
74 """
75 self.file = file
76 self.dialect: Dialect | None = None
78 def dump(self, sql: Any, *multiparams: Any, **params: Any) -> None:
79 """Dump the SQL statement to a file or stdout.
81 Statements with parameters will be formatted with the values
82 inserted into the resultant SQL output.
84 Parameters
85 ----------
86 sql : `typing.Any`
87 The SQL statement to dump.
88 multiparams : `typing.Any`
89 The multiparams to use for the SQL statement.
90 params : `typing.Any`
91 The params to use for the SQL statement.
92 """
93 compiled = sql.compile(dialect=self.dialect)
94 sql_str = str(compiled) + ";"
95 params_list = [compiled.params]
96 for params in params_list:
97 if not params:
98 print(sql_str, file=self.file)
99 continue
100 new_params = {}
101 for key, value in params.items():
102 if isinstance(value, str):
103 new_params[key] = f"'{value}'"
104 elif value is None:
105 new_params[key] = "null"
106 else:
107 new_params[key] = value
108 print(sql_str % new_params, file=self.file)
111def get_datatype_with_variants(column_obj: datamodel.Column) -> TypeEngine:
112 """Use the Felis type system to get a SQLAlchemy datatype with variant
113 overrides from the information in a `Column` object.
115 Parameters
116 ----------
117 column_obj : `felis.datamodel.Column`
118 The column object from which to get the datatype.
120 Raises
121 ------
122 ValueError
123 If the column has a sized type but no length.
124 """
125 variant_dict = make_variant_dict(column_obj)
126 felis_type = FelisType.felis_type(column_obj.datatype.value)
127 datatype_fun = getattr(sqltypes, column_obj.datatype.value)
128 if felis_type.is_sized:
129 if not column_obj.length:
130 raise ValueError(f"Column {column_obj.name} has sized type '{column_obj.datatype}' but no length")
131 datatype = datatype_fun(column_obj.length, **variant_dict)
132 else:
133 datatype = datatype_fun(**variant_dict)
134 return datatype
137class MetaDataBuilder:
138 """A class for building a `MetaData` object from a Felis `Schema`."""
140 def __init__(
141 self, schema: Schema, apply_schema_to_metadata: bool = True, apply_schema_to_tables: bool = True
142 ) -> None:
143 """Initialize the metadata builder.
145 Parameters
146 ----------
147 schema : `felis.datamodel.Schema`
148 The schema object from which to build the SQLAlchemy metadata.
149 apply_schema_to_metadata : `bool`, optional
150 Whether to apply the schema name to the metadata object.
151 apply_schema_to_tables : `bool`, optional
152 Whether to apply the schema name to the tables.
153 """
154 self.schema = schema
155 if not apply_schema_to_metadata:
156 logger.debug("Schema name will not be applied to metadata")
157 if not apply_schema_to_tables:
158 logger.debug("Schema name will not be applied to tables")
159 self.metadata = MetaData(schema=schema.name if apply_schema_to_metadata else None)
160 self._objects: dict[str, Any] = {}
161 self.apply_schema_to_tables = apply_schema_to_tables
163 def build(self) -> MetaData:
164 """Build the SQLAlchemy tables and constraints from the schema."""
165 self.build_tables()
166 self.build_constraints()
167 return self.metadata
169 def build_tables(self) -> None:
170 """Build the SQLAlchemy tables from the schema.
172 Notes
173 -----
174 This function builds all the tables by calling ``build_table`` on
175 each Pydantic object. It also calls ``build_primary_key`` to create the
176 primary key constraints.
177 """
178 for table in self.schema.tables:
179 self.build_table(table)
180 if table.primary_key:
181 primary_key = self.build_primary_key(table.primary_key)
182 self._objects[table.id].append_constraint(primary_key)
184 def build_primary_key(self, primary_key_columns: str | list[str]) -> PrimaryKeyConstraint:
185 """Build a SQLAlchemy `PrimaryKeyConstraint` from a single column ID
186 or a list.
188 The `primary_key_columns` are strings or a list of strings representing
189 IDs pointing to columns that will be looked up in the internal object
190 dictionary.
192 Parameters
193 ----------
194 primary_key_columns : `str` or `list` of `str`
195 The column ID or list of column IDs from which to build the primary
196 key.
198 Returns
199 -------
200 primary_key: `sqlalchemy.PrimaryKeyConstraint`
201 The SQLAlchemy primary key constraint object.
202 """
203 return PrimaryKeyConstraint(
204 *[self._objects[column_id] for column_id in ensure_iterable(primary_key_columns)]
205 )
207 def build_table(self, table_obj: datamodel.Table) -> None:
208 """Build a `sqlalchemy.Table` from a `felis.datamodel.Table` and add
209 it to the `sqlalchemy.MetaData` object.
211 Several MySQL table options are handled by annotations on the table,
212 including the engine and charset. This is not needed for Postgres,
213 which does not have equivalent options.
215 Parameters
216 ----------
217 table_obj : `felis.datamodel.Table`
218 The table object to build the SQLAlchemy table from.
219 """
220 # Process mysql table options.
221 optargs = {}
222 if table_obj.mysql_engine:
223 optargs["mysql_engine"] = table_obj.mysql_engine
224 if table_obj.mysql_charset:
225 optargs["mysql_charset"] = table_obj.mysql_charset
227 # Create the SQLAlchemy table object and its columns.
228 name = table_obj.name
229 id = table_obj.id
230 description = table_obj.description
231 columns = [self.build_column(column) for column in table_obj.columns]
232 table = Table(
233 name,
234 self.metadata,
235 *columns,
236 comment=description,
237 schema=self.schema.name if self.apply_schema_to_tables else None,
238 **optargs, # type: ignore[arg-type]
239 )
241 # Create the indexes and add them to the table.
242 indexes = [self.build_index(index) for index in table_obj.indexes]
243 for index in indexes:
244 index._set_parent(table)
245 table.indexes.add(index)
247 self._objects[id] = table
249 def build_column(self, column_obj: datamodel.Column) -> Column:
250 """Build a SQLAlchemy column from a `felis.datamodel.Column` object.
252 Parameters
253 ----------
254 column_obj : `felis.datamodel.Column`
255 The column object from which to build the SQLAlchemy column.
257 Returns
258 -------
259 column: `sqlalchemy.Column`
260 The SQLAlchemy column object.
261 """
262 # Get basic column attributes.
263 name = column_obj.name
264 id = column_obj.id
265 description = column_obj.description
266 default = column_obj.value
267 nullable = column_obj.nullable
269 # Get datatype, handling variant overrides such as "mysql:datatype".
270 datatype = get_datatype_with_variants(column_obj)
272 # Set autoincrement, depending on if it was provided explicitly.
273 autoincrement: Literal["auto"] | bool = (
274 column_obj.autoincrement if column_obj.autoincrement is not None else "auto"
275 )
277 column: Column = Column(
278 name,
279 datatype,
280 comment=description,
281 autoincrement=autoincrement,
282 nullable=nullable,
283 server_default=default,
284 )
286 self._objects[id] = column
288 return column
290 def build_constraints(self) -> None:
291 """Build the SQLAlchemy constraints in the Felis schema and append them
292 to the associated `Table`.
294 Notes
295 -----
296 This is performed as a separate step after building the tables so that
297 all the referenced objects in the constraints will be present and can
298 be looked up by their ID.
299 """
300 for table_obj in self.schema.tables:
301 table = self._objects[table_obj.id]
302 for constraint_obj in table_obj.constraints:
303 constraint = self.build_constraint(constraint_obj)
304 table.append_constraint(constraint)
306 def build_constraint(self, constraint_obj: datamodel.Constraint) -> Constraint:
307 """Build a SQLAlchemy `Constraint` from a `felis.datamodel.Constraint`
308 object.
310 Parameters
311 ----------
312 constraint_obj : `felis.datamodel.Constraint`
313 The constraint object from which to build the SQLAlchemy
314 constraint.
316 Returns
317 -------
318 constraint: `sqlalchemy.Constraint`
319 The SQLAlchemy constraint object.
321 Raises
322 ------
323 ValueError
324 If the constraint type is not recognized.
325 TypeError
326 If the constraint object is not the expected type.
327 """
328 args: dict[str, Any] = {
329 "name": constraint_obj.name or None,
330 "info": constraint_obj.description or None,
331 "deferrable": constraint_obj.deferrable or None,
332 "initially": constraint_obj.initially or None,
333 }
334 constraint: Constraint
335 constraint_type = constraint_obj.type
337 if isinstance(constraint_obj, datamodel.ForeignKeyConstraint):
338 fk_obj: datamodel.ForeignKeyConstraint = constraint_obj
339 columns = [self._objects[column_id] for column_id in fk_obj.columns]
340 refcolumns = [self._objects[column_id] for column_id in fk_obj.referenced_columns]
341 constraint = ForeignKeyConstraint(columns, refcolumns, **args)
342 elif isinstance(constraint_obj, datamodel.CheckConstraint):
343 check_obj: datamodel.CheckConstraint = constraint_obj
344 expression = check_obj.expression
345 constraint = CheckConstraint(expression, **args)
346 elif isinstance(constraint_obj, datamodel.UniqueConstraint):
347 uniq_obj: datamodel.UniqueConstraint = constraint_obj
348 columns = [self._objects[column_id] for column_id in uniq_obj.columns]
349 constraint = UniqueConstraint(*columns, **args)
350 else:
351 raise ValueError(f"Unknown constraint type: {constraint_type}")
353 self._objects[constraint_obj.id] = constraint
355 return constraint
357 def build_index(self, index_obj: datamodel.Index) -> Index:
358 """Build a SQLAlchemy `Index` from a `felis.datamodel.Index` object.
360 Parameters
361 ----------
362 index_obj : `felis.datamodel.Index`
363 The index object from which to build the SQLAlchemy index.
365 Returns
366 -------
367 index: `sqlalchemy.Index`
368 The SQLAlchemy index object.
369 """
370 columns = [self._objects[c_id] for c_id in (index_obj.columns if index_obj.columns else [])]
371 expressions = index_obj.expressions if index_obj.expressions else []
372 index = Index(index_obj.name, *columns, *expressions)
373 self._objects[index_obj.id] = index
374 return index
377class ConnectionWrapper:
378 """A wrapper for a SQLAlchemy engine or mock connection which provides a
379 consistent interface for executing SQL statements.
380 """
382 def __init__(self, engine: Engine | MockConnection):
383 """Initialize the connection wrapper.
385 Parameters
386 ----------
387 engine : `sqlalchemy.Engine` or `sqlalchemy.MockConnection`
388 The SQLAlchemy engine or mock connection to wrap.
389 """
390 self.engine = engine
392 def execute(self, statement: Any) -> ResultProxy:
393 """Execute a SQL statement on the engine and return the result."""
394 if isinstance(statement, str):
395 statement = text(statement)
396 if isinstance(self.engine, MockConnection):
397 return self.engine.connect().execute(statement)
398 else:
399 with self.engine.begin() as connection:
400 result = connection.execute(statement)
401 return result
404class DatabaseContext:
405 """A class for managing the schema and its database connection."""
407 def __init__(self, metadata: MetaData, engine: Engine | MockConnection):
408 """Initialize the database context.
410 Parameters
411 ----------
412 metadata : `sqlalchemy.MetaData`
413 The SQLAlchemy metadata object.
415 engine : `sqlalchemy.Engine` or `sqlalchemy.MockConnection`
416 The SQLAlchemy engine or mock connection object.
417 """
418 self.engine = engine
419 self.metadata = metadata
420 self.connection = ConnectionWrapper(engine)
422 def create_if_not_exists(self) -> None:
423 """Create the schema in the database if it does not exist.
425 In MySQL, this will create a new database. In PostgreSQL, it will
426 create a new schema. For other variants, this is an unsupported
427 operation.
429 Parameters
430 ----------
431 engine: `sqlalchemy.Engine`
432 The SQLAlchemy engine object.
433 schema_name: `str`
434 The name of the schema (or database) to create.
435 """
436 db_type = self.engine.dialect.name
437 schema_name = self.metadata.schema
438 try:
439 if db_type == "mysql":
440 logger.info(f"Creating MySQL database: {schema_name}")
441 self.connection.execute(text(f"CREATE DATABASE IF NOT EXISTS {schema_name}"))
442 elif db_type == "postgresql":
443 logger.info(f"Creating PG schema: {schema_name}")
444 self.connection.execute(sqa_schema.CreateSchema(schema_name, if_not_exists=True))
445 else:
446 raise ValueError("Unsupported database type:" + db_type)
447 except SQLAlchemyError as e:
448 logger.error(f"Error creating schema: {e}")
449 raise
451 def drop_if_exists(self) -> None:
452 """Drop the schema in the database if it exists.
454 In MySQL, this will drop a database. In PostgreSQL, it will drop a
455 schema. For other variants, this is unsupported for now.
457 Parameters
458 ----------
459 engine: `sqlalchemy.Engine`
460 The SQLAlchemy engine object.
461 schema_name: `str`
462 The name of the schema (or database) to drop.
463 """
464 db_type = self.engine.dialect.name
465 schema_name = self.metadata.schema
466 try:
467 if db_type == "mysql":
468 logger.info(f"Dropping MySQL database if exists: {schema_name}")
469 self.connection.execute(text(f"DROP DATABASE IF EXISTS {schema_name}"))
470 elif db_type == "postgresql":
471 logger.info(f"Dropping PostgreSQL schema if exists: {schema_name}")
472 self.connection.execute(sqa_schema.DropSchema(schema_name, if_exists=True))
473 else:
474 raise ValueError(f"Unsupported database type: {db_type}")
475 except SQLAlchemyError as e:
476 logger.error(f"Error dropping schema: {e}")
477 raise
479 def create_all(self) -> None:
480 """Create all tables in the schema using the metadata object."""
481 self.metadata.create_all(self.engine)
483 @staticmethod
484 def create_mock_engine(engine_url: URL, output_file: IO[str] | None = None) -> MockConnection:
485 """Create a mock engine for testing or dumping DDL statements.
487 Parameters
488 ----------
489 engine_url : `sqlalchemy.engine.url.URL`
490 The SQLAlchemy engine URL.
491 output_file : `typing.IO` [ `str` ] or `None`, optional
492 The file to write the SQL statements to. If None, the statements
493 will be written to stdout.
494 """
495 dumper = InsertDump(output_file)
496 engine = create_mock_engine(make_url(engine_url), executor=dumper.dump)
497 dumper.dialect = engine.dialect
498 return engine