Coverage for tests/test_dimension_record_containers.py: 10%
221 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-07 02:46 -0700
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-07 02:46 -0700
1# This file is part of daf_butler.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28import copy
29import os
30import unittest
32import pyarrow as pa
33import pyarrow.parquet as pq
34from lsst.daf.butler import DimensionRecordSet, DimensionRecordTable, YamlRepoImportBackend
35from lsst.daf.butler.registry import RegistryConfig, _RegistryFactory
37DIMENSION_DATA_FILES = [
38 os.path.normpath(os.path.join(os.path.dirname(__file__), "data", "registry", "base.yaml")),
39 os.path.normpath(os.path.join(os.path.dirname(__file__), "data", "registry", "spatial.yaml")),
40]
43class DimensionRecordContainersTestCase(unittest.TestCase):
44 """Tests for the DimensionRecordTable class."""
46 @classmethod
47 def setUpClass(cls):
48 # Create an in-memory SQLite database and Registry just to import the
49 # YAML data.
50 config = RegistryConfig()
51 config["db"] = "sqlite://"
52 registry = _RegistryFactory(config).create_from_config()
53 for data_file in DIMENSION_DATA_FILES:
54 with open(data_file) as stream:
55 backend = YamlRepoImportBackend(stream, registry)
56 backend.register()
57 backend.load(datastore=None)
58 cls.records = {
59 element: tuple(list(registry.queryDimensionRecords(element)))
60 for element in ("visit", "skymap", "patch")
61 }
62 cls.universe = registry.dimensions
63 cls.data_ids = list(registry.queryDataIds(["visit", "patch"]).expanded())
65 def test_record_table_schema_visit(self):
66 """Test that the Arrow schema for 'visit' has the right types,
67 including dictionary encoding.
68 """
69 schema = DimensionRecordTable.make_arrow_schema(self.universe["visit"])
70 self.assertEqual(schema.field("instrument").type, pa.dictionary(pa.int32(), pa.string()))
71 self.assertEqual(schema.field("id").type, pa.uint64())
72 self.assertEqual(schema.field("physical_filter").type, pa.dictionary(pa.int32(), pa.string()))
73 self.assertEqual(schema.field("name").type, pa.string())
74 self.assertEqual(schema.field("observation_reason").type, pa.dictionary(pa.int32(), pa.string()))
76 def test_record_table_schema_skymap(self):
77 """Test that the Arrow schema for 'skymap' has the right types,
78 including dictionary encoding.
79 """
80 schema = DimensionRecordTable.make_arrow_schema(self.universe["skymap"])
81 self.assertEqual(schema.field("name").type, pa.string())
82 self.assertEqual(schema.field("hash").type, pa.binary())
84 def test_empty_record_table_visit(self):
85 """Test methods on a table that was initialized with no records.
87 We use 'visit' records for this test because they have both timespans
88 and regions, and those are the tricky column types for interoperability
89 with Arrow.
90 """
91 table = DimensionRecordTable(self.universe["visit"])
92 self.assertEqual(len(table), 0)
93 self.assertEqual(list(table), [])
94 self.assertEqual(table.element, self.universe["visit"])
95 with self.assertRaises(IndexError):
96 table[0]
97 self.assertEqual(len(table.column("instrument")), 0)
98 self.assertEqual(len(table.column("id")), 0)
99 self.assertEqual(len(table.column("physical_filter")), 0)
100 self.assertEqual(len(table.column("name")), 0)
101 self.assertEqual(len(table.column("day_obs")), 0)
102 table.extend(self.records["visit"])
103 self.assertCountEqual(table, self.records["visit"])
104 self.assertEqual(
105 table.to_arrow().schema, DimensionRecordTable.make_arrow_schema(self.universe["visit"])
106 )
108 def test_empty_record_table_skymap(self):
109 """Test methods on a table that was initialized with no records.
111 We use 'skymap' records for this test because that's the only one with
112 a "hash" column.
113 """
114 table = DimensionRecordTable(self.universe["skymap"])
115 self.assertEqual(len(table), 0)
116 self.assertEqual(list(table), [])
117 self.assertEqual(table.element, self.universe["skymap"])
118 with self.assertRaises(IndexError):
119 table[0]
120 self.assertEqual(len(table.column("name")), 0)
121 self.assertEqual(len(table.column("hash")), 0)
122 table.extend(self.records["skymap"])
123 self.assertCountEqual(table, self.records["skymap"])
124 self.assertEqual(
125 table.to_arrow().schema, DimensionRecordTable.make_arrow_schema(self.universe["skymap"])
126 )
128 def test_full_record_table_visit(self):
129 """Test methods on a table that was initialized with an iterable.
131 We use 'visit' records for this test because they have both timespans
132 and regions, and those are the tricky column types for interoperability
133 with Arrow.
134 """
135 table = DimensionRecordTable(self.universe["visit"], self.records["visit"])
136 self.assertEqual(len(table), 2)
137 self.assertEqual(table[0], self.records["visit"][0])
138 self.assertEqual(table[1], self.records["visit"][1])
139 self.assertEqual(list(table), list(self.records["visit"]))
140 self.assertEqual(table.element, self.universe["visit"])
141 self.assertEqual(table.column("instrument")[0].as_py(), "Cam1")
142 self.assertEqual(table.column("instrument")[1].as_py(), "Cam1")
143 self.assertEqual(table.column("id")[0].as_py(), 1)
144 self.assertEqual(table.column("id")[1].as_py(), 2)
145 self.assertEqual(table.column("name")[0].as_py(), "1")
146 self.assertEqual(table.column("name")[1].as_py(), "2")
147 self.assertEqual(table.column("physical_filter")[0].as_py(), "Cam1-G")
148 self.assertEqual(table.column("physical_filter")[1].as_py(), "Cam1-R1")
149 self.assertEqual(table.column("day_obs")[0].as_py(), 20210909)
150 self.assertEqual(table.column("day_obs")[1].as_py(), 20210909)
151 self.assertEqual(list(table[:1]), list(self.records["visit"][:1]))
152 self.assertEqual(list(table[1:]), list(self.records["visit"][1:]))
153 table.extend(self.records["visit"])
154 self.assertCountEqual(table, self.records["visit"] + self.records["visit"])
156 def test_full_record_table_skymap(self):
157 """Test methods on a table that was initialized with an iterable.
159 We use 'skymap' records for this test because that's the only one with
160 a "hash" column.
161 """
162 table = DimensionRecordTable(self.universe["skymap"], self.records["skymap"])
163 self.assertEqual(len(table), 1)
164 self.assertEqual(table[0], self.records["skymap"][0])
165 self.assertEqual(list(table), list(self.records["skymap"]))
166 self.assertEqual(table.element, self.universe["skymap"])
167 self.assertEqual(table.column("name")[0].as_py(), "SkyMap1")
168 self.assertEqual(table.column("hash")[0].as_py(), b"notreallyahashofanything!")
169 table.extend(self.records["skymap"])
170 self.assertCountEqual(table, self.records["skymap"] + self.records["skymap"])
172 def test_record_table_parquet_visit(self):
173 """Test round-tripping a dimension record table through Parquet.
175 We use 'visit' records for this test because they have both timespans
176 and regions, and those are the tricky column types for interoperability
177 with Arrow.
178 """
179 table1 = DimensionRecordTable(self.universe["visit"], self.records["visit"])
180 stream = pa.BufferOutputStream()
181 pq.write_table(table1.to_arrow(), stream)
182 table2 = DimensionRecordTable(
183 universe=self.universe, table=pq.read_table(pa.BufferReader(stream.getvalue()))
184 )
185 self.assertEqual(list(table1), list(table2))
187 def test_record_table_parquet_skymap(self):
188 """Test round-tripping a dimension record table through Parquet.
190 We use 'skymap' records for this test because that's the only one with
191 a "hash" column.
192 """
193 table1 = DimensionRecordTable(self.universe["skymap"], self.records["skymap"])
194 stream = pa.BufferOutputStream()
195 pq.write_table(table1.to_arrow(), stream)
196 table2 = DimensionRecordTable(
197 universe=self.universe, table=pq.read_table(pa.BufferReader(stream.getvalue()))
198 )
199 self.assertEqual(list(table1), list(table2))
201 def test_record_chunk_init(self):
202 """Test constructing a DimensionRecordTable from an iterable in chunks.
204 We use 'patch' records for this test because there are enough of them
205 to have multiple chunks.
206 """
207 table1 = DimensionRecordTable(self.universe["patch"], self.records["patch"], batch_size=5)
208 self.assertEqual(len(table1), 12)
209 self.assertEqual([len(batch) for batch in table1.to_arrow().to_batches()], [5, 5, 2])
210 self.assertEqual(list(table1), list(self.records["patch"]))
212 def test_record_set_const(self):
213 """Test attributes and methods of `DimensionRecordSet` that do not
214 modify the set.
216 We use 'patch' records for this test because there are enough of them
217 to do nontrivial set-operation tests.
218 """
219 element = self.universe["patch"]
220 records = self.records["patch"]
221 set1 = DimensionRecordSet(element, records[:7])
222 self.assertEqual(set1, DimensionRecordSet("patch", records[:7], universe=self.universe))
223 # DimensionRecordSets do not compare as equal with other set types,
224 # even with the same content.
225 self.assertNotEqual(set1, set(records[:7]))
226 with self.assertRaises(TypeError):
227 DimensionRecordSet("patch", records[:7])
228 self.assertEqual(set1.element, self.universe["patch"])
229 self.assertEqual(len(set1), 7)
230 self.assertEqual(list(set1), list(records[:7]))
231 self.assertIn(records[4], set1)
232 self.assertIn(records[5].dataId, set1)
233 self.assertNotIn(self.records["visit"][0], set1)
234 self.assertTrue(set1.issubset(DimensionRecordSet(element, records[:8])))
235 self.assertFalse(set1.issubset(DimensionRecordSet(element, records[1:6])))
236 with self.assertRaises(ValueError):
237 set1.issubset(DimensionRecordSet(self.universe["tract"]))
238 self.assertTrue(set1.issuperset(DimensionRecordSet(element, records[1:6])))
239 self.assertFalse(set1.issuperset(DimensionRecordSet(element, records[:8])))
240 with self.assertRaises(ValueError):
241 set1.issuperset(DimensionRecordSet(self.universe["tract"]))
242 self.assertTrue(set1.isdisjoint(DimensionRecordSet(element, records[7:])))
243 self.assertFalse(set1.isdisjoint(DimensionRecordSet(element, records[5:8])))
244 with self.assertRaises(ValueError):
245 set1.isdisjoint(DimensionRecordSet(self.universe["tract"]))
246 self.assertEqual(
247 set1.intersection(DimensionRecordSet(element, records[5:])),
248 DimensionRecordSet(element, records[5:7]),
249 )
250 self.assertEqual(
251 set1.intersection(DimensionRecordSet(element, records[5:])),
252 DimensionRecordSet(element, records[5:7]),
253 )
254 with self.assertRaises(ValueError):
255 set1.intersection(DimensionRecordSet(self.universe["tract"]))
256 self.assertEqual(
257 set1.difference(DimensionRecordSet(element, records[5:])),
258 DimensionRecordSet(element, records[:5]),
259 )
260 with self.assertRaises(ValueError):
261 set1.difference(DimensionRecordSet(self.universe["tract"]))
262 self.assertEqual(
263 set1.union(DimensionRecordSet(element, records[5:9])),
264 DimensionRecordSet(element, records[:9]),
265 )
266 with self.assertRaises(ValueError):
267 set1.union(DimensionRecordSet(self.universe["tract"]))
268 self.assertEqual(set1.find(records[0].dataId), records[0])
269 with self.assertRaises(LookupError):
270 set1.find(self.records["patch"][8].dataId)
271 with self.assertRaises(ValueError):
272 set1.find(self.records["visit"][0].dataId)
273 self.assertEqual(set1.find_with_required_values(records[0].dataId.required_values), records[0])
275 def test_record_set_add(self):
276 """Test DimensionRecordSet.add."""
277 set1 = DimensionRecordSet("patch", self.records["patch"][:2], universe=self.universe)
278 set1.add(self.records["patch"][2])
279 with self.assertRaises(ValueError):
280 set1.add(self.records["visit"][0])
281 self.assertEqual(set1, DimensionRecordSet("patch", self.records["patch"][:3], universe=self.universe))
282 set1.add(self.records["patch"][2])
283 self.assertEqual(list(set1), list(self.records["patch"][:3]))
285 def test_record_set_find_or_add(self):
286 """Test DimensionRecordSet.find and find_with_required_values with
287 a 'or_add' callback.
288 """
289 set1 = DimensionRecordSet("patch", self.records["patch"][:2], universe=self.universe)
290 set1.find(self.records["patch"][2].dataId, or_add=lambda _c, _r: self.records["patch"][2])
291 with self.assertRaises(ValueError):
292 set1.find(self.records["visit"][0].dataId, or_add=lambda _c, _r: self.records["visit"][0])
293 self.assertEqual(set1, DimensionRecordSet("patch", self.records["patch"][:3], universe=self.universe))
295 set1.find_with_required_values(
296 self.records["patch"][3].dataId.required_values, or_add=lambda _c, _r: self.records["patch"][3]
297 )
298 self.assertEqual(set1, DimensionRecordSet("patch", self.records["patch"][:4], universe=self.universe))
300 def test_record_set_update_from_data_coordinates(self):
301 """Test DimensionRecordSet.update_from_data_coordinates."""
302 set1 = DimensionRecordSet("patch", self.records["patch"][:2], universe=self.universe)
303 set1.update_from_data_coordinates(self.data_ids)
304 for data_id in self.data_ids:
305 self.assertIn(data_id.records["patch"], set1)
307 def test_record_set_discard(self):
308 """Test DimensionRecordSet.discard."""
309 set1 = DimensionRecordSet("patch", self.records["patch"][:2], universe=self.universe)
310 set2 = copy.deepcopy(set1)
311 # These discards should do nothing.
312 set1.discard(self.records["patch"][2])
313 self.assertEqual(set1, set2)
314 set1.discard(self.records["patch"][2].dataId)
315 self.assertEqual(set1, set2)
316 with self.assertRaises(ValueError):
317 set1.discard(self.records["visit"][0])
318 self.assertEqual(set1, set2)
319 with self.assertRaises(ValueError):
320 set1.discard(self.records["visit"][0].dataId)
321 self.assertEqual(set1, set2)
322 # These ones should remove a record from each set.
323 set1.discard(self.records["patch"][1])
324 set2.discard(self.records["patch"][1].dataId)
325 self.assertEqual(set1, set2)
326 self.assertNotIn(self.records["patch"][1], set1)
327 self.assertNotIn(self.records["patch"][1], set2)
329 def test_record_set_remove(self):
330 """Test DimensionRecordSet.remove."""
331 set1 = DimensionRecordSet("patch", self.records["patch"][:2], universe=self.universe)
332 set2 = copy.deepcopy(set1)
333 # These removes should raise with strong exception safety.
334 with self.assertRaises(KeyError):
335 set1.remove(self.records["patch"][2])
336 self.assertEqual(set1, set2)
337 with self.assertRaises(KeyError):
338 set1.remove(self.records["patch"][2].dataId)
339 self.assertEqual(set1, set2)
340 with self.assertRaises(ValueError):
341 set1.remove(self.records["visit"][0])
342 self.assertEqual(set1, set2)
343 with self.assertRaises(ValueError):
344 set1.remove(self.records["visit"][0].dataId)
345 self.assertEqual(set1, set2)
346 # These ones should remove a record from each set.
347 set1.remove(self.records["patch"][1])
348 set2.remove(self.records["patch"][1].dataId)
349 self.assertEqual(set1, set2)
350 self.assertNotIn(self.records["patch"][1], set1)
351 self.assertNotIn(self.records["patch"][1], set2)
353 def test_record_set_pop(self):
354 """Test DimensionRecordSet.pop."""
355 set1 = DimensionRecordSet("patch", self.records["patch"][:2], universe=self.universe)
356 set2 = copy.deepcopy(set1)
357 record1 = set1.pop()
358 set2.remove(record1)
359 self.assertNotIn(record1, set1)
360 self.assertEqual(set1, set2)
361 record2 = set1.pop()
362 set2.remove(record2)
363 self.assertNotIn(record2, set1)
364 self.assertEqual(set1, set2)
365 self.assertFalse(set1)
368if __name__ == "__main__":
369 unittest.main()