Coverage for python/lsst/dax/apdb/apdbSqlSchema.py : 17%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of dax_apdb.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Module responsible for APDB schema operations.
23"""
25from __future__ import annotations
27__all__ = ["ColumnDef", "IndexDef", "TableDef", "ApdbSqlSchema"]
29import logging
30import os
31from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Type
32import yaml
34import sqlalchemy
35from sqlalchemy import (Column, Index, MetaData, PrimaryKeyConstraint,
36 UniqueConstraint, Table)
39_LOG = logging.getLogger(__name__)
42class ColumnDef(NamedTuple):
43 """Column representation in schema.
44 """
45 name: str
46 """column name"""
47 type: str
48 """name of cat type (INT, FLOAT, etc.)"""
49 nullable: bool
50 """True for nullable columns"""
51 default: Any
52 """default value for column, can be None"""
53 description: Optional[str]
54 """documentation, can be None or empty"""
55 unit: Optional[str]
56 """string with unit name, can be None"""
57 ucd: Optional[str]
58 """string with ucd, can be None"""
61class IndexDef(NamedTuple):
62 """Index description.
63 """
64 name: str
65 """index name, can be empty"""
66 type: str
67 """one of "PRIMARY", "UNIQUE", "INDEX"
68 """
69 columns: List[str]
70 """list of column names in index"""
73class TableDef(NamedTuple):
74 """Table description
75 """
76 name: str
77 """table name"""
78 description: Optional[str]
79 """documentation, can be None or empty"""
80 columns: List[ColumnDef]
81 """list of ColumnDef instances"""
82 indices: List[IndexDef]
83 """list of IndexDef instances, can be empty"""
86class ApdbSqlSchema(object):
87 """Class for management of APDB schema.
89 Attributes
90 ----------
91 objects : `sqlalchemy.Table`
92 DiaObject table instance
93 objects_last : `sqlalchemy.Table`
94 DiaObjectLast table instance, may be None
95 sources : `sqlalchemy.Table`
96 DiaSource table instance
97 forcedSources : `sqlalchemy.Table`
98 DiaForcedSource table instance
100 Parameters
101 ----------
102 engine : `sqlalchemy.engine.Engine`
103 SQLAlchemy engine instance
104 dia_object_index : `str`
105 Indexing mode for DiaObject table, see `ApdbConfig.dia_object_index`
106 for details.
107 schema_file : `str`
108 Name of the YAML schema file.
109 extra_schema_file : `str`, optional
110 Name of the YAML schema file with extra column definitions.
111 prefix : `str`, optional
112 Prefix to add to all scheam elements.
113 """
114 def __init__(self, engine: sqlalchemy.engine.Engine, dia_object_index: str,
115 schema_file: str, extra_schema_file: Optional[str] = None, prefix: str = ""):
117 self._engine = engine
118 self._dia_object_index = dia_object_index
119 self._prefix = prefix
121 self._metadata = MetaData(self._engine)
123 self.objects = None
124 self.objects_last = None
125 self.sources = None
126 self.forcedSources = None
128 # build complete table schema
129 self._schemas = self._buildSchemas(schema_file, extra_schema_file)
131 # map cat column types to alchemy
132 self._type_map = dict(DOUBLE=self._getDoubleType(engine),
133 FLOAT=sqlalchemy.types.Float,
134 DATETIME=sqlalchemy.types.TIMESTAMP,
135 BIGINT=sqlalchemy.types.BigInteger,
136 INTEGER=sqlalchemy.types.Integer,
137 INT=sqlalchemy.types.Integer,
138 TINYINT=sqlalchemy.types.Integer,
139 BLOB=sqlalchemy.types.LargeBinary,
140 CHAR=sqlalchemy.types.CHAR,
141 BOOL=sqlalchemy.types.Boolean)
143 # generate schema for all tables, must be called last
144 self._makeTables()
146 def _makeTables(self, mysql_engine: str = 'InnoDB') -> None:
147 """Generate schema for all tables.
149 Parameters
150 ----------
151 mysql_engine : `str`, optional
152 MySQL engine type to use for new tables.
153 """
155 info: Dict[str, Any] = {}
157 if self._dia_object_index == 'pix_id_iov':
158 # Special PK with HTM column in first position
159 constraints = self._tableIndices('DiaObjectIndexHtmFirst', info)
160 else:
161 constraints = self._tableIndices('DiaObject', info)
162 table = Table(self._prefix+'DiaObject', self._metadata,
163 *(self._tableColumns('DiaObject') + constraints),
164 mysql_engine=mysql_engine,
165 info=info)
166 self.objects = table
168 if self._dia_object_index == 'last_object_table':
169 # Same as DiaObject but with special index
170 table = Table(self._prefix+'DiaObjectLast', self._metadata,
171 *(self._tableColumns('DiaObjectLast')
172 + self._tableIndices('DiaObjectLast', info)),
173 mysql_engine=mysql_engine,
174 info=info)
175 self.objects_last = table
177 # for all other tables use index definitions in schema
178 for table_name in ('DiaSource', 'SSObject', 'DiaForcedSource', 'DiaObject_To_Object_Match'):
179 table = Table(self._prefix+table_name, self._metadata,
180 *(self._tableColumns(table_name)
181 + self._tableIndices(table_name, info)),
182 mysql_engine=mysql_engine,
183 info=info)
184 if table_name == 'DiaSource':
185 self.sources = table
186 elif table_name == 'DiaForcedSource':
187 self.forcedSources = table
189 def makeSchema(self, drop: bool = False, mysql_engine: str = 'InnoDB') -> None:
190 """Create or re-create all tables.
192 Parameters
193 ----------
194 drop : `bool`, optional
195 If True then drop tables before creating new ones.
196 mysql_engine : `str`, optional
197 MySQL engine type to use for new tables.
198 """
200 # re-make table schema for all needed tables with possibly different options
201 _LOG.debug("clear metadata")
202 self._metadata.clear()
203 _LOG.debug("re-do schema mysql_engine=%r", mysql_engine)
204 self._makeTables(mysql_engine=mysql_engine)
206 # create all tables (optionally drop first)
207 if drop:
208 _LOG.info('dropping all tables')
209 self._metadata.drop_all()
210 _LOG.info('creating all tables')
211 self._metadata.create_all()
213 def _buildSchemas(self, schema_file: str, extra_schema_file: Optional[str] = None,
214 ) -> Mapping[str, TableDef]:
215 """Create schema definitions for all tables.
217 Reads YAML schemas and builds dictionary containing `TableDef`
218 instances for each table.
220 Parameters
221 ----------
222 schema_file : `str`
223 Name of YAML file with standard cat schema.
224 extra_schema_file : `str`, optional
225 Name of YAML file with extra table information or `None`.
227 Returns
228 -------
229 schemas : `dict`
230 Mapping of table names to `TableDef` instances.
231 """
233 schema_file = os.path.expandvars(schema_file)
234 _LOG.debug("Reading schema file %s", schema_file)
235 with open(schema_file) as yaml_stream:
236 tables = list(yaml.load_all(yaml_stream, Loader=yaml.SafeLoader))
237 # index it by table name
238 _LOG.debug("Read %d tables from schema", len(tables))
240 if extra_schema_file:
241 extra_schema_file = os.path.expandvars(extra_schema_file)
242 _LOG.debug("Reading extra schema file %s", extra_schema_file)
243 with open(extra_schema_file) as yaml_stream:
244 extras = list(yaml.load_all(yaml_stream, Loader=yaml.SafeLoader))
245 # index it by table name
246 schemas_extra = {table['table']: table for table in extras}
247 else:
248 schemas_extra = {}
250 # merge extra schema into a regular schema, for now only columns are merged
251 for table in tables:
252 table_name = table['table']
253 if table_name in schemas_extra:
254 columns = table['columns']
255 extra_columns = schemas_extra[table_name].get('columns', [])
256 extra_columns = {col['name']: col for col in extra_columns}
257 _LOG.debug("Extra columns for table %s: %s", table_name, extra_columns.keys())
258 columns = []
259 for col in table['columns']:
260 if col['name'] in extra_columns:
261 columns.append(extra_columns.pop(col['name']))
262 else:
263 columns.append(col)
264 # add all remaining extra columns
265 table['columns'] = columns + list(extra_columns.values())
267 if 'indices' in schemas_extra[table_name]:
268 raise RuntimeError("Extra table definition contains indices, "
269 "merging is not implemented")
271 del schemas_extra[table_name]
273 # Pure "extra" table definitions may contain indices
274 tables += schemas_extra.values()
276 # convert all dicts into named tuples
277 schemas = {}
278 for table in tables:
280 columns = table.get('columns', [])
282 table_name = table['table']
284 table_columns = []
285 for col in columns:
286 # For prototype set default to 0 even if columns don't specify it
287 if "default" not in col:
288 default = None
289 if col['type'] not in ("BLOB", "DATETIME"):
290 default = 0
291 else:
292 default = col["default"]
294 column = ColumnDef(name=col['name'],
295 type=col['type'],
296 nullable=col.get("nullable"),
297 default=default,
298 description=col.get("description"),
299 unit=col.get("unit"),
300 ucd=col.get("ucd"))
301 table_columns.append(column)
303 table_indices = []
304 for idx in table.get('indices', []):
305 index = IndexDef(name=idx.get('name'),
306 type=idx.get('type'),
307 columns=idx.get('columns'))
308 table_indices.append(index)
310 schemas[table_name] = TableDef(name=table_name,
311 description=table.get('description'),
312 columns=table_columns,
313 indices=table_indices)
315 return schemas
317 def _tableColumns(self, table_name: str) -> List[Column]:
318 """Return set of columns in a table
320 Parameters
321 ----------
322 table_name : `str`
323 Name of the table.
325 Returns
326 -------
327 column_defs : `list`
328 List of `Column` objects.
329 """
331 # get the list of columns in primary key, they are treated somewhat
332 # specially below
333 table_schema = self._schemas[table_name]
334 pkey_columns = set()
335 for index in table_schema.indices:
336 if index.type == 'PRIMARY':
337 pkey_columns = set(index.columns)
338 break
340 # convert all column dicts into alchemy Columns
341 column_defs = []
342 for column in table_schema.columns:
343 kwargs: Dict[str, Any] = dict(nullable=column.nullable)
344 if column.default is not None:
345 kwargs.update(server_default=str(column.default))
346 if column.name in pkey_columns:
347 kwargs.update(autoincrement=False)
348 ctype = self._type_map[column.type]
349 column_defs.append(Column(column.name, ctype, **kwargs))
351 return column_defs
353 def _tableIndices(self, table_name: str, info: Dict) -> List[sqlalchemy.schema.Constraint]:
354 """Return set of constraints/indices in a table
356 Parameters
357 ----------
358 table_name : `str`
359 Name of the table.
360 info : `dict`
361 Additional options passed to SQLAlchemy index constructor.
363 Returns
364 -------
365 index_defs : `list`
366 List of SQLAlchemy index/constraint objects.
367 """
369 table_schema = self._schemas[table_name]
371 # convert all index dicts into alchemy Columns
372 index_defs: List[sqlalchemy.schema.Constraint] = []
373 for index in table_schema.indices:
374 if index.type == "INDEX":
375 index_defs.append(Index(self._prefix + index.name, *index.columns, info=info))
376 else:
377 kwargs = {}
378 if index.name:
379 kwargs['name'] = self._prefix+index.name
380 if index.type == "PRIMARY":
381 index_defs.append(PrimaryKeyConstraint(*index.columns, **kwargs))
382 elif index.type == "UNIQUE":
383 index_defs.append(UniqueConstraint(*index.columns, **kwargs))
385 return index_defs
387 @classmethod
388 def _getDoubleType(cls, engine: sqlalchemy.engine.Engine) -> Type:
389 """DOUBLE type is database-specific, select one based on dialect.
391 Parameters
392 ----------
393 engine : `sqlalchemy.engine.Engine`
394 Database engine.
396 Returns
397 -------
398 type_object : `object`
399 Database-specific type definition.
400 """
401 if engine.name == 'mysql':
402 from sqlalchemy.dialects.mysql import DOUBLE
403 return DOUBLE(asdecimal=False)
404 elif engine.name == 'postgresql':
405 from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION
406 return DOUBLE_PRECISION
407 elif engine.name == 'oracle':
408 from sqlalchemy.dialects.oracle import DOUBLE_PRECISION
409 return DOUBLE_PRECISION
410 elif engine.name == 'sqlite':
411 # all floats in sqlite are 8-byte
412 from sqlalchemy.dialects.sqlite import REAL
413 return REAL
414 else:
415 raise TypeError('cannot determine DOUBLE type, unexpected dialect: ' + engine.name)