Coverage for python/felis/tap.py: 12%

262 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-14 10:16 -0700

1# This file is part of felis. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["Tap11Base", "TapLoadingVisitor", "init_tables"] 

25 

26import logging 

27from collections.abc import Iterable, Mapping, MutableMapping 

28from typing import Any 

29 

30from sqlalchemy import Column, Integer, String 

31from sqlalchemy.engine import Engine 

32from sqlalchemy.engine.mock import MockConnection 

33from sqlalchemy.orm import Session, declarative_base, sessionmaker 

34from sqlalchemy.schema import MetaData 

35from sqlalchemy.sql.expression import Insert, insert 

36 

37from .check import FelisValidator 

38from .types import FelisType 

39from .visitor import Visitor 

40 

41_Mapping = Mapping[str, Any] 

42 

43Tap11Base: Any = declarative_base() # Any to avoid mypy mess with SA 2 

44logger = logging.getLogger("felis") 

45 

46IDENTIFIER_LENGTH = 128 

47SMALL_FIELD_LENGTH = 32 

48SIMPLE_FIELD_LENGTH = 128 

49TEXT_FIELD_LENGTH = 2048 

50QUALIFIED_TABLE_LENGTH = 3 * IDENTIFIER_LENGTH + 2 

51 

52_init_table_once = False 

53 

54 

55def init_tables( 

56 tap_schema_name: str | None = None, 

57 tap_tables_postfix: str | None = None, 

58 tap_schemas_table: str | None = None, 

59 tap_tables_table: str | None = None, 

60 tap_columns_table: str | None = None, 

61 tap_keys_table: str | None = None, 

62 tap_key_columns_table: str | None = None, 

63) -> MutableMapping[str, Any]: 

64 """Generate definitions for TAP tables.""" 

65 postfix = tap_tables_postfix or "" 

66 

67 # Dirty hack to enable this method to be called more than once, replaces 

68 # MetaData instance with a fresh copy if called more than once. 

69 # TODO: probably replace ORM stuff with core sqlalchemy functions. 

70 global _init_table_once 

71 if not _init_table_once: 

72 _init_table_once = True 

73 else: 

74 Tap11Base.metadata = MetaData() 

75 

76 if tap_schema_name: 

77 Tap11Base.metadata.schema = tap_schema_name 

78 

79 class Tap11Schemas(Tap11Base): 

80 __tablename__ = (tap_schemas_table or "schemas") + postfix 

81 schema_name = Column(String(IDENTIFIER_LENGTH), primary_key=True, nullable=False) 

82 utype = Column(String(SIMPLE_FIELD_LENGTH)) 

83 description = Column(String(TEXT_FIELD_LENGTH)) 

84 schema_index = Column(Integer) 

85 

86 class Tap11Tables(Tap11Base): 

87 __tablename__ = (tap_tables_table or "tables") + postfix 

88 schema_name = Column(String(IDENTIFIER_LENGTH), nullable=False) 

89 table_name = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False, primary_key=True) 

90 table_type = Column(String(SMALL_FIELD_LENGTH), nullable=False) 

91 utype = Column(String(SIMPLE_FIELD_LENGTH)) 

92 description = Column(String(TEXT_FIELD_LENGTH)) 

93 table_index = Column(Integer) 

94 

95 class Tap11Columns(Tap11Base): 

96 __tablename__ = (tap_columns_table or "columns") + postfix 

97 table_name = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False, primary_key=True) 

98 column_name = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True) 

99 datatype = Column(String(SIMPLE_FIELD_LENGTH), nullable=False) 

100 arraysize = Column(String(10)) 

101 xtype = Column(String(SIMPLE_FIELD_LENGTH)) 

102 # Size is deprecated 

103 # size = Column(Integer(), quote=True) 

104 description = Column(String(TEXT_FIELD_LENGTH)) 

105 utype = Column(String(SIMPLE_FIELD_LENGTH)) 

106 unit = Column(String(SIMPLE_FIELD_LENGTH)) 

107 ucd = Column(String(SIMPLE_FIELD_LENGTH)) 

108 indexed = Column(Integer, nullable=False) 

109 principal = Column(Integer, nullable=False) 

110 std = Column(Integer, nullable=False) 

111 column_index = Column(Integer) 

112 

113 class Tap11Keys(Tap11Base): 

114 __tablename__ = (tap_keys_table or "keys") + postfix 

115 key_id = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True) 

116 from_table = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False) 

117 target_table = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False) 

118 description = Column(String(TEXT_FIELD_LENGTH)) 

119 utype = Column(String(SIMPLE_FIELD_LENGTH)) 

120 

121 class Tap11KeyColumns(Tap11Base): 

122 __tablename__ = (tap_key_columns_table or "key_columns") + postfix 

123 key_id = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True) 

124 from_column = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True) 

125 target_column = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True) 

126 

127 return dict( 

128 schemas=Tap11Schemas, 

129 tables=Tap11Tables, 

130 columns=Tap11Columns, 

131 keys=Tap11Keys, 

132 key_columns=Tap11KeyColumns, 

133 ) 

134 

135 

136class TapLoadingVisitor(Visitor[None, tuple, Tap11Base, None, tuple, None, None]): 

137 """Felis schema visitor for generating TAP schema. 

138 

139 Parameters 

140 ---------- 

141 engine : `sqlalchemy.engine.Engine` or `None` 

142 SQLAlchemy engine instance. 

143 catalog_name : `str` or `None` 

144 Name of the database catalog. 

145 schema_name : `str` or `None` 

146 Name of the database schema. 

147 tap_tables : `~collections.abc.Mapping` 

148 Optional mapping of table name to its declarative base class. 

149 """ 

150 

151 def __init__( 

152 self, 

153 engine: Engine | None, 

154 catalog_name: str | None = None, 

155 schema_name: str | None = None, 

156 tap_tables: MutableMapping[str, Any] | None = None, 

157 ): 

158 self.graph_index: MutableMapping[str, Any] = {} 

159 self.catalog_name = catalog_name 

160 self.schema_name = schema_name 

161 self.engine = engine 

162 self._mock_connection: MockConnection | None = None 

163 self.tables = tap_tables or init_tables() 

164 self.checker = FelisValidator() 

165 

166 @classmethod 

167 def from_mock_connection( 

168 cls, 

169 mock_connection: MockConnection, 

170 catalog_name: str | None = None, 

171 schema_name: str | None = None, 

172 tap_tables: MutableMapping[str, Any] | None = None, 

173 ) -> TapLoadingVisitor: 

174 visitor = cls(engine=None, catalog_name=catalog_name, schema_name=schema_name, tap_tables=tap_tables) 

175 visitor._mock_connection = mock_connection 

176 return visitor 

177 

178 def visit_schema(self, schema_obj: _Mapping) -> None: 

179 self.checker.check_schema(schema_obj) 

180 if (version_obj := schema_obj.get("version")) is not None: 

181 self.visit_schema_version(version_obj, schema_obj) 

182 schema = self.tables["schemas"]() 

183 # Override with default 

184 self.schema_name = self.schema_name or schema_obj["name"] 

185 

186 schema.schema_name = self._schema_name() 

187 schema.description = schema_obj.get("description") 

188 schema.utype = schema_obj.get("votable:utype") 

189 schema.schema_index = int(schema_obj.get("tap:schema_index", 0)) 

190 

191 if self.engine is not None: 

192 session: Session = sessionmaker(self.engine)() 

193 

194 session.add(schema) 

195 

196 for table_obj in schema_obj["tables"]: 

197 table, columns = self.visit_table(table_obj, schema_obj) 

198 session.add(table) 

199 session.add_all(columns) 

200 

201 keys, key_columns = self.visit_constraints(schema_obj) 

202 session.add_all(keys) 

203 session.add_all(key_columns) 

204 

205 session.commit() 

206 else: 

207 logger.info("Dry run, not inserting into database") 

208 

209 # Only if we are mocking (dry run) 

210 assert self._mock_connection is not None, "Mock connection must not be None" 

211 conn = self._mock_connection 

212 conn.execute(_insert(self.tables["schemas"], schema)) 

213 

214 for table_obj in schema_obj["tables"]: 

215 table, columns = self.visit_table(table_obj, schema_obj) 

216 conn.execute(_insert(self.tables["tables"], table)) 

217 for column in columns: 

218 conn.execute(_insert(self.tables["columns"], column)) 

219 

220 keys, key_columns = self.visit_constraints(schema_obj) 

221 for key in keys: 

222 conn.execute(_insert(self.tables["keys"], key)) 

223 for key_column in key_columns: 

224 conn.execute(_insert(self.tables["key_columns"], key_column)) 

225 

226 def visit_constraints(self, schema_obj: _Mapping) -> tuple: 

227 all_keys = [] 

228 all_key_columns = [] 

229 for table_obj in schema_obj["tables"]: 

230 for c in table_obj.get("constraints", []): 

231 key, key_columns = self.visit_constraint(c, table_obj) 

232 if not key: 

233 continue 

234 all_keys.append(key) 

235 all_key_columns += key_columns 

236 return all_keys, all_key_columns 

237 

238 def visit_schema_version( 

239 self, version_obj: str | Mapping[str, Any], schema_obj: Mapping[str, Any] 

240 ) -> None: 

241 # Docstring is inherited. 

242 

243 # For now we ignore schema versioning completely, still do some checks. 

244 self.checker.check_schema_version(version_obj, schema_obj) 

245 

246 def visit_table(self, table_obj: _Mapping, schema_obj: _Mapping) -> tuple: 

247 self.checker.check_table(table_obj, schema_obj) 

248 table_id = table_obj["@id"] 

249 table = self.tables["tables"]() 

250 table.schema_name = self._schema_name() 

251 table.table_name = self._table_name(table_obj["name"]) 

252 table.table_type = "table" 

253 table.utype = table_obj.get("votable:utype") 

254 table.description = table_obj.get("description") 

255 table.table_index = int(table_obj.get("tap:table_index", 0)) 

256 

257 columns = [self.visit_column(c, table_obj) for c in table_obj["columns"]] 

258 self.visit_primary_key(table_obj.get("primaryKey", []), table_obj) 

259 

260 for i in table_obj.get("indexes", []): 

261 self.visit_index(i, table) 

262 

263 self.graph_index[table_id] = table 

264 return table, columns 

265 

266 def check_column(self, column_obj: _Mapping, table_obj: _Mapping) -> None: 

267 self.checker.check_column(column_obj, table_obj) 

268 _id = column_obj["@id"] 

269 # Guaranteed to exist at this point, for mypy use "" as default 

270 datatype_name = column_obj.get("datatype", "") 

271 felis_type = FelisType.felis_type(datatype_name) 

272 if felis_type.is_sized: 

273 # It is expected that both arraysize and length are fine for 

274 # length types. 

275 arraysize = column_obj.get("votable:arraysize", column_obj.get("length")) 

276 if arraysize is None: 

277 logger.warning( 

278 f"votable:arraysize and length for {_id} are None for type {datatype_name}. " 

279 'Using length "*". ' 

280 "Consider setting `votable:arraysize` or `length`." 

281 ) 

282 if felis_type.is_timestamp: 

283 # datetime types really should have a votable:arraysize, because 

284 # they are converted to strings and the `length` is loosely to the 

285 # string size 

286 if "votable:arraysize" not in column_obj: 

287 logger.warning( 

288 f"votable:arraysize for {_id} is None for type {datatype_name}. " 

289 f'Using length "*". ' 

290 "Consider setting `votable:arraysize` to an appropriate size for " 

291 "materialized datetime/timestamp strings." 

292 ) 

293 

294 def visit_column(self, column_obj: _Mapping, table_obj: _Mapping) -> Tap11Base: 

295 self.check_column(column_obj, table_obj) 

296 column_id = column_obj["@id"] 

297 table_name = self._table_name(table_obj["name"]) 

298 

299 column = self.tables["columns"]() 

300 column.table_name = table_name 

301 column.column_name = column_obj["name"] 

302 

303 felis_datatype = column_obj["datatype"] 

304 felis_type = FelisType.felis_type(felis_datatype) 

305 column.datatype = column_obj.get("votable:datatype", felis_type.votable_name) 

306 

307 arraysize = None 

308 if felis_type.is_sized: 

309 # prefer votable:arraysize to length, fall back to `*` 

310 arraysize = column_obj.get("votable:arraysize", column_obj.get("length", "*")) 

311 if felis_type.is_timestamp: 

312 arraysize = column_obj.get("votable:arraysize", "*") 

313 column.arraysize = arraysize 

314 

315 column.xtype = column_obj.get("votable:xtype") 

316 column.description = column_obj.get("description") 

317 column.utype = column_obj.get("votable:utype") 

318 

319 unit = column_obj.get("ivoa:unit") or column_obj.get("fits:tunit") 

320 column.unit = unit 

321 column.ucd = column_obj.get("ivoa:ucd") 

322 

323 # We modify this after we process columns 

324 column.indexed = 0 

325 

326 column.principal = column_obj.get("tap:principal", 0) 

327 column.std = column_obj.get("tap:std", 0) 

328 column.column_index = column_obj.get("tap:column_index") 

329 

330 self.graph_index[column_id] = column 

331 return column 

332 

333 def visit_primary_key(self, primary_key_obj: str | Iterable[str], table_obj: _Mapping) -> None: 

334 self.checker.check_primary_key(primary_key_obj, table_obj) 

335 if primary_key_obj: 

336 if isinstance(primary_key_obj, str): 

337 primary_key_obj = [primary_key_obj] 

338 columns = [self.graph_index[c_id] for c_id in primary_key_obj] 

339 # if just one column and it's indexed, update the object 

340 if len(columns) == 1: 

341 columns[0].indexed = 1 

342 return None 

343 

344 def visit_constraint(self, constraint_obj: _Mapping, table_obj: _Mapping) -> tuple: 

345 self.checker.check_constraint(constraint_obj, table_obj) 

346 constraint_type = constraint_obj["@type"] 

347 key = None 

348 key_columns = [] 

349 if constraint_type == "ForeignKey": 

350 constraint_name = constraint_obj["name"] 

351 description = constraint_obj.get("description") 

352 utype = constraint_obj.get("votable:utype") 

353 

354 columns = [self.graph_index[col["@id"]] for col in constraint_obj.get("columns", [])] 

355 refcolumns = [ 

356 self.graph_index[refcol["@id"]] for refcol in constraint_obj.get("referencedColumns", []) 

357 ] 

358 

359 table_name = None 

360 for column in columns: 

361 if not table_name: 

362 table_name = column.table_name 

363 if table_name != column.table_name: 

364 raise ValueError("Inconsisent use of table names") 

365 

366 table_name = None 

367 for column in refcolumns: 

368 if not table_name: 

369 table_name = column.table_name 

370 if table_name != column.table_name: 

371 raise ValueError("Inconsisent use of table names") 

372 first_column = columns[0] 

373 first_refcolumn = refcolumns[0] 

374 

375 key = self.tables["keys"]() 

376 key.key_id = constraint_name 

377 key.from_table = first_column.table_name 

378 key.target_table = first_refcolumn.table_name 

379 key.description = description 

380 key.utype = utype 

381 for column, refcolumn in zip(columns, refcolumns): 

382 key_column = self.tables["key_columns"]() 

383 key_column.key_id = constraint_name 

384 key_column.from_column = column.column_name 

385 key_column.target_column = refcolumn.column_name 

386 key_columns.append(key_column) 

387 return key, key_columns 

388 

389 def visit_index(self, index_obj: _Mapping, table_obj: _Mapping) -> None: 

390 self.checker.check_index(index_obj, table_obj) 

391 columns = [self.graph_index[col["@id"]] for col in index_obj.get("columns", [])] 

392 # if just one column and it's indexed, update the object 

393 if len(columns) == 1: 

394 columns[0].indexed = 1 

395 return None 

396 

397 def _schema_name(self, schema_name: str | None = None) -> str | None: 

398 # If _schema_name is None, SQLAlchemy will catch it 

399 _schema_name = schema_name or self.schema_name 

400 if self.catalog_name and _schema_name: 

401 return ".".join([self.catalog_name, _schema_name]) 

402 return _schema_name 

403 

404 def _table_name(self, table_name: str) -> str: 

405 schema_name = self._schema_name() 

406 if schema_name: 

407 return ".".join([schema_name, table_name]) 

408 return table_name 

409 

410 

411def _insert(table: Tap11Base, value: Any) -> Insert: 

412 """Return a SQLAlchemy insert statement. 

413 

414 Parameters 

415 ---------- 

416 table : `Tap11Base` 

417 The table we are inserting into. 

418 value : `Any` 

419 An object representing the object we are inserting to the table. 

420 

421 Returns 

422 ------- 

423 statement 

424 A SQLAlchemy insert statement 

425 """ 

426 values_dict = {} 

427 for i in table.__table__.columns: 

428 name = i.name 

429 column_value = getattr(value, i.name) 

430 if isinstance(column_value, str): 

431 column_value = column_value.replace("'", "''") 

432 values_dict[name] = column_value 

433 return insert(table).values(values_dict)