Coverage for python/felis/sql.py: 21%

139 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-12 10:48 -0700

1# This file is part of felis. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["SQLVisitor"] 

25 

26import logging 

27import re 

28from collections.abc import Iterable, Mapping, MutableMapping 

29from typing import Any, NamedTuple, Optional, Union 

30 

31from sqlalchemy import ( 

32 CheckConstraint, 

33 Column, 

34 Constraint, 

35 ForeignKeyConstraint, 

36 Index, 

37 MetaData, 

38 Numeric, 

39 PrimaryKeyConstraint, 

40 UniqueConstraint, 

41 types, 

42) 

43from sqlalchemy.dialects import mysql, oracle, postgresql, sqlite 

44from sqlalchemy.schema import Table 

45 

46from .check import FelisValidator 

47from .db import sqltypes 

48from .types import FelisType 

49from .visitor import Visitor 

50 

51_Mapping = Mapping[str, Any] 

52_MutableMapping = MutableMapping[str, Any] 

53 

54logger = logging.getLogger("felis") 

55 

56MYSQL = "mysql" 

57ORACLE = "oracle" 

58POSTGRES = "postgresql" 

59SQLITE = "sqlite" 

60 

61TABLE_OPTS = { 

62 "mysql:engine": "mysql_engine", 

63 "mysql:charset": "mysql_charset", 

64 "oracle:compress": "oracle_compress", 

65} 

66 

67COLUMN_VARIANT_OVERRIDE = { 

68 "mysql:datatype": "mysql", 

69 "oracle:datatype": "oracle", 

70 "postgresql:datatype": "postgresql", 

71 "sqlite:datatype": "sqlite", 

72} 

73 

74DIALECT_MODULES = {MYSQL: mysql, ORACLE: oracle, SQLITE: sqlite, POSTGRES: postgresql} 

75 

76length_regex = re.compile(r"\((.+)\)") 

77 

78 

79class Schema(NamedTuple): 

80 name: Optional[str] 

81 tables: list[Table] 

82 metadata: MetaData 

83 graph_index: Mapping[str, Any] 

84 

85 

86class SQLVisitor(Visitor[Schema, Table, Column, Optional[PrimaryKeyConstraint], Constraint, Index]): 

87 """A Felis Visitor which populates a SQLAlchemy metadata object. 

88 

89 Parameters 

90 ---------- 

91 schema_name : `str`, optional 

92 Override for the schema name. 

93 """ 

94 

95 def __init__(self, schema_name: Optional[str] = None): 

96 self.metadata = MetaData() 

97 self.schema_name = schema_name 

98 self.checker = FelisValidator() 

99 self.graph_index: MutableMapping[str, Any] = {} 

100 

101 def visit_schema(self, schema_obj: _Mapping) -> Schema: 

102 # Docstring is inherited. 

103 self.checker.check_schema(schema_obj) 

104 schema = Schema( 

105 name=self.schema_name or schema_obj["name"], 

106 tables=[self.visit_table(t, schema_obj) for t in schema_obj["tables"]], 

107 metadata=self.metadata, 

108 graph_index=self.graph_index, 

109 ) 

110 return schema 

111 

112 def visit_table(self, table_obj: _Mapping, schema_obj: _Mapping) -> Table: 

113 # Docstring is inherited. 

114 self.checker.check_table(table_obj, schema_obj) 

115 columns = [self.visit_column(c, table_obj) for c in table_obj["columns"]] 

116 

117 name = table_obj["name"] 

118 table_id = table_obj["@id"] 

119 description = table_obj.get("description") 

120 schema_name = self.schema_name or schema_obj["name"] 

121 

122 table = Table(name, self.metadata, *columns, schema=schema_name, comment=description) 

123 

124 primary_key = self.visit_primary_key(table_obj.get("primaryKey", []), table_obj) 

125 if primary_key: 

126 table.append_constraint(primary_key) 

127 

128 constraints = [self.visit_constraint(c, table_obj) for c in table_obj.get("constraints", [])] 

129 for constraint in constraints: 

130 table.append_constraint(constraint) 

131 

132 indexes = [self.visit_index(i, table_obj) for i in table_obj.get("indexes", [])] 

133 for index in indexes: 

134 # FIXME: Hack because there's no table.add_index 

135 index._set_parent(table) 

136 table.indexes.add(index) 

137 self.graph_index[table_id] = table 

138 return table 

139 

140 def visit_column(self, column_obj: _Mapping, table_obj: _Mapping) -> Column: 

141 # Docstring is inherited. 

142 self.checker.check_column(column_obj, table_obj) 

143 column_name = column_obj["name"] 

144 column_id = column_obj["@id"] 

145 datatype_name = column_obj["datatype"] 

146 column_description = column_obj.get("description") 

147 column_default = column_obj.get("value") 

148 column_length = column_obj.get("length") 

149 

150 kwargs = {} 

151 for column_opt in column_obj.keys(): 

152 if column_opt in COLUMN_VARIANT_OVERRIDE: 

153 dialect = COLUMN_VARIANT_OVERRIDE[column_opt] 

154 variant = _process_variant_override(dialect, column_obj[column_opt]) 

155 kwargs[dialect] = variant 

156 

157 felis_type = FelisType.felis_type(datatype_name) 

158 datatype_fun = getattr(sqltypes, datatype_name) 

159 

160 if felis_type.is_sized: 

161 datatype = datatype_fun(column_length, **kwargs) 

162 else: 

163 datatype = datatype_fun(**kwargs) 

164 

165 nullable_default = True 

166 if isinstance(datatype, Numeric): 

167 nullable_default = False 

168 

169 column_nullable = column_obj.get("nullable", nullable_default) 

170 column_autoincrement = column_obj.get("autoincrement", "auto") 

171 

172 column: Column = Column( 

173 column_name, 

174 datatype, 

175 comment=column_description, 

176 autoincrement=column_autoincrement, 

177 nullable=column_nullable, 

178 server_default=column_default, 

179 ) 

180 if column_id in self.graph_index: 

181 logger.warning(f"Duplication of @id {column_id}") 

182 self.graph_index[column_id] = column 

183 return column 

184 

185 def visit_primary_key( 

186 self, primary_key_obj: Union[str, Iterable[str]], table_obj: _Mapping 

187 ) -> Optional[PrimaryKeyConstraint]: 

188 # Docstring is inherited. 

189 self.checker.check_primary_key(primary_key_obj, table_obj) 

190 if primary_key_obj: 

191 if isinstance(primary_key_obj, str): 

192 primary_key_obj = [primary_key_obj] 

193 columns = [self.graph_index[c_id] for c_id in primary_key_obj] 

194 return PrimaryKeyConstraint(*columns) 

195 return None 

196 

197 def visit_constraint(self, constraint_obj: _Mapping, table_obj: _Mapping) -> Constraint: 

198 # Docstring is inherited. 

199 self.checker.check_constraint(constraint_obj, table_obj) 

200 constraint_type = constraint_obj["@type"] 

201 constraint_id = constraint_obj["@id"] 

202 

203 constraint_args: _MutableMapping = {} 

204 # The following are not used on every constraint 

205 _set_if("name", constraint_obj.get("name"), constraint_args) 

206 _set_if("info", constraint_obj.get("description"), constraint_args) 

207 _set_if("expression", constraint_obj.get("expression"), constraint_args) 

208 _set_if("deferrable", constraint_obj.get("deferrable"), constraint_args) 

209 _set_if("initially", constraint_obj.get("initially"), constraint_args) 

210 

211 columns = [self.graph_index[c_id] for c_id in constraint_obj.get("columns", [])] 

212 constraint: Constraint 

213 if constraint_type == "ForeignKey": 

214 refcolumns = [self.graph_index[c_id] for c_id in constraint_obj.get("referencedColumns", [])] 

215 constraint = ForeignKeyConstraint(columns, refcolumns, **constraint_args) 

216 elif constraint_type == "Check": 

217 expression = constraint_obj["expression"] 

218 constraint = CheckConstraint(expression, **constraint_args) 

219 elif constraint_type == "Unique": 

220 constraint = UniqueConstraint(*columns, **constraint_args) 

221 else: 

222 raise ValueError(f"Unexpected constraint type: {constraint_type}") 

223 self.graph_index[constraint_id] = constraint 

224 return constraint 

225 

226 def visit_index(self, index_obj: _Mapping, table_obj: _Mapping) -> Index: 

227 # Docstring is inherited. 

228 self.checker.check_index(index_obj, table_obj) 

229 name = index_obj["name"] 

230 description = index_obj.get("description") 

231 columns = [self.graph_index[c_id] for c_id in index_obj.get("columns", [])] 

232 expressions = index_obj.get("expressions", []) 

233 return Index(name, *columns, *expressions, info=description) 

234 

235 

236def _set_if(key: str, value: Any, mapping: _MutableMapping) -> None: 

237 if value is not None: 

238 mapping[key] = value 

239 

240 

241def _process_variant_override(dialect_name: str, variant_override_str: str) -> types.TypeEngine: 

242 """Simple Data Type Override""" 

243 match = length_regex.search(variant_override_str) 

244 dialect = DIALECT_MODULES[dialect_name] 

245 variant_type_name = variant_override_str.split("(")[0] 

246 

247 # Process Variant Type 

248 if variant_type_name not in dir(dialect): 

249 raise ValueError(f"Type {variant_type_name} not found in dialect {dialect_name}") 

250 variant_type = getattr(dialect, variant_type_name) 

251 length_params = [] 

252 if match: 

253 length_params.extend([int(i) for i in match.group(1).split(",")]) 

254 return variant_type(*length_params)