Coverage for python/felis/tap.py: 11%

246 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-01-14 02:21 -0800

1# This file is part of felis. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["Tap11Base", "TapLoadingVisitor", "init_tables"] 

25 

26import logging 

27from collections.abc import Iterable, Mapping, MutableMapping 

28from typing import Any, Optional, Union 

29 

30from sqlalchemy import Column, Integer, String 

31from sqlalchemy.engine import Engine 

32from sqlalchemy.ext.declarative import declarative_base 

33from sqlalchemy.orm import DeclarativeMeta, Session, sessionmaker 

34from sqlalchemy.schema import MetaData 

35from sqlalchemy.sql.expression import Insert, insert 

36 

37from .check import FelisValidator 

38from .types import FelisType 

39from .visitor import Visitor 

40 

41_Mapping = Mapping[str, Any] 

42 

43Tap11Base: DeclarativeMeta = declarative_base() 

44logger = logging.getLogger("felis") 

45 

46IDENTIFIER_LENGTH = 128 

47SMALL_FIELD_LENGTH = 32 

48SIMPLE_FIELD_LENGTH = 128 

49TEXT_FIELD_LENGTH = 2048 

50QUALIFIED_TABLE_LENGTH = 3 * IDENTIFIER_LENGTH + 2 

51 

52_init_table_once = False 

53 

54 

55def init_tables( 

56 tap_schema_name: Optional[str] = None, 

57 tap_tables_postfix: Optional[str] = None, 

58 tap_schemas_table: Optional[str] = None, 

59 tap_tables_table: Optional[str] = None, 

60 tap_columns_table: Optional[str] = None, 

61 tap_keys_table: Optional[str] = None, 

62 tap_key_columns_table: Optional[str] = None, 

63) -> MutableMapping[str, Any]: 

64 postfix = tap_tables_postfix or "" 

65 

66 # Dirty hack to enable this method to be called more than once, replaces 

67 # MetaData instance with a fresh copy if called more than once. 

68 # TODO: probably replace ORM stuff with core sqlalchemy functions. 

69 global _init_table_once 

70 if not _init_table_once: 

71 _init_table_once = True 

72 else: 

73 Tap11Base.metadata = MetaData() 

74 

75 if tap_schema_name: 

76 Tap11Base.metadata.schema = tap_schema_name 

77 

78 class Tap11Schemas(Tap11Base): 

79 __tablename__ = (tap_schemas_table or "schemas") + postfix 

80 schema_name = Column(String(IDENTIFIER_LENGTH), primary_key=True, nullable=False) 

81 utype = Column(String(SIMPLE_FIELD_LENGTH)) 

82 description = Column(String(TEXT_FIELD_LENGTH)) 

83 schema_index = Column(Integer) 

84 

85 class Tap11Tables(Tap11Base): 

86 __tablename__ = (tap_tables_table or "tables") + postfix 

87 schema_name = Column(String(IDENTIFIER_LENGTH), nullable=False) 

88 table_name = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False, primary_key=True) 

89 table_type = Column(String(SMALL_FIELD_LENGTH), nullable=False) 

90 utype = Column(String(SIMPLE_FIELD_LENGTH)) 

91 description = Column(String(TEXT_FIELD_LENGTH)) 

92 table_index = Column(Integer) 

93 

94 class Tap11Columns(Tap11Base): 

95 __tablename__ = (tap_columns_table or "columns") + postfix 

96 table_name = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False, primary_key=True) 

97 column_name = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True) 

98 datatype = Column(String(SIMPLE_FIELD_LENGTH), nullable=False) 

99 arraysize = Column(String(10)) 

100 xtype = Column(String(SIMPLE_FIELD_LENGTH)) 

101 # Size is deprecated 

102 # size = Column(Integer(), quote=True) 

103 description = Column(String(TEXT_FIELD_LENGTH)) 

104 utype = Column(String(SIMPLE_FIELD_LENGTH)) 

105 unit = Column(String(SIMPLE_FIELD_LENGTH)) 

106 ucd = Column(String(SIMPLE_FIELD_LENGTH)) 

107 indexed = Column(Integer, nullable=False) 

108 principal = Column(Integer, nullable=False) 

109 std = Column(Integer, nullable=False) 

110 column_index = Column(Integer) 

111 

112 class Tap11Keys(Tap11Base): 

113 __tablename__ = (tap_keys_table or "keys") + postfix 

114 key_id = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True) 

115 from_table = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False) 

116 target_table = Column(String(QUALIFIED_TABLE_LENGTH), nullable=False) 

117 description = Column(String(TEXT_FIELD_LENGTH)) 

118 utype = Column(String(SIMPLE_FIELD_LENGTH)) 

119 

120 class Tap11KeyColumns(Tap11Base): 

121 __tablename__ = (tap_key_columns_table or "key_columns") + postfix 

122 key_id = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True) 

123 from_column = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True) 

124 target_column = Column(String(IDENTIFIER_LENGTH), nullable=False, primary_key=True) 

125 

126 return dict( 

127 schemas=Tap11Schemas, 

128 tables=Tap11Tables, 

129 columns=Tap11Columns, 

130 keys=Tap11Keys, 

131 key_columns=Tap11KeyColumns, 

132 ) 

133 

134 

135class TapLoadingVisitor(Visitor[None, tuple, Tap11Base, None, tuple, None]): 

136 def __init__( 

137 self, 

138 engine: Engine, 

139 catalog_name: Optional[str] = None, 

140 schema_name: Optional[str] = None, 

141 mock: bool = False, 

142 tap_tables: Optional[MutableMapping[str, Any]] = None, 

143 ): 

144 self.graph_index: MutableMapping[str, Any] = {} 

145 self.catalog_name = catalog_name 

146 self.schema_name = schema_name 

147 self.engine = engine 

148 self.mock = mock 

149 self.tables = tap_tables or init_tables() 

150 self.checker = FelisValidator() 

151 

152 def visit_schema(self, schema_obj: _Mapping) -> None: 

153 self.checker.check_schema(schema_obj) 

154 schema = self.tables["schemas"]() 

155 # Override with default 

156 self.schema_name = self.schema_name or schema_obj["name"] 

157 

158 schema.schema_name = self._schema_name() 

159 schema.description = schema_obj.get("description") 

160 schema.utype = schema_obj.get("votable:utype") 

161 schema.schema_index = int(schema_obj.get("tap:schema_index", 0)) 

162 

163 if not self.mock: 

164 session: Session = sessionmaker(self.engine)() 

165 session.add(schema) 

166 for table_obj in schema_obj["tables"]: 

167 table, columns, keys, key_columns = self.visit_table(table_obj, schema_obj) 

168 session.add(table) 

169 session.add_all(columns) 

170 session.add_all(keys) 

171 session.add_all(key_columns) 

172 session.commit() 

173 else: 

174 # Only if we are mocking (dry run) 

175 conn = self.engine 

176 conn.execute(_insert(self.tables["schemas"], schema)) 

177 for table_obj in schema_obj["tables"]: 

178 table, columns, keys, key_columns = self.visit_table(table_obj, schema_obj) 

179 conn.execute(_insert(self.tables["tables"], table)) 

180 for column in columns: 

181 conn.execute(_insert(self.tables["columns"], column)) 

182 for key in keys: 

183 conn.execute(_insert(self.tables["keys"], key)) 

184 for key_column in key_columns: 

185 conn.execute(_insert(self.tables["key_columns"], key_column)) 

186 

187 def visit_table(self, table_obj: _Mapping, schema_obj: _Mapping) -> tuple: 

188 self.checker.check_table(table_obj, schema_obj) 

189 table_id = table_obj["@id"] 

190 table = self.tables["tables"]() 

191 table.schema_name = self._schema_name() 

192 table.table_name = self._table_name(table_obj["name"]) 

193 table.table_type = "table" 

194 table.utype = table_obj.get("votable:utype") 

195 table.description = table_obj.get("description") 

196 table.table_index = int(table_obj.get("tap:table_index", 0)) 

197 

198 columns = [self.visit_column(c, table_obj) for c in table_obj["columns"]] 

199 self.visit_primary_key(table_obj.get("primaryKey", []), table_obj) 

200 all_keys = [] 

201 all_key_columns = [] 

202 for c in table_obj.get("constraints", []): 

203 key, key_columns = self.visit_constraint(c, table) 

204 if not key: 

205 continue 

206 all_keys.append(key) 

207 all_key_columns += key_columns 

208 

209 for i in table_obj.get("indexes", []): 

210 self.visit_index(i, table) 

211 

212 self.graph_index[table_id] = table 

213 return table, columns, all_keys, all_key_columns 

214 

215 def check_column(self, column_obj: _Mapping, table_obj: _Mapping) -> None: 

216 self.checker.check_column(column_obj, table_obj) 

217 _id = column_obj["@id"] 

218 # Guaranteed to exist at this point, for mypy use "" as default 

219 datatype_name = column_obj.get("datatype", "") 

220 felis_type = FelisType.felis_type(datatype_name) 

221 if felis_type.is_sized: 

222 # It is expected that both arraysize and length are fine for 

223 # length types. 

224 arraysize = column_obj.get("votable:arraysize", column_obj.get("length")) 

225 if arraysize is None: 

226 logger.warning( 

227 f"votable:arraysize and length for {_id} are None for type {datatype_name}. " 

228 'Using length "*". ' 

229 "Consider setting `votable:arraysize` or `length`." 

230 ) 

231 if felis_type.is_timestamp: 

232 # datetime types really should have a votable:arraysize, because 

233 # they are converted to strings and the `length` is loosely to the 

234 # string size 

235 if "votable:arraysize" not in column_obj: 

236 logger.warning( 

237 f"votable:arraysize for {_id} is None for type {datatype_name}. " 

238 f'Using length "*". ' 

239 "Consider setting `votable:arraysize` to an appropriate size for " 

240 "materialized datetime/timestamp strings." 

241 ) 

242 

243 def visit_column(self, column_obj: _Mapping, table_obj: _Mapping) -> Tap11Base: 

244 self.check_column(column_obj, table_obj) 

245 column_id = column_obj["@id"] 

246 table_name = self._table_name(table_obj["name"]) 

247 

248 column = self.tables["columns"]() 

249 column.table_name = table_name 

250 column.column_name = column_obj["name"] 

251 

252 felis_datatype = column_obj["datatype"] 

253 felis_type = FelisType.felis_type(felis_datatype) 

254 column.datatype = column_obj.get("votable:datatype", felis_type.votable_name) 

255 

256 arraysize = None 

257 if felis_type.is_sized: 

258 # prefer votable:arraysize to length, fall back to `*` 

259 arraysize = column_obj.get("votable:arraysize", column_obj.get("length", "*")) 

260 if felis_type.is_timestamp: 

261 arraysize = column_obj.get("votable:arraysize", "*") 

262 column.arraysize = arraysize 

263 

264 column.xtype = column_obj.get("votable:xtype") 

265 column.description = column_obj.get("description") 

266 column.utype = column_obj.get("votable:utype") 

267 

268 unit = column_obj.get("ivoa:unit") or column_obj.get("fits:tunit") 

269 column.unit = unit 

270 column.ucd = column_obj.get("ivoa:ucd") 

271 

272 # We modify this after we process columns 

273 column.indexed = 0 

274 

275 column.principal = column_obj.get("tap:principal", 0) 

276 column.std = column_obj.get("tap:std", 0) 

277 column.column_index = column_obj.get("tap:column_index") 

278 

279 self.graph_index[column_id] = column 

280 return column 

281 

282 def visit_primary_key(self, primary_key_obj: Union[str, Iterable[str]], table_obj: _Mapping) -> None: 

283 self.checker.check_primary_key(primary_key_obj, table_obj) 

284 if primary_key_obj: 

285 if isinstance(primary_key_obj, str): 

286 primary_key_obj = [primary_key_obj] 

287 columns = [self.graph_index[c_id] for c_id in primary_key_obj] 

288 # if just one column and it's indexed, update the object 

289 if len(columns) == 1: 

290 columns[0].indexed = 1 

291 return None 

292 

293 def visit_constraint(self, constraint_obj: _Mapping, table_obj: _Mapping) -> tuple: 

294 self.checker.check_constraint(constraint_obj, table_obj) 

295 constraint_type = constraint_obj["@type"] 

296 key = None 

297 key_columns = [] 

298 if constraint_type == "ForeignKey": 

299 constraint_name = constraint_obj["name"] 

300 description = constraint_obj.get("description") 

301 utype = constraint_obj.get("votable:utype") 

302 

303 columns = [self.graph_index[c_id] for c_id in constraint_obj.get("columns", [])] 

304 refcolumns = [self.graph_index[c_id] for c_id in constraint_obj.get("referencedColumns", [])] 

305 

306 table_name = None 

307 for column in columns: 

308 if not table_name: 

309 table_name = column.table_name 

310 if table_name != column.table_name: 

311 raise ValueError("Inconsisent use of table names") 

312 

313 table_name = None 

314 for column in refcolumns: 

315 if not table_name: 

316 table_name = column.table_name 

317 if table_name != column.table_name: 

318 raise ValueError("Inconsisent use of table names") 

319 first_column = columns[0] 

320 first_refcolumn = refcolumns[0] 

321 

322 key = self.tables["keys"]() 

323 key.key_id = constraint_name 

324 key.from_table = first_column.table_name 

325 key.target_table = first_refcolumn.table_name 

326 key.description = description 

327 key.utype = utype 

328 for column, refcolumn in zip(columns, refcolumns): 

329 key_column = self.tables["key_columns"]() 

330 key_column.key_id = constraint_name 

331 key_column.from_column = column.column_name 

332 key_column.target_column = refcolumn.column_name 

333 key_columns.append(key_column) 

334 return key, key_columns 

335 

336 def visit_index(self, index_obj: _Mapping, table_obj: _Mapping) -> None: 

337 self.checker.check_index(index_obj, table_obj) 

338 columns = [self.graph_index[c_id] for c_id in index_obj.get("columns", [])] 

339 # if just one column and it's indexed, update the object 

340 if len(columns) == 1: 

341 columns[0].indexed = 1 

342 return None 

343 

344 def _schema_name(self, schema_name: Optional[str] = None) -> Optional[str]: 

345 # If _schema_name is None, SQLAlchemy will catch it 

346 _schema_name = schema_name or self.schema_name 

347 if self.catalog_name and _schema_name: 

348 return ".".join([self.catalog_name, _schema_name]) 

349 return _schema_name 

350 

351 def _table_name(self, table_name: str) -> str: 

352 schema_name = self._schema_name() 

353 if schema_name: 

354 return ".".join([schema_name, table_name]) 

355 return table_name 

356 

357 

358def _insert(table: Tap11Base, value: Any) -> Insert: 

359 """ 

360 Return a SQLAlchemy insert statement based on 

361 :param table: The table we are inserting to 

362 :param value: An object representing the object we are inserting 

363 to the table 

364 :return: A SQLAlchemy insert statement 

365 """ 

366 values_dict = {} 

367 for i in table.__table__.columns: 

368 name = i.name 

369 column_value = getattr(value, i.name) 

370 if type(column_value) == str: 

371 column_value = column_value.replace("'", "''") 

372 values_dict[name] = column_value 

373 return insert(table, values=values_dict)