Coverage for python/felis/sql.py: 21%

143 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-23 10:44 +0000

1# This file is part of felis. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["SQLVisitor"] 

25 

26import logging 

27import re 

28from collections.abc import Iterable, Mapping, MutableMapping 

29from typing import Any, NamedTuple 

30 

31from sqlalchemy import ( 

32 CheckConstraint, 

33 Column, 

34 Constraint, 

35 ForeignKeyConstraint, 

36 Index, 

37 MetaData, 

38 Numeric, 

39 PrimaryKeyConstraint, 

40 UniqueConstraint, 

41 types, 

42) 

43from sqlalchemy.dialects import mysql, oracle, postgresql, sqlite 

44from sqlalchemy.schema import Table 

45 

46from .check import FelisValidator 

47from .db import sqltypes 

48from .types import FelisType 

49from .visitor import Visitor 

50 

51_Mapping = Mapping[str, Any] 

52_MutableMapping = MutableMapping[str, Any] 

53 

54logger = logging.getLogger("felis") 

55 

56MYSQL = "mysql" 

57ORACLE = "oracle" 

58POSTGRES = "postgresql" 

59SQLITE = "sqlite" 

60 

61TABLE_OPTS = { 

62 "mysql:engine": "mysql_engine", 

63 "mysql:charset": "mysql_charset", 

64 "oracle:compress": "oracle_compress", 

65} 

66 

67COLUMN_VARIANT_OVERRIDE = { 

68 "mysql:datatype": "mysql", 

69 "oracle:datatype": "oracle", 

70 "postgresql:datatype": "postgresql", 

71 "sqlite:datatype": "sqlite", 

72} 

73 

74DIALECT_MODULES = {MYSQL: mysql, ORACLE: oracle, SQLITE: sqlite, POSTGRES: postgresql} 

75 

76length_regex = re.compile(r"\((.+)\)") 

77 

78 

79class Schema(NamedTuple): 

80 name: str | None 

81 tables: list[Table] 

82 metadata: MetaData 

83 graph_index: Mapping[str, Any] 

84 

85 

86class SQLVisitor(Visitor[Schema, Table, Column, PrimaryKeyConstraint | None, Constraint, Index, None]): 

87 """A Felis Visitor which populates a SQLAlchemy metadata object. 

88 

89 Parameters 

90 ---------- 

91 schema_name : `str`, optional 

92 Override for the schema name. 

93 """ 

94 

95 def __init__(self, schema_name: str | None = None): 

96 self.metadata = MetaData() 

97 self.schema_name = schema_name 

98 self.checker = FelisValidator() 

99 self.graph_index: MutableMapping[str, Any] = {} 

100 

101 def visit_schema(self, schema_obj: _Mapping) -> Schema: 

102 # Docstring is inherited. 

103 self.checker.check_schema(schema_obj) 

104 if (version_obj := schema_obj.get("version")) is not None: 

105 self.visit_schema_version(version_obj, schema_obj) 

106 schema = Schema( 

107 name=self.schema_name or schema_obj["name"], 

108 tables=[self.visit_table(t, schema_obj) for t in schema_obj["tables"]], 

109 metadata=self.metadata, 

110 graph_index=self.graph_index, 

111 ) 

112 return schema 

113 

114 def visit_schema_version( 

115 self, version_obj: str | Mapping[str, Any], schema_obj: Mapping[str, Any] 

116 ) -> None: 

117 # Docstring is inherited. 

118 

119 # For now we ignore schema versioning completely, still do some checks. 

120 self.checker.check_schema_version(version_obj, schema_obj) 

121 

122 def visit_table(self, table_obj: _Mapping, schema_obj: _Mapping) -> Table: 

123 # Docstring is inherited. 

124 self.checker.check_table(table_obj, schema_obj) 

125 columns = [self.visit_column(c, table_obj) for c in table_obj["columns"]] 

126 

127 name = table_obj["name"] 

128 table_id = table_obj["@id"] 

129 description = table_obj.get("description") 

130 schema_name = self.schema_name or schema_obj["name"] 

131 

132 table = Table(name, self.metadata, *columns, schema=schema_name, comment=description) 

133 

134 primary_key = self.visit_primary_key(table_obj.get("primaryKey", []), table_obj) 

135 if primary_key: 

136 table.append_constraint(primary_key) 

137 

138 constraints = [self.visit_constraint(c, table_obj) for c in table_obj.get("constraints", [])] 

139 for constraint in constraints: 

140 table.append_constraint(constraint) 

141 

142 indexes = [self.visit_index(i, table_obj) for i in table_obj.get("indexes", [])] 

143 for index in indexes: 

144 # FIXME: Hack because there's no table.add_index 

145 index._set_parent(table) 

146 table.indexes.add(index) 

147 self.graph_index[table_id] = table 

148 return table 

149 

150 def visit_column(self, column_obj: _Mapping, table_obj: _Mapping) -> Column: 

151 # Docstring is inherited. 

152 self.checker.check_column(column_obj, table_obj) 

153 column_name = column_obj["name"] 

154 column_id = column_obj["@id"] 

155 datatype_name = column_obj["datatype"] 

156 column_description = column_obj.get("description") 

157 column_default = column_obj.get("value") 

158 column_length = column_obj.get("length") 

159 

160 kwargs = {} 

161 for column_opt in column_obj.keys(): 

162 if column_opt in COLUMN_VARIANT_OVERRIDE: 

163 dialect = COLUMN_VARIANT_OVERRIDE[column_opt] 

164 variant = _process_variant_override(dialect, column_obj[column_opt]) 

165 kwargs[dialect] = variant 

166 

167 felis_type = FelisType.felis_type(datatype_name) 

168 datatype_fun = getattr(sqltypes, datatype_name) 

169 

170 if felis_type.is_sized: 

171 datatype = datatype_fun(column_length, **kwargs) 

172 else: 

173 datatype = datatype_fun(**kwargs) 

174 

175 nullable_default = True 

176 if isinstance(datatype, Numeric): 

177 nullable_default = False 

178 

179 column_nullable = column_obj.get("nullable", nullable_default) 

180 column_autoincrement = column_obj.get("autoincrement", "auto") 

181 

182 column: Column = Column( 

183 column_name, 

184 datatype, 

185 comment=column_description, 

186 autoincrement=column_autoincrement, 

187 nullable=column_nullable, 

188 server_default=column_default, 

189 ) 

190 if column_id in self.graph_index: 

191 logger.warning(f"Duplication of @id {column_id}") 

192 self.graph_index[column_id] = column 

193 return column 

194 

195 def visit_primary_key( 

196 self, primary_key_obj: str | Iterable[str], table_obj: _Mapping 

197 ) -> PrimaryKeyConstraint | None: 

198 # Docstring is inherited. 

199 self.checker.check_primary_key(primary_key_obj, table_obj) 

200 if primary_key_obj: 

201 if isinstance(primary_key_obj, str): 

202 primary_key_obj = [primary_key_obj] 

203 columns = [self.graph_index[c_id] for c_id in primary_key_obj] 

204 return PrimaryKeyConstraint(*columns) 

205 return None 

206 

207 def visit_constraint(self, constraint_obj: _Mapping, table_obj: _Mapping) -> Constraint: 

208 # Docstring is inherited. 

209 self.checker.check_constraint(constraint_obj, table_obj) 

210 constraint_type = constraint_obj["@type"] 

211 constraint_id = constraint_obj["@id"] 

212 

213 constraint_args: _MutableMapping = {} 

214 # The following are not used on every constraint 

215 _set_if("name", constraint_obj.get("name"), constraint_args) 

216 _set_if("info", constraint_obj.get("description"), constraint_args) 

217 _set_if("expression", constraint_obj.get("expression"), constraint_args) 

218 _set_if("deferrable", constraint_obj.get("deferrable"), constraint_args) 

219 _set_if("initially", constraint_obj.get("initially"), constraint_args) 

220 

221 columns = [self.graph_index[c_id] for c_id in constraint_obj.get("columns", [])] 

222 constraint: Constraint 

223 if constraint_type == "ForeignKey": 

224 refcolumns = [self.graph_index[c_id] for c_id in constraint_obj.get("referencedColumns", [])] 

225 constraint = ForeignKeyConstraint(columns, refcolumns, **constraint_args) 

226 elif constraint_type == "Check": 

227 expression = constraint_obj["expression"] 

228 constraint = CheckConstraint(expression, **constraint_args) 

229 elif constraint_type == "Unique": 

230 constraint = UniqueConstraint(*columns, **constraint_args) 

231 else: 

232 raise ValueError(f"Unexpected constraint type: {constraint_type}") 

233 self.graph_index[constraint_id] = constraint 

234 return constraint 

235 

236 def visit_index(self, index_obj: _Mapping, table_obj: _Mapping) -> Index: 

237 # Docstring is inherited. 

238 self.checker.check_index(index_obj, table_obj) 

239 name = index_obj["name"] 

240 description = index_obj.get("description") 

241 columns = [self.graph_index[c_id] for c_id in index_obj.get("columns", [])] 

242 expressions = index_obj.get("expressions", []) 

243 return Index(name, *columns, *expressions, info=description) 

244 

245 

246def _set_if(key: str, value: Any, mapping: _MutableMapping) -> None: 

247 if value is not None: 

248 mapping[key] = value 

249 

250 

251def _process_variant_override(dialect_name: str, variant_override_str: str) -> types.TypeEngine: 

252 """Return variant type for given dialect.""" 

253 match = length_regex.search(variant_override_str) 

254 dialect = DIALECT_MODULES[dialect_name] 

255 variant_type_name = variant_override_str.split("(")[0] 

256 

257 # Process Variant Type 

258 if variant_type_name not in dir(dialect): 

259 raise ValueError(f"Type {variant_type_name} not found in dialect {dialect_name}") 

260 variant_type = getattr(dialect, variant_type_name) 

261 length_params = [] 

262 if match: 

263 length_params.extend([int(i) for i in match.group(1).split(",")]) 

264 return variant_type(*length_params)