Coverage for python/felis/sql.py: 20%

146 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-02-27 11:44 +0000

1# This file is part of felis. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["SQLVisitor"] 

25 

26import logging 

27import re 

28from collections.abc import Iterable, Mapping, MutableMapping 

29from typing import Any, NamedTuple 

30 

31from sqlalchemy import ( 

32 CheckConstraint, 

33 Column, 

34 Constraint, 

35 ForeignKeyConstraint, 

36 Index, 

37 MetaData, 

38 Numeric, 

39 PrimaryKeyConstraint, 

40 UniqueConstraint, 

41 types, 

42) 

43from sqlalchemy.dialects import mysql, oracle, postgresql, sqlite 

44from sqlalchemy.schema import Table 

45 

46from .check import FelisValidator 

47from .db import sqltypes 

48from .types import FelisType 

49from .visitor import Visitor 

50 

51_Mapping = Mapping[str, Any] 

52_MutableMapping = MutableMapping[str, Any] 

53 

54logger = logging.getLogger("felis") 

55 

56MYSQL = "mysql" 

57ORACLE = "oracle" 

58POSTGRES = "postgresql" 

59SQLITE = "sqlite" 

60 

61TABLE_OPTS = { 

62 "mysql:engine": "mysql_engine", 

63 "mysql:charset": "mysql_charset", 

64 "oracle:compress": "oracle_compress", 

65} 

66 

67COLUMN_VARIANT_OVERRIDE = { 

68 "mysql:datatype": "mysql", 

69 "oracle:datatype": "oracle", 

70 "postgresql:datatype": "postgresql", 

71 "sqlite:datatype": "sqlite", 

72} 

73 

74DIALECT_MODULES = {MYSQL: mysql, ORACLE: oracle, SQLITE: sqlite, POSTGRES: postgresql} 

75 

76length_regex = re.compile(r"\((.+)\)") 

77 

78 

79class Schema(NamedTuple): 

80 name: str | None 

81 tables: list[Table] 

82 metadata: MetaData 

83 graph_index: Mapping[str, Any] 

84 

85 

86class SQLVisitor(Visitor[Schema, Table, Column, PrimaryKeyConstraint | None, Constraint, Index, None]): 

87 """A Felis Visitor which populates a SQLAlchemy metadata object. 

88 

89 Parameters 

90 ---------- 

91 schema_name : `str`, optional 

92 Override for the schema name. 

93 """ 

94 

95 def __init__(self, schema_name: str | None = None): 

96 self.metadata = MetaData() 

97 self.schema_name = schema_name 

98 self.checker = FelisValidator() 

99 self.graph_index: MutableMapping[str, Any] = {} 

100 

101 def visit_schema(self, schema_obj: _Mapping) -> Schema: 

102 # Docstring is inherited. 

103 self.checker.check_schema(schema_obj) 

104 if (version_obj := schema_obj.get("version")) is not None: 

105 self.visit_schema_version(version_obj, schema_obj) 

106 

107 # Create tables but don't add constraints yet. 

108 tables = [self.visit_table(t, schema_obj) for t in schema_obj["tables"]] 

109 

110 # Process constraints after the tables are created so that all 

111 # referenced columns are available. 

112 for table_obj in schema_obj["tables"]: 

113 constraints = [ 

114 self.visit_constraint(constraint, table_obj) 

115 for constraint in table_obj.get("constraints", []) 

116 ] 

117 table = self.graph_index[table_obj["@id"]] 

118 for constraint in constraints: 

119 table.append_constraint(constraint) 

120 

121 schema = Schema( 

122 name=self.schema_name or schema_obj["name"], 

123 tables=tables, 

124 metadata=self.metadata, 

125 graph_index=self.graph_index, 

126 ) 

127 return schema 

128 

129 def visit_schema_version( 

130 self, version_obj: str | Mapping[str, Any], schema_obj: Mapping[str, Any] 

131 ) -> None: 

132 # Docstring is inherited. 

133 

134 # For now we ignore schema versioning completely, still do some checks. 

135 self.checker.check_schema_version(version_obj, schema_obj) 

136 

137 def visit_table(self, table_obj: _Mapping, schema_obj: _Mapping) -> Table: 

138 # Docstring is inherited. 

139 self.checker.check_table(table_obj, schema_obj) 

140 columns = [self.visit_column(c, table_obj) for c in table_obj["columns"]] 

141 

142 name = table_obj["name"] 

143 table_id = table_obj["@id"] 

144 description = table_obj.get("description") 

145 schema_name = self.schema_name or schema_obj["name"] 

146 

147 table = Table(name, self.metadata, *columns, schema=schema_name, comment=description) 

148 

149 primary_key = self.visit_primary_key(table_obj.get("primaryKey", []), table_obj) 

150 if primary_key: 

151 table.append_constraint(primary_key) 

152 

153 indexes = [self.visit_index(i, table_obj) for i in table_obj.get("indexes", [])] 

154 for index in indexes: 

155 # FIXME: Hack because there's no table.add_index 

156 index._set_parent(table) 

157 table.indexes.add(index) 

158 self.graph_index[table_id] = table 

159 return table 

160 

161 def visit_column(self, column_obj: _Mapping, table_obj: _Mapping) -> Column: 

162 # Docstring is inherited. 

163 self.checker.check_column(column_obj, table_obj) 

164 column_name = column_obj["name"] 

165 column_id = column_obj["@id"] 

166 datatype_name = column_obj["datatype"] 

167 column_description = column_obj.get("description") 

168 column_default = column_obj.get("value") 

169 column_length = column_obj.get("length") 

170 

171 kwargs = {} 

172 for column_opt in column_obj.keys(): 

173 if column_opt in COLUMN_VARIANT_OVERRIDE: 

174 dialect = COLUMN_VARIANT_OVERRIDE[column_opt] 

175 variant = _process_variant_override(dialect, column_obj[column_opt]) 

176 kwargs[dialect] = variant 

177 

178 felis_type = FelisType.felis_type(datatype_name) 

179 datatype_fun = getattr(sqltypes, datatype_name) 

180 

181 if felis_type.is_sized: 

182 datatype = datatype_fun(column_length, **kwargs) 

183 else: 

184 datatype = datatype_fun(**kwargs) 

185 

186 nullable_default = True 

187 if isinstance(datatype, Numeric): 

188 nullable_default = False 

189 

190 column_nullable = column_obj.get("nullable", nullable_default) 

191 column_autoincrement = column_obj.get("autoincrement", "auto") 

192 

193 column: Column = Column( 

194 column_name, 

195 datatype, 

196 comment=column_description, 

197 autoincrement=column_autoincrement, 

198 nullable=column_nullable, 

199 server_default=column_default, 

200 ) 

201 if column_id in self.graph_index: 

202 logger.warning(f"Duplication of @id {column_id}") 

203 self.graph_index[column_id] = column 

204 return column 

205 

206 def visit_primary_key( 

207 self, primary_key_obj: str | Iterable[str], table_obj: _Mapping 

208 ) -> PrimaryKeyConstraint | None: 

209 # Docstring is inherited. 

210 self.checker.check_primary_key(primary_key_obj, table_obj) 

211 if primary_key_obj: 

212 if isinstance(primary_key_obj, str): 

213 primary_key_obj = [primary_key_obj] 

214 columns = [self.graph_index[c_id] for c_id in primary_key_obj] 

215 return PrimaryKeyConstraint(*columns) 

216 return None 

217 

218 def visit_constraint(self, constraint_obj: _Mapping, table_obj: _Mapping) -> Constraint: 

219 # Docstring is inherited. 

220 self.checker.check_constraint(constraint_obj, table_obj) 

221 constraint_type = constraint_obj["@type"] 

222 constraint_id = constraint_obj["@id"] 

223 

224 constraint_args: _MutableMapping = {} 

225 # The following are not used on every constraint 

226 _set_if("name", constraint_obj.get("name"), constraint_args) 

227 _set_if("info", constraint_obj.get("description"), constraint_args) 

228 _set_if("expression", constraint_obj.get("expression"), constraint_args) 

229 _set_if("deferrable", constraint_obj.get("deferrable"), constraint_args) 

230 _set_if("initially", constraint_obj.get("initially"), constraint_args) 

231 

232 columns = [self.graph_index[c_id] for c_id in constraint_obj.get("columns", [])] 

233 constraint: Constraint 

234 if constraint_type == "ForeignKey": 

235 refcolumns = [self.graph_index[c_id] for c_id in constraint_obj.get("referencedColumns", [])] 

236 constraint = ForeignKeyConstraint(columns, refcolumns, **constraint_args) 

237 elif constraint_type == "Check": 

238 expression = constraint_obj["expression"] 

239 constraint = CheckConstraint(expression, **constraint_args) 

240 elif constraint_type == "Unique": 

241 constraint = UniqueConstraint(*columns, **constraint_args) 

242 else: 

243 raise ValueError(f"Unexpected constraint type: {constraint_type}") 

244 self.graph_index[constraint_id] = constraint 

245 return constraint 

246 

247 def visit_index(self, index_obj: _Mapping, table_obj: _Mapping) -> Index: 

248 # Docstring is inherited. 

249 self.checker.check_index(index_obj, table_obj) 

250 name = index_obj["name"] 

251 description = index_obj.get("description") 

252 columns = [self.graph_index[c_id] for c_id in index_obj.get("columns", [])] 

253 expressions = index_obj.get("expressions", []) 

254 return Index(name, *columns, *expressions, info=description) 

255 

256 

257def _set_if(key: str, value: Any, mapping: _MutableMapping) -> None: 

258 if value is not None: 

259 mapping[key] = value 

260 

261 

262def _process_variant_override(dialect_name: str, variant_override_str: str) -> types.TypeEngine: 

263 """Return variant type for given dialect.""" 

264 match = length_regex.search(variant_override_str) 

265 dialect = DIALECT_MODULES[dialect_name] 

266 variant_type_name = variant_override_str.split("(")[0] 

267 

268 # Process Variant Type 

269 if variant_type_name not in dir(dialect): 

270 raise ValueError(f"Type {variant_type_name} not found in dialect {dialect_name}") 

271 variant_type = getattr(dialect, variant_type_name) 

272 length_params = [] 

273 if match: 

274 length_params.extend([int(i) for i in match.group(1).split(",")]) 

275 return variant_type(*length_params)