Coverage for python/lsst/daf/butler/direct_query_driver/_query_builder.py: 27%
167 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-16 02:58 -0700
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-16 02:58 -0700
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = ("QueryJoiner", "QueryBuilder")
32import dataclasses
33import itertools
34from collections.abc import Iterable, Sequence
35from typing import TYPE_CHECKING, Any, ClassVar
37import sqlalchemy
39from .. import ddl
40from ..name_shrinker import NameShrinker
41from ..nonempty_mapping import NonemptyMapping
42from ..queries import tree as qt
43from ._postprocessing import Postprocessing
45if TYPE_CHECKING:
46 from ..registry.interfaces import Database
47 from ..timespan_database_representation import TimespanDatabaseRepresentation
50@dataclasses.dataclass
51class QueryBuilder:
52 """A struct used to represent an under-construction SQL SELECT query.
54 This object's methods frequently "consume" ``self``, by either returning
55 it after modification or returning related copy that may share state with
56 the original. Users should be careful never to use consumed instances, and
57 are recommended to reuse the same variable name to make that hard to do
58 accidentally.
59 """
61 joiner: QueryJoiner
62 """Struct representing the SQL FROM and WHERE clauses, as well as the
63 columns *available* to the query (but not necessarily in the SELECT
64 clause).
65 """
67 columns: qt.ColumnSet
68 """Columns to include the SELECT clause.
70 This does not include columns required only by `postprocessing` and columns
71 in `QueryJoiner.special`, which are also always included in the SELECT
72 clause.
73 """
75 postprocessing: Postprocessing = dataclasses.field(default_factory=Postprocessing)
76 """Postprocessing that will be needed in Python after the SQL query has
77 been executed.
78 """
80 distinct: bool | Sequence[sqlalchemy.ColumnElement[Any]] = ()
81 """A representation of a DISTINCT or DISTINCT ON clause.
83 If `True`, this represents a SELECT DISTINCT. If a non-empty sequence,
84 this represents a SELECT DISTINCT ON. If `False` or an empty sequence,
85 there is no DISTINCT clause.
86 """
88 group_by: Sequence[sqlalchemy.ColumnElement[Any]] = ()
89 """A representation of a GROUP BY clause.
91 If not-empty, a GROUP BY clause with these columns is added. This
92 generally requires that every `sqlalchemy.ColumnElement` held in the nested
93 `joiner` that is part of `columns` must either be part of `group_by` or
94 hold an aggregate function.
95 """
97 EMPTY_COLUMNS_NAME: ClassVar[str] = "IGNORED"
98 """Name of the column added to a SQL SELECT clause in order to construct
99 queries that have no real columns.
100 """
102 EMPTY_COLUMNS_TYPE: ClassVar[type] = sqlalchemy.Boolean
103 """Type of the column added to a SQL SELECT clause in order to construct
104 queries that have no real columns.
105 """
107 @classmethod
108 def handle_empty_columns(
109 cls, columns: list[sqlalchemy.sql.ColumnElement]
110 ) -> list[sqlalchemy.ColumnElement]:
111 """Handle the edge case where a SELECT statement has no columns, by
112 adding a literal column that should be ignored.
114 Parameters
115 ----------
116 columns : `list` [ `sqlalchemy.ColumnElement` ]
117 List of SQLAlchemy column objects. This may have no elements when
118 this method is called, and will always have at least one element
119 when it returns.
121 Returns
122 -------
123 columns : `list` [ `sqlalchemy.ColumnElement` ]
124 The same list that was passed in, after any modification.
125 """
126 if not columns:
127 columns.append(sqlalchemy.sql.literal(True).label(cls.EMPTY_COLUMNS_NAME))
128 return columns
130 def select(self) -> sqlalchemy.Select:
131 """Transform this builder into a SQLAlchemy representation of a SELECT
132 query.
134 Returns
135 -------
136 select : `sqlalchemy.Select`
137 SQLAlchemy SELECT statement.
138 """
139 assert not (self.distinct and self.group_by), "At most one of distinct and group_by can be set."
140 if self.joiner.name_shrinker is None:
141 self.joiner.name_shrinker = self.joiner._make_name_shrinker()
142 sql_columns: list[sqlalchemy.ColumnElement[Any]] = []
143 for logical_table, field in self.columns:
144 name = self.columns.get_qualified_name(logical_table, field)
145 if field is None:
146 sql_columns.append(self.joiner.dimension_keys[logical_table][0].label(name))
147 else:
148 name = self.joiner.name_shrinker.shrink(name)
149 if self.columns.is_timespan(logical_table, field):
150 sql_columns.extend(self.joiner.timespans[logical_table].flatten(name))
151 else:
152 sql_columns.append(self.joiner.fields[logical_table][field].label(name))
153 if self.postprocessing is not None:
154 for element in self.postprocessing.iter_missing(self.columns):
155 sql_columns.append(
156 self.joiner.fields[element.name]["region"].label(
157 self.joiner.name_shrinker.shrink(
158 self.columns.get_qualified_name(element.name, "region")
159 )
160 )
161 )
162 for label, sql_column in self.joiner.special.items():
163 sql_columns.append(sql_column.label(label))
164 self.handle_empty_columns(sql_columns)
165 result = sqlalchemy.select(*sql_columns)
166 if self.joiner.from_clause is not None:
167 result = result.select_from(self.joiner.from_clause)
168 if self.distinct is True:
169 result = result.distinct()
170 elif self.distinct:
171 result = result.distinct(*self.distinct)
172 if self.group_by:
173 result = result.group_by(*self.group_by)
174 if self.joiner.where_terms:
175 result = result.where(*self.joiner.where_terms)
176 return result
178 def join(self, other: QueryJoiner) -> QueryBuilder:
179 """Join tables, subqueries, and WHERE clauses from another query into
180 this one, in place.
182 Parameters
183 ----------
184 other : `QueryJoiner`
185 Object holding the FROM and WHERE clauses to add to this one.
186 JOIN ON clauses are generated via the dimension keys in common.
188 Returns
189 -------
190 self : `QueryBuilder`
191 This `QueryBuilder` instance (never a copy); returned to enable
192 method-chaining.
193 """
194 self.joiner.join(other)
195 return self
197 def to_joiner(self, cte: bool = False, force: bool = False) -> QueryJoiner:
198 """Convert this builder into a `QueryJoiner`, nesting it in a subquery
199 or common table expression only if needed to apply DISTINCT or GROUP BY
200 clauses.
202 This method consumes ``self``.
204 Parameters
205 ----------
206 cte : `bool`, optional
207 If `True`, nest via a common table expression instead of a
208 subquery.
209 force : `bool`, optional
210 If `True`, nest via a subquery or common table expression even if
211 there is no DISTINCT or GROUP BY.
213 Returns
214 -------
215 joiner : `QueryJoiner`
216 QueryJoiner` with at least all columns in `columns` available.
217 This may or may not be the `joiner` attribute of this object.
218 """
219 if force or self.distinct or self.group_by:
220 sql_from_clause = self.select().cte() if cte else self.select().subquery()
221 return QueryJoiner(
222 self.joiner.db, sql_from_clause, name_shrinker=self.joiner.name_shrinker
223 ).extract_columns(self.columns, self.postprocessing, special=self.joiner.special.keys())
224 return self.joiner
226 def nested(self, cte: bool = False, force: bool = False) -> QueryBuilder:
227 """Convert this builder into a `QueryBuiler` that is guaranteed to have
228 no DISTINCT or GROUP BY, nesting it in a subquery or common table
229 expression only if needed to apply any current DISTINCT or GROUP BY
230 clauses.
232 This method consumes ``self``.
234 Parameters
235 ----------
236 cte : `bool`, optional
237 If `True`, nest via a common table expression instead of a
238 subquery.
239 force : `bool`, optional
240 If `True`, nest via a subquery or common table expression even if
241 there is no DISTINCT or GROUP BY.
243 Returns
244 -------
245 builder : `QueryBuilder`
246 `QueryBuilder` with at least all columns in `columns` available.
247 This may or may not be the `builder` attribute of this object.
248 """
249 return QueryBuilder(
250 self.to_joiner(cte=cte, force=force), columns=self.columns, postprocessing=self.postprocessing
251 )
253 def union_subquery(
254 self,
255 others: Iterable[QueryBuilder],
256 ) -> QueryJoiner:
257 """Combine this builder with others to make a SELECT UNION subquery.
259 Parameters
260 ----------
261 others : `~collections.abc.Iterable` [ `QueryBuilder` ]
262 Other query builders to union with. Their `columns` attributes
263 must be the same as those of ``self``.
265 Returns
266 -------
267 joiner : `QueryJoiner`
268 `QueryJoiner` with at least all columns in `columns` available.
269 This may or may not be the `joiner` attribute of this object.
270 """
271 select0 = self.select()
272 other_selects = [other.select() for other in others]
273 return QueryJoiner(
274 self.joiner.db,
275 from_clause=select0.union(*other_selects).subquery(),
276 name_shrinker=self.joiner.name_shrinker,
277 ).extract_columns(self.columns, self.postprocessing)
279 def make_table_spec(self) -> ddl.TableSpec:
280 """Make a specification that can be used to create a table to store
281 this query's outputs.
283 Returns
284 -------
285 spec : `.ddl.TableSpec`
286 Table specification for this query's result columns (including
287 those from `postprocessing` and `QueryJoiner.special`).
288 """
289 assert not self.joiner.special, "special columns not supported in make_table_spec"
290 if self.joiner.name_shrinker is None:
291 self.joiner.name_shrinker = self.joiner._make_name_shrinker()
292 results = ddl.TableSpec(
293 [
294 self.columns.get_column_spec(logical_table, field).to_sql_spec(
295 name_shrinker=self.joiner.name_shrinker
296 )
297 for logical_table, field in self.columns
298 ]
299 )
300 if self.postprocessing:
301 for element in self.postprocessing.iter_missing(self.columns):
302 results.fields.add(
303 ddl.FieldSpec.for_region(
304 self.joiner.name_shrinker.shrink(
305 self.columns.get_qualified_name(element.name, "region")
306 )
307 )
308 )
309 return results
312@dataclasses.dataclass
313class QueryJoiner:
314 """A struct used to represent the FROM and WHERE clauses of an
315 under-construction SQL SELECT query.
317 This object's methods frequently "consume" ``self``, by either returning
318 it after modification or returning related copy that may share state with
319 the original. Users should be careful never to use consumed instances, and
320 are recommended to reuse the same variable name to make that hard to do
321 accidentally.
322 """
324 db: Database
325 """Object that abstracts over the database engine."""
327 from_clause: sqlalchemy.FromClause | None = None
328 """SQLAlchemy representation of the FROM clause.
330 This is initialized to `None` but in almost all cases is immediately
331 replaced.
332 """
334 where_terms: list[sqlalchemy.ColumnElement[bool]] = dataclasses.field(default_factory=list)
335 """Sequence of WHERE clause terms to be combined with AND."""
337 dimension_keys: NonemptyMapping[str, list[sqlalchemy.ColumnElement]] = dataclasses.field( 337 ↛ exitline 337 didn't jump to the function exit
338 default_factory=lambda: NonemptyMapping(list)
339 )
340 """Mapping of dimension keys included in the FROM clause.
342 Nested lists correspond to different tables that have the same dimension
343 key (which should all have equal values for all result rows).
344 """
346 fields: NonemptyMapping[str, dict[str, sqlalchemy.ColumnElement[Any]]] = dataclasses.field( 346 ↛ exitline 346 didn't jump to the function exit
347 default_factory=lambda: NonemptyMapping(dict)
348 )
349 """Mapping of columns that are neither dimension keys nor timespans.
351 Inner and outer keys correspond to the "logical table" and "field" pairs
352 that result from iterating over `~.queries.tree.ColumnSet`, with the former
353 either a dimension element name or dataset type name.
354 """
356 timespans: dict[str, TimespanDatabaseRepresentation] = dataclasses.field(default_factory=dict)
357 """Mapping of timespan columns.
359 Keys are "logical tables" - dimension element names or dataset type names.
360 """
362 special: dict[str, sqlalchemy.ColumnElement[Any]] = dataclasses.field(default_factory=dict)
363 """Special columns that are available from the FROM clause and
364 automatically included in the SELECT clause when this joiner is nested
365 within a `QueryBuilder`.
367 These columns are not part of the dimension universe and are not associated
368 with a dataset. They are never returned to users, even if they may be
369 included in raw SQL results.
370 """
372 name_shrinker: NameShrinker | None = None
373 """An object that can be used to shrink field names to fit within the
374 identifier limit of the database engine.
376 This is important for PostgreSQL (which has a 64-character limit) and
377 dataset fields, since dataset type names are used to qualify those and they
378 can be quite long. `DimensionUniverse` guarantees at construction that
379 dimension names and fully-qualified dimension fields do not exceed this
380 limit.
381 """
383 def extract_dimensions(self, dimensions: Iterable[str], **kwargs: str) -> QueryJoiner:
384 """Add dimension key columns from `from_clause` into `dimension_keys`.
386 Parameters
387 ----------
388 dimensions : `~collections.abc.Iterable` [ `str` ]
389 Names of dimensions to include, assuming that their names in
390 `sql_columns` are just the dimension names.
391 **kwargs : `str`
392 Additional dimensions to include, with the names in `sql_columns`
393 as keys and the actual dimension names as values.
395 Returns
396 -------
397 self : `QueryJoiner`
398 This `QueryJoiner` instance (never a copy). Provided to enable
399 method chaining.
400 """
401 assert self.from_clause is not None, "Cannot extract columns with no FROM clause."
402 for dimension_name in dimensions:
403 self.dimension_keys[dimension_name].append(self.from_clause.columns[dimension_name])
404 for k, v in kwargs.items():
405 self.dimension_keys[v].append(self.from_clause.columns[k])
406 return self
408 def extract_columns(
409 self,
410 columns: qt.ColumnSet,
411 postprocessing: Postprocessing | None = None,
412 special: Iterable[str] = (),
413 ) -> QueryJoiner:
414 """Add columns from `from_clause` into `dimension_keys`.
416 Parameters
417 ----------
418 columns : `.queries.tree.ColumnSet`
419 Columns to include, assuming that
420 `.queries.tree.ColumnSet.get_qualified_name` corresponds to the
421 name used in `sql_columns` (after name shrinking).
422 postprocessing : `Postprocessing`, optional
423 Postprocessing object whose needed columns should also be included.
424 special : `~collections.abc.Iterable` [ `str` ], optional
425 Additional special columns to extract.
427 Returns
428 -------
429 self : `QueryJoiner`
430 This `QueryJoiner` instance (never a copy). Provided to enable
431 method chaining.
432 """
433 assert self.from_clause is not None, "Cannot extract columns with no FROM clause."
434 if self.name_shrinker is None:
435 self.name_shrinker = self._make_name_shrinker()
436 for logical_table, field in columns:
437 name = columns.get_qualified_name(logical_table, field)
438 if field is None:
439 self.dimension_keys[logical_table].append(self.from_clause.columns[name])
440 else:
441 name = self.name_shrinker.shrink(name)
442 if columns.is_timespan(logical_table, field):
443 self.timespans[logical_table] = self.db.getTimespanRepresentation().from_columns(
444 self.from_clause.columns, name
445 )
446 else:
447 self.fields[logical_table][field] = self.from_clause.columns[name]
448 if postprocessing is not None:
449 for element in postprocessing.iter_missing(columns):
450 self.fields[element.name]["region"] = self.from_clause.columns[
451 self.name_shrinker.shrink(columns.get_qualified_name(element.name, "region"))
452 ]
453 if postprocessing.check_validity_match_count:
454 self.special[postprocessing.VALIDITY_MATCH_COUNT] = self.from_clause.columns[
455 postprocessing.VALIDITY_MATCH_COUNT
456 ]
457 for name in special:
458 self.special[name] = self.from_clause.columns[name]
459 return self
461 def join(self, other: QueryJoiner) -> QueryJoiner:
462 """Combine this `QueryJoiner` with another via an INNER JOIN on
463 dimension keys.
465 This method consumes ``self``.
467 Parameters
468 ----------
469 other : `QueryJoiner`
470 Other joiner to combine with this one.
472 Returns
473 -------
474 joined : `QueryJoiner`
475 A `QueryJoiner` with all columns present in either operand, with
476 its `from_clause` representing a SQL INNER JOIN where the dimension
477 key columns common to both operands are constrained to be equal.
478 If either operand does not have `from_clause`, the other's is used.
479 The `where_terms` of the two operands are concatenated,
480 representing a logical AND (with no attempt at deduplication).
481 """
482 join_on: list[sqlalchemy.ColumnElement] = []
483 for dimension_name in other.dimension_keys.keys():
484 if dimension_name in self.dimension_keys:
485 for column1, column2 in itertools.product(
486 self.dimension_keys[dimension_name], other.dimension_keys[dimension_name]
487 ):
488 join_on.append(column1 == column2)
489 self.dimension_keys[dimension_name].extend(other.dimension_keys[dimension_name])
490 if self.from_clause is None:
491 self.from_clause = other.from_clause
492 elif other.from_clause is not None:
493 join_on_sql: sqlalchemy.ColumnElement[bool]
494 match len(join_on):
495 case 0:
496 join_on_sql = sqlalchemy.true()
497 case 1:
498 (join_on_sql,) = join_on
499 case _:
500 join_on_sql = sqlalchemy.and_(*join_on)
501 self.from_clause = self.from_clause.join(other.from_clause, onclause=join_on_sql)
502 for logical_table, fields in other.fields.items():
503 self.fields[logical_table].update(fields)
504 self.timespans.update(other.timespans)
505 self.special.update(other.special)
506 self.where_terms += other.where_terms
507 if other.name_shrinker:
508 if self.name_shrinker is not None:
509 self.name_shrinker.update(other.name_shrinker)
510 else:
511 self.name_shrinker = other.name_shrinker
512 return self
514 def where(self, *args: sqlalchemy.ColumnElement[bool]) -> QueryJoiner:
515 """Add a WHERE clause term.
517 Parameters
518 ----------
519 *args : `sqlalchemy.ColumnElement`
520 SQL boolean column expressions to be combined with AND.
522 Returns
523 -------
524 self : `QueryJoiner`
525 This `QueryJoiner` instance (never a copy). Provided to enable
526 method chaining.
527 """
528 self.where_terms.extend(args)
529 return self
531 def to_builder(
532 self,
533 columns: qt.ColumnSet,
534 postprocessing: Postprocessing | None = None,
535 distinct: bool | Sequence[sqlalchemy.ColumnElement[Any]] = (),
536 group_by: Sequence[sqlalchemy.ColumnElement[Any]] = (),
537 ) -> QueryBuilder:
538 """Convert this joiner into a `QueryBuilder` by providing SELECT clause
539 columns and optional DISTINCT or GROUP BY clauses.
541 This method consumes ``self``.
543 Parameters
544 ----------
545 columns : `~.queries.tree.ColumnSet`
546 Regular columns to include in the SELECT clause.
547 postprocessing : `Postprocessing`, optional
548 Addition processing to be performed on result rows after executing
549 the SQL query.
550 distinct : `bool` or `~collections.abc.Sequence` [ \
551 `sqlalchemy.ColumnElement` ], optional
552 Specification of the DISTINCT clause (see `QueryBuilder.distinct`).
553 group_by : `~collections.abc.Sequence` [ \
554 `sqlalchemy.ColumnElement` ], optional
555 Specification of the GROUP BY clause (see `QueryBuilder.group_by`).
557 Returns
558 -------
559 builder : `QueryBuilder`
560 New query builder.
561 """
562 return QueryBuilder(
563 self,
564 columns,
565 postprocessing=postprocessing if postprocessing is not None else Postprocessing(),
566 distinct=distinct,
567 group_by=group_by,
568 )
570 def _make_name_shrinker(self) -> NameShrinker:
571 return NameShrinker(self.db.dialect.max_identifier_length, 6)