Coverage for python/felis/sql.py: 21%

139 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-01-14 02:21 -0800

1# This file is part of felis. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["SQLVisitor"] 

25 

26import logging 

27import re 

28from collections.abc import Iterable, Mapping, MutableMapping 

29from typing import Any, NamedTuple, Optional, Union 

30 

31from sqlalchemy import ( 

32 CheckConstraint, 

33 Column, 

34 Constraint, 

35 ForeignKeyConstraint, 

36 Index, 

37 MetaData, 

38 Numeric, 

39 PrimaryKeyConstraint, 

40 UniqueConstraint, 

41 types, 

42) 

43from sqlalchemy.dialects import mysql, oracle, postgresql, sqlite 

44from sqlalchemy.schema import Table 

45 

46from .check import FelisValidator 

47from .db import sqltypes 

48from .types import FelisType 

49from .visitor import Visitor 

50 

51_Mapping = Mapping[str, Any] 

52_MutableMapping = MutableMapping[str, Any] 

53 

54logger = logging.getLogger("felis") 

55 

56MYSQL = "mysql" 

57ORACLE = "oracle" 

58POSTGRES = "postgresql" 

59SQLITE = "sqlite" 

60 

61TABLE_OPTS = { 

62 "mysql:engine": "mysql_engine", 

63 "mysql:charset": "mysql_charset", 

64 "oracle:compress": "oracle_compress", 

65} 

66 

67COLUMN_VARIANT_OVERRIDE = { 

68 "mysql:datatype": "mysql", 

69 "oracle:datatype": "oracle", 

70 "postgresql:datatype": "postgresql", 

71 "sqlite:datatype": "sqlite", 

72} 

73 

74DIALECT_MODULES = {MYSQL: mysql, ORACLE: oracle, SQLITE: sqlite, POSTGRES: postgresql} 

75 

76length_regex = re.compile(r"\((.+)\)") 

77 

78 

79class Schema(NamedTuple): 

80 

81 name: Optional[str] 

82 tables: list[Table] 

83 metadata: MetaData 

84 graph_index: Mapping[str, Any] 

85 

86 

87class SQLVisitor(Visitor[Schema, Table, Column, Optional[PrimaryKeyConstraint], Constraint, Index]): 

88 """A Felis Visitor which populates a SQLAlchemy metadata object. 

89 

90 Parameters 

91 ---------- 

92 schema_name : `str`, optional 

93 Override for the schema name. 

94 """ 

95 

96 def __init__(self, schema_name: Optional[str] = None): 

97 self.metadata = MetaData() 

98 self.schema_name = schema_name 

99 self.checker = FelisValidator() 

100 self.graph_index: MutableMapping[str, Any] = {} 

101 

102 def visit_schema(self, schema_obj: _Mapping) -> Schema: 

103 # Docstring is inherited. 

104 self.checker.check_schema(schema_obj) 

105 schema = Schema( 

106 name=self.schema_name or schema_obj["name"], 

107 tables=[self.visit_table(t, schema_obj) for t in schema_obj["tables"]], 

108 metadata=self.metadata, 

109 graph_index=self.graph_index, 

110 ) 

111 return schema 

112 

113 def visit_table(self, table_obj: _Mapping, schema_obj: _Mapping) -> Table: 

114 # Docstring is inherited. 

115 self.checker.check_table(table_obj, schema_obj) 

116 columns = [self.visit_column(c, table_obj) for c in table_obj["columns"]] 

117 

118 name = table_obj["name"] 

119 table_id = table_obj["@id"] 

120 description = table_obj.get("description") 

121 schema_name = self.schema_name or schema_obj["name"] 

122 

123 table = Table(name, self.metadata, *columns, schema=schema_name, comment=description) 

124 

125 primary_key = self.visit_primary_key(table_obj.get("primaryKey", []), table_obj) 

126 if primary_key: 

127 table.append_constraint(primary_key) 

128 

129 constraints = [self.visit_constraint(c, table) for c in table_obj.get("constraints", [])] 

130 for constraint in constraints: 

131 table.append_constraint(constraint) 

132 

133 indexes = [self.visit_index(i, table) for i in table_obj.get("indexes", [])] 

134 for index in indexes: 

135 # FIXME: Hack because there's no table.add_index 

136 index._set_parent(table) 

137 table.indexes.add(index) 

138 self.graph_index[table_id] = table 

139 return table 

140 

141 def visit_column(self, column_obj: _Mapping, table_obj: _Mapping) -> Column: 

142 # Docstring is inherited. 

143 self.checker.check_column(column_obj, table_obj) 

144 column_name = column_obj["name"] 

145 column_id = column_obj["@id"] 

146 datatype_name = column_obj["datatype"] 

147 column_description = column_obj.get("description") 

148 column_default = column_obj.get("value") 

149 column_length = column_obj.get("length") 

150 

151 kwargs = {} 

152 for column_opt in column_obj.keys(): 

153 if column_opt in COLUMN_VARIANT_OVERRIDE: 

154 dialect = COLUMN_VARIANT_OVERRIDE[column_opt] 

155 variant = _process_variant_override(dialect, column_obj[column_opt]) 

156 kwargs[dialect] = variant 

157 

158 felis_type = FelisType.felis_type(datatype_name) 

159 datatype_fun = getattr(sqltypes, datatype_name) 

160 

161 if felis_type.is_sized: 

162 datatype = datatype_fun(column_length, **kwargs) 

163 else: 

164 datatype = datatype_fun(**kwargs) 

165 

166 nullable_default = True 

167 if isinstance(datatype, Numeric): 

168 nullable_default = False 

169 

170 column_nullable = column_obj.get("nullable", nullable_default) 

171 column_autoincrement = column_obj.get("autoincrement", "auto") 

172 

173 column = Column( 

174 column_name, 

175 datatype, 

176 comment=column_description, 

177 autoincrement=column_autoincrement, 

178 nullable=column_nullable, 

179 server_default=column_default, 

180 ) 

181 if column_id in self.graph_index: 

182 logger.warning(f"Duplication of @id {column_id}") 

183 self.graph_index[column_id] = column 

184 return column 

185 

186 def visit_primary_key( 

187 self, primary_key_obj: Union[str, Iterable[str]], table_obj: _Mapping 

188 ) -> Optional[PrimaryKeyConstraint]: 

189 # Docstring is inherited. 

190 self.checker.check_primary_key(primary_key_obj, table_obj) 

191 if primary_key_obj: 

192 if isinstance(primary_key_obj, str): 

193 primary_key_obj = [primary_key_obj] 

194 columns = [self.graph_index[c_id] for c_id in primary_key_obj] 

195 return PrimaryKeyConstraint(*columns) 

196 return None 

197 

198 def visit_constraint(self, constraint_obj: _Mapping, table_obj: _Mapping) -> Constraint: 

199 # Docstring is inherited. 

200 self.checker.check_constraint(constraint_obj, table_obj) 

201 constraint_type = constraint_obj["@type"] 

202 constraint_id = constraint_obj["@id"] 

203 

204 constraint_args: _MutableMapping = {} 

205 # The following are not used on every constraint 

206 _set_if("name", constraint_obj.get("name"), constraint_args) 

207 _set_if("info", constraint_obj.get("description"), constraint_args) 

208 _set_if("expression", constraint_obj.get("expression"), constraint_args) 

209 _set_if("deferrable", constraint_obj.get("deferrable"), constraint_args) 

210 _set_if("initially", constraint_obj.get("initially"), constraint_args) 

211 

212 columns = [self.graph_index[c_id] for c_id in constraint_obj.get("columns", [])] 

213 constraint: Constraint 

214 if constraint_type == "ForeignKey": 

215 refcolumns = [self.graph_index[c_id] for c_id in constraint_obj.get("referencedColumns", [])] 

216 constraint = ForeignKeyConstraint(columns, refcolumns, **constraint_args) 

217 elif constraint_type == "Check": 

218 expression = constraint_obj["expression"] 

219 constraint = CheckConstraint(expression, **constraint_args) 

220 elif constraint_type == "Unique": 

221 constraint = UniqueConstraint(*columns, **constraint_args) 

222 else: 

223 raise ValueError(f"Unexpected constraint type: {constraint_type}") 

224 self.graph_index[constraint_id] = constraint 

225 return constraint 

226 

227 def visit_index(self, index_obj: _Mapping, table_obj: _Mapping) -> Index: 

228 # Docstring is inherited. 

229 self.checker.check_index(index_obj, table_obj) 

230 name = index_obj["name"] 

231 description = index_obj.get("description") 

232 columns = [self.graph_index[c_id] for c_id in index_obj.get("columns", [])] 

233 expressions = index_obj.get("expressions", []) 

234 return Index(name, *columns, *expressions, info=description) 

235 

236 

237def _set_if(key: str, value: Any, mapping: _MutableMapping) -> None: 

238 if value is not None: 

239 mapping[key] = value 

240 

241 

242def _process_variant_override(dialect_name: str, variant_override_str: str) -> types.TypeEngine: 

243 """Simple Data Type Override""" 

244 match = length_regex.search(variant_override_str) 

245 dialect = DIALECT_MODULES[dialect_name] 

246 variant_type_name = variant_override_str.split("(")[0] 

247 

248 # Process Variant Type 

249 if variant_type_name not in dir(dialect): 

250 raise ValueError(f"Type {variant_type_name} not found in dialect {dialect_name}") 

251 variant_type = getattr(dialect, variant_type_name) 

252 length_params = [] 

253 if match: 

254 length_params.extend([int(i) for i in match.group(1).split(",")]) 

255 return variant_type(*length_params)