Coverage for python/felis/tap.py: 11%

246 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-03-09 03:01 -0800

1# This file is part of felis. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["Tap11Base", "TapLoadingVisitor", "init_tables"] 

25 

26import logging 

27from collections.abc import Iterable, Mapping, MutableMapping 

28from typing import Any, Optional, Union 

29 

30from sqlalchemy import Column, Integer, String 

31from sqlalchemy.engine import Engine 

32from sqlalchemy.ext.declarative import declarative_base 

33from sqlalchemy.orm import Session, sessionmaker 

34from sqlalchemy.schema import MetaData 

35from sqlalchemy.sql.expression import Insert, insert 

36 

37from .check import FelisValidator 

38from .types import FelisType 

39from .visitor import Visitor 

40 

41_Mapping = Mapping[str, Any] 

42 

43Tap11Base: Any = declarative_base() # Any to avoid mypy mess with SA 2 

44logger = logging.getLogger("felis") 

45 

46IDENTIFIER_LENGTH = 128 

47SMALL_FIELD_LENGTH = 32 

48SIMPLE_FIELD_LENGTH = 128 

49TEXT_FIELD_LENGTH = 2048 

50QUALIFIED_TABLE_LENGTH = 3 * IDENTIFIER_LENGTH + 2 

51 

52_init_table_once = False 

53 

54 

55def init_tables( 

56 tap_schema_name: Optional[str] = None, 

57 tap_tables_postfix: Optional[str] = None, 

58 tap_schemas_table: Optional[str] = None, 

59 tap_tables_table: Optional[str] = None, 

60 tap_columns_table: Optional[str] = None, 

61 tap_keys_table: Optional[str] = None, 

62 tap_key_columns_table: Optional[str] = None, 

63) -> MutableMapping[str, Any]: 

64 postfix = tap_tables_postfix or "" 

65 

66 # Dirty hack to enable this method to be called more than once, replaces 

67 # MetaData instance with a fresh copy if called more than once. 

68 # TODO: probably replace ORM stuff with core sqlalchemy functions. 

69 global _init_table_once 

70 if not _init_table_once: 

71 _init_table_once = True 

72 else: 

73 Tap11Base.metadata = MetaData() 

74 

75 if tap_schema_name: 

76 Tap11Base.metadata.schema = tap_schema_name 

77 

78 class Tap11Schemas(Tap11Base): 

79 __tablename__ = (tap_schemas_table or "schemas") + postfix 

80 schema_name = Column(String(IDENTIFIER_LENGTH), primary_key=True, nullable=False) 

81 utype = Column(String(SIMPLE_FIELD_LENGTH)) 

82 description = Column(String(TEXT_FIELD_LENGTH)) 

83 schema_index = Column(Integer) 

84 

85 class Tap11Tables(Tap11Base): 

86 __tablename__ = (tap_tables_table or "tables") + postfix 

87 schema_name = Column(String(IDENTIFIER_LENGTH), nullable=False) 

88 table_name = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False, primary_key=True) 

89 table_type = Column(String(SMALL_FIELD_LENGTH), nullable=False) 

90 utype = Column(String(SIMPLE_FIELD_LENGTH)) 

91 description = Column(String(TEXT_FIELD_LENGTH)) 

92 table_index = Column(Integer) 

93 

94 class Tap11Columns(Tap11Base): 

95 __tablename__ = (tap_columns_table or "columns") + postfix 

96 table_name = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False, primary_key=True) 

97 column_name = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True) 

98 datatype = Column(String(SIMPLE_FIELD_LENGTH), nullable=False) 

99 arraysize = Column(String(10)) 

100 xtype = Column(String(SIMPLE_FIELD_LENGTH)) 

101 # Size is deprecated 

102 # size = Column(Integer(), quote=True) 

103 description = Column(String(TEXT_FIELD_LENGTH)) 

104 utype = Column(String(SIMPLE_FIELD_LENGTH)) 

105 unit = Column(String(SIMPLE_FIELD_LENGTH)) 

106 ucd = Column(String(SIMPLE_FIELD_LENGTH)) 

107 indexed = Column(Integer, nullable=False) 

108 principal = Column(Integer, nullable=False) 

109 std = Column(Integer, nullable=False) 

110 column_index = Column(Integer) 

111 

112 class Tap11Keys(Tap11Base): 

113 __tablename__ = (tap_keys_table or "keys") + postfix 

114 key_id = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True) 

115 from_table = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False) 

116 target_table = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False) 

117 description = Column(String(TEXT_FIELD_LENGTH)) 

118 utype = Column(String(SIMPLE_FIELD_LENGTH)) 

119 

120 class Tap11KeyColumns(Tap11Base): 

121 __tablename__ = (tap_key_columns_table or "key_columns") + postfix 

122 key_id = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True) 

123 from_column = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True) 

124 target_column = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True) 

125 

126 return dict( 

127 schemas=Tap11Schemas, 

128 tables=Tap11Tables, 

129 columns=Tap11Columns, 

130 keys=Tap11Keys, 

131 key_columns=Tap11KeyColumns, 

132 ) 

133 

134 

135class TapLoadingVisitor(Visitor[None, tuple, Tap11Base, None, tuple, None]): 

136 def __init__( 

137 self, 

138 engine: Engine, 

139 catalog_name: Optional[str] = None, 

140 schema_name: Optional[str] = None, 

141 mock: bool = False, 

142 tap_tables: Optional[MutableMapping[str, Any]] = None, 

143 ): 

144 self.graph_index: MutableMapping[str, Any] = {} 

145 self.catalog_name = catalog_name 

146 self.schema_name = schema_name 

147 self.engine = engine 

148 self.mock = mock 

149 self.tables = tap_tables or init_tables() 

150 self.checker = FelisValidator() 

151 

152 def visit_schema(self, schema_obj: _Mapping) -> None: 

153 self.checker.check_schema(schema_obj) 

154 schema = self.tables["schemas"]() 

155 # Override with default 

156 self.schema_name = self.schema_name or schema_obj["name"] 

157 

158 schema.schema_name = self._schema_name() 

159 schema.description = schema_obj.get("description") 

160 schema.utype = schema_obj.get("votable:utype") 

161 schema.schema_index = int(schema_obj.get("tap:schema_index", 0)) 

162 

163 if not self.mock: 

164 session: Session = sessionmaker(self.engine)() 

165 session.add(schema) 

166 for table_obj in schema_obj["tables"]: 

167 table, columns, keys, key_columns = self.visit_table(table_obj, schema_obj) 

168 session.add(table) 

169 session.add_all(columns) 

170 session.add_all(keys) 

171 session.add_all(key_columns) 

172 session.commit() 

173 else: 

174 # Only if we are mocking (dry run) 

175 with self.engine.begin() as conn: 

176 conn.execute(_insert(self.tables["schemas"], schema)) 

177 for table_obj in schema_obj["tables"]: 

178 table, columns, keys, key_columns = self.visit_table(table_obj, schema_obj) 

179 conn.execute(_insert(self.tables["tables"], table)) 

180 for column in columns: 

181 conn.execute(_insert(self.tables["columns"], column)) 

182 for key in keys: 

183 conn.execute(_insert(self.tables["keys"], key)) 

184 for key_column in key_columns: 

185 conn.execute(_insert(self.tables["key_columns"], key_column)) 

186 

187 def visit_table(self, table_obj: _Mapping, schema_obj: _Mapping) -> tuple: 

188 self.checker.check_table(table_obj, schema_obj) 

189 table_id = table_obj["@id"] 

190 table = self.tables["tables"]() 

191 table.schema_name = self._schema_name() 

192 table.table_name = self._table_name(table_obj["name"]) 

193 table.table_type = "table" 

194 table.utype = table_obj.get("votable:utype") 

195 table.description = table_obj.get("description") 

196 table.table_index = int(table_obj.get("tap:table_index", 0)) 

197 

198 columns = [self.visit_column(c, table_obj) for c in table_obj["columns"]] 

199 self.visit_primary_key(table_obj.get("primaryKey", []), table_obj) 

200 all_keys = [] 

201 all_key_columns = [] 

202 for c in table_obj.get("constraints", []): 

203 key, key_columns = self.visit_constraint(c, table) 

204 if not key: 

205 continue 

206 all_keys.append(key) 

207 all_key_columns += key_columns 

208 

209 for i in table_obj.get("indexes", []): 

210 self.visit_index(i, table) 

211 

212 self.graph_index[table_id] = table 

213 return table, columns, all_keys, all_key_columns 

214 

215 def check_column(self, column_obj: _Mapping, table_obj: _Mapping) -> None: 

216 self.checker.check_column(column_obj, table_obj) 

217 _id = column_obj["@id"] 

218 # Guaranteed to exist at this point, for mypy use "" as default 

219 datatype_name = column_obj.get("datatype", "") 

220 felis_type = FelisType.felis_type(datatype_name) 

221 if felis_type.is_sized: 

222 # It is expected that both arraysize and length are fine for 

223 # length types. 

224 arraysize = column_obj.get("votable:arraysize", column_obj.get("length")) 

225 if arraysize is None: 

226 logger.warning( 

227 f"votable:arraysize and length for {_id} are None for type {datatype_name}. " 

228 'Using length "*". ' 

229 "Consider setting `votable:arraysize` or `length`." 

230 ) 

231 if felis_type.is_timestamp: 

232 # datetime types really should have a votable:arraysize, because 

233 # they are converted to strings and the `length` is loosely to the 

234 # string size 

235 if "votable:arraysize" not in column_obj: 

236 logger.warning( 

237 f"votable:arraysize for {_id} is None for type {datatype_name}. " 

238 f'Using length "*". ' 

239 "Consider setting `votable:arraysize` to an appropriate size for " 

240 "materialized datetime/timestamp strings." 

241 ) 

242 

243 def visit_column(self, column_obj: _Mapping, table_obj: _Mapping) -> Tap11Base: 

244 self.check_column(column_obj, table_obj) 

245 column_id = column_obj["@id"] 

246 table_name = self._table_name(table_obj["name"]) 

247 

248 column = self.tables["columns"]() 

249 column.table_name = table_name 

250 column.column_name = column_obj["name"] 

251 

252 felis_datatype = column_obj["datatype"] 

253 felis_type = FelisType.felis_type(felis_datatype) 

254 column.datatype = column_obj.get("votable:datatype", felis_type.votable_name) 

255 

256 arraysize = None 

257 if felis_type.is_sized: 

258 # prefer votable:arraysize to length, fall back to `*` 

259 arraysize = column_obj.get("votable:arraysize", column_obj.get("length", "*")) 

260 if felis_type.is_timestamp: 

261 arraysize = column_obj.get("votable:arraysize", "*") 

262 column.arraysize = arraysize 

263 

264 column.xtype = column_obj.get("votable:xtype") 

265 column.description = column_obj.get("description") 

266 column.utype = column_obj.get("votable:utype") 

267 

268 unit = column_obj.get("ivoa:unit") or column_obj.get("fits:tunit") 

269 column.unit = unit 

270 column.ucd = column_obj.get("ivoa:ucd") 

271 

272 # We modify this after we process columns 

273 column.indexed = 0 

274 

275 column.principal = column_obj.get("tap:principal", 0) 

276 column.std = column_obj.get("tap:std", 0) 

277 column.column_index = column_obj.get("tap:column_index") 

278 

279 self.graph_index[column_id] = column 

280 return column 

281 

282 def visit_primary_key(self, primary_key_obj: Union[str, Iterable[str]], table_obj: _Mapping) -> None: 

283 self.checker.check_primary_key(primary_key_obj, table_obj) 

284 if primary_key_obj: 

285 if isinstance(primary_key_obj, str): 

286 primary_key_obj = [primary_key_obj] 

287 columns = [self.graph_index[c_id] for c_id in primary_key_obj] 

288 # if just one column and it's indexed, update the object 

289 if len(columns) == 1: 

290 columns[0].indexed = 1 

291 return None 

292 

293 def visit_constraint(self, constraint_obj: _Mapping, table_obj: _Mapping) -> tuple: 

294 self.checker.check_constraint(constraint_obj, table_obj) 

295 constraint_type = constraint_obj["@type"] 

296 key = None 

297 key_columns = [] 

298 if constraint_type == "ForeignKey": 

299 constraint_name = constraint_obj["name"] 

300 description = constraint_obj.get("description") 

301 utype = constraint_obj.get("votable:utype") 

302 

303 columns = [self.graph_index[c_id] for c_id in constraint_obj.get("columns", [])] 

304 refcolumns = [self.graph_index[c_id] for c_id in constraint_obj.get("referencedColumns", [])] 

305 

306 table_name = None 

307 for column in columns: 

308 if not table_name: 

309 table_name = column.table_name 

310 if table_name != column.table_name: 

311 raise ValueError("Inconsisent use of table names") 

312 

313 table_name = None 

314 for column in refcolumns: 

315 if not table_name: 

316 table_name = column.table_name 

317 if table_name != column.table_name: 

318 raise ValueError("Inconsisent use of table names") 

319 first_column = columns[0] 

320 first_refcolumn = refcolumns[0] 

321 

322 key = self.tables["keys"]() 

323 key.key_id = constraint_name 

324 key.from_table = first_column.table_name 

325 key.target_table = first_refcolumn.table_name 

326 key.description = description 

327 key.utype = utype 

328 for column, refcolumn in zip(columns, refcolumns): 

329 key_column = self.tables["key_columns"]() 

330 key_column.key_id = constraint_name 

331 key_column.from_column = column.column_name 

332 key_column.target_column = refcolumn.column_name 

333 key_columns.append(key_column) 

334 return key, key_columns 

335 

336 def visit_index(self, index_obj: _Mapping, table_obj: _Mapping) -> None: 

337 self.checker.check_index(index_obj, table_obj) 

338 columns = [self.graph_index[c_id] for c_id in index_obj.get("columns", [])] 

339 # if just one column and it's indexed, update the object 

340 if len(columns) == 1: 

341 columns[0].indexed = 1 

342 return None 

343 

344 def _schema_name(self, schema_name: Optional[str] = None) -> Optional[str]: 

345 # If _schema_name is None, SQLAlchemy will catch it 

346 _schema_name = schema_name or self.schema_name 

347 if self.catalog_name and _schema_name: 

348 return ".".join([self.catalog_name, _schema_name]) 

349 return _schema_name 

350 

351 def _table_name(self, table_name: str) -> str: 

352 schema_name = self._schema_name() 

353 if schema_name: 

354 return ".".join([schema_name, table_name]) 

355 return table_name 

356 

357 

358def _insert(table: Tap11Base, value: Any) -> Insert: 

359 """ 

360 Return a SQLAlchemy insert statement based on 

361 :param table: The table we are inserting to 

362 :param value: An object representing the object we are inserting 

363 to the table 

364 :return: A SQLAlchemy insert statement 

365 """ 

366 values_dict = {} 

367 for i in table.__table__.columns: 

368 name = i.name 

369 column_value = getattr(value, i.name) 

370 if type(column_value) == str: 

371 column_value = column_value.replace("'", "''") 

372 values_dict[name] = column_value 

373 return insert(table).values(values_dict)