Coverage for python/felis/sql.py: 21%
143 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-23 10:44 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-23 10:44 +0000
1# This file is part of felis.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["SQLVisitor"]
26import logging
27import re
28from collections.abc import Iterable, Mapping, MutableMapping
29from typing import Any, NamedTuple
31from sqlalchemy import (
32 CheckConstraint,
33 Column,
34 Constraint,
35 ForeignKeyConstraint,
36 Index,
37 MetaData,
38 Numeric,
39 PrimaryKeyConstraint,
40 UniqueConstraint,
41 types,
42)
43from sqlalchemy.dialects import mysql, oracle, postgresql, sqlite
44from sqlalchemy.schema import Table
46from .check import FelisValidator
47from .db import sqltypes
48from .types import FelisType
49from .visitor import Visitor
51_Mapping = Mapping[str, Any]
52_MutableMapping = MutableMapping[str, Any]
54logger = logging.getLogger("felis")
56MYSQL = "mysql"
57ORACLE = "oracle"
58POSTGRES = "postgresql"
59SQLITE = "sqlite"
61TABLE_OPTS = {
62 "mysql:engine": "mysql_engine",
63 "mysql:charset": "mysql_charset",
64 "oracle:compress": "oracle_compress",
65}
67COLUMN_VARIANT_OVERRIDE = {
68 "mysql:datatype": "mysql",
69 "oracle:datatype": "oracle",
70 "postgresql:datatype": "postgresql",
71 "sqlite:datatype": "sqlite",
72}
74DIALECT_MODULES = {MYSQL: mysql, ORACLE: oracle, SQLITE: sqlite, POSTGRES: postgresql}
76length_regex = re.compile(r"\((.+)\)")
79class Schema(NamedTuple):
80 name: str | None
81 tables: list[Table]
82 metadata: MetaData
83 graph_index: Mapping[str, Any]
86class SQLVisitor(Visitor[Schema, Table, Column, PrimaryKeyConstraint | None, Constraint, Index, None]):
87 """A Felis Visitor which populates a SQLAlchemy metadata object.
89 Parameters
90 ----------
91 schema_name : `str`, optional
92 Override for the schema name.
93 """
95 def __init__(self, schema_name: str | None = None):
96 self.metadata = MetaData()
97 self.schema_name = schema_name
98 self.checker = FelisValidator()
99 self.graph_index: MutableMapping[str, Any] = {}
101 def visit_schema(self, schema_obj: _Mapping) -> Schema:
102 # Docstring is inherited.
103 self.checker.check_schema(schema_obj)
104 if (version_obj := schema_obj.get("version")) is not None:
105 self.visit_schema_version(version_obj, schema_obj)
106 schema = Schema(
107 name=self.schema_name or schema_obj["name"],
108 tables=[self.visit_table(t, schema_obj) for t in schema_obj["tables"]],
109 metadata=self.metadata,
110 graph_index=self.graph_index,
111 )
112 return schema
114 def visit_schema_version(
115 self, version_obj: str | Mapping[str, Any], schema_obj: Mapping[str, Any]
116 ) -> None:
117 # Docstring is inherited.
119 # For now we ignore schema versioning completely, still do some checks.
120 self.checker.check_schema_version(version_obj, schema_obj)
122 def visit_table(self, table_obj: _Mapping, schema_obj: _Mapping) -> Table:
123 # Docstring is inherited.
124 self.checker.check_table(table_obj, schema_obj)
125 columns = [self.visit_column(c, table_obj) for c in table_obj["columns"]]
127 name = table_obj["name"]
128 table_id = table_obj["@id"]
129 description = table_obj.get("description")
130 schema_name = self.schema_name or schema_obj["name"]
132 table = Table(name, self.metadata, *columns, schema=schema_name, comment=description)
134 primary_key = self.visit_primary_key(table_obj.get("primaryKey", []), table_obj)
135 if primary_key:
136 table.append_constraint(primary_key)
138 constraints = [self.visit_constraint(c, table_obj) for c in table_obj.get("constraints", [])]
139 for constraint in constraints:
140 table.append_constraint(constraint)
142 indexes = [self.visit_index(i, table_obj) for i in table_obj.get("indexes", [])]
143 for index in indexes:
144 # FIXME: Hack because there's no table.add_index
145 index._set_parent(table)
146 table.indexes.add(index)
147 self.graph_index[table_id] = table
148 return table
150 def visit_column(self, column_obj: _Mapping, table_obj: _Mapping) -> Column:
151 # Docstring is inherited.
152 self.checker.check_column(column_obj, table_obj)
153 column_name = column_obj["name"]
154 column_id = column_obj["@id"]
155 datatype_name = column_obj["datatype"]
156 column_description = column_obj.get("description")
157 column_default = column_obj.get("value")
158 column_length = column_obj.get("length")
160 kwargs = {}
161 for column_opt in column_obj.keys():
162 if column_opt in COLUMN_VARIANT_OVERRIDE:
163 dialect = COLUMN_VARIANT_OVERRIDE[column_opt]
164 variant = _process_variant_override(dialect, column_obj[column_opt])
165 kwargs[dialect] = variant
167 felis_type = FelisType.felis_type(datatype_name)
168 datatype_fun = getattr(sqltypes, datatype_name)
170 if felis_type.is_sized:
171 datatype = datatype_fun(column_length, **kwargs)
172 else:
173 datatype = datatype_fun(**kwargs)
175 nullable_default = True
176 if isinstance(datatype, Numeric):
177 nullable_default = False
179 column_nullable = column_obj.get("nullable", nullable_default)
180 column_autoincrement = column_obj.get("autoincrement", "auto")
182 column: Column = Column(
183 column_name,
184 datatype,
185 comment=column_description,
186 autoincrement=column_autoincrement,
187 nullable=column_nullable,
188 server_default=column_default,
189 )
190 if column_id in self.graph_index:
191 logger.warning(f"Duplication of @id {column_id}")
192 self.graph_index[column_id] = column
193 return column
195 def visit_primary_key(
196 self, primary_key_obj: str | Iterable[str], table_obj: _Mapping
197 ) -> PrimaryKeyConstraint | None:
198 # Docstring is inherited.
199 self.checker.check_primary_key(primary_key_obj, table_obj)
200 if primary_key_obj:
201 if isinstance(primary_key_obj, str):
202 primary_key_obj = [primary_key_obj]
203 columns = [self.graph_index[c_id] for c_id in primary_key_obj]
204 return PrimaryKeyConstraint(*columns)
205 return None
207 def visit_constraint(self, constraint_obj: _Mapping, table_obj: _Mapping) -> Constraint:
208 # Docstring is inherited.
209 self.checker.check_constraint(constraint_obj, table_obj)
210 constraint_type = constraint_obj["@type"]
211 constraint_id = constraint_obj["@id"]
213 constraint_args: _MutableMapping = {}
214 # The following are not used on every constraint
215 _set_if("name", constraint_obj.get("name"), constraint_args)
216 _set_if("info", constraint_obj.get("description"), constraint_args)
217 _set_if("expression", constraint_obj.get("expression"), constraint_args)
218 _set_if("deferrable", constraint_obj.get("deferrable"), constraint_args)
219 _set_if("initially", constraint_obj.get("initially"), constraint_args)
221 columns = [self.graph_index[c_id] for c_id in constraint_obj.get("columns", [])]
222 constraint: Constraint
223 if constraint_type == "ForeignKey":
224 refcolumns = [self.graph_index[c_id] for c_id in constraint_obj.get("referencedColumns", [])]
225 constraint = ForeignKeyConstraint(columns, refcolumns, **constraint_args)
226 elif constraint_type == "Check":
227 expression = constraint_obj["expression"]
228 constraint = CheckConstraint(expression, **constraint_args)
229 elif constraint_type == "Unique":
230 constraint = UniqueConstraint(*columns, **constraint_args)
231 else:
232 raise ValueError(f"Unexpected constraint type: {constraint_type}")
233 self.graph_index[constraint_id] = constraint
234 return constraint
236 def visit_index(self, index_obj: _Mapping, table_obj: _Mapping) -> Index:
237 # Docstring is inherited.
238 self.checker.check_index(index_obj, table_obj)
239 name = index_obj["name"]
240 description = index_obj.get("description")
241 columns = [self.graph_index[c_id] for c_id in index_obj.get("columns", [])]
242 expressions = index_obj.get("expressions", [])
243 return Index(name, *columns, *expressions, info=description)
246def _set_if(key: str, value: Any, mapping: _MutableMapping) -> None:
247 if value is not None:
248 mapping[key] = value
251def _process_variant_override(dialect_name: str, variant_override_str: str) -> types.TypeEngine:
252 """Return variant type for given dialect."""
253 match = length_regex.search(variant_override_str)
254 dialect = DIALECT_MODULES[dialect_name]
255 variant_type_name = variant_override_str.split("(")[0]
257 # Process Variant Type
258 if variant_type_name not in dir(dialect):
259 raise ValueError(f"Type {variant_type_name} not found in dialect {dialect_name}")
260 variant_type = getattr(dialect, variant_type_name)
261 length_params = []
262 if match:
263 length_params.extend([int(i) for i in match.group(1).split(",")])
264 return variant_type(*length_params)