Coverage for python/felis/sql.py: 21%
139 statements
« prev ^ index » next coverage.py v7.2.1, created at 2023-03-12 20:29 -0700
« prev ^ index » next coverage.py v7.2.1, created at 2023-03-12 20:29 -0700
1# This file is part of felis.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["SQLVisitor"]
26import logging
27import re
28from collections.abc import Iterable, Mapping, MutableMapping
29from typing import Any, NamedTuple, Optional, Union
31from sqlalchemy import (
32 CheckConstraint,
33 Column,
34 Constraint,
35 ForeignKeyConstraint,
36 Index,
37 MetaData,
38 Numeric,
39 PrimaryKeyConstraint,
40 UniqueConstraint,
41 types,
42)
43from sqlalchemy.dialects import mysql, oracle, postgresql, sqlite
44from sqlalchemy.schema import Table
46from .check import FelisValidator
47from .db import sqltypes
48from .types import FelisType
49from .visitor import Visitor
51_Mapping = Mapping[str, Any]
52_MutableMapping = MutableMapping[str, Any]
54logger = logging.getLogger("felis")
56MYSQL = "mysql"
57ORACLE = "oracle"
58POSTGRES = "postgresql"
59SQLITE = "sqlite"
61TABLE_OPTS = {
62 "mysql:engine": "mysql_engine",
63 "mysql:charset": "mysql_charset",
64 "oracle:compress": "oracle_compress",
65}
67COLUMN_VARIANT_OVERRIDE = {
68 "mysql:datatype": "mysql",
69 "oracle:datatype": "oracle",
70 "postgresql:datatype": "postgresql",
71 "sqlite:datatype": "sqlite",
72}
74DIALECT_MODULES = {MYSQL: mysql, ORACLE: oracle, SQLITE: sqlite, POSTGRES: postgresql}
76length_regex = re.compile(r"\((.+)\)")
79class Schema(NamedTuple):
81 name: Optional[str]
82 tables: list[Table]
83 metadata: MetaData
84 graph_index: Mapping[str, Any]
87class SQLVisitor(Visitor[Schema, Table, Column, Optional[PrimaryKeyConstraint], Constraint, Index]):
88 """A Felis Visitor which populates a SQLAlchemy metadata object.
90 Parameters
91 ----------
92 schema_name : `str`, optional
93 Override for the schema name.
94 """
96 def __init__(self, schema_name: Optional[str] = None):
97 self.metadata = MetaData()
98 self.schema_name = schema_name
99 self.checker = FelisValidator()
100 self.graph_index: MutableMapping[str, Any] = {}
102 def visit_schema(self, schema_obj: _Mapping) -> Schema:
103 # Docstring is inherited.
104 self.checker.check_schema(schema_obj)
105 schema = Schema(
106 name=self.schema_name or schema_obj["name"],
107 tables=[self.visit_table(t, schema_obj) for t in schema_obj["tables"]],
108 metadata=self.metadata,
109 graph_index=self.graph_index,
110 )
111 return schema
113 def visit_table(self, table_obj: _Mapping, schema_obj: _Mapping) -> Table:
114 # Docstring is inherited.
115 self.checker.check_table(table_obj, schema_obj)
116 columns = [self.visit_column(c, table_obj) for c in table_obj["columns"]]
118 name = table_obj["name"]
119 table_id = table_obj["@id"]
120 description = table_obj.get("description")
121 schema_name = self.schema_name or schema_obj["name"]
123 table = Table(name, self.metadata, *columns, schema=schema_name, comment=description)
125 primary_key = self.visit_primary_key(table_obj.get("primaryKey", []), table_obj)
126 if primary_key:
127 table.append_constraint(primary_key)
129 constraints = [self.visit_constraint(c, table) for c in table_obj.get("constraints", [])]
130 for constraint in constraints:
131 table.append_constraint(constraint)
133 indexes = [self.visit_index(i, table) for i in table_obj.get("indexes", [])]
134 for index in indexes:
135 # FIXME: Hack because there's no table.add_index
136 index._set_parent(table)
137 table.indexes.add(index)
138 self.graph_index[table_id] = table
139 return table
141 def visit_column(self, column_obj: _Mapping, table_obj: _Mapping) -> Column:
142 # Docstring is inherited.
143 self.checker.check_column(column_obj, table_obj)
144 column_name = column_obj["name"]
145 column_id = column_obj["@id"]
146 datatype_name = column_obj["datatype"]
147 column_description = column_obj.get("description")
148 column_default = column_obj.get("value")
149 column_length = column_obj.get("length")
151 kwargs = {}
152 for column_opt in column_obj.keys():
153 if column_opt in COLUMN_VARIANT_OVERRIDE:
154 dialect = COLUMN_VARIANT_OVERRIDE[column_opt]
155 variant = _process_variant_override(dialect, column_obj[column_opt])
156 kwargs[dialect] = variant
158 felis_type = FelisType.felis_type(datatype_name)
159 datatype_fun = getattr(sqltypes, datatype_name)
161 if felis_type.is_sized:
162 datatype = datatype_fun(column_length, **kwargs)
163 else:
164 datatype = datatype_fun(**kwargs)
166 nullable_default = True
167 if isinstance(datatype, Numeric):
168 nullable_default = False
170 column_nullable = column_obj.get("nullable", nullable_default)
171 column_autoincrement = column_obj.get("autoincrement", "auto")
173 column = Column(
174 column_name,
175 datatype,
176 comment=column_description,
177 autoincrement=column_autoincrement,
178 nullable=column_nullable,
179 server_default=column_default,
180 )
181 if column_id in self.graph_index:
182 logger.warning(f"Duplication of @id {column_id}")
183 self.graph_index[column_id] = column
184 return column
186 def visit_primary_key(
187 self, primary_key_obj: Union[str, Iterable[str]], table_obj: _Mapping
188 ) -> Optional[PrimaryKeyConstraint]:
189 # Docstring is inherited.
190 self.checker.check_primary_key(primary_key_obj, table_obj)
191 if primary_key_obj:
192 if isinstance(primary_key_obj, str):
193 primary_key_obj = [primary_key_obj]
194 columns = [self.graph_index[c_id] for c_id in primary_key_obj]
195 return PrimaryKeyConstraint(*columns)
196 return None
198 def visit_constraint(self, constraint_obj: _Mapping, table_obj: _Mapping) -> Constraint:
199 # Docstring is inherited.
200 self.checker.check_constraint(constraint_obj, table_obj)
201 constraint_type = constraint_obj["@type"]
202 constraint_id = constraint_obj["@id"]
204 constraint_args: _MutableMapping = {}
205 # The following are not used on every constraint
206 _set_if("name", constraint_obj.get("name"), constraint_args)
207 _set_if("info", constraint_obj.get("description"), constraint_args)
208 _set_if("expression", constraint_obj.get("expression"), constraint_args)
209 _set_if("deferrable", constraint_obj.get("deferrable"), constraint_args)
210 _set_if("initially", constraint_obj.get("initially"), constraint_args)
212 columns = [self.graph_index[c_id] for c_id in constraint_obj.get("columns", [])]
213 constraint: Constraint
214 if constraint_type == "ForeignKey":
215 refcolumns = [self.graph_index[c_id] for c_id in constraint_obj.get("referencedColumns", [])]
216 constraint = ForeignKeyConstraint(columns, refcolumns, **constraint_args)
217 elif constraint_type == "Check":
218 expression = constraint_obj["expression"]
219 constraint = CheckConstraint(expression, **constraint_args)
220 elif constraint_type == "Unique":
221 constraint = UniqueConstraint(*columns, **constraint_args)
222 else:
223 raise ValueError(f"Unexpected constraint type: {constraint_type}")
224 self.graph_index[constraint_id] = constraint
225 return constraint
227 def visit_index(self, index_obj: _Mapping, table_obj: _Mapping) -> Index:
228 # Docstring is inherited.
229 self.checker.check_index(index_obj, table_obj)
230 name = index_obj["name"]
231 description = index_obj.get("description")
232 columns = [self.graph_index[c_id] for c_id in index_obj.get("columns", [])]
233 expressions = index_obj.get("expressions", [])
234 return Index(name, *columns, *expressions, info=description)
237def _set_if(key: str, value: Any, mapping: _MutableMapping) -> None:
238 if value is not None:
239 mapping[key] = value
242def _process_variant_override(dialect_name: str, variant_override_str: str) -> types.TypeEngine:
243 """Simple Data Type Override"""
244 match = length_regex.search(variant_override_str)
245 dialect = DIALECT_MODULES[dialect_name]
246 variant_type_name = variant_override_str.split("(")[0]
248 # Process Variant Type
249 if variant_type_name not in dir(dialect):
250 raise ValueError(f"Type {variant_type_name} not found in dialect {dialect_name}")
251 variant_type = getattr(dialect, variant_type_name)
252 length_params = []
253 if match:
254 length_params.extend([int(i) for i in match.group(1).split(",")])
255 return variant_type(*length_params)