Coverage for tests / test_query_utilities.py: 13%

262 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-06 08:30 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <https://www.gnu.org/licenses/>. 

27 

28"""Tests for non-public Butler._query functionality that is not specific to 

29any Butler or QueryDriver implementation. 

30""" 

31 

32from __future__ import annotations 

33 

34import unittest 

35from collections.abc import Iterable, Set 

36 

37import astropy.time 

38 

39from lsst.daf.butler import DimensionUniverse, InvalidQueryError, Timespan 

40from lsst.daf.butler.dimensions import DimensionElement, DimensionGroup 

41from lsst.daf.butler.queries import tree as qt 

42from lsst.daf.butler.queries.expression_factory import ExpressionFactory 

43from lsst.daf.butler.queries.overlaps import OverlapsVisitor, _NaiveDisjointSet 

44from lsst.daf.butler.queries.visitors import PredicateVisitFlags 

45from lsst.sphgeom import Mq3cPixelization, Region 

46 

47 

48class ColumnSetTestCase(unittest.TestCase): 

49 """Tests for lsst.daf.butler.queries.ColumnSet.""" 

50 

51 def setUp(self) -> None: 

52 self.universe = DimensionUniverse() 

53 

54 def test_basics(self) -> None: 

55 columns = qt.ColumnSet(self.universe.conform(["detector"])) 

56 self.assertNotEqual(columns, columns.dimensions.names) # intentionally not comparable to other sets 

57 self.assertEqual(columns.dimensions, self.universe.conform(["detector"])) 

58 self.assertFalse(columns.dataset_fields) 

59 columns.dataset_fields["bias"].add("dataset_id") 

60 self.assertEqual(dict(columns.dataset_fields), {"bias": {"dataset_id"}}) 

61 columns.dimension_fields["detector"].add("purpose") 

62 self.assertEqual(columns.dimension_fields["detector"], {"purpose"}) 

63 self.assertTrue(columns) 

64 self.assertEqual( 

65 list(columns), 

66 [(k, None) for k in columns.dimensions.data_coordinate_keys] 

67 + [("detector", "purpose"), ("bias", "dataset_id")], 

68 ) 

69 self.assertEqual(str(columns), "{instrument, detector, detector:purpose, bias:dataset_id}") 

70 empty = qt.ColumnSet(self.universe.empty) 

71 self.assertFalse(empty) 

72 self.assertFalse(columns.issubset(empty)) 

73 self.assertTrue(columns.issuperset(empty)) 

74 self.assertTrue(columns.isdisjoint(empty)) 

75 copy = columns.copy() 

76 self.assertEqual(columns, copy) 

77 self.assertTrue(columns.issubset(copy)) 

78 self.assertTrue(columns.issuperset(copy)) 

79 self.assertFalse(columns.isdisjoint(copy)) 

80 copy.dataset_fields["bias"].add("timespan") 

81 copy.dimension_fields["detector"].add("name") 

82 copy.update_dimensions(self.universe.conform(["band"])) 

83 self.assertEqual(copy.dataset_fields["bias"], {"dataset_id", "timespan"}) 

84 self.assertEqual(columns.dataset_fields["bias"], {"dataset_id"}) 

85 self.assertEqual(copy.dimension_fields["detector"], {"purpose", "name"}) 

86 self.assertEqual(columns.dimension_fields["detector"], {"purpose"}) 

87 self.assertTrue(columns.issubset(copy)) 

88 self.assertFalse(columns.issuperset(copy)) 

89 self.assertFalse(columns.isdisjoint(copy)) 

90 columns.update(copy) 

91 self.assertEqual(columns, copy) 

92 self.assertTrue(columns.is_timespan("visit", "timespan")) 

93 self.assertFalse(columns.is_timespan("visit", None)) 

94 self.assertFalse(columns.is_timespan("detector", "purpose")) 

95 

96 def test_drop_dimension_keys(self): 

97 columns = qt.ColumnSet(self.universe.conform(["physical_filter"])) 

98 columns.drop_implied_dimension_keys() 

99 self.assertEqual(list(columns), [("instrument", None), ("physical_filter", None)]) 

100 undropped = qt.ColumnSet(columns.dimensions) 

101 self.assertTrue(columns.issubset(undropped)) 

102 self.assertFalse(columns.issuperset(undropped)) 

103 self.assertFalse(columns.isdisjoint(undropped)) 

104 band_only = qt.ColumnSet(self.universe.conform(["band"])) 

105 self.assertFalse(columns.issubset(band_only)) 

106 self.assertFalse(columns.issuperset(band_only)) 

107 self.assertTrue(columns.isdisjoint(band_only)) 

108 copy = columns.copy() 

109 copy.update(band_only) 

110 self.assertEqual(copy, undropped) 

111 columns.restore_dimension_keys() 

112 self.assertEqual(columns, undropped) 

113 

114 def test_get_column_spec(self) -> None: 

115 columns = qt.ColumnSet(self.universe.conform(["detector"])) 

116 columns.dimension_fields["detector"].add("purpose") 

117 columns.dataset_fields["bias"].update(["dataset_id", "run", "collection", "timespan", "ingest_date"]) 

118 self.assertEqual(columns.get_column_spec("instrument", None).name, "instrument") 

119 self.assertEqual(columns.get_column_spec("instrument", None).type, "string") 

120 self.assertEqual(columns.get_column_spec("instrument", None).nullable, False) 

121 self.assertEqual(columns.get_column_spec("detector", None).name, "detector") 

122 self.assertEqual(columns.get_column_spec("detector", None).type, "int") 

123 self.assertEqual(columns.get_column_spec("detector", None).nullable, False) 

124 self.assertEqual(columns.get_column_spec("detector", "purpose").name, "detector:purpose") 

125 self.assertEqual(columns.get_column_spec("detector", "purpose").type, "string") 

126 self.assertEqual(columns.get_column_spec("detector", "purpose").nullable, True) 

127 self.assertEqual(columns.get_column_spec("bias", "dataset_id").name, "bias:dataset_id") 

128 self.assertEqual(columns.get_column_spec("bias", "dataset_id").type, "uuid") 

129 self.assertEqual(columns.get_column_spec("bias", "dataset_id").nullable, False) 

130 self.assertEqual(columns.get_column_spec("bias", "run").name, "bias:run") 

131 self.assertEqual(columns.get_column_spec("bias", "run").type, "string") 

132 self.assertEqual(columns.get_column_spec("bias", "run").nullable, False) 

133 self.assertEqual(columns.get_column_spec("bias", "collection").name, "bias:collection") 

134 self.assertEqual(columns.get_column_spec("bias", "collection").type, "string") 

135 self.assertEqual(columns.get_column_spec("bias", "collection").nullable, False) 

136 self.assertEqual(columns.get_column_spec("bias", "timespan").name, "bias:timespan") 

137 self.assertEqual(columns.get_column_spec("bias", "timespan").type, "timespan") 

138 self.assertEqual(columns.get_column_spec("bias", "timespan").nullable, True) 

139 self.assertEqual(columns.get_column_spec("bias", "ingest_date").name, "bias:ingest_date") 

140 self.assertEqual(columns.get_column_spec("bias", "ingest_date").type, "datetime") 

141 self.assertEqual(columns.get_column_spec("bias", "ingest_date").nullable, True) 

142 

143 

144class _RecordingOverlapsVisitor(OverlapsVisitor): 

145 def __init__(self, dimensions: DimensionGroup, calibration_dataset_types: Set[str] = frozenset()): 

146 super().__init__(dimensions, calibration_dataset_types) 

147 self.spatial_constraints: list[tuple[str, PredicateVisitFlags]] = [] 

148 self.spatial_joins: list[tuple[str, str, PredicateVisitFlags]] = [] 

149 self.temporal_dimension_joins: list[tuple[str, str, PredicateVisitFlags]] = [] 

150 self.validity_range_dimension_joins: list[tuple[str, str, PredicateVisitFlags]] = [] 

151 self.validity_range_joins: list[tuple[str, str, PredicateVisitFlags]] = [] 

152 

153 def visit_spatial_constraint( 

154 self, element: DimensionElement, region: Region, flags: PredicateVisitFlags 

155 ) -> qt.Predicate | None: 

156 self.spatial_constraints.append((element.name, flags)) 

157 return super().visit_spatial_constraint(element, region, flags) 

158 

159 def visit_spatial_join( 

160 self, a: DimensionElement, b: DimensionElement, flags: PredicateVisitFlags 

161 ) -> qt.Predicate | None: 

162 self.spatial_joins.append((a.name, b.name, flags)) 

163 return super().visit_spatial_join(a, b, flags) 

164 

165 def visit_temporal_dimension_join( 

166 self, a: DimensionElement, b: DimensionElement, flags: PredicateVisitFlags 

167 ) -> qt.Predicate | None: 

168 self.temporal_dimension_joins.append((a.name, b.name, flags)) 

169 return super().visit_temporal_dimension_join(a, b, flags) 

170 

171 def visit_validity_range_dimension_join( 

172 self, a: str, b: DimensionElement, flags: PredicateVisitFlags 

173 ) -> qt.Predicate | None: 

174 self.validity_range_dimension_joins.append((a, b.name, flags)) 

175 return super().visit_validity_range_dimension_join(a, b, flags) 

176 

177 def visit_validity_range_join(self, a: str, b: str, flags: PredicateVisitFlags) -> qt.Predicate | None: 

178 self.validity_range_joins.append((a, b, flags)) 

179 return super().visit_validity_range_join(a, b, flags) 

180 

181 

182class OverlapsVisitorTestCase(unittest.TestCase): 

183 """Tests for lsst.daf.butler.queries.overlaps.OverlapsVisitor, which is 

184 responsible for validating and inferring spatial and temporal joins and 

185 constraints. 

186 """ 

187 

188 def setUp(self) -> None: 

189 self.universe = DimensionUniverse() 

190 

191 def run_visitor( 

192 self, 

193 dimensions: Iterable[str], 

194 predicate: qt.Predicate, 

195 expected: str | None = None, 

196 join_operands: Iterable[DimensionGroup] = (), 

197 calibration_dataset_types: Set[str] = frozenset(), 

198 ) -> _RecordingOverlapsVisitor: 

199 visitor = _RecordingOverlapsVisitor(self.universe.conform(dimensions), calibration_dataset_types) 

200 if expected is None: 

201 expected = str(predicate) 

202 new_predicate = visitor.run(predicate, join_operands=join_operands) 

203 self.assertEqual(str(new_predicate), expected) 

204 return visitor 

205 

206 def test_trivial(self) -> None: 

207 """Test the overlaps visitor when there is nothing spatial or temporal 

208 in the query at all. 

209 """ 

210 x = ExpressionFactory(self.universe) 

211 # Trivial predicate. 

212 visitor = self.run_visitor(["physical_filter"], qt.Predicate.from_bool(True)) 

213 self.assertFalse(visitor.spatial_joins) 

214 self.assertFalse(visitor.spatial_constraints) 

215 self.assertFalse(visitor.temporal_dimension_joins) 

216 # Non-overlap predicate. 

217 visitor = self.run_visitor(["physical_filter"], x.any(x.band == "r", x.band == "i")) 

218 self.assertFalse(visitor.spatial_joins) 

219 self.assertFalse(visitor.spatial_constraints) 

220 self.assertFalse(visitor.temporal_dimension_joins) 

221 

222 def test_one_spatial_family(self) -> None: 

223 """Test the overlaps visitor when there is one spatial family.""" 

224 x = ExpressionFactory(self.universe) 

225 pixelization = Mq3cPixelization(10) 

226 region = pixelization.quad(12058870) 

227 # Trivial predicate. 

228 visitor = self.run_visitor(["visit"], qt.Predicate.from_bool(True)) 

229 self.assertFalse(visitor.spatial_joins) 

230 self.assertFalse(visitor.spatial_constraints) 

231 self.assertFalse(visitor.temporal_dimension_joins) 

232 # Non-overlap predicate. 

233 visitor = self.run_visitor(["visit"], x.any(x.band == "r", x.visit > 2)) 

234 self.assertFalse(visitor.spatial_joins) 

235 self.assertFalse(visitor.spatial_constraints) 

236 self.assertFalse(visitor.temporal_dimension_joins) 

237 # Spatial constraint predicate, in various positions relative to other 

238 # non-overlap predicates. 

239 visitor = self.run_visitor(["visit"], x.visit.region.overlaps(region)) 

240 self.assertEqual(visitor.spatial_constraints, [(self.universe["visit"], PredicateVisitFlags(0))]) 

241 visitor = self.run_visitor(["visit"], x.all(x.visit.region.overlaps(region), x.band == "r")) 

242 self.assertEqual( 

243 visitor.spatial_constraints, [(self.universe["visit"], PredicateVisitFlags.HAS_AND_SIBLINGS)] 

244 ) 

245 visitor = self.run_visitor(["visit"], x.any(x.visit.region.overlaps(region), x.band == "r")) 

246 self.assertEqual( 

247 visitor.spatial_constraints, [(self.universe["visit"], PredicateVisitFlags.HAS_OR_SIBLINGS)] 

248 ) 

249 visitor = self.run_visitor( 

250 ["visit"], 

251 x.all( 

252 x.any(x.literal(region).overlaps(x.visit.region), x.band == "r"), 

253 x.visit.observation_reason == "science", 

254 ), 

255 ) 

256 self.assertEqual( 

257 visitor.spatial_constraints, 

258 [ 

259 ( 

260 self.universe["visit"], 

261 PredicateVisitFlags.HAS_OR_SIBLINGS | PredicateVisitFlags.HAS_AND_SIBLINGS, 

262 ) 

263 ], 

264 ) 

265 visitor = self.run_visitor( 

266 ["visit"], 

267 x.any( 

268 x.all(x.visit.region.overlaps(region), x.band == "r"), 

269 x.visit.observation_reason == "science", 

270 ), 

271 ) 

272 self.assertEqual( 

273 visitor.spatial_constraints, 

274 [ 

275 ( 

276 self.universe["visit"], 

277 PredicateVisitFlags.HAS_OR_SIBLINGS | PredicateVisitFlags.HAS_AND_SIBLINGS, 

278 ) 

279 ], 

280 ) 

281 # A spatial join between dimensions in the same family is an error. 

282 with self.assertRaises(InvalidQueryError): 

283 self.run_visitor(["patch", "tract"], x.patch.region.overlaps(x.tract.region)) 

284 

285 def test_single_unambiguous_spatial_join(self) -> None: 

286 """Test the overlaps visitor when there are two spatial families with 

287 one dimension element in each, and hence exactly one join is needed. 

288 """ 

289 x = ExpressionFactory(self.universe) 

290 # Trivial predicate; an automatic join is added. Order of elements in 

291 # automatic joins is lexicographical in order to be deterministic. 

292 visitor = self.run_visitor( 

293 ["visit", "tract"], qt.Predicate.from_bool(True), "tract.region OVERLAPS visit.region" 

294 ) 

295 self.assertEqual(visitor.spatial_joins, [("tract", "visit", PredicateVisitFlags.HAS_AND_SIBLINGS)]) 

296 self.assertFalse(visitor.spatial_constraints) 

297 self.assertFalse(visitor.temporal_dimension_joins) 

298 # Non-overlap predicate; an automatic join is added. 

299 visitor = self.run_visitor( 

300 ["visit", "tract"], 

301 x.all(x.band == "r", x.visit > 2), 

302 "band == 'r' AND visit > 2 AND tract.region OVERLAPS visit.region", 

303 ) 

304 self.assertEqual(visitor.spatial_joins, [("tract", "visit", PredicateVisitFlags.HAS_AND_SIBLINGS)]) 

305 self.assertFalse(visitor.spatial_constraints) 

306 self.assertFalse(visitor.temporal_dimension_joins) 

307 # The same overlap predicate that would be added automatically has been 

308 # added manually. 

309 visitor = self.run_visitor( 

310 ["visit", "tract"], 

311 x.tract.region.overlaps(x.visit.region), 

312 "tract.region OVERLAPS visit.region", 

313 ) 

314 self.assertEqual(visitor.spatial_joins, [("tract", "visit", PredicateVisitFlags(0))]) 

315 self.assertFalse(visitor.spatial_constraints) 

316 self.assertFalse(visitor.temporal_dimension_joins) 

317 # Add the join overlap predicate in an OR expression, which is unusual 

318 # but enough to block the addition of an automatic join; we assume the 

319 # user knows what they're doing. 

320 visitor = self.run_visitor( 

321 ["visit", "tract"], 

322 x.any(x.visit > 2, x.tract.region.overlaps(x.visit.region)), 

323 "visit > 2 OR tract.region OVERLAPS visit.region", 

324 ) 

325 self.assertEqual(visitor.spatial_joins, [("tract", "visit", PredicateVisitFlags.HAS_OR_SIBLINGS)]) 

326 self.assertFalse(visitor.spatial_constraints) 

327 self.assertFalse(visitor.temporal_dimension_joins) 

328 # Add the join overlap predicate in a NOT expression, which is unusual 

329 # but permitted in the same sense as OR expressions. 

330 visitor = self.run_visitor( 

331 ["visit", "tract"], 

332 x.not_(x.tract.region.overlaps(x.visit.region)), 

333 "NOT tract.region OVERLAPS visit.region", 

334 ) 

335 self.assertEqual(visitor.spatial_joins, [("tract", "visit", PredicateVisitFlags.INVERTED)]) 

336 self.assertFalse(visitor.spatial_constraints) 

337 self.assertFalse(visitor.temporal_dimension_joins) 

338 # Add a "join operand" whose dimensions include both spatial families. 

339 # This blocks an automatic join from being created, because we assume 

340 # that join operand (e.g. a materialization or dataset search) already 

341 # encodes some spatial join. 

342 visitor = self.run_visitor( 

343 ["visit", "tract"], 

344 qt.Predicate.from_bool(True), 

345 "True", 

346 join_operands=[self.universe.conform(["tract", "visit"])], 

347 ) 

348 self.assertFalse(visitor.spatial_joins) 

349 self.assertFalse(visitor.spatial_constraints) 

350 self.assertFalse(visitor.temporal_dimension_joins) 

351 

352 def test_single_flexible_spatial_join(self) -> None: 

353 """Test the overlaps visitor when there are two spatial families and 

354 one has multiple dimension elements. 

355 """ 

356 x = ExpressionFactory(self.universe) 

357 # Trivial predicate; an automatic join between the fine-grained 

358 # elements is added. Order of elements in automatic joins is 

359 # lexicographical in order to be deterministic. 

360 visitor = self.run_visitor( 

361 ["visit", "detector", "patch"], 

362 qt.Predicate.from_bool(True), 

363 "patch.region OVERLAPS visit_detector_region.region", 

364 ) 

365 self.assertEqual( 

366 visitor.spatial_joins, [("patch", "visit_detector_region", PredicateVisitFlags.HAS_AND_SIBLINGS)] 

367 ) 

368 self.assertFalse(visitor.spatial_constraints) 

369 self.assertFalse(visitor.temporal_dimension_joins) 

370 # The same overlap predicate that would be added automatically has been 

371 # added manually. 

372 visitor = self.run_visitor( 

373 ["visit", "detector", "patch"], 

374 x.patch.region.overlaps(x.visit_detector_region.region), 

375 "patch.region OVERLAPS visit_detector_region.region", 

376 ) 

377 self.assertEqual(visitor.spatial_joins, [("patch", "visit_detector_region", PredicateVisitFlags(0))]) 

378 self.assertFalse(visitor.spatial_constraints) 

379 self.assertFalse(visitor.temporal_dimension_joins) 

380 # A coarse overlap join has been added; respect it and do not add an 

381 # automatic one. 

382 visitor = self.run_visitor( 

383 ["visit", "detector", "patch"], 

384 x.tract.region.overlaps(x.visit.region), 

385 "tract.region OVERLAPS visit.region", 

386 ) 

387 self.assertEqual(visitor.spatial_joins, [("tract", "visit", PredicateVisitFlags(0))]) 

388 self.assertFalse(visitor.spatial_constraints) 

389 self.assertFalse(visitor.temporal_dimension_joins) 

390 # Add a "join operand" whose dimensions include both spatial families 

391 # with the most fine-grained dimensions in the query. 

392 # This blocks an automatic join from being created, because we assume 

393 # that join operand (e.g. a materialization or dataset search) already 

394 # encodes some spatial join. 

395 visitor = self.run_visitor( 

396 ["visit", "detector", "patch"], 

397 qt.Predicate.from_bool(True), 

398 "True", 

399 join_operands=[self.universe.conform(["patch", "visit_detector_region"])], 

400 ) 

401 self.assertFalse(visitor.spatial_joins) 

402 self.assertFalse(visitor.spatial_constraints) 

403 self.assertFalse(visitor.temporal_dimension_joins) 

404 

405 def test_multiple_spatial_joins(self) -> None: 

406 """Test the overlaps visitor when there are >2 spatial families.""" 

407 x = ExpressionFactory(self.universe) 

408 # Trivial predicate. This is an error, because we cannot generate 

409 # automatic spatial joins when there are more than two families 

410 with self.assertRaises(InvalidQueryError): 

411 self.run_visitor(["visit", "patch", "htm7"], qt.Predicate.from_bool(True)) 

412 # Predicate that joins one pair of families but orphans the the other; 

413 # also an error. 

414 with self.assertRaises(InvalidQueryError): 

415 self.run_visitor(["visit", "patch", "htm7"], x.visit.region.overlaps(x.htm7.region)) 

416 # A sufficient overlap join predicate has been added; each family is 

417 # connected to at least one other. 

418 visitor = self.run_visitor( 

419 ["visit", "patch", "htm7"], 

420 x.all(x.tract.region.overlaps(x.visit.region), x.tract.region.overlaps(x.htm7.region)), 

421 "tract.region OVERLAPS visit.region AND tract.region OVERLAPS htm7.region", 

422 ) 

423 self.assertEqual( 

424 visitor.spatial_joins, 

425 [ 

426 ("tract", "visit", PredicateVisitFlags.HAS_AND_SIBLINGS), 

427 ("tract", "htm7", PredicateVisitFlags.HAS_AND_SIBLINGS), 

428 ], 

429 ) 

430 self.assertFalse(visitor.spatial_constraints) 

431 self.assertFalse(visitor.temporal_dimension_joins) 

432 # Add a "join operand" whose dimensions includes two spatial families, 

433 # with the most fine-grained dimensions in the query, and a predicate 

434 # that joins the third in. 

435 visitor = self.run_visitor( 

436 ["visit", "patch", "htm7"], 

437 x.tract.region.overlaps(x.htm7.region), 

438 "tract.region OVERLAPS htm7.region", 

439 join_operands=[self.universe.conform(["visit", "patch"])], 

440 ) 

441 self.assertEqual( 

442 visitor.spatial_joins, 

443 [ 

444 ("tract", "htm7", PredicateVisitFlags(0)), 

445 ], 

446 ) 

447 self.assertFalse(visitor.spatial_constraints) 

448 self.assertFalse(visitor.temporal_dimension_joins) 

449 

450 def test_one_temporal_family(self) -> None: 

451 """Test the overlaps visitor when there is one temporal family.""" 

452 x = ExpressionFactory(self.universe) 

453 begin = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") 

454 end = astropy.time.Time("2020-01-01T00:01:00", format="isot", scale="tai") 

455 timespan = Timespan(begin, end) 

456 # Trivial predicate. 

457 visitor = self.run_visitor(["exposure"], qt.Predicate.from_bool(True)) 

458 self.assertFalse(visitor.spatial_joins) 

459 self.assertFalse(visitor.spatial_constraints) 

460 self.assertFalse(visitor.temporal_dimension_joins) 

461 # Non-overlap predicate. 

462 visitor = self.run_visitor(["exposure"], x.any(x.band == "r", x.exposure > 2)) 

463 self.assertFalse(visitor.spatial_joins) 

464 self.assertFalse(visitor.spatial_constraints) 

465 self.assertFalse(visitor.temporal_dimension_joins) 

466 # Temporal constraint predicate. 

467 visitor = self.run_visitor(["exposure"], x.exposure.timespan.overlaps(timespan)) 

468 self.assertFalse(visitor.spatial_joins) 

469 self.assertFalse(visitor.spatial_constraints) 

470 self.assertFalse(visitor.temporal_dimension_joins) 

471 # A temporal join between dimensions in the same family is an error. 

472 with self.assertRaises(InvalidQueryError): 

473 self.run_visitor(["exposure", "visit"], x.exposure.timespan.overlaps(x.visit.timespan)) 

474 # Overlap join with a calibration dataset's validity ranges. 

475 visitor = self.run_visitor( 

476 ["exposure"], x.exposure.timespan.overlaps(x["bias"].timespan), calibration_dataset_types={"bias"} 

477 ) 

478 self.assertFalse(visitor.spatial_joins) 

479 self.assertFalse(visitor.spatial_constraints) 

480 self.assertFalse(visitor.temporal_dimension_joins) 

481 self.assertEqual( 

482 visitor.validity_range_dimension_joins, [("bias", "exposure", PredicateVisitFlags(0))] 

483 ) 

484 self.assertFalse(visitor.validity_range_joins) 

485 # Overlap join between two calibration dataset validity ranges. 

486 # (It's not clear this kind of query is ever useful in practice, but 

487 # there's a good consistency argument for what it ought to do). 

488 visitor = self.run_visitor( 

489 [], x["flat"].timespan.overlaps(x["bias"].timespan), calibration_dataset_types={"bias", "flat"} 

490 ) 

491 self.assertFalse(visitor.spatial_joins) 

492 self.assertFalse(visitor.spatial_constraints) 

493 self.assertFalse(visitor.temporal_dimension_joins) 

494 self.assertFalse(visitor.validity_range_dimension_joins) 

495 self.assertEqual(visitor.validity_range_joins, [("flat", "bias", PredicateVisitFlags(0))]) 

496 

497 # There are no tests for temporal dimension joins, because the default 

498 # dimension universe only has one spatial family, and the untested logic 

499 # trivially duplicates the spatial-join logic. 

500 

501 

502class NaiveDisjointSetTestCase(unittest.TestCase): 

503 """Test the naive disjoint-set implementation that backs automatic overlap 

504 join creation. 

505 """ 

506 

507 def test_naive_disjoint_set(self) -> None: 

508 s = _NaiveDisjointSet(range(8)) 

509 self.assertCountEqual(s.subsets(), [{n} for n in range(8)]) 

510 s.merge(3, 4) 

511 self.assertCountEqual(s.subsets(), [{0}, {1}, {2}, {3, 4}, {5}, {6}, {7}]) 

512 s.merge(2, 1) 

513 self.assertCountEqual(s.subsets(), [{0}, {1, 2}, {3, 4}, {5}, {6}, {7}]) 

514 s.merge(1, 3) 

515 self.assertCountEqual(s.subsets(), [{0}, {1, 2, 3, 4}, {5}, {6}, {7}]) 

516 

517 

518if __name__ == "__main__": 

519 unittest.main()