Coverage for tests / test_defineVisits.py: 21%
215 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-25 08:22 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-25 08:22 +0000
1# This file is part of obs_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22import logging
23import os
24import pickle
25import shutil
26import tempfile
27import unittest
28import warnings
29from collections import defaultdict
31import lsst.daf.butler.tests as butlerTests
32from lsst.daf.butler import Butler, DataCoordinate, DimensionRecord, SerializedDimensionRecord
33from lsst.daf.butler.registry import ConflictingDefinitionError
34from lsst.obs.base import DefineVisitsConfig, DefineVisitsTask
35from lsst.obs.base.instrument_tests import DummyCam
36from lsst.utils.iteration import ensure_iterable
38TESTDIR = os.path.dirname(__file__)
39DATADIR = os.path.join(TESTDIR, "data", "visits")
42class DefineVisitsBase:
43 """General set up that can be shared."""
45 use_data_ids = True
46 """Use data IDs when calling defineVisits, else use dimension records."""
48 def setUpExposures(self):
49 """Create a new butler for each test since we are changing dimension
50 records.
51 """
52 self.root = tempfile.mkdtemp(dir=TESTDIR)
53 self.creatorButler = butlerTests.makeTestRepo(self.root, {})
54 self.enterContext(self.creatorButler)
55 self.butler = butlerTests.makeTestCollection(self.creatorButler, uniqueId=self.id())
56 self.enterContext(self.butler)
58 self.config = self.get_config()
59 self.task = DefineVisitsTask(config=self.config, butler=self.butler)
61 # Need to register the instrument.
62 DummyCam().register(self.butler.registry)
64 # Choose serializations based on universe.
65 universe = self.butler.dimensions
66 uversion = universe.version
67 # Not all universe changes result in visible changes.
68 match uversion:
69 case uversion if uversion < 2:
70 raise unittest.SkipTest(f"Universe {uversion} is not compatible with these test files.")
71 case 2 | 3 | 4 | 5:
72 # has_simulated, azimuth, seq_start, seq_end.
73 v = 2
74 case 6:
75 # group not group_name, group_id dropped.
76 v = 6
77 case 7:
78 # can_see_sky.
79 v = 7
80 case _:
81 # Might work.
82 warnings.warn(f"Universe {uversion} has not been validated.")
83 v = 7
85 # Read the exposure records.
86 self.records: dict[int, DimensionRecord] = {}
87 for i in (347, 348, 349):
88 with open(os.path.join(DATADIR, f"exp_v{v}_{i}.json")) as fh:
89 simple = SerializedDimensionRecord.model_validate_json(fh.read())
90 self.records[i] = DimensionRecord.from_simple(simple, registry=self.butler.registry)
92 def define_visits(
93 self,
94 exposures: list[DimensionRecord | list[DimensionRecord]],
95 incremental: bool,
96 ) -> None:
97 for records in exposures:
98 records = list(ensure_iterable(records))
99 if "group" in self.butler.dimensions["exposure"].implied:
100 # This is a group + day_obs universe.
101 for rec in records:
102 self.butler.registry.syncDimensionData(
103 "group", dict(instrument=rec.instrument, name=rec.group)
104 )
105 self.butler.registry.syncDimensionData(
106 "day_obs", dict(instrument=rec.instrument, id=rec.day_obs)
107 )
109 deduped_records = set(records)
110 self.butler.registry.insertDimensionData("exposure", *deduped_records)
111 # Include all records so far in definition.
112 if self.use_data_ids:
113 dataIds = sorted(self.butler.registry.queryDataIds("exposure", instrument="DummyCam"))
114 else:
115 dataIds = records
117 if not incremental:
118 # Force duplicate records in non-incremental mode to ensure
119 # that the task can deduplicate.
120 dataIds.extend(dataIds)
121 n_exposures = len(self.records)
122 with self.assertLogs(level=logging.INFO) as cm:
123 self.task.run(dataIds, incremental=incremental)
124 self.assertIn(f"Grouping {n_exposures} exposure(s) into visits", "\n".join(cm.output))
125 else:
126 self.task.run(dataIds, incremental=incremental)
129class DefineVisitsTestCase(unittest.TestCase, DefineVisitsBase):
130 """Test visit definition."""
132 def setUp(self):
133 self.setUpExposures()
135 def tearDown(self):
136 if self.root is not None:
137 shutil.rmtree(self.root, ignore_errors=True)
139 def get_config(self) -> DefineVisitsConfig:
140 config = DefineVisitsTask.ConfigClass()
141 config.groupExposures.name = "one-to-one-and-by-counter"
142 return config
144 def assertVisits(self):
145 """Check that the visits were registered as expected."""
146 visits = list(self.butler.registry.queryDimensionRecords("visit"))
147 self.assertEqual(len(visits), 4)
148 self.assertEqual(
149 {visit.id for visit in visits}, {2022040500347, 2022040500348, 2022040500349, 92022040500348}
150 )
152 # Ensure that the definitions are correct (ignoring order).
153 defmap = defaultdict(set)
154 definitions = list(self.butler.registry.queryDimensionRecords("visit_definition"))
155 for defn in definitions:
156 defmap[defn.visit].add(defn.exposure)
158 self.assertEqual(
159 dict(defmap),
160 {
161 92022040500348: {2022040500348},
162 2022040500347: {2022040500347},
163 2022040500348: {2022040500348, 2022040500349},
164 2022040500349: {2022040500349},
165 },
166 )
168 def test_defineVisits(self):
169 # Test visit definition with all the records.
170 self.define_visits([list(self.records.values())], incremental=False) # list inside a list
171 self.assertVisits()
173 def test_prefilter(self):
174 self.define_visits([list(self.records.values())], incremental=False)
175 self.assertVisits()
176 with self.assertLogs(level=logging.INFO) as cm:
177 result = self.task.run(self.records.values(), incremental=False, prefilter=True)
178 self.assertIn(
179 f"Filtered out {len(self.records)} on-sky exposure(s) (of {len(self.records)}) that were "
180 "already associated with a visit.",
181 "\n".join(cm.output),
182 )
183 self.assertEqual(result.n_filtered, len(self.records))
184 self.assertEqual(result.n_visits, 0)
185 self.assertEqual(result.n_skipped, 0)
187 def test_check_detector_regions(self):
188 self.define_visits([list(self.records.values())], incremental=False)
189 self.assertVisits()
190 # We can't remove dimension records from a repository, so to test
191 # fixing a case of missing visit_detector regions, we have to make a
192 # new butler repository and transfer only some records to it (this is
193 # actually what happens in the production context where we need this
194 # functionality).
195 with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as new_butler_root:
196 config = Butler.makeRepo(new_butler_root)
197 butler = Butler.from_config(config, writeable=True)
198 # We can't even use transfer_dimension_records_from because that's
199 # too careful to include everything, including the detector
200 # regions.
201 for element_name in (
202 "instrument",
203 "physical_filter",
204 "day_obs",
205 "visit",
206 "group",
207 "exposure",
208 "visit_system",
209 "visit_definition",
210 "visit_system_membership",
211 "detector",
212 ):
213 butler.registry.insertDimensionData(
214 element_name, *self.butler.query_dimension_records(element_name)
215 )
216 task = DefineVisitsTask(config=self.config, butler=butler)
217 with self.assertLogs(level=logging.INFO) as cm:
218 task.run(
219 self.records.values(), incremental=False, prefilter=True, check_detector_regions=True
220 )
221 self.assertIn("missing detector region records", "\n".join(cm.output))
222 self.assertCountEqual(
223 butler.query_dimension_records("visit_detector_region"),
224 self.butler.query_dimension_records("visit_detector_region"),
225 )
227 def test_incremental_cumulative(self):
228 # Define the visits after each exposure.
229 self.define_visits(list(self.records.values()), incremental=True)
230 self.assertVisits()
232 def test_incremental_cumulative_reverse(self):
233 # In reverse order we should still eventually end up with the right
234 # answer.
235 with self.assertLogs("lsst.defineVisits.groupExposures", level="WARNING") as cm:
236 self.define_visits(list(reversed(self.records.values())), incremental=True)
237 self.assertIn("Skipping the multi-snap definition", "\n".join(cm.output))
238 self.assertVisits()
240 def define_visits_incrementally(self, exposure: DimensionRecord) -> None:
241 if "group" in self.butler.dimensions["exposure"].implied:
242 self.butler.registry.syncDimensionData(
243 "group", dict(instrument=exposure.instrument, name=exposure.group)
244 )
245 self.butler.registry.syncDimensionData(
246 "day_obs",
247 dict(
248 instrument=exposure.instrument,
249 id=exposure.day_obs,
250 ),
251 )
252 self.butler.registry.insertDimensionData("exposure", exposure)
253 dataIds = [
254 DataCoordinate.standardize(
255 instrument="DummyCam", exposure=exposure.id, universe=self.butler.dimensions
256 )
257 ]
258 self.task.run(dataIds, incremental=True)
260 def test_incremental(self):
261 for record in self.records.values():
262 self.define_visits_incrementally(record)
263 self.assertVisits()
265 def test_incremental_reverse(self):
266 for record in reversed(self.records.values()):
267 self.define_visits_incrementally(record)
268 self.assertVisits()
270 def testPickleTask(self):
271 stream = pickle.dumps(self.task)
272 copy = pickle.loads(stream)
273 self.enterContext(copy.butler)
274 self.assertEqual(self.task.getFullName(), copy.getFullName())
275 self.assertEqual(self.task.log.name, copy.log.name)
276 self.assertEqual(self.task.config, copy.config)
277 self.assertEqual(self.task.butler._config, copy.butler._config)
278 self.assertEqual(list(self.task.butler.collections.defaults), list(copy.butler.collections.defaults))
279 self.assertEqual(self.task.butler.run, copy.butler.run)
280 self.assertEqual(self.task.universe, copy.universe)
283class DefineVisitsRecordsTestCase(DefineVisitsTestCase):
284 """Define visits using only dimension records."""
286 use_data_ids = False
289class DefineVisitsGroupingTestCase(unittest.TestCase, DefineVisitsBase):
290 """Test visit grouping by group metadata."""
292 def setUp(self):
293 self.setUpExposures()
295 def tearDown(self):
296 if self.root is not None:
297 shutil.rmtree(self.root, ignore_errors=True)
299 def get_config(self) -> DefineVisitsConfig:
300 config = DefineVisitsTask.ConfigClass()
301 config.groupExposures.name = "by-group-metadata"
302 return config
304 def test_defineVisits(self):
305 # Test visit definition with all the records.
306 self.define_visits([list(self.records.values())], incremental=False) # list inside a list
307 self.assertVisits()
309 def assertVisits(self):
310 """Check that the visits were registered as expected."""
311 visits = list(self.butler.registry.queryDimensionRecords("visit"))
312 self.assertEqual(len(visits), 2)
314 # The visit ID itself depends on which universe we are using.
315 # It is either calculated or comes from the JSON record.
316 if "group" in self.butler.dimensions["exposure"].implied:
317 visit_ids = [20220406025653255, 20220406025807181]
318 else:
319 visit_ids = [2291434132550000, 2291434871810000]
320 self.assertEqual({visit.id for visit in visits}, set(visit_ids))
322 # Ensure that the definitions are correct (ignoring order).
323 defmap = defaultdict(set)
324 definitions = list(self.butler.registry.queryDimensionRecords("visit_definition"))
325 for defn in definitions:
326 defmap[defn.visit].add(defn.exposure)
328 self.assertEqual(
329 dict(defmap),
330 {
331 visit_ids[0]: {2022040500347},
332 visit_ids[1]: {2022040500348, 2022040500349},
333 },
334 )
337class DefineVisitsGroupingRecordsTestCase(DefineVisitsGroupingTestCase):
338 """Test using dimension records instead of Data IDs."""
340 use_data_ids = False
343class DefineVisitsOneToOneTestCase(unittest.TestCase, DefineVisitsBase):
344 """Test visit grouping by group metadata."""
346 def setUp(self):
347 self.setUpExposures()
349 def tearDown(self):
350 if self.root is not None:
351 shutil.rmtree(self.root, ignore_errors=True)
353 def get_config(self) -> DefineVisitsConfig:
354 config = DefineVisitsTask.ConfigClass()
355 config.groupExposures.name = "one-to-one"
356 return config
358 def test_defineVisits(self):
359 # Test visit definition with all the records.
360 self.define_visits([list(self.records.values())], incremental=False) # list inside a list
361 self.assertVisits()
363 def assertVisits(self):
364 """Check that the visits were registered as expected."""
365 visits = list(self.butler.registry.queryDimensionRecords("visit"))
366 self.assertEqual(len(visits), 3)
368 # For one-to-one the visit ID is the exposure ID.
369 visit_ids = [rec.id for rec in self.records.values()]
370 self.assertEqual({visit.id for visit in visits}, set(visit_ids))
372 # Ensure that the definitions map an exposure ID to an identical visit
373 # ID.
374 definitions = list(self.butler.registry.queryDimensionRecords("visit_definition"))
375 for defn in definitions:
376 self.assertEqual(defn.visit, defn.exposure)
378 def test_update_records(self):
379 self.define_visits([list(self.records.values())], incremental=False) # list inside a list
380 self.assertVisits()
382 # Modify one of the records.
383 records = self.records
384 simple = records[348].to_simple()
385 simple.record["target_name"] = "new target"
386 records[348] = DimensionRecord.from_simple(simple, universe=self.butler.dimensions)
387 self.butler.registry.syncDimensionData("exposure", records[348], update=True)
389 # Re-run without updates or skipping should fail.
390 with self.assertRaises(ConflictingDefinitionError):
391 self.task.run(records.values())
393 result = self.task.run(records.values(), skip_conflicting=True)
394 self.assertEqual(result.n_skipped, 3, str(result))
396 # Check that the visit definition did not change.
397 visit_348 = self.butler.query_dimension_records("visit", where="visit.seq_num = 348")[0]
398 self.assertEqual(visit_348.target_name, "LATISS_E6A_00000040", visit_348)
400 # Run with forced update.
401 result = self.task.run(records.values(), skip_conflicting=True, update_records=True)
403 # Every record reports it was updated if we are updating, even if
404 # a record was not really changed.
405 self.assertEqual(result.n_skipped, 0, str(result))
406 self.assertEqual(result.n_fully_updated, 3, str(result))
407 visit_348 = self.butler.query_dimension_records("visit", where="visit.seq_num = 348")[0]
408 self.assertEqual(visit_348.target_name, "new target", visit_348)
411if __name__ == "__main__": 411 ↛ 412line 411 didn't jump to line 412 because the condition on line 411 was never true
412 unittest.main()