Coverage for python/felis/sql.py: 20%
146 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-14 10:16 -0700
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-14 10:16 -0700
1# This file is part of felis.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["SQLVisitor"]
26import logging
27import re
28from collections.abc import Iterable, Mapping, MutableMapping
29from typing import Any, NamedTuple
31from sqlalchemy import (
32 CheckConstraint,
33 Column,
34 Constraint,
35 ForeignKeyConstraint,
36 Index,
37 MetaData,
38 Numeric,
39 PrimaryKeyConstraint,
40 UniqueConstraint,
41 types,
42)
43from sqlalchemy.dialects import mysql, oracle, postgresql, sqlite
44from sqlalchemy.schema import Table
46from .check import FelisValidator
47from .db import sqltypes
48from .types import FelisType
49from .visitor import Visitor
51_Mapping = Mapping[str, Any]
52_MutableMapping = MutableMapping[str, Any]
54logger = logging.getLogger("felis")
56MYSQL = "mysql"
57ORACLE = "oracle"
58POSTGRES = "postgresql"
59SQLITE = "sqlite"
61TABLE_OPTS = {
62 "mysql:engine": "mysql_engine",
63 "mysql:charset": "mysql_charset",
64 "oracle:compress": "oracle_compress",
65}
67COLUMN_VARIANT_OVERRIDE = {
68 "mysql:datatype": "mysql",
69 "oracle:datatype": "oracle",
70 "postgresql:datatype": "postgresql",
71 "sqlite:datatype": "sqlite",
72}
74DIALECT_MODULES = {MYSQL: mysql, ORACLE: oracle, SQLITE: sqlite, POSTGRES: postgresql}
76length_regex = re.compile(r"\((.+)\)")
79class Schema(NamedTuple):
80 name: str | None
81 tables: list[Table]
82 metadata: MetaData
83 graph_index: Mapping[str, Any]
86class SQLVisitor(Visitor[Schema, Table, Column, PrimaryKeyConstraint | None, Constraint, Index, None]):
87 """A Felis Visitor which populates a SQLAlchemy metadata object.
89 Parameters
90 ----------
91 schema_name : `str`, optional
92 Override for the schema name.
93 """
95 def __init__(self, schema_name: str | None = None):
96 self.metadata = MetaData()
97 self.schema_name = schema_name
98 self.checker = FelisValidator()
99 self.graph_index: MutableMapping[str, Any] = {}
101 def visit_schema(self, schema_obj: _Mapping) -> Schema:
102 # Docstring is inherited.
103 self.checker.check_schema(schema_obj)
104 if (version_obj := schema_obj.get("version")) is not None:
105 self.visit_schema_version(version_obj, schema_obj)
107 # Create tables but don't add constraints yet.
108 tables = [self.visit_table(t, schema_obj) for t in schema_obj["tables"]]
110 # Process constraints after the tables are created so that all
111 # referenced columns are available.
112 for table_obj in schema_obj["tables"]:
113 constraints = [
114 self.visit_constraint(constraint, table_obj)
115 for constraint in table_obj.get("constraints", [])
116 ]
117 table = self.graph_index[table_obj["@id"]]
118 for constraint in constraints:
119 table.append_constraint(constraint)
121 schema = Schema(
122 name=self.schema_name or schema_obj["name"],
123 tables=tables,
124 metadata=self.metadata,
125 graph_index=self.graph_index,
126 )
127 return schema
129 def visit_schema_version(
130 self, version_obj: str | Mapping[str, Any], schema_obj: Mapping[str, Any]
131 ) -> None:
132 # Docstring is inherited.
134 # For now we ignore schema versioning completely, still do some checks.
135 self.checker.check_schema_version(version_obj, schema_obj)
137 def visit_table(self, table_obj: _Mapping, schema_obj: _Mapping) -> Table:
138 # Docstring is inherited.
139 self.checker.check_table(table_obj, schema_obj)
140 columns = [self.visit_column(c, table_obj) for c in table_obj["columns"]]
142 name = table_obj["name"]
143 table_id = table_obj["@id"]
144 description = table_obj.get("description")
145 schema_name = self.schema_name or schema_obj["name"]
147 table = Table(name, self.metadata, *columns, schema=schema_name, comment=description)
149 primary_key = self.visit_primary_key(table_obj.get("primaryKey", []), table_obj)
150 if primary_key:
151 table.append_constraint(primary_key)
153 indexes = [self.visit_index(i, table_obj) for i in table_obj.get("indexes", [])]
154 for index in indexes:
155 # FIXME: Hack because there's no table.add_index
156 index._set_parent(table)
157 table.indexes.add(index)
158 self.graph_index[table_id] = table
159 return table
161 def visit_column(self, column_obj: _Mapping, table_obj: _Mapping) -> Column:
162 # Docstring is inherited.
163 self.checker.check_column(column_obj, table_obj)
164 column_name = column_obj["name"]
165 column_id = column_obj["@id"]
166 datatype_name = column_obj["datatype"]
167 column_description = column_obj.get("description")
168 column_default = column_obj.get("value")
169 column_length = column_obj.get("length")
171 kwargs = {}
172 for column_opt in column_obj.keys():
173 if column_opt in COLUMN_VARIANT_OVERRIDE:
174 dialect = COLUMN_VARIANT_OVERRIDE[column_opt]
175 variant = _process_variant_override(dialect, column_obj[column_opt])
176 kwargs[dialect] = variant
178 felis_type = FelisType.felis_type(datatype_name)
179 datatype_fun = getattr(sqltypes, datatype_name)
181 if felis_type.is_sized:
182 datatype = datatype_fun(column_length, **kwargs)
183 else:
184 datatype = datatype_fun(**kwargs)
186 nullable_default = True
187 if isinstance(datatype, Numeric):
188 nullable_default = False
190 column_nullable = column_obj.get("nullable", nullable_default)
191 column_autoincrement = column_obj.get("autoincrement", "auto")
193 column: Column = Column(
194 column_name,
195 datatype,
196 comment=column_description,
197 autoincrement=column_autoincrement,
198 nullable=column_nullable,
199 server_default=column_default,
200 )
201 if column_id in self.graph_index:
202 logger.warning(f"Duplication of @id {column_id}")
203 self.graph_index[column_id] = column
204 return column
206 def visit_primary_key(
207 self, primary_key_obj: str | Iterable[str], table_obj: _Mapping
208 ) -> PrimaryKeyConstraint | None:
209 # Docstring is inherited.
210 self.checker.check_primary_key(primary_key_obj, table_obj)
211 if primary_key_obj:
212 if isinstance(primary_key_obj, str):
213 primary_key_obj = [primary_key_obj]
214 columns = [self.graph_index[c_id] for c_id in primary_key_obj]
215 return PrimaryKeyConstraint(*columns)
216 return None
218 def visit_constraint(self, constraint_obj: _Mapping, table_obj: _Mapping) -> Constraint:
219 # Docstring is inherited.
220 self.checker.check_constraint(constraint_obj, table_obj)
221 constraint_type = constraint_obj["@type"]
222 constraint_id = constraint_obj["@id"]
224 constraint_args: _MutableMapping = {}
225 # The following are not used on every constraint
226 _set_if("name", constraint_obj.get("name"), constraint_args)
227 _set_if("info", constraint_obj.get("description"), constraint_args)
228 _set_if("expression", constraint_obj.get("expression"), constraint_args)
229 _set_if("deferrable", constraint_obj.get("deferrable"), constraint_args)
230 _set_if("initially", constraint_obj.get("initially"), constraint_args)
232 columns = [self.graph_index[c_id] for c_id in constraint_obj.get("columns", [])]
233 constraint: Constraint
234 if constraint_type == "ForeignKey":
235 refcolumns = [self.graph_index[c_id] for c_id in constraint_obj.get("referencedColumns", [])]
236 constraint = ForeignKeyConstraint(columns, refcolumns, **constraint_args)
237 elif constraint_type == "Check":
238 expression = constraint_obj["expression"]
239 constraint = CheckConstraint(expression, **constraint_args)
240 elif constraint_type == "Unique":
241 constraint = UniqueConstraint(*columns, **constraint_args)
242 else:
243 raise ValueError(f"Unexpected constraint type: {constraint_type}")
244 self.graph_index[constraint_id] = constraint
245 return constraint
247 def visit_index(self, index_obj: _Mapping, table_obj: _Mapping) -> Index:
248 # Docstring is inherited.
249 self.checker.check_index(index_obj, table_obj)
250 name = index_obj["name"]
251 description = index_obj.get("description")
252 columns = [self.graph_index[c_id] for c_id in index_obj.get("columns", [])]
253 expressions = index_obj.get("expressions", [])
254 return Index(name, *columns, *expressions, info=description)
257def _set_if(key: str, value: Any, mapping: _MutableMapping) -> None:
258 if value is not None:
259 mapping[key] = value
262def _process_variant_override(dialect_name: str, variant_override_str: str) -> types.TypeEngine:
263 """Return variant type for given dialect."""
264 match = length_regex.search(variant_override_str)
265 dialect = DIALECT_MODULES[dialect_name]
266 variant_type_name = variant_override_str.split("(")[0]
268 # Process Variant Type
269 if variant_type_name not in dir(dialect):
270 raise ValueError(f"Type {variant_type_name} not found in dialect {dialect_name}")
271 variant_type = getattr(dialect, variant_type_name)
272 length_params = []
273 if match:
274 length_params.extend([int(i) for i in match.group(1).split(",")])
275 return variant_type(*length_params)