Coverage for python/felis/tap.py: 12%

263 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-23 10:44 +0000

1# This file is part of felis. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["Tap11Base", "TapLoadingVisitor", "init_tables"] 

25 

26import logging 

27from collections.abc import Iterable, Mapping, MutableMapping 

28from typing import Any 

29 

30from sqlalchemy import Column, Integer, String 

31from sqlalchemy.engine import Engine 

32from sqlalchemy.engine.mock import MockConnection 

33from sqlalchemy.ext.declarative import declarative_base 

34from sqlalchemy.orm import Session, sessionmaker 

35from sqlalchemy.schema import MetaData 

36from sqlalchemy.sql.expression import Insert, insert 

37 

38from .check import FelisValidator 

39from .types import FelisType 

40from .visitor import Visitor 

41 

42_Mapping = Mapping[str, Any] 

43 

44Tap11Base: Any = declarative_base() # Any to avoid mypy mess with SA 2 

45logger = logging.getLogger("felis") 

46 

47IDENTIFIER_LENGTH = 128 

48SMALL_FIELD_LENGTH = 32 

49SIMPLE_FIELD_LENGTH = 128 

50TEXT_FIELD_LENGTH = 2048 

51QUALIFIED_TABLE_LENGTH = 3 * IDENTIFIER_LENGTH + 2 

52 

53_init_table_once = False 

54 

55 

56def init_tables( 

57 tap_schema_name: str | None = None, 

58 tap_tables_postfix: str | None = None, 

59 tap_schemas_table: str | None = None, 

60 tap_tables_table: str | None = None, 

61 tap_columns_table: str | None = None, 

62 tap_keys_table: str | None = None, 

63 tap_key_columns_table: str | None = None, 

64) -> MutableMapping[str, Any]: 

65 """Generate definitions for TAP tables.""" 

66 postfix = tap_tables_postfix or "" 

67 

68 # Dirty hack to enable this method to be called more than once, replaces 

69 # MetaData instance with a fresh copy if called more than once. 

70 # TODO: probably replace ORM stuff with core sqlalchemy functions. 

71 global _init_table_once 

72 if not _init_table_once: 

73 _init_table_once = True 

74 else: 

75 Tap11Base.metadata = MetaData() 

76 

77 if tap_schema_name: 

78 Tap11Base.metadata.schema = tap_schema_name 

79 

80 class Tap11Schemas(Tap11Base): 

81 __tablename__ = (tap_schemas_table or "schemas") + postfix 

82 schema_name = Column(String(IDENTIFIER_LENGTH), primary_key=True, nullable=False) 

83 utype = Column(String(SIMPLE_FIELD_LENGTH)) 

84 description = Column(String(TEXT_FIELD_LENGTH)) 

85 schema_index = Column(Integer) 

86 

87 class Tap11Tables(Tap11Base): 

88 __tablename__ = (tap_tables_table or "tables") + postfix 

89 schema_name = Column(String(IDENTIFIER_LENGTH), nullable=False) 

90 table_name = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False, primary_key=True) 

91 table_type = Column(String(SMALL_FIELD_LENGTH), nullable=False) 

92 utype = Column(String(SIMPLE_FIELD_LENGTH)) 

93 description = Column(String(TEXT_FIELD_LENGTH)) 

94 table_index = Column(Integer) 

95 

96 class Tap11Columns(Tap11Base): 

97 __tablename__ = (tap_columns_table or "columns") + postfix 

98 table_name = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False, primary_key=True) 

99 column_name = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True) 

100 datatype = Column(String(SIMPLE_FIELD_LENGTH), nullable=False) 

101 arraysize = Column(String(10)) 

102 xtype = Column(String(SIMPLE_FIELD_LENGTH)) 

103 # Size is deprecated 

104 # size = Column(Integer(), quote=True) 

105 description = Column(String(TEXT_FIELD_LENGTH)) 

106 utype = Column(String(SIMPLE_FIELD_LENGTH)) 

107 unit = Column(String(SIMPLE_FIELD_LENGTH)) 

108 ucd = Column(String(SIMPLE_FIELD_LENGTH)) 

109 indexed = Column(Integer, nullable=False) 

110 principal = Column(Integer, nullable=False) 

111 std = Column(Integer, nullable=False) 

112 column_index = Column(Integer) 

113 

114 class Tap11Keys(Tap11Base): 

115 __tablename__ = (tap_keys_table or "keys") + postfix 

116 key_id = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True) 

117 from_table = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False) 

118 target_table = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False) 

119 description = Column(String(TEXT_FIELD_LENGTH)) 

120 utype = Column(String(SIMPLE_FIELD_LENGTH)) 

121 

122 class Tap11KeyColumns(Tap11Base): 

123 __tablename__ = (tap_key_columns_table or "key_columns") + postfix 

124 key_id = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True) 

125 from_column = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True) 

126 target_column = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True) 

127 

128 return dict( 

129 schemas=Tap11Schemas, 

130 tables=Tap11Tables, 

131 columns=Tap11Columns, 

132 keys=Tap11Keys, 

133 key_columns=Tap11KeyColumns, 

134 ) 

135 

136 

137class TapLoadingVisitor(Visitor[None, tuple, Tap11Base, None, tuple, None, None]): 

138 """Felis schema visitor for generating TAP schema. 

139 

140 Parameters 

141 ---------- 

142 engine : `sqlalchemy.engine.Engine` or `None` 

143 SQLAlchemy engine instance. 

144 catalog_name : `str` or `None` 

145 Name of the database catalog. 

146 schema_name : `str` or `None` 

147 Name of the database schema. 

148 tap_tables : `~collections.abc.Mapping` 

149 Optional mapping of table name to its declarative base class. 

150 """ 

151 

152 def __init__( 

153 self, 

154 engine: Engine | None, 

155 catalog_name: str | None = None, 

156 schema_name: str | None = None, 

157 tap_tables: MutableMapping[str, Any] | None = None, 

158 ): 

159 self.graph_index: MutableMapping[str, Any] = {} 

160 self.catalog_name = catalog_name 

161 self.schema_name = schema_name 

162 self.engine = engine 

163 self._mock_connection: MockConnection | None = None 

164 self.tables = tap_tables or init_tables() 

165 self.checker = FelisValidator() 

166 

167 @classmethod 

168 def from_mock_connection( 

169 cls, 

170 mock_connection: MockConnection, 

171 catalog_name: str | None = None, 

172 schema_name: str | None = None, 

173 tap_tables: MutableMapping[str, Any] | None = None, 

174 ) -> TapLoadingVisitor: 

175 visitor = cls(engine=None, catalog_name=catalog_name, schema_name=schema_name, tap_tables=tap_tables) 

176 visitor._mock_connection = mock_connection 

177 return visitor 

178 

179 def visit_schema(self, schema_obj: _Mapping) -> None: 

180 self.checker.check_schema(schema_obj) 

181 if (version_obj := schema_obj.get("version")) is not None: 

182 self.visit_schema_version(version_obj, schema_obj) 

183 schema = self.tables["schemas"]() 

184 # Override with default 

185 self.schema_name = self.schema_name or schema_obj["name"] 

186 

187 schema.schema_name = self._schema_name() 

188 schema.description = schema_obj.get("description") 

189 schema.utype = schema_obj.get("votable:utype") 

190 schema.schema_index = int(schema_obj.get("tap:schema_index", 0)) 

191 

192 if self.engine is not None: 

193 session: Session = sessionmaker(self.engine)() 

194 

195 session.add(schema) 

196 

197 for table_obj in schema_obj["tables"]: 

198 table, columns = self.visit_table(table_obj, schema_obj) 

199 session.add(table) 

200 session.add_all(columns) 

201 

202 keys, key_columns = self.visit_constraints(schema_obj) 

203 session.add_all(keys) 

204 session.add_all(key_columns) 

205 

206 session.commit() 

207 else: 

208 logger.info("Dry run, not inserting into database") 

209 

210 # Only if we are mocking (dry run) 

211 assert self._mock_connection is not None, "Mock connection must not be None" 

212 conn = self._mock_connection 

213 conn.execute(_insert(self.tables["schemas"], schema)) 

214 

215 for table_obj in schema_obj["tables"]: 

216 table, columns = self.visit_table(table_obj, schema_obj) 

217 conn.execute(_insert(self.tables["tables"], table)) 

218 for column in columns: 

219 conn.execute(_insert(self.tables["columns"], column)) 

220 

221 keys, key_columns = self.visit_constraints(schema_obj) 

222 for key in keys: 

223 conn.execute(_insert(self.tables["keys"], key)) 

224 for key_column in key_columns: 

225 conn.execute(_insert(self.tables["key_columns"], key_column)) 

226 

227 def visit_constraints(self, schema_obj: _Mapping) -> tuple: 

228 all_keys = [] 

229 all_key_columns = [] 

230 for table_obj in schema_obj["tables"]: 

231 for c in table_obj.get("constraints", []): 

232 key, key_columns = self.visit_constraint(c, table_obj) 

233 if not key: 

234 continue 

235 all_keys.append(key) 

236 all_key_columns += key_columns 

237 return all_keys, all_key_columns 

238 

239 def visit_schema_version( 

240 self, version_obj: str | Mapping[str, Any], schema_obj: Mapping[str, Any] 

241 ) -> None: 

242 # Docstring is inherited. 

243 

244 # For now we ignore schema versioning completely, still do some checks. 

245 self.checker.check_schema_version(version_obj, schema_obj) 

246 

247 def visit_table(self, table_obj: _Mapping, schema_obj: _Mapping) -> tuple: 

248 self.checker.check_table(table_obj, schema_obj) 

249 table_id = table_obj["@id"] 

250 table = self.tables["tables"]() 

251 table.schema_name = self._schema_name() 

252 table.table_name = self._table_name(table_obj["name"]) 

253 table.table_type = "table" 

254 table.utype = table_obj.get("votable:utype") 

255 table.description = table_obj.get("description") 

256 table.table_index = int(table_obj.get("tap:table_index", 0)) 

257 

258 columns = [self.visit_column(c, table_obj) for c in table_obj["columns"]] 

259 self.visit_primary_key(table_obj.get("primaryKey", []), table_obj) 

260 

261 for i in table_obj.get("indexes", []): 

262 self.visit_index(i, table) 

263 

264 self.graph_index[table_id] = table 

265 return table, columns 

266 

267 def check_column(self, column_obj: _Mapping, table_obj: _Mapping) -> None: 

268 self.checker.check_column(column_obj, table_obj) 

269 _id = column_obj["@id"] 

270 # Guaranteed to exist at this point, for mypy use "" as default 

271 datatype_name = column_obj.get("datatype", "") 

272 felis_type = FelisType.felis_type(datatype_name) 

273 if felis_type.is_sized: 

274 # It is expected that both arraysize and length are fine for 

275 # length types. 

276 arraysize = column_obj.get("votable:arraysize", column_obj.get("length")) 

277 if arraysize is None: 

278 logger.warning( 

279 f"votable:arraysize and length for {_id} are None for type {datatype_name}. " 

280 'Using length "*". ' 

281 "Consider setting `votable:arraysize` or `length`." 

282 ) 

283 if felis_type.is_timestamp: 

284 # datetime types really should have a votable:arraysize, because 

285 # they are converted to strings and the `length` is loosely to the 

286 # string size 

287 if "votable:arraysize" not in column_obj: 

288 logger.warning( 

289 f"votable:arraysize for {_id} is None for type {datatype_name}. " 

290 f'Using length "*". ' 

291 "Consider setting `votable:arraysize` to an appropriate size for " 

292 "materialized datetime/timestamp strings." 

293 ) 

294 

295 def visit_column(self, column_obj: _Mapping, table_obj: _Mapping) -> Tap11Base: 

296 self.check_column(column_obj, table_obj) 

297 column_id = column_obj["@id"] 

298 table_name = self._table_name(table_obj["name"]) 

299 

300 column = self.tables["columns"]() 

301 column.table_name = table_name 

302 column.column_name = column_obj["name"] 

303 

304 felis_datatype = column_obj["datatype"] 

305 felis_type = FelisType.felis_type(felis_datatype) 

306 column.datatype = column_obj.get("votable:datatype", felis_type.votable_name) 

307 

308 arraysize = None 

309 if felis_type.is_sized: 

310 # prefer votable:arraysize to length, fall back to `*` 

311 arraysize = column_obj.get("votable:arraysize", column_obj.get("length", "*")) 

312 if felis_type.is_timestamp: 

313 arraysize = column_obj.get("votable:arraysize", "*") 

314 column.arraysize = arraysize 

315 

316 column.xtype = column_obj.get("votable:xtype") 

317 column.description = column_obj.get("description") 

318 column.utype = column_obj.get("votable:utype") 

319 

320 unit = column_obj.get("ivoa:unit") or column_obj.get("fits:tunit") 

321 column.unit = unit 

322 column.ucd = column_obj.get("ivoa:ucd") 

323 

324 # We modify this after we process columns 

325 column.indexed = 0 

326 

327 column.principal = column_obj.get("tap:principal", 0) 

328 column.std = column_obj.get("tap:std", 0) 

329 column.column_index = column_obj.get("tap:column_index") 

330 

331 self.graph_index[column_id] = column 

332 return column 

333 

334 def visit_primary_key(self, primary_key_obj: str | Iterable[str], table_obj: _Mapping) -> None: 

335 self.checker.check_primary_key(primary_key_obj, table_obj) 

336 if primary_key_obj: 

337 if isinstance(primary_key_obj, str): 

338 primary_key_obj = [primary_key_obj] 

339 columns = [self.graph_index[c_id] for c_id in primary_key_obj] 

340 # if just one column and it's indexed, update the object 

341 if len(columns) == 1: 

342 columns[0].indexed = 1 

343 return None 

344 

345 def visit_constraint(self, constraint_obj: _Mapping, table_obj: _Mapping) -> tuple: 

346 self.checker.check_constraint(constraint_obj, table_obj) 

347 constraint_type = constraint_obj["@type"] 

348 key = None 

349 key_columns = [] 

350 if constraint_type == "ForeignKey": 

351 constraint_name = constraint_obj["name"] 

352 description = constraint_obj.get("description") 

353 utype = constraint_obj.get("votable:utype") 

354 

355 columns = [self.graph_index[col["@id"]] for col in constraint_obj.get("columns", [])] 

356 refcolumns = [ 

357 self.graph_index[refcol["@id"]] for refcol in constraint_obj.get("referencedColumns", []) 

358 ] 

359 

360 table_name = None 

361 for column in columns: 

362 if not table_name: 

363 table_name = column.table_name 

364 if table_name != column.table_name: 

365 raise ValueError("Inconsisent use of table names") 

366 

367 table_name = None 

368 for column in refcolumns: 

369 if not table_name: 

370 table_name = column.table_name 

371 if table_name != column.table_name: 

372 raise ValueError("Inconsisent use of table names") 

373 first_column = columns[0] 

374 first_refcolumn = refcolumns[0] 

375 

376 key = self.tables["keys"]() 

377 key.key_id = constraint_name 

378 key.from_table = first_column.table_name 

379 key.target_table = first_refcolumn.table_name 

380 key.description = description 

381 key.utype = utype 

382 for column, refcolumn in zip(columns, refcolumns): 

383 key_column = self.tables["key_columns"]() 

384 key_column.key_id = constraint_name 

385 key_column.from_column = column.column_name 

386 key_column.target_column = refcolumn.column_name 

387 key_columns.append(key_column) 

388 return key, key_columns 

389 

390 def visit_index(self, index_obj: _Mapping, table_obj: _Mapping) -> None: 

391 self.checker.check_index(index_obj, table_obj) 

392 columns = [self.graph_index[col["@id"]] for col in index_obj.get("columns", [])] 

393 # if just one column and it's indexed, update the object 

394 if len(columns) == 1: 

395 columns[0].indexed = 1 

396 return None 

397 

398 def _schema_name(self, schema_name: str | None = None) -> str | None: 

399 # If _schema_name is None, SQLAlchemy will catch it 

400 _schema_name = schema_name or self.schema_name 

401 if self.catalog_name and _schema_name: 

402 return ".".join([self.catalog_name, _schema_name]) 

403 return _schema_name 

404 

405 def _table_name(self, table_name: str) -> str: 

406 schema_name = self._schema_name() 

407 if schema_name: 

408 return ".".join([schema_name, table_name]) 

409 return table_name 

410 

411 

412def _insert(table: Tap11Base, value: Any) -> Insert: 

413 """Return a SQLAlchemy insert statement. 

414 

415 Parameters 

416 ---------- 

417 table : `Tap11Base` 

418 The table we are inserting into. 

419 value : `Any` 

420 An object representing the object we are inserting to the table. 

421 

422 Returns 

423 ------- 

424 statement 

425 A SQLAlchemy insert statement 

426 """ 

427 values_dict = {} 

428 for i in table.__table__.columns: 

429 name = i.name 

430 column_value = getattr(value, i.name) 

431 if isinstance(column_value, str): 

432 column_value = column_value.replace("'", "''") 

433 values_dict[name] = column_value 

434 return insert(table).values(values_dict)