Coverage for python/felis/sql.py: 21%
139 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-28 10:01 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-28 10:01 +0000
1# This file is part of felis.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["SQLVisitor"]
26import logging
27import re
28from collections.abc import Iterable, Mapping, MutableMapping
29from typing import Any, NamedTuple, Optional, Union
31from sqlalchemy import (
32 CheckConstraint,
33 Column,
34 Constraint,
35 ForeignKeyConstraint,
36 Index,
37 MetaData,
38 Numeric,
39 PrimaryKeyConstraint,
40 UniqueConstraint,
41 types,
42)
43from sqlalchemy.dialects import mysql, oracle, postgresql, sqlite
44from sqlalchemy.schema import Table
46from .check import FelisValidator
47from .db import sqltypes
48from .types import FelisType
49from .visitor import Visitor
51_Mapping = Mapping[str, Any]
52_MutableMapping = MutableMapping[str, Any]
54logger = logging.getLogger("felis")
56MYSQL = "mysql"
57ORACLE = "oracle"
58POSTGRES = "postgresql"
59SQLITE = "sqlite"
61TABLE_OPTS = {
62 "mysql:engine": "mysql_engine",
63 "mysql:charset": "mysql_charset",
64 "oracle:compress": "oracle_compress",
65}
67COLUMN_VARIANT_OVERRIDE = {
68 "mysql:datatype": "mysql",
69 "oracle:datatype": "oracle",
70 "postgresql:datatype": "postgresql",
71 "sqlite:datatype": "sqlite",
72}
74DIALECT_MODULES = {MYSQL: mysql, ORACLE: oracle, SQLITE: sqlite, POSTGRES: postgresql}
76length_regex = re.compile(r"\((.+)\)")
79class Schema(NamedTuple):
80 name: Optional[str]
81 tables: list[Table]
82 metadata: MetaData
83 graph_index: Mapping[str, Any]
86class SQLVisitor(Visitor[Schema, Table, Column, Optional[PrimaryKeyConstraint], Constraint, Index]):
87 """A Felis Visitor which populates a SQLAlchemy metadata object.
89 Parameters
90 ----------
91 schema_name : `str`, optional
92 Override for the schema name.
93 """
95 def __init__(self, schema_name: Optional[str] = None):
96 self.metadata = MetaData()
97 self.schema_name = schema_name
98 self.checker = FelisValidator()
99 self.graph_index: MutableMapping[str, Any] = {}
101 def visit_schema(self, schema_obj: _Mapping) -> Schema:
102 # Docstring is inherited.
103 self.checker.check_schema(schema_obj)
104 schema = Schema(
105 name=self.schema_name or schema_obj["name"],
106 tables=[self.visit_table(t, schema_obj) for t in schema_obj["tables"]],
107 metadata=self.metadata,
108 graph_index=self.graph_index,
109 )
110 return schema
112 def visit_table(self, table_obj: _Mapping, schema_obj: _Mapping) -> Table:
113 # Docstring is inherited.
114 self.checker.check_table(table_obj, schema_obj)
115 columns = [self.visit_column(c, table_obj) for c in table_obj["columns"]]
117 name = table_obj["name"]
118 table_id = table_obj["@id"]
119 description = table_obj.get("description")
120 schema_name = self.schema_name or schema_obj["name"]
122 table = Table(name, self.metadata, *columns, schema=schema_name, comment=description)
124 primary_key = self.visit_primary_key(table_obj.get("primaryKey", []), table_obj)
125 if primary_key:
126 table.append_constraint(primary_key)
128 constraints = [self.visit_constraint(c, table_obj) for c in table_obj.get("constraints", [])]
129 for constraint in constraints:
130 table.append_constraint(constraint)
132 indexes = [self.visit_index(i, table_obj) for i in table_obj.get("indexes", [])]
133 for index in indexes:
134 # FIXME: Hack because there's no table.add_index
135 index._set_parent(table)
136 table.indexes.add(index)
137 self.graph_index[table_id] = table
138 return table
140 def visit_column(self, column_obj: _Mapping, table_obj: _Mapping) -> Column:
141 # Docstring is inherited.
142 self.checker.check_column(column_obj, table_obj)
143 column_name = column_obj["name"]
144 column_id = column_obj["@id"]
145 datatype_name = column_obj["datatype"]
146 column_description = column_obj.get("description")
147 column_default = column_obj.get("value")
148 column_length = column_obj.get("length")
150 kwargs = {}
151 for column_opt in column_obj.keys():
152 if column_opt in COLUMN_VARIANT_OVERRIDE:
153 dialect = COLUMN_VARIANT_OVERRIDE[column_opt]
154 variant = _process_variant_override(dialect, column_obj[column_opt])
155 kwargs[dialect] = variant
157 felis_type = FelisType.felis_type(datatype_name)
158 datatype_fun = getattr(sqltypes, datatype_name)
160 if felis_type.is_sized:
161 datatype = datatype_fun(column_length, **kwargs)
162 else:
163 datatype = datatype_fun(**kwargs)
165 nullable_default = True
166 if isinstance(datatype, Numeric):
167 nullable_default = False
169 column_nullable = column_obj.get("nullable", nullable_default)
170 column_autoincrement = column_obj.get("autoincrement", "auto")
172 column: Column = Column(
173 column_name,
174 datatype,
175 comment=column_description,
176 autoincrement=column_autoincrement,
177 nullable=column_nullable,
178 server_default=column_default,
179 )
180 if column_id in self.graph_index:
181 logger.warning(f"Duplication of @id {column_id}")
182 self.graph_index[column_id] = column
183 return column
185 def visit_primary_key(
186 self, primary_key_obj: Union[str, Iterable[str]], table_obj: _Mapping
187 ) -> Optional[PrimaryKeyConstraint]:
188 # Docstring is inherited.
189 self.checker.check_primary_key(primary_key_obj, table_obj)
190 if primary_key_obj:
191 if isinstance(primary_key_obj, str):
192 primary_key_obj = [primary_key_obj]
193 columns = [self.graph_index[c_id] for c_id in primary_key_obj]
194 return PrimaryKeyConstraint(*columns)
195 return None
197 def visit_constraint(self, constraint_obj: _Mapping, table_obj: _Mapping) -> Constraint:
198 # Docstring is inherited.
199 self.checker.check_constraint(constraint_obj, table_obj)
200 constraint_type = constraint_obj["@type"]
201 constraint_id = constraint_obj["@id"]
203 constraint_args: _MutableMapping = {}
204 # The following are not used on every constraint
205 _set_if("name", constraint_obj.get("name"), constraint_args)
206 _set_if("info", constraint_obj.get("description"), constraint_args)
207 _set_if("expression", constraint_obj.get("expression"), constraint_args)
208 _set_if("deferrable", constraint_obj.get("deferrable"), constraint_args)
209 _set_if("initially", constraint_obj.get("initially"), constraint_args)
211 columns = [self.graph_index[c_id] for c_id in constraint_obj.get("columns", [])]
212 constraint: Constraint
213 if constraint_type == "ForeignKey":
214 refcolumns = [self.graph_index[c_id] for c_id in constraint_obj.get("referencedColumns", [])]
215 constraint = ForeignKeyConstraint(columns, refcolumns, **constraint_args)
216 elif constraint_type == "Check":
217 expression = constraint_obj["expression"]
218 constraint = CheckConstraint(expression, **constraint_args)
219 elif constraint_type == "Unique":
220 constraint = UniqueConstraint(*columns, **constraint_args)
221 else:
222 raise ValueError(f"Unexpected constraint type: {constraint_type}")
223 self.graph_index[constraint_id] = constraint
224 return constraint
226 def visit_index(self, index_obj: _Mapping, table_obj: _Mapping) -> Index:
227 # Docstring is inherited.
228 self.checker.check_index(index_obj, table_obj)
229 name = index_obj["name"]
230 description = index_obj.get("description")
231 columns = [self.graph_index[c_id] for c_id in index_obj.get("columns", [])]
232 expressions = index_obj.get("expressions", [])
233 return Index(name, *columns, *expressions, info=description)
236def _set_if(key: str, value: Any, mapping: _MutableMapping) -> None:
237 if value is not None:
238 mapping[key] = value
241def _process_variant_override(dialect_name: str, variant_override_str: str) -> types.TypeEngine:
242 """Simple Data Type Override"""
243 match = length_regex.search(variant_override_str)
244 dialect = DIALECT_MODULES[dialect_name]
245 variant_type_name = variant_override_str.split("(")[0]
247 # Process Variant Type
248 if variant_type_name not in dir(dialect):
249 raise ValueError(f"Type {variant_type_name} not found in dialect {dialect_name}")
250 variant_type = getattr(dialect, variant_type_name)
251 length_params = []
252 if match:
253 length_params.extend([int(i) for i in match.group(1).split(",")])
254 return variant_type(*length_params)