Coverage for python/lsst/daf/butler/direct_query_driver/_query_builder.py: 27%

167 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-05-03 02:48 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30__all__ = ("QueryJoiner", "QueryBuilder") 

31 

32import dataclasses 

33import itertools 

34from collections.abc import Iterable, Sequence 

35from typing import TYPE_CHECKING, Any, ClassVar 

36 

37import sqlalchemy 

38 

39from .. import ddl 

40from ..name_shrinker import NameShrinker 

41from ..nonempty_mapping import NonemptyMapping 

42from ..queries import tree as qt 

43from ._postprocessing import Postprocessing 

44 

45if TYPE_CHECKING: 

46 from ..registry.interfaces import Database 

47 from ..timespan_database_representation import TimespanDatabaseRepresentation 

48 

49 

50@dataclasses.dataclass 

51class QueryBuilder: 

52 """A struct used to represent an under-construction SQL SELECT query. 

53 

54 This object's methods frequently "consume" ``self``, by either returning 

55 it after modification or returning related copy that may share state with 

56 the original. Users should be careful never to use consumed instances, and 

57 are recommended to reuse the same variable name to make that hard to do 

58 accidentally. 

59 """ 

60 

61 joiner: QueryJoiner 

62 """Struct representing the SQL FROM and WHERE clauses, as well as the 

63 columns *available* to the query (but not necessarily in the SELECT 

64 clause). 

65 """ 

66 

67 columns: qt.ColumnSet 

68 """Columns to include the SELECT clause. 

69 

70 This does not include columns required only by `postprocessing` and columns 

71 in `QueryJoiner.special`, which are also always included in the SELECT 

72 clause. 

73 """ 

74 

75 postprocessing: Postprocessing = dataclasses.field(default_factory=Postprocessing) 

76 """Postprocessing that will be needed in Python after the SQL query has 

77 been executed. 

78 """ 

79 

80 distinct: bool | Sequence[sqlalchemy.ColumnElement[Any]] = () 

81 """A representation of a DISTINCT or DISTINCT ON clause. 

82 

83 If `True`, this represents a SELECT DISTINCT. If a non-empty sequence, 

84 this represents a SELECT DISTINCT ON. If `False` or an empty sequence, 

85 there is no DISTINCT clause. 

86 """ 

87 

88 group_by: Sequence[sqlalchemy.ColumnElement[Any]] = () 

89 """A representation of a GROUP BY clause. 

90 

91 If not-empty, a GROUP BY clause with these columns is added. This 

92 generally requires that every `sqlalchemy.ColumnElement` held in the nested 

93 `joiner` that is part of `columns` must either be part of `group_by` or 

94 hold an aggregate function. 

95 """ 

96 

97 EMPTY_COLUMNS_NAME: ClassVar[str] = "IGNORED" 

98 """Name of the column added to a SQL SELECT clause in order to construct 

99 queries that have no real columns. 

100 """ 

101 

102 EMPTY_COLUMNS_TYPE: ClassVar[type] = sqlalchemy.Boolean 

103 """Type of the column added to a SQL SELECT clause in order to construct 

104 queries that have no real columns. 

105 """ 

106 

107 @classmethod 

108 def handle_empty_columns( 

109 cls, columns: list[sqlalchemy.sql.ColumnElement] 

110 ) -> list[sqlalchemy.ColumnElement]: 

111 """Handle the edge case where a SELECT statement has no columns, by 

112 adding a literal column that should be ignored. 

113 

114 Parameters 

115 ---------- 

116 columns : `list` [ `sqlalchemy.ColumnElement` ] 

117 List of SQLAlchemy column objects. This may have no elements when 

118 this method is called, and will always have at least one element 

119 when it returns. 

120 

121 Returns 

122 ------- 

123 columns : `list` [ `sqlalchemy.ColumnElement` ] 

124 The same list that was passed in, after any modification. 

125 """ 

126 if not columns: 

127 columns.append(sqlalchemy.sql.literal(True).label(cls.EMPTY_COLUMNS_NAME)) 

128 return columns 

129 

130 def select(self) -> sqlalchemy.Select: 

131 """Transform this builder into a SQLAlchemy representation of a SELECT 

132 query. 

133 

134 Returns 

135 ------- 

136 select : `sqlalchemy.Select` 

137 SQLAlchemy SELECT statement. 

138 """ 

139 assert not (self.distinct and self.group_by), "At most one of distinct and group_by can be set." 

140 if self.joiner.name_shrinker is None: 

141 self.joiner.name_shrinker = self.joiner._make_name_shrinker() 

142 sql_columns: list[sqlalchemy.ColumnElement[Any]] = [] 

143 for logical_table, field in self.columns: 

144 name = self.columns.get_qualified_name(logical_table, field) 

145 if field is None: 

146 sql_columns.append(self.joiner.dimension_keys[logical_table][0].label(name)) 

147 else: 

148 name = self.joiner.name_shrinker.shrink(name) 

149 if self.columns.is_timespan(logical_table, field): 

150 sql_columns.extend(self.joiner.timespans[logical_table].flatten(name)) 

151 else: 

152 sql_columns.append(self.joiner.fields[logical_table][field].label(name)) 

153 if self.postprocessing is not None: 

154 for element in self.postprocessing.iter_missing(self.columns): 

155 sql_columns.append( 

156 self.joiner.fields[element.name]["region"].label( 

157 self.joiner.name_shrinker.shrink( 

158 self.columns.get_qualified_name(element.name, "region") 

159 ) 

160 ) 

161 ) 

162 for label, sql_column in self.joiner.special.items(): 

163 sql_columns.append(sql_column.label(label)) 

164 self.handle_empty_columns(sql_columns) 

165 result = sqlalchemy.select(*sql_columns) 

166 if self.joiner.from_clause is not None: 

167 result = result.select_from(self.joiner.from_clause) 

168 if self.distinct is True: 

169 result = result.distinct() 

170 elif self.distinct: 

171 result = result.distinct(*self.distinct) 

172 if self.group_by: 

173 result = result.group_by(*self.group_by) 

174 if self.joiner.where_terms: 

175 result = result.where(*self.joiner.where_terms) 

176 return result 

177 

178 def join(self, other: QueryJoiner) -> QueryBuilder: 

179 """Join tables, subqueries, and WHERE clauses from another query into 

180 this one, in place. 

181 

182 Parameters 

183 ---------- 

184 other : `QueryJoiner` 

185 Object holding the FROM and WHERE clauses to add to this one. 

186 JOIN ON clauses are generated via the dimension keys in common. 

187 

188 Returns 

189 ------- 

190 self : `QueryBuilder` 

191 This `QueryBuilder` instance (never a copy); returned to enable 

192 method-chaining. 

193 """ 

194 self.joiner.join(other) 

195 return self 

196 

197 def to_joiner(self, cte: bool = False, force: bool = False) -> QueryJoiner: 

198 """Convert this builder into a `QueryJoiner`, nesting it in a subquery 

199 or common table expression only if needed to apply DISTINCT or GROUP BY 

200 clauses. 

201 

202 This method consumes ``self``. 

203 

204 Parameters 

205 ---------- 

206 cte : `bool`, optional 

207 If `True`, nest via a common table expression instead of a 

208 subquery. 

209 force : `bool`, optional 

210 If `True`, nest via a subquery or common table expression even if 

211 there is no DISTINCT or GROUP BY. 

212 

213 Returns 

214 ------- 

215 joiner : `QueryJoiner` 

216 QueryJoiner` with at least all columns in `columns` available. 

217 This may or may not be the `joiner` attribute of this object. 

218 """ 

219 if force or self.distinct or self.group_by: 

220 sql_from_clause = self.select().cte() if cte else self.select().subquery() 

221 return QueryJoiner( 

222 self.joiner.db, sql_from_clause, name_shrinker=self.joiner.name_shrinker 

223 ).extract_columns(self.columns, self.postprocessing, special=self.joiner.special.keys()) 

224 return self.joiner 

225 

226 def nested(self, cte: bool = False, force: bool = False) -> QueryBuilder: 

227 """Convert this builder into a `QueryBuiler` that is guaranteed to have 

228 no DISTINCT or GROUP BY, nesting it in a subquery or common table 

229 expression only if needed to apply any current DISTINCT or GROUP BY 

230 clauses. 

231 

232 This method consumes ``self``. 

233 

234 Parameters 

235 ---------- 

236 cte : `bool`, optional 

237 If `True`, nest via a common table expression instead of a 

238 subquery. 

239 force : `bool`, optional 

240 If `True`, nest via a subquery or common table expression even if 

241 there is no DISTINCT or GROUP BY. 

242 

243 Returns 

244 ------- 

245 builder : `QueryBuilder` 

246 `QueryBuilder` with at least all columns in `columns` available. 

247 This may or may not be the `builder` attribute of this object. 

248 """ 

249 return QueryBuilder( 

250 self.to_joiner(cte=cte, force=force), columns=self.columns, postprocessing=self.postprocessing 

251 ) 

252 

253 def union_subquery( 

254 self, 

255 others: Iterable[QueryBuilder], 

256 ) -> QueryJoiner: 

257 """Combine this builder with others to make a SELECT UNION subquery. 

258 

259 Parameters 

260 ---------- 

261 others : `~collections.abc.Iterable` [ `QueryBuilder` ] 

262 Other query builders to union with. Their `columns` attributes 

263 must be the same as those of ``self``. 

264 

265 Returns 

266 ------- 

267 joiner : `QueryJoiner` 

268 `QueryJoiner` with at least all columns in `columns` available. 

269 This may or may not be the `joiner` attribute of this object. 

270 """ 

271 select0 = self.select() 

272 other_selects = [other.select() for other in others] 

273 return QueryJoiner( 

274 self.joiner.db, 

275 from_clause=select0.union(*other_selects).subquery(), 

276 name_shrinker=self.joiner.name_shrinker, 

277 ).extract_columns(self.columns, self.postprocessing) 

278 

279 def make_table_spec(self) -> ddl.TableSpec: 

280 """Make a specification that can be used to create a table to store 

281 this query's outputs. 

282 

283 Returns 

284 ------- 

285 spec : `.ddl.TableSpec` 

286 Table specification for this query's result columns (including 

287 those from `postprocessing` and `QueryJoiner.special`). 

288 """ 

289 assert not self.joiner.special, "special columns not supported in make_table_spec" 

290 if self.joiner.name_shrinker is None: 

291 self.joiner.name_shrinker = self.joiner._make_name_shrinker() 

292 results = ddl.TableSpec( 

293 [ 

294 self.columns.get_column_spec(logical_table, field).to_sql_spec( 

295 name_shrinker=self.joiner.name_shrinker 

296 ) 

297 for logical_table, field in self.columns 

298 ] 

299 ) 

300 if self.postprocessing: 

301 for element in self.postprocessing.iter_missing(self.columns): 

302 results.fields.add( 

303 ddl.FieldSpec.for_region( 

304 self.joiner.name_shrinker.shrink( 

305 self.columns.get_qualified_name(element.name, "region") 

306 ) 

307 ) 

308 ) 

309 return results 

310 

311 

312@dataclasses.dataclass 

313class QueryJoiner: 

314 """A struct used to represent the FROM and WHERE clauses of an 

315 under-construction SQL SELECT query. 

316 

317 This object's methods frequently "consume" ``self``, by either returning 

318 it after modification or returning related copy that may share state with 

319 the original. Users should be careful never to use consumed instances, and 

320 are recommended to reuse the same variable name to make that hard to do 

321 accidentally. 

322 """ 

323 

324 db: Database 

325 """Object that abstracts over the database engine.""" 

326 

327 from_clause: sqlalchemy.FromClause | None = None 

328 """SQLAlchemy representation of the FROM clause. 

329 

330 This is initialized to `None` but in almost all cases is immediately 

331 replaced. 

332 """ 

333 

334 where_terms: list[sqlalchemy.ColumnElement[bool]] = dataclasses.field(default_factory=list) 

335 """Sequence of WHERE clause terms to be combined with AND.""" 

336 

337 dimension_keys: NonemptyMapping[str, list[sqlalchemy.ColumnElement]] = dataclasses.field( 337 ↛ exitline 337 didn't jump to the function exit

338 default_factory=lambda: NonemptyMapping(list) 

339 ) 

340 """Mapping of dimension keys included in the FROM clause. 

341 

342 Nested lists correspond to different tables that have the same dimension 

343 key (which should all have equal values for all result rows). 

344 """ 

345 

346 fields: NonemptyMapping[str, dict[str, sqlalchemy.ColumnElement[Any]]] = dataclasses.field( 346 ↛ exitline 346 didn't jump to the function exit

347 default_factory=lambda: NonemptyMapping(dict) 

348 ) 

349 """Mapping of columns that are neither dimension keys nor timespans. 

350 

351 Inner and outer keys correspond to the "logical table" and "field" pairs 

352 that result from iterating over `~.queries.tree.ColumnSet`, with the former 

353 either a dimension element name or dataset type name. 

354 """ 

355 

356 timespans: dict[str, TimespanDatabaseRepresentation] = dataclasses.field(default_factory=dict) 

357 """Mapping of timespan columns. 

358 

359 Keys are "logical tables" - dimension element names or dataset type names. 

360 """ 

361 

362 special: dict[str, sqlalchemy.ColumnElement[Any]] = dataclasses.field(default_factory=dict) 

363 """Special columns that are available from the FROM clause and 

364 automatically included in the SELECT clause when this joiner is nested 

365 within a `QueryBuilder`. 

366 

367 These columns are not part of the dimension universe and are not associated 

368 with a dataset. They are never returned to users, even if they may be 

369 included in raw SQL results. 

370 """ 

371 

372 name_shrinker: NameShrinker | None = None 

373 """An object that can be used to shrink field names to fit within the 

374 identifier limit of the database engine. 

375 

376 This is important for PostgreSQL (which has a 64-character limit) and 

377 dataset fields, since dataset type names are used to qualify those and they 

378 can be quite long. `DimensionUniverse` guarantees at construction that 

379 dimension names and fully-qualified dimension fields do not exceed this 

380 limit. 

381 """ 

382 

383 def extract_dimensions(self, dimensions: Iterable[str], **kwargs: str) -> QueryJoiner: 

384 """Add dimension key columns from `from_clause` into `dimension_keys`. 

385 

386 Parameters 

387 ---------- 

388 dimensions : `~collections.abc.Iterable` [ `str` ] 

389 Names of dimensions to include, assuming that their names in 

390 `sql_columns` are just the dimension names. 

391 **kwargs : `str` 

392 Additional dimensions to include, with the names in `sql_columns` 

393 as keys and the actual dimension names as values. 

394 

395 Returns 

396 ------- 

397 self : `QueryJoiner` 

398 This `QueryJoiner` instance (never a copy). Provided to enable 

399 method chaining. 

400 """ 

401 assert self.from_clause is not None, "Cannot extract columns with no FROM clause." 

402 for dimension_name in dimensions: 

403 self.dimension_keys[dimension_name].append(self.from_clause.columns[dimension_name]) 

404 for k, v in kwargs.items(): 

405 self.dimension_keys[v].append(self.from_clause.columns[k]) 

406 return self 

407 

408 def extract_columns( 

409 self, 

410 columns: qt.ColumnSet, 

411 postprocessing: Postprocessing | None = None, 

412 special: Iterable[str] = (), 

413 ) -> QueryJoiner: 

414 """Add columns from `from_clause` into `dimension_keys`. 

415 

416 Parameters 

417 ---------- 

418 columns : `.queries.tree.ColumnSet` 

419 Columns to include, assuming that 

420 `.queries.tree.ColumnSet.get_qualified_name` corresponds to the 

421 name used in `sql_columns` (after name shrinking). 

422 postprocessing : `Postprocessing`, optional 

423 Postprocessing object whose needed columns should also be included. 

424 special : `~collections.abc.Iterable` [ `str` ], optional 

425 Additional special columns to extract. 

426 

427 Returns 

428 ------- 

429 self : `QueryJoiner` 

430 This `QueryJoiner` instance (never a copy). Provided to enable 

431 method chaining. 

432 """ 

433 assert self.from_clause is not None, "Cannot extract columns with no FROM clause." 

434 if self.name_shrinker is None: 

435 self.name_shrinker = self._make_name_shrinker() 

436 for logical_table, field in columns: 

437 name = columns.get_qualified_name(logical_table, field) 

438 if field is None: 

439 self.dimension_keys[logical_table].append(self.from_clause.columns[name]) 

440 else: 

441 name = self.name_shrinker.shrink(name) 

442 if columns.is_timespan(logical_table, field): 

443 self.timespans[logical_table] = self.db.getTimespanRepresentation().from_columns( 

444 self.from_clause.columns, name 

445 ) 

446 else: 

447 self.fields[logical_table][field] = self.from_clause.columns[name] 

448 if postprocessing is not None: 

449 for element in postprocessing.iter_missing(columns): 

450 self.fields[element.name]["region"] = self.from_clause.columns[ 

451 self.name_shrinker.shrink(columns.get_qualified_name(element.name, "region")) 

452 ] 

453 if postprocessing.check_validity_match_count: 

454 self.special[postprocessing.VALIDITY_MATCH_COUNT] = self.from_clause.columns[ 

455 postprocessing.VALIDITY_MATCH_COUNT 

456 ] 

457 for name in special: 

458 self.special[name] = self.from_clause.columns[name] 

459 return self 

460 

461 def join(self, other: QueryJoiner) -> QueryJoiner: 

462 """Combine this `QueryJoiner` with another via an INNER JOIN on 

463 dimension keys. 

464 

465 This method consumes ``self``. 

466 

467 Parameters 

468 ---------- 

469 other : `QueryJoiner` 

470 Other joiner to combine with this one. 

471 

472 Returns 

473 ------- 

474 joined : `QueryJoiner` 

475 A `QueryJoiner` with all columns present in either operand, with 

476 its `from_clause` representing a SQL INNER JOIN where the dimension 

477 key columns common to both operands are constrained to be equal. 

478 If either operand does not have `from_clause`, the other's is used. 

479 The `where_terms` of the two operands are concatenated, 

480 representing a logical AND (with no attempt at deduplication). 

481 """ 

482 join_on: list[sqlalchemy.ColumnElement] = [] 

483 for dimension_name in other.dimension_keys.keys(): 

484 if dimension_name in self.dimension_keys: 

485 for column1, column2 in itertools.product( 

486 self.dimension_keys[dimension_name], other.dimension_keys[dimension_name] 

487 ): 

488 join_on.append(column1 == column2) 

489 self.dimension_keys[dimension_name].extend(other.dimension_keys[dimension_name]) 

490 if self.from_clause is None: 

491 self.from_clause = other.from_clause 

492 elif other.from_clause is not None: 

493 join_on_sql: sqlalchemy.ColumnElement[bool] 

494 match len(join_on): 

495 case 0: 

496 join_on_sql = sqlalchemy.true() 

497 case 1: 

498 (join_on_sql,) = join_on 

499 case _: 

500 join_on_sql = sqlalchemy.and_(*join_on) 

501 self.from_clause = self.from_clause.join(other.from_clause, onclause=join_on_sql) 

502 for logical_table, fields in other.fields.items(): 

503 self.fields[logical_table].update(fields) 

504 self.timespans.update(other.timespans) 

505 self.special.update(other.special) 

506 self.where_terms += other.where_terms 

507 if other.name_shrinker: 

508 if self.name_shrinker is not None: 

509 self.name_shrinker.update(other.name_shrinker) 

510 else: 

511 self.name_shrinker = other.name_shrinker 

512 return self 

513 

514 def where(self, *args: sqlalchemy.ColumnElement[bool]) -> QueryJoiner: 

515 """Add a WHERE clause term. 

516 

517 Parameters 

518 ---------- 

519 *args : `sqlalchemy.ColumnElement` 

520 SQL boolean column expressions to be combined with AND. 

521 

522 Returns 

523 ------- 

524 self : `QueryJoiner` 

525 This `QueryJoiner` instance (never a copy). Provided to enable 

526 method chaining. 

527 """ 

528 self.where_terms.extend(args) 

529 return self 

530 

531 def to_builder( 

532 self, 

533 columns: qt.ColumnSet, 

534 postprocessing: Postprocessing | None = None, 

535 distinct: bool | Sequence[sqlalchemy.ColumnElement[Any]] = (), 

536 group_by: Sequence[sqlalchemy.ColumnElement[Any]] = (), 

537 ) -> QueryBuilder: 

538 """Convert this joiner into a `QueryBuilder` by providing SELECT clause 

539 columns and optional DISTINCT or GROUP BY clauses. 

540 

541 This method consumes ``self``. 

542 

543 Parameters 

544 ---------- 

545 columns : `~.queries.tree.ColumnSet` 

546 Regular columns to include in the SELECT clause. 

547 postprocessing : `Postprocessing`, optional 

548 Addition processing to be performed on result rows after executing 

549 the SQL query. 

550 distinct : `bool` or `~collections.abc.Sequence` [ \ 

551 `sqlalchemy.ColumnElement` ], optional 

552 Specification of the DISTINCT clause (see `QueryBuilder.distinct`). 

553 group_by : `~collections.abc.Sequence` [ \ 

554 `sqlalchemy.ColumnElement` ], optional 

555 Specification of the GROUP BY clause (see `QueryBuilder.group_by`). 

556 

557 Returns 

558 ------- 

559 builder : `QueryBuilder` 

560 New query builder. 

561 """ 

562 return QueryBuilder( 

563 self, 

564 columns, 

565 postprocessing=postprocessing if postprocessing is not None else Postprocessing(), 

566 distinct=distinct, 

567 group_by=group_by, 

568 ) 

569 

570 def _make_name_shrinker(self) -> NameShrinker: 

571 return NameShrinker(self.db.dialect.max_identifier_length, 6)