Coverage for tests/test_dimension_record_containers.py: 10%

221 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-26 02:47 -0700

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28import copy 

29import os 

30import unittest 

31 

32import pyarrow as pa 

33import pyarrow.parquet as pq 

34from lsst.daf.butler import DimensionRecordSet, DimensionRecordTable, YamlRepoImportBackend 

35from lsst.daf.butler.registry import RegistryConfig, _RegistryFactory 

36 

37DIMENSION_DATA_FILES = [ 

38 os.path.normpath(os.path.join(os.path.dirname(__file__), "data", "registry", "base.yaml")), 

39 os.path.normpath(os.path.join(os.path.dirname(__file__), "data", "registry", "spatial.yaml")), 

40] 

41 

42 

43class DimensionRecordContainersTestCase(unittest.TestCase): 

44 """Tests for the DimensionRecordTable class.""" 

45 

46 @classmethod 

47 def setUpClass(cls): 

48 # Create an in-memory SQLite database and Registry just to import the 

49 # YAML data. 

50 config = RegistryConfig() 

51 config["db"] = "sqlite://" 

52 registry = _RegistryFactory(config).create_from_config() 

53 for data_file in DIMENSION_DATA_FILES: 

54 with open(data_file) as stream: 

55 backend = YamlRepoImportBackend(stream, registry) 

56 backend.register() 

57 backend.load(datastore=None) 

58 cls.records = { 

59 element: tuple(list(registry.queryDimensionRecords(element))) 

60 for element in ("visit", "skymap", "patch") 

61 } 

62 cls.universe = registry.dimensions 

63 cls.data_ids = list(registry.queryDataIds(["visit", "patch"]).expanded()) 

64 

65 def test_record_table_schema_visit(self): 

66 """Test that the Arrow schema for 'visit' has the right types, 

67 including dictionary encoding. 

68 """ 

69 schema = DimensionRecordTable.make_arrow_schema(self.universe["visit"]) 

70 self.assertEqual(schema.field("instrument").type, pa.dictionary(pa.int32(), pa.string())) 

71 self.assertEqual(schema.field("id").type, pa.uint64()) 

72 self.assertEqual(schema.field("physical_filter").type, pa.dictionary(pa.int32(), pa.string())) 

73 self.assertEqual(schema.field("name").type, pa.string()) 

74 self.assertEqual(schema.field("observation_reason").type, pa.dictionary(pa.int32(), pa.string())) 

75 

76 def test_record_table_schema_skymap(self): 

77 """Test that the Arrow schema for 'skymap' has the right types, 

78 including dictionary encoding. 

79 """ 

80 schema = DimensionRecordTable.make_arrow_schema(self.universe["skymap"]) 

81 self.assertEqual(schema.field("name").type, pa.string()) 

82 self.assertEqual(schema.field("hash").type, pa.binary()) 

83 

84 def test_empty_record_table_visit(self): 

85 """Test methods on a table that was initialized with no records. 

86 

87 We use 'visit' records for this test because they have both timespans 

88 and regions, and those are the tricky column types for interoperability 

89 with Arrow. 

90 """ 

91 table = DimensionRecordTable(self.universe["visit"]) 

92 self.assertEqual(len(table), 0) 

93 self.assertEqual(list(table), []) 

94 self.assertEqual(table.element, self.universe["visit"]) 

95 with self.assertRaises(IndexError): 

96 table[0] 

97 self.assertEqual(len(table.column("instrument")), 0) 

98 self.assertEqual(len(table.column("id")), 0) 

99 self.assertEqual(len(table.column("physical_filter")), 0) 

100 self.assertEqual(len(table.column("name")), 0) 

101 self.assertEqual(len(table.column("day_obs")), 0) 

102 table.extend(self.records["visit"]) 

103 self.assertCountEqual(table, self.records["visit"]) 

104 self.assertEqual( 

105 table.to_arrow().schema, DimensionRecordTable.make_arrow_schema(self.universe["visit"]) 

106 ) 

107 

108 def test_empty_record_table_skymap(self): 

109 """Test methods on a table that was initialized with no records. 

110 

111 We use 'skymap' records for this test because that's the only one with 

112 a "hash" column. 

113 """ 

114 table = DimensionRecordTable(self.universe["skymap"]) 

115 self.assertEqual(len(table), 0) 

116 self.assertEqual(list(table), []) 

117 self.assertEqual(table.element, self.universe["skymap"]) 

118 with self.assertRaises(IndexError): 

119 table[0] 

120 self.assertEqual(len(table.column("name")), 0) 

121 self.assertEqual(len(table.column("hash")), 0) 

122 table.extend(self.records["skymap"]) 

123 self.assertCountEqual(table, self.records["skymap"]) 

124 self.assertEqual( 

125 table.to_arrow().schema, DimensionRecordTable.make_arrow_schema(self.universe["skymap"]) 

126 ) 

127 

128 def test_full_record_table_visit(self): 

129 """Test methods on a table that was initialized with an iterable. 

130 

131 We use 'visit' records for this test because they have both timespans 

132 and regions, and those are the tricky column types for interoperability 

133 with Arrow. 

134 """ 

135 table = DimensionRecordTable(self.universe["visit"], self.records["visit"]) 

136 self.assertEqual(len(table), 2) 

137 self.assertEqual(table[0], self.records["visit"][0]) 

138 self.assertEqual(table[1], self.records["visit"][1]) 

139 self.assertEqual(list(table), list(self.records["visit"])) 

140 self.assertEqual(table.element, self.universe["visit"]) 

141 self.assertEqual(table.column("instrument")[0].as_py(), "Cam1") 

142 self.assertEqual(table.column("instrument")[1].as_py(), "Cam1") 

143 self.assertEqual(table.column("id")[0].as_py(), 1) 

144 self.assertEqual(table.column("id")[1].as_py(), 2) 

145 self.assertEqual(table.column("name")[0].as_py(), "1") 

146 self.assertEqual(table.column("name")[1].as_py(), "2") 

147 self.assertEqual(table.column("physical_filter")[0].as_py(), "Cam1-G") 

148 self.assertEqual(table.column("physical_filter")[1].as_py(), "Cam1-R1") 

149 self.assertEqual(table.column("day_obs")[0].as_py(), 20210909) 

150 self.assertEqual(table.column("day_obs")[1].as_py(), 20210909) 

151 self.assertEqual(list(table[:1]), list(self.records["visit"][:1])) 

152 self.assertEqual(list(table[1:]), list(self.records["visit"][1:])) 

153 table.extend(self.records["visit"]) 

154 self.assertCountEqual(table, self.records["visit"] + self.records["visit"]) 

155 

156 def test_full_record_table_skymap(self): 

157 """Test methods on a table that was initialized with an iterable. 

158 

159 We use 'skymap' records for this test because that's the only one with 

160 a "hash" column. 

161 """ 

162 table = DimensionRecordTable(self.universe["skymap"], self.records["skymap"]) 

163 self.assertEqual(len(table), 1) 

164 self.assertEqual(table[0], self.records["skymap"][0]) 

165 self.assertEqual(list(table), list(self.records["skymap"])) 

166 self.assertEqual(table.element, self.universe["skymap"]) 

167 self.assertEqual(table.column("name")[0].as_py(), "SkyMap1") 

168 self.assertEqual(table.column("hash")[0].as_py(), b"notreallyahashofanything!") 

169 table.extend(self.records["skymap"]) 

170 self.assertCountEqual(table, self.records["skymap"] + self.records["skymap"]) 

171 

172 def test_record_table_parquet_visit(self): 

173 """Test round-tripping a dimension record table through Parquet. 

174 

175 We use 'visit' records for this test because they have both timespans 

176 and regions, and those are the tricky column types for interoperability 

177 with Arrow. 

178 """ 

179 table1 = DimensionRecordTable(self.universe["visit"], self.records["visit"]) 

180 stream = pa.BufferOutputStream() 

181 pq.write_table(table1.to_arrow(), stream) 

182 table2 = DimensionRecordTable( 

183 universe=self.universe, table=pq.read_table(pa.BufferReader(stream.getvalue())) 

184 ) 

185 self.assertEqual(list(table1), list(table2)) 

186 

187 def test_record_table_parquet_skymap(self): 

188 """Test round-tripping a dimension record table through Parquet. 

189 

190 We use 'skymap' records for this test because that's the only one with 

191 a "hash" column. 

192 """ 

193 table1 = DimensionRecordTable(self.universe["skymap"], self.records["skymap"]) 

194 stream = pa.BufferOutputStream() 

195 pq.write_table(table1.to_arrow(), stream) 

196 table2 = DimensionRecordTable( 

197 universe=self.universe, table=pq.read_table(pa.BufferReader(stream.getvalue())) 

198 ) 

199 self.assertEqual(list(table1), list(table2)) 

200 

201 def test_record_chunk_init(self): 

202 """Test constructing a DimensionRecordTable from an iterable in chunks. 

203 

204 We use 'patch' records for this test because there are enough of them 

205 to have multiple chunks. 

206 """ 

207 table1 = DimensionRecordTable(self.universe["patch"], self.records["patch"], batch_size=5) 

208 self.assertEqual(len(table1), 12) 

209 self.assertEqual([len(batch) for batch in table1.to_arrow().to_batches()], [5, 5, 2]) 

210 self.assertEqual(list(table1), list(self.records["patch"])) 

211 

212 def test_record_set_const(self): 

213 """Test attributes and methods of `DimensionRecordSet` that do not 

214 modify the set. 

215 

216 We use 'patch' records for this test because there are enough of them 

217 to do nontrivial set-operation tests. 

218 """ 

219 element = self.universe["patch"] 

220 records = self.records["patch"] 

221 set1 = DimensionRecordSet(element, records[:7]) 

222 self.assertEqual(set1, DimensionRecordSet("patch", records[:7], universe=self.universe)) 

223 # DimensionRecordSets do not compare as equal with other set types, 

224 # even with the same content. 

225 self.assertNotEqual(set1, set(records[:7])) 

226 with self.assertRaises(TypeError): 

227 DimensionRecordSet("patch", records[:7]) 

228 self.assertEqual(set1.element, self.universe["patch"]) 

229 self.assertEqual(len(set1), 7) 

230 self.assertEqual(list(set1), list(records[:7])) 

231 self.assertIn(records[4], set1) 

232 self.assertIn(records[5].dataId, set1) 

233 self.assertNotIn(self.records["visit"][0], set1) 

234 self.assertTrue(set1.issubset(DimensionRecordSet(element, records[:8]))) 

235 self.assertFalse(set1.issubset(DimensionRecordSet(element, records[1:6]))) 

236 with self.assertRaises(ValueError): 

237 set1.issubset(DimensionRecordSet(self.universe["tract"])) 

238 self.assertTrue(set1.issuperset(DimensionRecordSet(element, records[1:6]))) 

239 self.assertFalse(set1.issuperset(DimensionRecordSet(element, records[:8]))) 

240 with self.assertRaises(ValueError): 

241 set1.issuperset(DimensionRecordSet(self.universe["tract"])) 

242 self.assertTrue(set1.isdisjoint(DimensionRecordSet(element, records[7:]))) 

243 self.assertFalse(set1.isdisjoint(DimensionRecordSet(element, records[5:8]))) 

244 with self.assertRaises(ValueError): 

245 set1.isdisjoint(DimensionRecordSet(self.universe["tract"])) 

246 self.assertEqual( 

247 set1.intersection(DimensionRecordSet(element, records[5:])), 

248 DimensionRecordSet(element, records[5:7]), 

249 ) 

250 self.assertEqual( 

251 set1.intersection(DimensionRecordSet(element, records[5:])), 

252 DimensionRecordSet(element, records[5:7]), 

253 ) 

254 with self.assertRaises(ValueError): 

255 set1.intersection(DimensionRecordSet(self.universe["tract"])) 

256 self.assertEqual( 

257 set1.difference(DimensionRecordSet(element, records[5:])), 

258 DimensionRecordSet(element, records[:5]), 

259 ) 

260 with self.assertRaises(ValueError): 

261 set1.difference(DimensionRecordSet(self.universe["tract"])) 

262 self.assertEqual( 

263 set1.union(DimensionRecordSet(element, records[5:9])), 

264 DimensionRecordSet(element, records[:9]), 

265 ) 

266 with self.assertRaises(ValueError): 

267 set1.union(DimensionRecordSet(self.universe["tract"])) 

268 self.assertEqual(set1.find(records[0].dataId), records[0]) 

269 with self.assertRaises(LookupError): 

270 set1.find(self.records["patch"][8].dataId) 

271 with self.assertRaises(ValueError): 

272 set1.find(self.records["visit"][0].dataId) 

273 self.assertEqual(set1.find_with_required_values(records[0].dataId.required_values), records[0]) 

274 

275 def test_record_set_add(self): 

276 """Test DimensionRecordSet.add.""" 

277 set1 = DimensionRecordSet("patch", self.records["patch"][:2], universe=self.universe) 

278 set1.add(self.records["patch"][2]) 

279 with self.assertRaises(ValueError): 

280 set1.add(self.records["visit"][0]) 

281 self.assertEqual(set1, DimensionRecordSet("patch", self.records["patch"][:3], universe=self.universe)) 

282 set1.add(self.records["patch"][2]) 

283 self.assertEqual(list(set1), list(self.records["patch"][:3])) 

284 

285 def test_record_set_find_or_add(self): 

286 """Test DimensionRecordSet.find and find_with_required_values with 

287 a 'or_add' callback. 

288 """ 

289 set1 = DimensionRecordSet("patch", self.records["patch"][:2], universe=self.universe) 

290 set1.find(self.records["patch"][2].dataId, or_add=lambda _c, _r: self.records["patch"][2]) 

291 with self.assertRaises(ValueError): 

292 set1.find(self.records["visit"][0].dataId, or_add=lambda _c, _r: self.records["visit"][0]) 

293 self.assertEqual(set1, DimensionRecordSet("patch", self.records["patch"][:3], universe=self.universe)) 

294 

295 set1.find_with_required_values( 

296 self.records["patch"][3].dataId.required_values, or_add=lambda _c, _r: self.records["patch"][3] 

297 ) 

298 self.assertEqual(set1, DimensionRecordSet("patch", self.records["patch"][:4], universe=self.universe)) 

299 

300 def test_record_set_update_from_data_coordinates(self): 

301 """Test DimensionRecordSet.update_from_data_coordinates.""" 

302 set1 = DimensionRecordSet("patch", self.records["patch"][:2], universe=self.universe) 

303 set1.update_from_data_coordinates(self.data_ids) 

304 for data_id in self.data_ids: 

305 self.assertIn(data_id.records["patch"], set1) 

306 

307 def test_record_set_discard(self): 

308 """Test DimensionRecordSet.discard.""" 

309 set1 = DimensionRecordSet("patch", self.records["patch"][:2], universe=self.universe) 

310 set2 = copy.deepcopy(set1) 

311 # These discards should do nothing. 

312 set1.discard(self.records["patch"][2]) 

313 self.assertEqual(set1, set2) 

314 set1.discard(self.records["patch"][2].dataId) 

315 self.assertEqual(set1, set2) 

316 with self.assertRaises(ValueError): 

317 set1.discard(self.records["visit"][0]) 

318 self.assertEqual(set1, set2) 

319 with self.assertRaises(ValueError): 

320 set1.discard(self.records["visit"][0].dataId) 

321 self.assertEqual(set1, set2) 

322 # These ones should remove a record from each set. 

323 set1.discard(self.records["patch"][1]) 

324 set2.discard(self.records["patch"][1].dataId) 

325 self.assertEqual(set1, set2) 

326 self.assertNotIn(self.records["patch"][1], set1) 

327 self.assertNotIn(self.records["patch"][1], set2) 

328 

329 def test_record_set_remove(self): 

330 """Test DimensionRecordSet.remove.""" 

331 set1 = DimensionRecordSet("patch", self.records["patch"][:2], universe=self.universe) 

332 set2 = copy.deepcopy(set1) 

333 # These removes should raise with strong exception safety. 

334 with self.assertRaises(KeyError): 

335 set1.remove(self.records["patch"][2]) 

336 self.assertEqual(set1, set2) 

337 with self.assertRaises(KeyError): 

338 set1.remove(self.records["patch"][2].dataId) 

339 self.assertEqual(set1, set2) 

340 with self.assertRaises(ValueError): 

341 set1.remove(self.records["visit"][0]) 

342 self.assertEqual(set1, set2) 

343 with self.assertRaises(ValueError): 

344 set1.remove(self.records["visit"][0].dataId) 

345 self.assertEqual(set1, set2) 

346 # These ones should remove a record from each set. 

347 set1.remove(self.records["patch"][1]) 

348 set2.remove(self.records["patch"][1].dataId) 

349 self.assertEqual(set1, set2) 

350 self.assertNotIn(self.records["patch"][1], set1) 

351 self.assertNotIn(self.records["patch"][1], set2) 

352 

353 def test_record_set_pop(self): 

354 """Test DimensionRecordSet.pop.""" 

355 set1 = DimensionRecordSet("patch", self.records["patch"][:2], universe=self.universe) 

356 set2 = copy.deepcopy(set1) 

357 record1 = set1.pop() 

358 set2.remove(record1) 

359 self.assertNotIn(record1, set1) 

360 self.assertEqual(set1, set2) 

361 record2 = set1.pop() 

362 set2.remove(record2) 

363 self.assertNotIn(record2, set1) 

364 self.assertEqual(set1, set2) 

365 self.assertFalse(set1) 

366 

367 

368if __name__ == "__main__": 

369 unittest.main()