Coverage for tests/test_query_utilities.py: 13%

247 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-08 02:51 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <https://www.gnu.org/licenses/>. 

27 

28"""Tests for non-public Butler._query functionality that is not specific to 

29any Butler or QueryDriver implementation. 

30""" 

31 

32from __future__ import annotations 

33 

34import unittest 

35from collections.abc import Iterable 

36 

37import astropy.time 

38from lsst.daf.butler import DimensionUniverse, InvalidQueryError, Timespan 

39from lsst.daf.butler.dimensions import DimensionElement, DimensionGroup 

40from lsst.daf.butler.queries import tree as qt 

41from lsst.daf.butler.queries.expression_factory import ExpressionFactory 

42from lsst.daf.butler.queries.overlaps import OverlapsVisitor, _NaiveDisjointSet 

43from lsst.daf.butler.queries.visitors import PredicateVisitFlags 

44from lsst.sphgeom import Mq3cPixelization, Region 

45 

46 

47class ColumnSetTestCase(unittest.TestCase): 

48 """Tests for lsst.daf.butler.queries.ColumnSet.""" 

49 

50 def setUp(self) -> None: 

51 self.universe = DimensionUniverse() 

52 

53 def test_basics(self) -> None: 

54 columns = qt.ColumnSet(self.universe.conform(["detector"])) 

55 self.assertNotEqual(columns, columns.dimensions.names) # intentionally not comparable to other sets 

56 self.assertEqual(columns.dimensions, self.universe.conform(["detector"])) 

57 self.assertFalse(columns.dataset_fields) 

58 columns.dataset_fields["bias"].add("dataset_id") 

59 self.assertEqual(dict(columns.dataset_fields), {"bias": {"dataset_id"}}) 

60 columns.dimension_fields["detector"].add("purpose") 

61 self.assertEqual(columns.dimension_fields["detector"], {"purpose"}) 

62 self.assertTrue(columns) 

63 self.assertEqual( 

64 list(columns), 

65 [(k, None) for k in columns.dimensions.data_coordinate_keys] 

66 + [("detector", "purpose"), ("bias", "dataset_id")], 

67 ) 

68 self.assertEqual(str(columns), "{instrument, detector, detector:purpose, bias:dataset_id}") 

69 empty = qt.ColumnSet(self.universe.empty.as_group()) 

70 self.assertFalse(empty) 

71 self.assertFalse(columns.issubset(empty)) 

72 self.assertTrue(columns.issuperset(empty)) 

73 self.assertTrue(columns.isdisjoint(empty)) 

74 copy = columns.copy() 

75 self.assertEqual(columns, copy) 

76 self.assertTrue(columns.issubset(copy)) 

77 self.assertTrue(columns.issuperset(copy)) 

78 self.assertFalse(columns.isdisjoint(copy)) 

79 copy.dataset_fields["bias"].add("timespan") 

80 copy.dimension_fields["detector"].add("name") 

81 copy.update_dimensions(self.universe.conform(["band"])) 

82 self.assertEqual(copy.dataset_fields["bias"], {"dataset_id", "timespan"}) 

83 self.assertEqual(columns.dataset_fields["bias"], {"dataset_id"}) 

84 self.assertEqual(copy.dimension_fields["detector"], {"purpose", "name"}) 

85 self.assertEqual(columns.dimension_fields["detector"], {"purpose"}) 

86 self.assertTrue(columns.issubset(copy)) 

87 self.assertFalse(columns.issuperset(copy)) 

88 self.assertFalse(columns.isdisjoint(copy)) 

89 columns.update(copy) 

90 self.assertEqual(columns, copy) 

91 self.assertTrue(columns.is_timespan("visit", "timespan")) 

92 self.assertFalse(columns.is_timespan("visit", None)) 

93 self.assertFalse(columns.is_timespan("detector", "purpose")) 

94 

95 def test_drop_dimension_keys(self): 

96 columns = qt.ColumnSet(self.universe.conform(["physical_filter"])) 

97 columns.drop_implied_dimension_keys() 

98 self.assertEqual(list(columns), [("instrument", None), ("physical_filter", None)]) 

99 undropped = qt.ColumnSet(columns.dimensions) 

100 self.assertTrue(columns.issubset(undropped)) 

101 self.assertFalse(columns.issuperset(undropped)) 

102 self.assertFalse(columns.isdisjoint(undropped)) 

103 band_only = qt.ColumnSet(self.universe.conform(["band"])) 

104 self.assertFalse(columns.issubset(band_only)) 

105 self.assertFalse(columns.issuperset(band_only)) 

106 self.assertTrue(columns.isdisjoint(band_only)) 

107 copy = columns.copy() 

108 copy.update(band_only) 

109 self.assertEqual(copy, undropped) 

110 columns.restore_dimension_keys() 

111 self.assertEqual(columns, undropped) 

112 

113 def test_get_column_spec(self) -> None: 

114 columns = qt.ColumnSet(self.universe.conform(["detector"])) 

115 columns.dimension_fields["detector"].add("purpose") 

116 columns.dataset_fields["bias"].update(["dataset_id", "run", "collection", "timespan", "ingest_date"]) 

117 self.assertEqual(columns.get_column_spec("instrument", None).name, "instrument") 

118 self.assertEqual(columns.get_column_spec("instrument", None).type, "string") 

119 self.assertEqual(columns.get_column_spec("instrument", None).nullable, False) 

120 self.assertEqual(columns.get_column_spec("detector", None).name, "detector") 

121 self.assertEqual(columns.get_column_spec("detector", None).type, "int") 

122 self.assertEqual(columns.get_column_spec("detector", None).nullable, False) 

123 self.assertEqual(columns.get_column_spec("detector", "purpose").name, "detector:purpose") 

124 self.assertEqual(columns.get_column_spec("detector", "purpose").type, "string") 

125 self.assertEqual(columns.get_column_spec("detector", "purpose").nullable, True) 

126 self.assertEqual(columns.get_column_spec("bias", "dataset_id").name, "bias:dataset_id") 

127 self.assertEqual(columns.get_column_spec("bias", "dataset_id").type, "uuid") 

128 self.assertEqual(columns.get_column_spec("bias", "dataset_id").nullable, False) 

129 self.assertEqual(columns.get_column_spec("bias", "run").name, "bias:run") 

130 self.assertEqual(columns.get_column_spec("bias", "run").type, "string") 

131 self.assertEqual(columns.get_column_spec("bias", "run").nullable, False) 

132 self.assertEqual(columns.get_column_spec("bias", "collection").name, "bias:collection") 

133 self.assertEqual(columns.get_column_spec("bias", "collection").type, "string") 

134 self.assertEqual(columns.get_column_spec("bias", "collection").nullable, False) 

135 self.assertEqual(columns.get_column_spec("bias", "timespan").name, "bias:timespan") 

136 self.assertEqual(columns.get_column_spec("bias", "timespan").type, "timespan") 

137 self.assertEqual(columns.get_column_spec("bias", "timespan").nullable, False) 

138 self.assertEqual(columns.get_column_spec("bias", "ingest_date").name, "bias:ingest_date") 

139 self.assertEqual(columns.get_column_spec("bias", "ingest_date").type, "datetime") 

140 self.assertEqual(columns.get_column_spec("bias", "ingest_date").nullable, True) 

141 

142 

143class _RecordingOverlapsVisitor(OverlapsVisitor): 

144 def __init__(self, dimensions: DimensionGroup): 

145 super().__init__(dimensions) 

146 self.spatial_constraints: list[tuple[str, PredicateVisitFlags]] = [] 

147 self.spatial_joins: list[tuple[str, str, PredicateVisitFlags]] = [] 

148 self.temporal_dimension_joins: list[tuple[str, str, PredicateVisitFlags]] = [] 

149 

150 def visit_spatial_constraint( 

151 self, element: DimensionElement, region: Region, flags: PredicateVisitFlags 

152 ) -> qt.Predicate | None: 

153 self.spatial_constraints.append((element.name, flags)) 

154 return super().visit_spatial_constraint(element, region, flags) 

155 

156 def visit_spatial_join( 

157 self, a: DimensionElement, b: DimensionElement, flags: PredicateVisitFlags 

158 ) -> qt.Predicate | None: 

159 self.spatial_joins.append((a.name, b.name, flags)) 

160 return super().visit_spatial_join(a, b, flags) 

161 

162 def visit_temporal_dimension_join( 

163 self, a: DimensionElement, b: DimensionElement, flags: PredicateVisitFlags 

164 ) -> qt.Predicate | None: 

165 self.temporal_dimension_joins.append((a.name, b.name, flags)) 

166 return super().visit_temporal_dimension_join(a, b, flags) 

167 

168 

169class OverlapsVisitorTestCase(unittest.TestCase): 

170 """Tests for lsst.daf.butler.queries.overlaps.OverlapsVisitor, which is 

171 responsible for validating and inferring spatial and temporal joins and 

172 constraints. 

173 """ 

174 

175 def setUp(self) -> None: 

176 self.universe = DimensionUniverse() 

177 

178 def run_visitor( 

179 self, 

180 dimensions: Iterable[str], 

181 predicate: qt.Predicate, 

182 expected: str | None = None, 

183 join_operands: Iterable[DimensionGroup] = (), 

184 ) -> _RecordingOverlapsVisitor: 

185 visitor = _RecordingOverlapsVisitor(self.universe.conform(dimensions)) 

186 if expected is None: 

187 expected = str(predicate) 

188 new_predicate = visitor.run(predicate, join_operands=join_operands) 

189 self.assertEqual(str(new_predicate), expected) 

190 return visitor 

191 

192 def test_trivial(self) -> None: 

193 """Test the overlaps visitor when there is nothing spatial or temporal 

194 in the query at all. 

195 """ 

196 x = ExpressionFactory(self.universe) 

197 # Trivial predicate. 

198 visitor = self.run_visitor(["physical_filter"], qt.Predicate.from_bool(True)) 

199 self.assertFalse(visitor.spatial_joins) 

200 self.assertFalse(visitor.spatial_constraints) 

201 self.assertFalse(visitor.temporal_dimension_joins) 

202 # Non-overlap predicate. 

203 visitor = self.run_visitor(["physical_filter"], x.any(x.band == "r", x.band == "i")) 

204 self.assertFalse(visitor.spatial_joins) 

205 self.assertFalse(visitor.spatial_constraints) 

206 self.assertFalse(visitor.temporal_dimension_joins) 

207 

208 def test_one_spatial_family(self) -> None: 

209 """Test the overlaps visitor when there is one spatial family.""" 

210 x = ExpressionFactory(self.universe) 

211 pixelization = Mq3cPixelization(10) 

212 region = pixelization.quad(12058870) 

213 # Trivial predicate. 

214 visitor = self.run_visitor(["visit"], qt.Predicate.from_bool(True)) 

215 self.assertFalse(visitor.spatial_joins) 

216 self.assertFalse(visitor.spatial_constraints) 

217 self.assertFalse(visitor.temporal_dimension_joins) 

218 # Non-overlap predicate. 

219 visitor = self.run_visitor(["visit"], x.any(x.band == "r", x.visit > 2)) 

220 self.assertFalse(visitor.spatial_joins) 

221 self.assertFalse(visitor.spatial_constraints) 

222 self.assertFalse(visitor.temporal_dimension_joins) 

223 # Spatial constraint predicate, in various positions relative to other 

224 # non-overlap predicates. 

225 visitor = self.run_visitor(["visit"], x.visit.region.overlaps(region)) 

226 self.assertEqual(visitor.spatial_constraints, [(self.universe["visit"], PredicateVisitFlags(0))]) 

227 visitor = self.run_visitor(["visit"], x.all(x.visit.region.overlaps(region), x.band == "r")) 

228 self.assertEqual( 

229 visitor.spatial_constraints, [(self.universe["visit"], PredicateVisitFlags.HAS_AND_SIBLINGS)] 

230 ) 

231 visitor = self.run_visitor(["visit"], x.any(x.visit.region.overlaps(region), x.band == "r")) 

232 self.assertEqual( 

233 visitor.spatial_constraints, [(self.universe["visit"], PredicateVisitFlags.HAS_OR_SIBLINGS)] 

234 ) 

235 visitor = self.run_visitor( 

236 ["visit"], 

237 x.all( 

238 x.any(x.literal(region).overlaps(x.visit.region), x.band == "r"), 

239 x.visit.observation_reason == "science", 

240 ), 

241 ) 

242 self.assertEqual( 

243 visitor.spatial_constraints, 

244 [ 

245 ( 

246 self.universe["visit"], 

247 PredicateVisitFlags.HAS_OR_SIBLINGS | PredicateVisitFlags.HAS_AND_SIBLINGS, 

248 ) 

249 ], 

250 ) 

251 visitor = self.run_visitor( 

252 ["visit"], 

253 x.any( 

254 x.all(x.visit.region.overlaps(region), x.band == "r"), 

255 x.visit.observation_reason == "science", 

256 ), 

257 ) 

258 self.assertEqual( 

259 visitor.spatial_constraints, 

260 [ 

261 ( 

262 self.universe["visit"], 

263 PredicateVisitFlags.HAS_OR_SIBLINGS | PredicateVisitFlags.HAS_AND_SIBLINGS, 

264 ) 

265 ], 

266 ) 

267 # A spatial join between dimensions in the same family is an error. 

268 with self.assertRaises(InvalidQueryError): 

269 self.run_visitor(["patch", "tract"], x.patch.region.overlaps(x.tract.region)) 

270 

271 def test_single_unambiguous_spatial_join(self) -> None: 

272 """Test the overlaps visitor when there are two spatial families with 

273 one dimension element in each, and hence exactly one join is needed. 

274 """ 

275 x = ExpressionFactory(self.universe) 

276 # Trivial predicate; an automatic join is added. Order of elements in 

277 # automatic joins is lexicographical in order to be deterministic. 

278 visitor = self.run_visitor( 

279 ["visit", "tract"], qt.Predicate.from_bool(True), "tract.region OVERLAPS visit.region" 

280 ) 

281 self.assertEqual(visitor.spatial_joins, [("tract", "visit", PredicateVisitFlags.HAS_AND_SIBLINGS)]) 

282 self.assertFalse(visitor.spatial_constraints) 

283 self.assertFalse(visitor.temporal_dimension_joins) 

284 # Non-overlap predicate; an automatic join is added. 

285 visitor = self.run_visitor( 

286 ["visit", "tract"], 

287 x.all(x.band == "r", x.visit > 2), 

288 "band == 'r' AND visit > 2 AND tract.region OVERLAPS visit.region", 

289 ) 

290 self.assertEqual(visitor.spatial_joins, [("tract", "visit", PredicateVisitFlags.HAS_AND_SIBLINGS)]) 

291 self.assertFalse(visitor.spatial_constraints) 

292 self.assertFalse(visitor.temporal_dimension_joins) 

293 # The same overlap predicate that would be added automatically has been 

294 # added manually. 

295 visitor = self.run_visitor( 

296 ["visit", "tract"], 

297 x.tract.region.overlaps(x.visit.region), 

298 "tract.region OVERLAPS visit.region", 

299 ) 

300 self.assertEqual(visitor.spatial_joins, [("tract", "visit", PredicateVisitFlags(0))]) 

301 self.assertFalse(visitor.spatial_constraints) 

302 self.assertFalse(visitor.temporal_dimension_joins) 

303 # Add the join overlap predicate in an OR expression, which is unusual 

304 # but enough to block the addition of an automatic join; we assume the 

305 # user knows what they're doing. 

306 visitor = self.run_visitor( 

307 ["visit", "tract"], 

308 x.any(x.visit > 2, x.tract.region.overlaps(x.visit.region)), 

309 "visit > 2 OR tract.region OVERLAPS visit.region", 

310 ) 

311 self.assertEqual(visitor.spatial_joins, [("tract", "visit", PredicateVisitFlags.HAS_OR_SIBLINGS)]) 

312 self.assertFalse(visitor.spatial_constraints) 

313 self.assertFalse(visitor.temporal_dimension_joins) 

314 # Add the join overlap predicate in a NOT expression, which is unusual 

315 # but permitted in the same sense as OR expressions. 

316 visitor = self.run_visitor( 

317 ["visit", "tract"], 

318 x.not_(x.tract.region.overlaps(x.visit.region)), 

319 "NOT tract.region OVERLAPS visit.region", 

320 ) 

321 self.assertEqual(visitor.spatial_joins, [("tract", "visit", PredicateVisitFlags.INVERTED)]) 

322 self.assertFalse(visitor.spatial_constraints) 

323 self.assertFalse(visitor.temporal_dimension_joins) 

324 # Add a "join operand" whose dimensions include both spatial families. 

325 # This blocks an automatic join from being created, because we assume 

326 # that join operand (e.g. a materialization or dataset search) already 

327 # encodes some spatial join. 

328 visitor = self.run_visitor( 

329 ["visit", "tract"], 

330 qt.Predicate.from_bool(True), 

331 "True", 

332 join_operands=[self.universe.conform(["tract", "visit"])], 

333 ) 

334 self.assertFalse(visitor.spatial_joins) 

335 self.assertFalse(visitor.spatial_constraints) 

336 self.assertFalse(visitor.temporal_dimension_joins) 

337 

338 def test_single_flexible_spatial_join(self) -> None: 

339 """Test the overlaps visitor when there are two spatial families and 

340 one has multiple dimension elements. 

341 """ 

342 x = ExpressionFactory(self.universe) 

343 # Trivial predicate; an automatic join between the fine-grained 

344 # elements is added. Order of elements in automatic joins is 

345 # lexicographical in order to be deterministic. 

346 visitor = self.run_visitor( 

347 ["visit", "detector", "patch"], 

348 qt.Predicate.from_bool(True), 

349 "patch.region OVERLAPS visit_detector_region.region", 

350 ) 

351 self.assertEqual( 

352 visitor.spatial_joins, [("patch", "visit_detector_region", PredicateVisitFlags.HAS_AND_SIBLINGS)] 

353 ) 

354 self.assertFalse(visitor.spatial_constraints) 

355 self.assertFalse(visitor.temporal_dimension_joins) 

356 # The same overlap predicate that would be added automatically has been 

357 # added manually. 

358 visitor = self.run_visitor( 

359 ["visit", "detector", "patch"], 

360 x.patch.region.overlaps(x.visit_detector_region.region), 

361 "patch.region OVERLAPS visit_detector_region.region", 

362 ) 

363 self.assertEqual(visitor.spatial_joins, [("patch", "visit_detector_region", PredicateVisitFlags(0))]) 

364 self.assertFalse(visitor.spatial_constraints) 

365 self.assertFalse(visitor.temporal_dimension_joins) 

366 # A coarse overlap join has been added; respect it and do not add an 

367 # automatic one. 

368 visitor = self.run_visitor( 

369 ["visit", "detector", "patch"], 

370 x.tract.region.overlaps(x.visit.region), 

371 "tract.region OVERLAPS visit.region", 

372 ) 

373 self.assertEqual(visitor.spatial_joins, [("tract", "visit", PredicateVisitFlags(0))]) 

374 self.assertFalse(visitor.spatial_constraints) 

375 self.assertFalse(visitor.temporal_dimension_joins) 

376 # Add a "join operand" whose dimensions include both spatial families. 

377 # This blocks an automatic join from being created, because we assume 

378 # that join operand (e.g. a materialization or dataset search) already 

379 # encodes some spatial join. 

380 visitor = self.run_visitor( 

381 ["visit", "detector", "patch"], 

382 qt.Predicate.from_bool(True), 

383 "True", 

384 join_operands=[self.universe.conform(["tract", "visit_detector_region"])], 

385 ) 

386 self.assertFalse(visitor.spatial_joins) 

387 self.assertFalse(visitor.spatial_constraints) 

388 self.assertFalse(visitor.temporal_dimension_joins) 

389 

390 def test_multiple_spatial_joins(self) -> None: 

391 """Test the overlaps visitor when there are >2 spatial families.""" 

392 x = ExpressionFactory(self.universe) 

393 # Trivial predicate. This is an error, because we cannot generate 

394 # automatic spatial joins when there are more than two families 

395 with self.assertRaises(InvalidQueryError): 

396 self.run_visitor(["visit", "patch", "htm7"], qt.Predicate.from_bool(True)) 

397 # Predicate that joins one pair of families but orphans the the other; 

398 # also an error. 

399 with self.assertRaises(InvalidQueryError): 

400 self.run_visitor(["visit", "patch", "htm7"], x.visit.region.overlaps(x.htm7.region)) 

401 # A sufficient overlap join predicate has been added; each family is 

402 # connected to at least one other. 

403 visitor = self.run_visitor( 

404 ["visit", "patch", "htm7"], 

405 x.all(x.tract.region.overlaps(x.visit.region), x.tract.region.overlaps(x.htm7.region)), 

406 "tract.region OVERLAPS visit.region AND tract.region OVERLAPS htm7.region", 

407 ) 

408 self.assertEqual( 

409 visitor.spatial_joins, 

410 [ 

411 ("tract", "visit", PredicateVisitFlags.HAS_AND_SIBLINGS), 

412 ("tract", "htm7", PredicateVisitFlags.HAS_AND_SIBLINGS), 

413 ], 

414 ) 

415 self.assertFalse(visitor.spatial_constraints) 

416 self.assertFalse(visitor.temporal_dimension_joins) 

417 # Add a "join operand" whose dimensions includes two spatial families, 

418 # with a predicate that joins the third in. 

419 visitor = self.run_visitor( 

420 ["visit", "patch", "htm7"], 

421 x.tract.region.overlaps(x.htm7.region), 

422 "tract.region OVERLAPS htm7.region", 

423 join_operands=[self.universe.conform(["visit", "tract"])], 

424 ) 

425 self.assertEqual( 

426 visitor.spatial_joins, 

427 [ 

428 ("tract", "htm7", PredicateVisitFlags(0)), 

429 ], 

430 ) 

431 self.assertFalse(visitor.spatial_constraints) 

432 self.assertFalse(visitor.temporal_dimension_joins) 

433 

434 def test_one_temporal_family(self) -> None: 

435 """Test the overlaps visitor when there is one temporal family.""" 

436 x = ExpressionFactory(self.universe) 

437 begin = astropy.time.Time("2020-01-01T00:00:00", format="isot", scale="tai") 

438 end = astropy.time.Time("2020-01-01T00:01:00", format="isot", scale="tai") 

439 timespan = Timespan(begin, end) 

440 # Trivial predicate. 

441 visitor = self.run_visitor(["exposure"], qt.Predicate.from_bool(True)) 

442 self.assertFalse(visitor.spatial_joins) 

443 self.assertFalse(visitor.spatial_constraints) 

444 self.assertFalse(visitor.temporal_dimension_joins) 

445 # Non-overlap predicate. 

446 visitor = self.run_visitor(["exposure"], x.any(x.band == "r", x.exposure > 2)) 

447 self.assertFalse(visitor.spatial_joins) 

448 self.assertFalse(visitor.spatial_constraints) 

449 self.assertFalse(visitor.temporal_dimension_joins) 

450 # Temporal constraint predicate. 

451 visitor = self.run_visitor(["exposure"], x.exposure.timespan.overlaps(timespan)) 

452 self.assertFalse(visitor.spatial_joins) 

453 self.assertFalse(visitor.spatial_constraints) 

454 self.assertFalse(visitor.temporal_dimension_joins) 

455 # A temporal join between dimensions in the same family is an error. 

456 with self.assertRaises(InvalidQueryError): 

457 self.run_visitor(["exposure", "visit"], x.exposure.timespan.overlaps(x.visit.timespan)) 

458 # Overlap join with a calibration dataset's validity ranges. 

459 visitor = self.run_visitor(["exposure"], x.exposure.timespan.overlaps(x["bias"].timespan)) 

460 self.assertFalse(visitor.spatial_joins) 

461 self.assertFalse(visitor.spatial_constraints) 

462 self.assertFalse(visitor.temporal_dimension_joins) 

463 

464 # There are no tests for temporal dimension joins, because the default 

465 # dimension universe only has one spatial family, and the untested logic 

466 # trivially duplicates the spatial-join logic. 

467 

468 

469class NaiveDisjointSetTestCase(unittest.TestCase): 

470 """Test the naive disjoint-set implementation that backs automatic overlap 

471 join creation. 

472 """ 

473 

474 def test_naive_disjoint_set(self) -> None: 

475 s = _NaiveDisjointSet(range(8)) 

476 self.assertCountEqual(s.subsets(), [{n} for n in range(8)]) 

477 s.merge(3, 4) 

478 self.assertCountEqual(s.subsets(), [{0}, {1}, {2}, {3, 4}, {5}, {6}, {7}]) 

479 s.merge(2, 1) 

480 self.assertCountEqual(s.subsets(), [{0}, {1, 2}, {3, 4}, {5}, {6}, {7}]) 

481 s.merge(1, 3) 

482 self.assertCountEqual(s.subsets(), [{0}, {1, 2, 3, 4}, {5}, {6}, {7}]) 

483 

484 

485if __name__ == "__main__": 

486 unittest.main()