Coverage for python/lsst/daf/butler/registries/sqlPreFlight.py : 7%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
# This file is part of daf_butler. # # Developed for the LSST Data Management System. # This product includes software developed by the LSST Project # (http://www.lsst.org). # See the COPYRIGHT file at the top-level directory of this distribution # for details of code ownership. # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Type used as a sentinal value to indicate when a dataset ID was not queried for, as opposed to not found."""
"""Implement TreeVisitor to convert user expression into SQLAlchemy clause.
Parameters ---------- tables : `dict` Mapping of table names to `sqlalchemy.Table` instances. dimensions : `DimensionGraph` All Dimensions included in the query. """
"+": lambda x: +x, "-": lambda x: -x} """Mapping or unary operator names to corresponding functions"""
"AND": lambda x, y: and_(x, y), "=": lambda x, y: x == y, "!=": lambda x, y: x != y, "<": lambda x, y: x < y, "<=": lambda x, y: x <= y, ">": lambda x, y: x > y, ">=": lambda x, y: x >= y, "+": lambda x, y: x + y, "-": lambda x, y: x - y, "*": lambda x, y: x * y, "/": lambda x, y: x / y, "%": lambda x, y: x % y} """Mapping or binary operator names to corresponding functions"""
self.tables = tables self.dimensions = dimensions
# Docstring inherited from TreeVisitor.visitNumericLiteral # Convert string value into float or int if value.isdigit(): value = int(value) else: value = float(value) return literal(value)
# Docstring inherited from TreeVisitor.visitStringLiteral return literal(value)
# Docstring inherited from TreeVisitor.visitIdentifier table, sep, column = name.partition('.') if column: return self.tables[table].columns[column] else: link = table for dim in self.dimensions: if link in dim.links(): return self.tables[dim.name].columns[link] # can't find the link raise ValueError(f"Link name `{link}' is not in the dimensions for this query.")
# Docstring inherited from TreeVisitor.visitUnaryOp func = self.unaryOps.get(operator) if func: return func(operand) else: raise ValueError(f"Unexpected unary operator `{operator}' in `{node}'.")
# Docstring inherited from TreeVisitor.visitBinaryOp func = self.binaryOps.get(operator) if func: return func(lhs, rhs) else: raise ValueError(f"Unexpected binary operator `{operator}' in `{node}'.")
# Docstring inherited from TreeVisitor.visitIsIn if not_in: return lhs.notin_(values) else: return lhs.in_(values)
# Docstring inherited from TreeVisitor.visitParens return expression.self_group()
"""Class implementing part of preflight solver which extracts dimension data from registry.
This is an implementation detail only to be used by SqlRegistry class, not supposed to be used anywhere else.
Parameters ---------- registry : `SqlRegistry`` Registry instance originInfo : `DatasetOriginInfo` Object which provides names of the input/output collections. neededDatasetTypes : `list` of `DatasetType` The `list` of `DatasetTypes <DatasetType>` whose Dimensions will be included in the returned column set. Output is limited to the the Datasets of these DatasetTypes which already exist in the registry. futureDatasetTypes : `list` of `DatasetType` The `list` of `DatasetTypes <DatasetType>` whose Dimensions will be included in the returned column set. expandDataIds : `bool` If `True` (default), expand all data IDs when returning them. deferOutputIdQueries: `bool` If `True`, do not include subqueries for preexisting output datasets in the main initial query that constrains the results, and instead query for them one-by-one when processing the results of the initial query. This option should be used when keeping the number of joins in the query low is important (i.e. for SQLite, which has a maximum of 64 joins in a single query). """ expandDataIds=True, deferOutputIdQueries=False): self.registry = registry # Make a copy of the tables in the schema so we can modify it to fake # nonexistent tables without modifying registry state. self.tables = self.registry._schema.tables.copy() self.originInfo = originInfo self.neededDatasetTypes = neededDatasetTypes self.futureDatasetTypes = futureDatasetTypes self.expandDataIds = expandDataIds self.deferOutputIdQueries = deferOutputIdQueries
"""Add new table for join clause.
Assumption here is that this Dimension table has a foreign key to all other tables and names of columns are the same in both tables, so we just get primary key columns from other tables and join on them.
Parameters ---------- fromClause : `sqlalchemy.FromClause` May be `None`, in that case ``otherDimensions`` is expected to be empty and is ignored. dimension : `DimensionElement` `Dimension` or `DimensionJoin` to join with ``fromClause``. otherDimensions : iterable of `Dimension` Dimensions whose tables have PKs for ``dimension`` table's FK. These must be in ``fromClause`` already.
Returns ------- fromClause : `sqlalchemy.FromClause` SQLAlchemy FROM clause extended with new join. """ if fromClause is None: # starting point, first table in JOIN return self.tables[dimension.name] else: joinOn = [] for otherDimension in otherDimensions: primaryKeyColumns = {name: self.tables[otherDimension.name].c[name] for name in otherDimension.links()} for name, col in primaryKeyColumns.items(): joinOn.append(self.tables[dimension.name].c[name] == col) _LOG.debug("join %s with %s on columns %s", dimension.name, dimension.name, list(primaryKeyColumns.keys())) if joinOn: return fromClause.join(self.tables[dimension.name], and_(*joinOn)) else: # Completely unrelated tables, e.g. joining SkyMap and # Instrument. # We need a cross join here but SQLAlchemy does not have # specific method for that. Using join() without `onclause` # will try to join on FK and will raise an exception for # unrelated tables, so we have to use `onclause` which is # always true. return fromClause.join(self.tables[dimension.name], literal(True))
"""Evaluate a filter expression and lists of `DatasetTypes <DatasetType>` and return a set of dimension values.
Returned set consists of combinations of dimensions participating in data transformation from ``neededDatasetTypes`` to ``futureDatasetTypes``, restricted by existing data and filter expression.
Parameters ---------- expression : `str`, optional An expression that limits the `Dimensions <Dimension>` and (indirectly) the Datasets returned.
Yields ------ row : `PreFlightDimensionsRow` Single row is a unique combination of dimensions in a transform. """ # parse expression, can raise on errors try: parser = ParserYacc() expression = parser.parse(expression or "") except Exception as exc: raise ValueError(f"Failed to parse user expression `{expression}'") from exc
# Brief overview of the code below: # - extract all Dimensions used by all input/output dataset types # - build a complex SQL query to run against registry database: # - first do (natural) join for all tables for all Dimensions # involved based on their foreign keys # - then add Join tables to the mix, only use Join tables which # have their lhs/rhs links in the above Dimensions set, also # ignore Joins which summarize other Joins # - next join with Dataset for each input dataset type, this # limits result only to existing input dataset # - also do outer join with Dataset for each output dataset type # to see which output datasets are already there # - append user filter expression # - query returns all Dimension values, regions for region-based # joins, and dataset IDs for all existing datasets # - run this query # - filter out records whose regions do not overlap # - return result as iterator of records containing Dimension values
# Collect dimensions from both input and output dataset types dimensions = self.registry.dimensions.extract( itertools.chain( itertools.chain.from_iterable(dsType.dimensions.names for dsType in self.neededDatasetTypes), itertools.chain.from_iterable(dsType.dimensions.names for dsType in self.futureDatasetTypes), ) ) _LOG.debug("dimensions: %s", dimensions)
def findSkyPixSubstitute(): # SkyPix doesn't have its own table; if it's included in the # dimensions we care about, find a join table that we'll # henceforth treat as the dimension table for SkyPix. Note # that there may be multiple SkyPix join tables, and we only # reserve one for this role, but we don't actually care which. for dimensionJoin in dimensions.joins(summaries=False): if "SkyPix" in dimensionJoin.dependencies(): _LOG.debug("Using %s as primary table for SkyPix", dimensionJoin.name) return dimensionJoin raise AssertionError("At least one SkyPix join should be present if SkyPix dimension is.")
# Build select column list selectColumns = [] linkColumnIndices = {} skyPixSubstitute = None for dimension in dimensions: if dimension.name == "SkyPix": # Find a SkyPix join table to use in place of the (nonexistent) # SkyPix table, by adding it to our copy of the dict of tables. skyPixSubstitute = findSkyPixSubstitute() self.tables[dimension.name] = self.tables[skyPixSubstitute.name] table = self.tables.get(dimension.name) if table is not None: # take link column names, usually there is one for link in dimension.links(expand=False): linkColumnIndices[link] = len(selectColumns) selectColumns.append(table.c[link])
_LOG.debug("selectColumns: %s", selectColumns) _LOG.debug("linkColumnIndices: %s", linkColumnIndices)
# Extend dimensions with the "implied" superset, so that joins work # correctly. This may bring more tables into query than really needed, # potential for optimization. dimensions = dimensions.union(dimensions.implied())
fromJoin = None for dimension in dimensions: _LOG.debug("processing Dimension: %s", dimension.name) if dimension.name == "SkyPix" and skyPixSubstitute is None: skyPixSubstitute = findSkyPixSubstitute() self.tables[dimension.name] = self.tables[skyPixSubstitute.name] if dimension.name in self.tables: fromJoin = self._joinOnForeignKey(fromJoin, dimension, dimension.dependencies(implied=True))
joinedRegionTables = set() regionColumnIndices = {} for dimensionJoin in dimensions.joins(summaries=False): if dimensionJoin == skyPixSubstitute: # this table has already been included continue _LOG.debug("processing DimensionJoin: %s", dimensionJoin.name) # Some `DimensionJoin`s have an associated region in that case # they shouldn't be joined separately in the region lookup. if dimensionJoin.hasRegion: _LOG.debug("%s has a region, skipping", dimensionJoin.name) continue
# Look at each side of the DimensionJoin and join it with # corresponding Dimension tables, including making all necessary # joins for special multi-Dimension region table(s). regionHolders = [] for connection in (dimensionJoin.lhs, dimensionJoin.rhs): graph = self.registry.dimensions.extract(connection) try: regionHolder = graph.getRegionHolder() except KeyError: # means there is no region for these dimensions, want to # skip it _LOG.debug("Dimensions %s are not spatial, skipping", connection) break if regionHolder.name == "SkyPix": # SkyPix regions are constructed in Python as needed, not # stored in the database. # Note that by the time we've processed both connections, # regionHolders should still be non-empty, since at least # one of the connections will be to something other than # SkyPix. _LOG.debug("Dimension is SkyMap, continuing.") continue if isinstance(regionHolder, DimensionJoin): # If one of the connections is with a DimensionJoin, then # it must be one with a region (and hence one we skip # in the outermost 'for' loop). # Bring that join in now, but (unlike the logic in the # outermost 'for' loop) bring the region along too. assert regionHolder.hasRegion, "Spatial join with a join that has no region." if regionHolder.name in joinedRegionTables: _LOG.debug("region table already joined: %s", regionHolder.name) else: _LOG.debug("joining region table: %s", regionHolder.name) joinedRegionTables.add(regionHolder.name)
fromJoin = self._joinOnForeignKey(fromJoin, regionHolder, connection)
# add to the list of tables this join table joins against regionHolders.append(regionHolder)
# We also have to include regions from each side of the join # into result set so that we can filter-out non-overlapping # regions. # Note that a region holder may have already appeared in this # loop because it's a connection of multiple different join # tables (e.g. Visit for both VisitTractJoin and # VisitSkyPixJoin). In that case we've already put its region # in the query output fields. if regionHolder.name not in regionColumnIndices: regionColumnIndices[regionHolder.name] = len(selectColumns) regionColumn = self.tables[regionHolder.name].c.region selectColumns.append(regionColumn)
if regionHolders: fromJoin = self._joinOnForeignKey(fromJoin, dimensionJoin, regionHolders)
_LOG.debug("selectColumns: %s", selectColumns) _LOG.debug("linkColumnIndices: %s", linkColumnIndices) _LOG.debug("regionColumnIndices: %s", regionColumnIndices)
# join with input datasets to restrict to existing inputs dsIdColumns = {} allDsTypes = [(dsType, False) for dsType in self.neededDatasetTypes] + \ [(dsType, True) for dsType in self.futureDatasetTypes] for dsType, isOutput in allDsTypes:
if isOutput and self.deferOutputIdQueries: _LOG.debug("deferring lookup for output dataset type: %s", dsType.name) dsIdColumns[dsType] = DATASET_ID_DEFERRED continue
# Build a sub-query. subquery = self._buildDatasetSubquery(dsType, isOutput) if subquery is None: # If there nothing to join (e.g. we know that output # collection is empty) then just pass None as column # index for this dataset type to the code below. _LOG.debug("nothing to join for %s dataset type: %s", "output" if isOutput else "input", dsType.name) dsIdColumns[dsType] = None continue
_LOG.debug("joining %s dataset type: %s", "output" if isOutput else "input", dsType.name)
# Join sub-query with all dimensions on their link names, # OUTER JOIN is used for output datasets (they don't usually exist) joinOn = [] for dimension in dsType.dimensions: if dimension.name == "ExposureRange": # very special handling of ExposureRange # TODO: try to generalize this in some way, maybe using # sql from ExposureRangeJoin _LOG.debug(" joining on dimension: %s", dimension.name) exposureTable = self.tables["Exposure"] joinOn.append(between(exposureTable.c.datetime_begin, subquery.c.valid_first, subquery.c.valid_last)) linkColumnIndices[dsType.name + ".valid_first"] = len(selectColumns) selectColumns.append(subquery.c.valid_first) linkColumnIndices[dsType.name + ".valid_last"] = len(selectColumns) selectColumns.append(subquery.c.valid_last) else: for link in dimension.links(): _LOG.debug(" joining on link: %s", link) joinOn.append(subquery.c[link] == self.tables[dimension.name].c[link]) fromJoin = fromJoin.join(subquery, and_(*joinOn), isouter=isOutput)
# remember dataset_id column index for this dataset dsIdColumns[dsType] = len(selectColumns) selectColumns.append(subquery.c.dataset_id)
# build full query q = select(selectColumns).select_from(fromJoin) if expression: visitor = _ClauseVisitor(self.tables, dimensions) where = expression.visit(visitor) _LOG.debug("full where: %s", where) q = q.where(where) _LOG.debug("full query: %s", q.compile(bind=self.registry._connection.engine, compile_kwargs={"literal_binds": True}))
# execute and return result iterator rows = self.registry._connection.execute(q).fetchall() return self._convertResultRows(rows, dimensions, linkColumnIndices, regionColumnIndices, dsIdColumns)
"""Build a sub-query for a dataset type to be joined with "big join".
If there is only one collection then there is a guarantee that DataIds are all unique (by DataId I mean combination of all link values relevant for this dataset), in that case subquery can be written as:
SELECT Dataset.dataset_id AS dataset_id, Dataset.link1 AS link1 ... FROM Dataset JOIN DatasetCollection ON Dataset.dataset_id = DatasetCollection.dataset_id WHERE Dataset.dataset_type_name = :dsType_name AND DatasetCollection.collection = :collection_name
We only have single collection for output DatasetTypes so for them subqueries always look like above.
If there are multiple collections then there can be multiple matching Datasets for the same DataId. In that case we need only one Dataset record which comes from earliest collection (in the user-provided order). Here things become complicated, we have to: - replace collection names with their order in input list - select all combinations of rows from Dataset and DatasetCollection which match collection names and dataset type name - from those only select rows with lowest collection position if there are multiple collections for the same DataId
Replacing collection names with positions is easy:
SELECT dataset_id, CASE collection WHEN 'collection1' THEN 0 WHEN 'collection2' THEN 1 ... END AS collorder FROM DatasetCollection
Combined query will look like (CASE ... END is as above):
SELECT Dataset.dataset_id AS dataset_id, CASE DatasetCollection.collection ... END AS collorder, Dataset.DataId FROM Dataset JOIN DatasetCollection ON Dataset.dataset_id = DatasetCollection.dataset_id WHERE Dataset.dataset_type_name = <dsType.name> AND DatasetCollection.collection IN (<collections>)
(here ``Dataset.DataId`` means ``Dataset.link1, Dataset.link2, etc.``)
Filtering is complicated, it is simpler to use Common Table Expression (WITH clause) but not all databases support CTEs so we will have to do with the repeating sub-queries. Use GROUP BY for DataId and MIN(collorder) to find ``collorder`` for given DataId, then join it with previous combined selection:
SELECT DS.dataset_id AS dataset_id, DS.link1 AS link1 ... FROM (SELECT Dataset.dataset_id AS dataset_id, CASE ... END AS collorder, Dataset.DataId FROM Dataset JOIN DatasetCollection ON Dataset.dataset_id = DatasetCollection.dataset_id WHERE Dataset.dataset_type_name = <dsType.name> AND DatasetCollection.collection IN (<collections>)) DS INNER JOIN (SELECT MIN(CASE ... END AS) collorder, Dataset.DataId FROM Dataset JOIN DatasetCollection ON Dataset.dataset_id = DatasetCollection.dataset_id WHERE Dataset.dataset_type_name = <dsType.name> AND DatasetCollection.collection IN (<collections>) GROUP BY Dataset.DataId) DSG ON DS.colpos = DSG.colpos AND DS.DataId = DSG.DataId
Parameters ---------- dsType : `DatasetType` isOutput : `bool` `True` for output datasets.
Returns ------- subquery : `sqlalchemy.FromClause` or `None` """
# helper method def _columns(selectable, names): """Return list of columns for given column names""" return [selectable.c[name].label(name) for name in names]
if isOutput:
outputCollection = self.originInfo.getOutputCollection(dsType.name) if not outputCollection: # No output collection means no output datasets exist, we do # not need to do any joins here. return None
dsCollections = [outputCollection] else: dsCollections = self.originInfo.getInputCollections(dsType.name)
_LOG.debug("using collections: %s", dsCollections)
# full set of link names for this DatasetType links = list(dsType.dimensions.links())
dsTable = self.tables["Dataset"] dsCollTable = self.tables["DatasetCollection"]
if len(dsCollections) == 1:
# single collection, easy-peasy subJoin = dsTable.join(dsCollTable, dsTable.c.dataset_id == dsCollTable.c.dataset_id) subWhere = and_(dsTable.c.dataset_type_name == dsType.name, dsCollTable.c.collection == dsCollections[0])
columns = _columns(dsTable, ["dataset_id"] + links) subquery = select(columns).select_from(subJoin).where(subWhere)
else:
# multiple collections subJoin = dsTable.join(dsCollTable, dsTable.c.dataset_id == dsCollTable.c.dataset_id) subWhere = and_(dsTable.c.dataset_type_name == dsType.name, dsCollTable.c.collection.in_(dsCollections))
# CASE caluse collorder = case([ (dsCollTable.c.collection == coll, pos) for pos, coll in enumerate(dsCollections) ])
# first GROUP BY sub-query, find minimum `collorder` for each # DataId columns = [functions.min(collorder).label("collorder")] + _columns(dsTable, links) groupSubq = select(columns).select_from(subJoin).where(subWhere) groupSubq = groupSubq.group_by(*links) groupSubq = groupSubq.alias("sub1" + dsType.name)
# next combined sub-query columns = [collorder.label("collorder")] + _columns(dsTable, ["dataset_id"] + links) combined = select(columns).select_from(subJoin).where(subWhere) combined = combined.alias("sub2" + dsType.name)
# now join these two joinsOn = [groupSubq.c.collorder == combined.c.collorder] + \ [groupSubq.c[colName] == combined.c[colName] for colName in links] subJoin = combined.join(groupSubq, and_(*joinsOn)) columns = _columns(combined, ["dataset_id"] + links) subquery = select(columns).select_from(subJoin)
# need a unique alias name for it, otherwise we'll see name conflicts subquery = subquery.alias("ds" + dsType.name) return subquery
"""Convert query result rows into `PreFlightDimensionsRow` instances.
Parameters ---------- rowIter : iterable Iterator for rows returned by the query on registry dimensions : `DimensionGraph` All Dimensions included in this query. linkColumnIndices : `dict` Dictionary of {dimension link name: column index} for the column that contains the link value regionColumnIndices : `dict` Dictionary of (Dimension name, column index), column contains encoded region data dsIdColumns : `dict` Dictionary of (DatasetType, column index), column contains dataset Id, or None if dataset does not exist
Yields ------ row : `PreFlightDimensionsRow` """
total = 0 count = 0 for row in rowIter:
total += 1
# Filter result rows that have non-overlapping regions. # Result set generated by query in selectDimensions() method can # include set of regions in each row (encoded as bytes). Due to # pixel-based matching some regions may not overlap, this # generator method filters rows that have disjoint regions. If # result row contains more than two regions (this should not # happen with our current schema) then row is filtered if any of # two regions are disjoint. disjoint = False regions = {holder: Region.decode(row[col]) for holder, col in regionColumnIndices.items()}
# SkyPix regions aren't in the query because they aren't in the # database. If the data IDs we yield include a skypix key, # calculate their regions using sphgeom. if "skypix" in linkColumnIndices: skypix = row[linkColumnIndices["skypix"]] regions["SkyPix"] = self.registry.pixelization.pixel(skypix)
for reg1, reg2 in itertools.combinations(regions.values(), 2): if reg1.relate(reg2) == DISJOINT: disjoint = True break if disjoint: continue
def extractRegion(dims): try: holder = dims.getRegionHolder() except ValueError: return None if holder is not None: return regions.get(holder.name) return None
# Find all of the link columns that aren't NULL. rowDataIdDict = {link: row[col] for link, col in linkColumnIndices.items() if row[col] is not None} # Find all of the Dimensions we can uniquely identify with the # non-NULL link columns. rowDimensions = self.registry.dimensions.extract(dim for dim in dimensions if dim.links().issubset(rowDataIdDict.keys())) # Remove all of the link columns that weren't needed by the # Dimensions we selected (in practice this is just ExposureRange # links right now, so this step might not be needed once we make # that less of a special case). dataId = DataId( {k: v for k, v in rowDataIdDict.items() if k in rowDimensions.links()}, dimensions=rowDimensions, region=extractRegion(rowDimensions) ) # row-wide Data IDs are never expanded, even if expandDataIds=True; # this is slightly confusing, but we don't actually need them # expanded, and it's actually quite slow.
# get Dataset for each DatasetType datasetRefs = {} for dsType, col in dsIdColumns.items(): linkNames = {} # maps full link name in linkColumnIndices to dataId key for dimension in dsType.dimensions: if dimension.name == "ExposureRange": # special case of ExposureRange, its columns come from # Dataset table instead of Dimension linkNames[dsType.name + ".valid_first"] = "valid_first" linkNames[dsType.name + ".valid_last"] = "valid_last" else: if self.tables.get(dimension.name) is not None: linkNames.update((s, s) for s in dimension.links(expand=False)) dsDataId = DataId({val: row[linkColumnIndices[key]] for key, val in linkNames.items()}, dimensions=dsType.dimensions, region=extractRegion(dsType.dimensions)) if self.expandDataIds: self.registry.expandDataId(dsDataId)
if col is None: # Dataset does not exist yet. datasetRefs[dsType] = DatasetRef(dsType, dsDataId, id=None) elif col is DATASET_ID_DEFERRED: # We haven't searched for the dataset yet, because we've # deferred these queries ref = self.registry.find( collection=self.originInfo.getOutputCollection(dsType.name), datasetType=dsType, dataId=dsDataId ) datasetRefs[dsType] = ref if ref is not None else DatasetRef(dsType, dsDataId, id=None) else: datasetRefs[dsType] = self.registry.getDataset(id=row[col], datasetType=dsType, dataId=dsDataId)
count += 1 yield PreFlightDimensionsRow(dataId, datasetRefs)
_LOG.debug("Total %d rows in result set, %d after region filtering", total, count) |