Coverage for python/felis/tap.py: 12%

253 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-15 02:20 -0700

1# This file is part of felis. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["Tap11Base", "TapLoadingVisitor", "init_tables"] 

25 

26import logging 

27from collections.abc import Iterable, Mapping, MutableMapping 

28from typing import Any, Optional, Union 

29 

30from sqlalchemy import Column, Integer, String 

31from sqlalchemy.engine import Engine 

32from sqlalchemy.engine.mock import MockConnection 

33from sqlalchemy.ext.declarative import declarative_base 

34from sqlalchemy.orm import Session, sessionmaker 

35from sqlalchemy.schema import MetaData 

36from sqlalchemy.sql.expression import Insert, insert 

37 

38from .check import FelisValidator 

39from .types import FelisType 

40from .visitor import Visitor 

41 

42_Mapping = Mapping[str, Any] 

43 

44Tap11Base: Any = declarative_base() # Any to avoid mypy mess with SA 2 

45logger = logging.getLogger("felis") 

46 

47IDENTIFIER_LENGTH = 128 

48SMALL_FIELD_LENGTH = 32 

49SIMPLE_FIELD_LENGTH = 128 

50TEXT_FIELD_LENGTH = 2048 

51QUALIFIED_TABLE_LENGTH = 3 * IDENTIFIER_LENGTH + 2 

52 

53_init_table_once = False 

54 

55 

56def init_tables( 

57 tap_schema_name: Optional[str] = None, 

58 tap_tables_postfix: Optional[str] = None, 

59 tap_schemas_table: Optional[str] = None, 

60 tap_tables_table: Optional[str] = None, 

61 tap_columns_table: Optional[str] = None, 

62 tap_keys_table: Optional[str] = None, 

63 tap_key_columns_table: Optional[str] = None, 

64) -> MutableMapping[str, Any]: 

65 postfix = tap_tables_postfix or "" 

66 

67 # Dirty hack to enable this method to be called more than once, replaces 

68 # MetaData instance with a fresh copy if called more than once. 

69 # TODO: probably replace ORM stuff with core sqlalchemy functions. 

70 global _init_table_once 

71 if not _init_table_once: 

72 _init_table_once = True 

73 else: 

74 Tap11Base.metadata = MetaData() 

75 

76 if tap_schema_name: 

77 Tap11Base.metadata.schema = tap_schema_name 

78 

79 class Tap11Schemas(Tap11Base): 

80 __tablename__ = (tap_schemas_table or "schemas") + postfix 

81 schema_name = Column(String(IDENTIFIER_LENGTH), primary_key=True, nullable=False) 

82 utype = Column(String(SIMPLE_FIELD_LENGTH)) 

83 description = Column(String(TEXT_FIELD_LENGTH)) 

84 schema_index = Column(Integer) 

85 

86 class Tap11Tables(Tap11Base): 

87 __tablename__ = (tap_tables_table or "tables") + postfix 

88 schema_name = Column(String(IDENTIFIER_LENGTH), nullable=False) 

89 table_name = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False, primary_key=True) 

90 table_type = Column(String(SMALL_FIELD_LENGTH), nullable=False) 

91 utype = Column(String(SIMPLE_FIELD_LENGTH)) 

92 description = Column(String(TEXT_FIELD_LENGTH)) 

93 table_index = Column(Integer) 

94 

95 class Tap11Columns(Tap11Base): 

96 __tablename__ = (tap_columns_table or "columns") + postfix 

97 table_name = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False, primary_key=True) 

98 column_name = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True) 

99 datatype = Column(String(SIMPLE_FIELD_LENGTH), nullable=False) 

100 arraysize = Column(String(10)) 

101 xtype = Column(String(SIMPLE_FIELD_LENGTH)) 

102 # Size is deprecated 

103 # size = Column(Integer(), quote=True) 

104 description = Column(String(TEXT_FIELD_LENGTH)) 

105 utype = Column(String(SIMPLE_FIELD_LENGTH)) 

106 unit = Column(String(SIMPLE_FIELD_LENGTH)) 

107 ucd = Column(String(SIMPLE_FIELD_LENGTH)) 

108 indexed = Column(Integer, nullable=False) 

109 principal = Column(Integer, nullable=False) 

110 std = Column(Integer, nullable=False) 

111 column_index = Column(Integer) 

112 

113 class Tap11Keys(Tap11Base): 

114 __tablename__ = (tap_keys_table or "keys") + postfix 

115 key_id = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True) 

116 from_table = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False) 

117 target_table = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False) 

118 description = Column(String(TEXT_FIELD_LENGTH)) 

119 utype = Column(String(SIMPLE_FIELD_LENGTH)) 

120 

121 class Tap11KeyColumns(Tap11Base): 

122 __tablename__ = (tap_key_columns_table or "key_columns") + postfix 

123 key_id = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True) 

124 from_column = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True) 

125 target_column = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True) 

126 

127 return dict( 

128 schemas=Tap11Schemas, 

129 tables=Tap11Tables, 

130 columns=Tap11Columns, 

131 keys=Tap11Keys, 

132 key_columns=Tap11KeyColumns, 

133 ) 

134 

135 

136class TapLoadingVisitor(Visitor[None, tuple, Tap11Base, None, tuple, None]): 

137 def __init__( 

138 self, 

139 engine: Engine | None, 

140 catalog_name: Optional[str] = None, 

141 schema_name: Optional[str] = None, 

142 tap_tables: Optional[MutableMapping[str, Any]] = None, 

143 ): 

144 self.graph_index: MutableMapping[str, Any] = {} 

145 self.catalog_name = catalog_name 

146 self.schema_name = schema_name 

147 self.engine = engine 

148 self._mock_connection: MockConnection | None = None 

149 self.tables = tap_tables or init_tables() 

150 self.checker = FelisValidator() 

151 

152 @classmethod 

153 def from_mock_connection( 

154 cls, 

155 mock_connection: MockConnection, 

156 catalog_name: Optional[str] = None, 

157 schema_name: Optional[str] = None, 

158 tap_tables: Optional[MutableMapping[str, Any]] = None, 

159 ) -> TapLoadingVisitor: 

160 visitor = cls(engine=None, catalog_name=catalog_name, schema_name=schema_name, tap_tables=tap_tables) 

161 visitor._mock_connection = mock_connection 

162 return visitor 

163 

164 def visit_schema(self, schema_obj: _Mapping) -> None: 

165 self.checker.check_schema(schema_obj) 

166 schema = self.tables["schemas"]() 

167 # Override with default 

168 self.schema_name = self.schema_name or schema_obj["name"] 

169 

170 schema.schema_name = self._schema_name() 

171 schema.description = schema_obj.get("description") 

172 schema.utype = schema_obj.get("votable:utype") 

173 schema.schema_index = int(schema_obj.get("tap:schema_index", 0)) 

174 

175 if self.engine is not None: 

176 session: Session = sessionmaker(self.engine)() 

177 session.add(schema) 

178 for table_obj in schema_obj["tables"]: 

179 table, columns, keys, key_columns = self.visit_table(table_obj, schema_obj) 

180 session.add(table) 

181 session.add_all(columns) 

182 session.add_all(keys) 

183 session.add_all(key_columns) 

184 session.commit() 

185 else: 

186 # Only if we are mocking (dry run) 

187 assert self._mock_connection is not None, "Mock connection must not be None" 

188 conn = self._mock_connection 

189 conn.execute(_insert(self.tables["schemas"], schema)) 

190 for table_obj in schema_obj["tables"]: 

191 table, columns, keys, key_columns = self.visit_table(table_obj, schema_obj) 

192 conn.execute(_insert(self.tables["tables"], table)) 

193 for column in columns: 

194 conn.execute(_insert(self.tables["columns"], column)) 

195 for key in keys: 

196 conn.execute(_insert(self.tables["keys"], key)) 

197 for key_column in key_columns: 

198 conn.execute(_insert(self.tables["key_columns"], key_column)) 

199 

200 def visit_table(self, table_obj: _Mapping, schema_obj: _Mapping) -> tuple: 

201 self.checker.check_table(table_obj, schema_obj) 

202 table_id = table_obj["@id"] 

203 table = self.tables["tables"]() 

204 table.schema_name = self._schema_name() 

205 table.table_name = self._table_name(table_obj["name"]) 

206 table.table_type = "table" 

207 table.utype = table_obj.get("votable:utype") 

208 table.description = table_obj.get("description") 

209 table.table_index = int(table_obj.get("tap:table_index", 0)) 

210 

211 columns = [self.visit_column(c, table_obj) for c in table_obj["columns"]] 

212 self.visit_primary_key(table_obj.get("primaryKey", []), table_obj) 

213 all_keys = [] 

214 all_key_columns = [] 

215 for c in table_obj.get("constraints", []): 

216 key, key_columns = self.visit_constraint(c, table) 

217 if not key: 

218 continue 

219 all_keys.append(key) 

220 all_key_columns += key_columns 

221 

222 for i in table_obj.get("indexes", []): 

223 self.visit_index(i, table) 

224 

225 self.graph_index[table_id] = table 

226 return table, columns, all_keys, all_key_columns 

227 

228 def check_column(self, column_obj: _Mapping, table_obj: _Mapping) -> None: 

229 self.checker.check_column(column_obj, table_obj) 

230 _id = column_obj["@id"] 

231 # Guaranteed to exist at this point, for mypy use "" as default 

232 datatype_name = column_obj.get("datatype", "") 

233 felis_type = FelisType.felis_type(datatype_name) 

234 if felis_type.is_sized: 

235 # It is expected that both arraysize and length are fine for 

236 # length types. 

237 arraysize = column_obj.get("votable:arraysize", column_obj.get("length")) 

238 if arraysize is None: 

239 logger.warning( 

240 f"votable:arraysize and length for {_id} are None for type {datatype_name}. " 

241 'Using length "*". ' 

242 "Consider setting `votable:arraysize` or `length`." 

243 ) 

244 if felis_type.is_timestamp: 

245 # datetime types really should have a votable:arraysize, because 

246 # they are converted to strings and the `length` is loosely to the 

247 # string size 

248 if "votable:arraysize" not in column_obj: 

249 logger.warning( 

250 f"votable:arraysize for {_id} is None for type {datatype_name}. " 

251 f'Using length "*". ' 

252 "Consider setting `votable:arraysize` to an appropriate size for " 

253 "materialized datetime/timestamp strings." 

254 ) 

255 

256 def visit_column(self, column_obj: _Mapping, table_obj: _Mapping) -> Tap11Base: 

257 self.check_column(column_obj, table_obj) 

258 column_id = column_obj["@id"] 

259 table_name = self._table_name(table_obj["name"]) 

260 

261 column = self.tables["columns"]() 

262 column.table_name = table_name 

263 column.column_name = column_obj["name"] 

264 

265 felis_datatype = column_obj["datatype"] 

266 felis_type = FelisType.felis_type(felis_datatype) 

267 column.datatype = column_obj.get("votable:datatype", felis_type.votable_name) 

268 

269 arraysize = None 

270 if felis_type.is_sized: 

271 # prefer votable:arraysize to length, fall back to `*` 

272 arraysize = column_obj.get("votable:arraysize", column_obj.get("length", "*")) 

273 if felis_type.is_timestamp: 

274 arraysize = column_obj.get("votable:arraysize", "*") 

275 column.arraysize = arraysize 

276 

277 column.xtype = column_obj.get("votable:xtype") 

278 column.description = column_obj.get("description") 

279 column.utype = column_obj.get("votable:utype") 

280 

281 unit = column_obj.get("ivoa:unit") or column_obj.get("fits:tunit") 

282 column.unit = unit 

283 column.ucd = column_obj.get("ivoa:ucd") 

284 

285 # We modify this after we process columns 

286 column.indexed = 0 

287 

288 column.principal = column_obj.get("tap:principal", 0) 

289 column.std = column_obj.get("tap:std", 0) 

290 column.column_index = column_obj.get("tap:column_index") 

291 

292 self.graph_index[column_id] = column 

293 return column 

294 

295 def visit_primary_key(self, primary_key_obj: Union[str, Iterable[str]], table_obj: _Mapping) -> None: 

296 self.checker.check_primary_key(primary_key_obj, table_obj) 

297 if primary_key_obj: 

298 if isinstance(primary_key_obj, str): 

299 primary_key_obj = [primary_key_obj] 

300 columns = [self.graph_index[c_id] for c_id in primary_key_obj] 

301 # if just one column and it's indexed, update the object 

302 if len(columns) == 1: 

303 columns[0].indexed = 1 

304 return None 

305 

306 def visit_constraint(self, constraint_obj: _Mapping, table_obj: _Mapping) -> tuple: 

307 self.checker.check_constraint(constraint_obj, table_obj) 

308 constraint_type = constraint_obj["@type"] 

309 key = None 

310 key_columns = [] 

311 if constraint_type == "ForeignKey": 

312 constraint_name = constraint_obj["name"] 

313 description = constraint_obj.get("description") 

314 utype = constraint_obj.get("votable:utype") 

315 

316 columns = [self.graph_index[c_id] for c_id in constraint_obj.get("columns", [])] 

317 refcolumns = [self.graph_index[c_id] for c_id in constraint_obj.get("referencedColumns", [])] 

318 

319 table_name = None 

320 for column in columns: 

321 if not table_name: 

322 table_name = column.table_name 

323 if table_name != column.table_name: 

324 raise ValueError("Inconsisent use of table names") 

325 

326 table_name = None 

327 for column in refcolumns: 

328 if not table_name: 

329 table_name = column.table_name 

330 if table_name != column.table_name: 

331 raise ValueError("Inconsisent use of table names") 

332 first_column = columns[0] 

333 first_refcolumn = refcolumns[0] 

334 

335 key = self.tables["keys"]() 

336 key.key_id = constraint_name 

337 key.from_table = first_column.table_name 

338 key.target_table = first_refcolumn.table_name 

339 key.description = description 

340 key.utype = utype 

341 for column, refcolumn in zip(columns, refcolumns): 

342 key_column = self.tables["key_columns"]() 

343 key_column.key_id = constraint_name 

344 key_column.from_column = column.column_name 

345 key_column.target_column = refcolumn.column_name 

346 key_columns.append(key_column) 

347 return key, key_columns 

348 

349 def visit_index(self, index_obj: _Mapping, table_obj: _Mapping) -> None: 

350 self.checker.check_index(index_obj, table_obj) 

351 columns = [self.graph_index[c_id] for c_id in index_obj.get("columns", [])] 

352 # if just one column and it's indexed, update the object 

353 if len(columns) == 1: 

354 columns[0].indexed = 1 

355 return None 

356 

357 def _schema_name(self, schema_name: Optional[str] = None) -> Optional[str]: 

358 # If _schema_name is None, SQLAlchemy will catch it 

359 _schema_name = schema_name or self.schema_name 

360 if self.catalog_name and _schema_name: 

361 return ".".join([self.catalog_name, _schema_name]) 

362 return _schema_name 

363 

364 def _table_name(self, table_name: str) -> str: 

365 schema_name = self._schema_name() 

366 if schema_name: 

367 return ".".join([schema_name, table_name]) 

368 return table_name 

369 

370 

371def _insert(table: Tap11Base, value: Any) -> Insert: 

372 """ 

373 Return a SQLAlchemy insert statement based on 

374 :param table: The table we are inserting to 

375 :param value: An object representing the object we are inserting 

376 to the table 

377 :return: A SQLAlchemy insert statement 

378 """ 

379 values_dict = {} 

380 for i in table.__table__.columns: 

381 name = i.name 

382 column_value = getattr(value, i.name) 

383 if type(column_value) == str: 

384 column_value = column_value.replace("'", "''") 

385 values_dict[name] = column_value 

386 return insert(table).values(values_dict)