Coverage for python/felis/tap.py: 12%

255 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-25 10:20 -0700

1# This file is part of felis. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["Tap11Base", "TapLoadingVisitor", "init_tables"] 

25 

26import logging 

27from collections.abc import Iterable, MutableMapping 

28from typing import Any 

29 

30from sqlalchemy import Column, Integer, String 

31from sqlalchemy.engine import Engine 

32from sqlalchemy.engine.mock import MockConnection 

33from sqlalchemy.orm import Session, declarative_base, sessionmaker 

34from sqlalchemy.schema import MetaData 

35from sqlalchemy.sql.expression import Insert, insert 

36 

37from felis import datamodel 

38 

39from .datamodel import Constraint, Index, Schema, Table 

40from .types import FelisType 

41 

42Tap11Base: Any = declarative_base() # Any to avoid mypy mess with SA 2 

43logger = logging.getLogger(__name__) 

44 

45IDENTIFIER_LENGTH = 128 

46SMALL_FIELD_LENGTH = 32 

47SIMPLE_FIELD_LENGTH = 128 

48TEXT_FIELD_LENGTH = 2048 

49QUALIFIED_TABLE_LENGTH = 3 * IDENTIFIER_LENGTH + 2 

50 

51_init_table_once = False 

52 

53 

54def init_tables( 

55 tap_schema_name: str | None = None, 

56 tap_tables_postfix: str | None = None, 

57 tap_schemas_table: str | None = None, 

58 tap_tables_table: str | None = None, 

59 tap_columns_table: str | None = None, 

60 tap_keys_table: str | None = None, 

61 tap_key_columns_table: str | None = None, 

62) -> MutableMapping[str, Any]: 

63 """Generate definitions for TAP tables.""" 

64 postfix = tap_tables_postfix or "" 

65 

66 # Dirty hack to enable this method to be called more than once, replaces 

67 # MetaData instance with a fresh copy if called more than once. 

68 # TODO: probably replace ORM stuff with core sqlalchemy functions. 

69 global _init_table_once 

70 if not _init_table_once: 

71 _init_table_once = True 

72 else: 

73 Tap11Base.metadata = MetaData() 

74 

75 if tap_schema_name: 

76 Tap11Base.metadata.schema = tap_schema_name 

77 

78 class Tap11Schemas(Tap11Base): 

79 __tablename__ = (tap_schemas_table or "schemas") + postfix 

80 schema_name = Column(String(IDENTIFIER_LENGTH), primary_key=True, nullable=False) 

81 utype = Column(String(SIMPLE_FIELD_LENGTH)) 

82 description = Column(String(TEXT_FIELD_LENGTH)) 

83 schema_index = Column(Integer) 

84 

85 class Tap11Tables(Tap11Base): 

86 __tablename__ = (tap_tables_table or "tables") + postfix 

87 schema_name = Column(String(IDENTIFIER_LENGTH), nullable=False) 

88 table_name = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False, primary_key=True) 

89 table_type = Column(String(SMALL_FIELD_LENGTH), nullable=False) 

90 utype = Column(String(SIMPLE_FIELD_LENGTH)) 

91 description = Column(String(TEXT_FIELD_LENGTH)) 

92 table_index = Column(Integer) 

93 

94 class Tap11Columns(Tap11Base): 

95 __tablename__ = (tap_columns_table or "columns") + postfix 

96 table_name = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False, primary_key=True) 

97 column_name = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True) 

98 datatype = Column(String(SIMPLE_FIELD_LENGTH), nullable=False) 

99 arraysize = Column(String(10)) 

100 xtype = Column(String(SIMPLE_FIELD_LENGTH)) 

101 # Size is deprecated 

102 # size = Column(Integer(), quote=True) 

103 description = Column(String(TEXT_FIELD_LENGTH)) 

104 utype = Column(String(SIMPLE_FIELD_LENGTH)) 

105 unit = Column(String(SIMPLE_FIELD_LENGTH)) 

106 ucd = Column(String(SIMPLE_FIELD_LENGTH)) 

107 indexed = Column(Integer, nullable=False) 

108 principal = Column(Integer, nullable=False) 

109 std = Column(Integer, nullable=False) 

110 column_index = Column(Integer) 

111 

112 class Tap11Keys(Tap11Base): 

113 __tablename__ = (tap_keys_table or "keys") + postfix 

114 key_id = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True) 

115 from_table = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False) 

116 target_table = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False) 

117 description = Column(String(TEXT_FIELD_LENGTH)) 

118 utype = Column(String(SIMPLE_FIELD_LENGTH)) 

119 

120 class Tap11KeyColumns(Tap11Base): 

121 __tablename__ = (tap_key_columns_table or "key_columns") + postfix 

122 key_id = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True) 

123 from_column = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True) 

124 target_column = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True) 

125 

126 return dict( 

127 schemas=Tap11Schemas, 

128 tables=Tap11Tables, 

129 columns=Tap11Columns, 

130 keys=Tap11Keys, 

131 key_columns=Tap11KeyColumns, 

132 ) 

133 

134 

135class TapLoadingVisitor: 

136 """Felis schema visitor for generating TAP schema. 

137 

138 Parameters 

139 ---------- 

140 engine : `sqlalchemy.engine.Engine` or `None` 

141 SQLAlchemy engine instance. 

142 catalog_name : `str` or `None` 

143 Name of the database catalog. 

144 schema_name : `str` or `None` 

145 Name of the database schema. 

146 tap_tables : `~collections.abc.Mapping` 

147 Optional mapping of table name to its declarative base class. 

148 """ 

149 

150 def __init__( 

151 self, 

152 engine: Engine | None, 

153 catalog_name: str | None = None, 

154 schema_name: str | None = None, 

155 tap_tables: MutableMapping[str, Any] | None = None, 

156 tap_schema_index: int | None = None, 

157 ): 

158 self.graph_index: MutableMapping[str, Any] = {} 

159 self.catalog_name = catalog_name 

160 self.schema_name = schema_name 

161 self.engine = engine 

162 self._mock_connection: MockConnection | None = None 

163 self.tables = tap_tables or init_tables() 

164 self.tap_schema_index = tap_schema_index 

165 

166 @classmethod 

167 def from_mock_connection( 

168 cls, 

169 mock_connection: MockConnection, 

170 catalog_name: str | None = None, 

171 schema_name: str | None = None, 

172 tap_tables: MutableMapping[str, Any] | None = None, 

173 tap_schema_index: int | None = None, 

174 ) -> TapLoadingVisitor: 

175 visitor = cls(engine=None, catalog_name=catalog_name, schema_name=schema_name, tap_tables=tap_tables) 

176 visitor._mock_connection = mock_connection 

177 visitor.tap_schema_index = tap_schema_index 

178 return visitor 

179 

180 def visit_schema(self, schema_obj: Schema) -> None: 

181 schema = self.tables["schemas"]() 

182 # Override with default 

183 self.schema_name = self.schema_name or schema_obj.name 

184 

185 schema.schema_name = self._schema_name() 

186 schema.description = schema_obj.description 

187 schema.utype = schema_obj.votable_utype 

188 schema.schema_index = self.tap_schema_index 

189 logger.debug("Set TAP_SCHEMA index: {}".format(self.tap_schema_index)) 

190 

191 if self.engine is not None: 

192 session: Session = sessionmaker(self.engine)() 

193 

194 session.add(schema) 

195 

196 for table_obj in schema_obj.tables: 

197 table, columns = self.visit_table(table_obj, schema_obj) 

198 session.add(table) 

199 session.add_all(columns) 

200 

201 keys, key_columns = self.visit_constraints(schema_obj) 

202 session.add_all(keys) 

203 session.add_all(key_columns) 

204 

205 logger.debug("Committing TAP schema: %s", schema_obj.name) 

206 logger.debug("TAP tables: %s", len(self.tables)) 

207 session.commit() 

208 else: 

209 logger.info("Dry run, not inserting into database") 

210 

211 # Only if we are mocking (dry run) 

212 assert self._mock_connection is not None, "Mock connection must not be None" 

213 conn = self._mock_connection 

214 conn.execute(_insert(self.tables["schemas"], schema)) 

215 

216 for table_obj in schema_obj.tables: 

217 table, columns = self.visit_table(table_obj, schema_obj) 

218 conn.execute(_insert(self.tables["tables"], table)) 

219 for column in columns: 

220 conn.execute(_insert(self.tables["columns"], column)) 

221 

222 keys, key_columns = self.visit_constraints(schema_obj) 

223 for key in keys: 

224 conn.execute(_insert(self.tables["keys"], key)) 

225 for key_column in key_columns: 

226 conn.execute(_insert(self.tables["key_columns"], key_column)) 

227 

228 def visit_constraints(self, schema_obj: Schema) -> tuple: 

229 all_keys = [] 

230 all_key_columns = [] 

231 for table_obj in schema_obj.tables: 

232 for c in table_obj.constraints: 

233 key, key_columns = self.visit_constraint(c) 

234 if not key: 

235 continue 

236 all_keys.append(key) 

237 all_key_columns += key_columns 

238 return all_keys, all_key_columns 

239 

240 def visit_table(self, table_obj: Table, schema_obj: Schema) -> tuple: 

241 table_id = table_obj.id 

242 table = self.tables["tables"]() 

243 table.schema_name = self._schema_name() 

244 table.table_name = self._table_name(table_obj.name) 

245 table.table_type = "table" 

246 table.utype = table_obj.votable_utype 

247 table.description = table_obj.description 

248 table.table_index = 0 if table_obj.tap_table_index is None else table_obj.tap_table_index 

249 

250 columns = [self.visit_column(c, table_obj) for c in table_obj.columns] 

251 self.visit_primary_key(table_obj.primary_key, table_obj) 

252 

253 for i in table_obj.indexes: 

254 self.visit_index(i, table) 

255 

256 self.graph_index[table_id] = table 

257 return table, columns 

258 

259 def check_column(self, column_obj: datamodel.Column) -> None: 

260 _id = column_obj.id 

261 datatype_name = column_obj.datatype 

262 felis_type = FelisType.felis_type(datatype_name.value) 

263 if felis_type.is_sized: 

264 # It is expected that both arraysize and length are fine for 

265 # length types. 

266 arraysize = column_obj.votable_arraysize or column_obj.length 

267 if arraysize is None: 

268 logger.warning( 

269 f"votable:arraysize and length for {_id} are None for type {datatype_name}. " 

270 'Using length "*". ' 

271 "Consider setting `votable:arraysize` or `length`." 

272 ) 

273 if felis_type.is_timestamp: 

274 # datetime types really should have a votable:arraysize, because 

275 # they are converted to strings and the `length` is loosely to the 

276 # string size 

277 if not column_obj.votable_arraysize: 

278 logger.warning( 

279 f"votable:arraysize for {_id} is None for type {datatype_name}. " 

280 f'Using length "*". ' 

281 "Consider setting `votable:arraysize` to an appropriate size for " 

282 "materialized datetime/timestamp strings." 

283 ) 

284 

285 def visit_column(self, column_obj: datamodel.Column, table_obj: Table) -> Tap11Base: 

286 self.check_column(column_obj) 

287 column_id = column_obj.id 

288 table_name = self._table_name(table_obj.name) 

289 

290 column = self.tables["columns"]() 

291 column.table_name = table_name 

292 column.column_name = column_obj.name 

293 

294 felis_datatype = column_obj.datatype 

295 felis_type = FelisType.felis_type(felis_datatype.value) 

296 column.datatype = column_obj.votable_datatype or felis_type.votable_name 

297 

298 arraysize = None 

299 if felis_type.is_sized: 

300 arraysize = column_obj.votable_arraysize or column_obj.length or "*" 

301 if felis_type.is_timestamp: 

302 arraysize = column_obj.votable_arraysize or "*" 

303 column.arraysize = arraysize 

304 

305 column.xtype = column_obj.votable_xtype 

306 column.description = column_obj.description 

307 column.utype = column_obj.votable_utype 

308 

309 unit = column_obj.ivoa_unit or column_obj.fits_tunit 

310 column.unit = unit 

311 column.ucd = column_obj.ivoa_ucd 

312 

313 # We modify this after we process columns 

314 column.indexed = 0 

315 

316 column.principal = column_obj.tap_principal 

317 column.std = column_obj.tap_std 

318 column.column_index = column_obj.tap_column_index 

319 

320 self.graph_index[column_id] = column 

321 return column 

322 

323 def visit_primary_key(self, primary_key_obj: str | Iterable[str] | None, table_obj: Table) -> None: 

324 if primary_key_obj: 

325 if isinstance(primary_key_obj, str): 

326 primary_key_obj = [primary_key_obj] 

327 columns = [self.graph_index[c_id] for c_id in primary_key_obj] 

328 # if just one column and it's indexed, update the object 

329 if len(columns) == 1: 

330 columns[0].indexed = 1 

331 return None 

332 

333 def visit_constraint(self, constraint_obj: Constraint) -> tuple: 

334 constraint_type = constraint_obj.type 

335 key = None 

336 key_columns = [] 

337 if constraint_type == "ForeignKey": 

338 constraint_name = constraint_obj.name 

339 description = constraint_obj.description 

340 utype = constraint_obj.votable_utype 

341 

342 columns = [self.graph_index[col_id] for col_id in getattr(constraint_obj, "columns", [])] 

343 refcolumns = [ 

344 self.graph_index[refcol_id] for refcol_id in getattr(constraint_obj, "referenced_columns", []) 

345 ] 

346 

347 table_name = None 

348 for column in columns: 

349 if not table_name: 

350 table_name = column.table_name 

351 if table_name != column.table_name: 

352 raise ValueError("Inconsisent use of table names") 

353 

354 table_name = None 

355 for column in refcolumns: 

356 if not table_name: 

357 table_name = column.table_name 

358 if table_name != column.table_name: 

359 raise ValueError("Inconsisent use of table names") 

360 first_column = columns[0] 

361 first_refcolumn = refcolumns[0] 

362 

363 key = self.tables["keys"]() 

364 key.key_id = constraint_name 

365 key.from_table = first_column.table_name 

366 key.target_table = first_refcolumn.table_name 

367 key.description = description 

368 key.utype = utype 

369 for column, refcolumn in zip(columns, refcolumns): 

370 key_column = self.tables["key_columns"]() 

371 key_column.key_id = constraint_name 

372 key_column.from_column = column.column_name 

373 key_column.target_column = refcolumn.column_name 

374 key_columns.append(key_column) 

375 return key, key_columns 

376 

377 def visit_index(self, index_obj: Index, table_obj: Table) -> None: 

378 columns = [self.graph_index[col_id] for col_id in getattr(index_obj, "columns", [])] 

379 # if just one column and it's indexed, update the object 

380 if len(columns) == 1: 

381 columns[0].indexed = 1 

382 return None 

383 

384 def _schema_name(self, schema_name: str | None = None) -> str | None: 

385 # If _schema_name is None, SQLAlchemy will catch it 

386 _schema_name = schema_name or self.schema_name 

387 if self.catalog_name and _schema_name: 

388 return ".".join([self.catalog_name, _schema_name]) 

389 return _schema_name 

390 

391 def _table_name(self, table_name: str) -> str: 

392 schema_name = self._schema_name() 

393 if schema_name: 

394 return ".".join([schema_name, table_name]) 

395 return table_name 

396 

397 

398def _insert(table: Tap11Base, value: Any) -> Insert: 

399 """Return a SQLAlchemy insert statement. 

400 

401 Parameters 

402 ---------- 

403 table : `Tap11Base` 

404 The table we are inserting into. 

405 value : `Any` 

406 An object representing the object we are inserting to the table. 

407 

408 Returns 

409 ------- 

410 statement 

411 A SQLAlchemy insert statement 

412 """ 

413 values_dict = {} 

414 for i in table.__table__.columns: 

415 name = i.name 

416 column_value = getattr(value, i.name) 

417 if isinstance(column_value, str): 

418 column_value = column_value.replace("'", "''") 

419 values_dict[name] = column_value 

420 return insert(table).values(values_dict)