Coverage for python/lsst/dax/apdb/tests/_apdb.py: 17%

276 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-26 10:23 +0000

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ApdbSchemaUpdateTest", "ApdbTest"] 

25 

26from abc import ABC, abstractmethod 

27from collections.abc import Callable 

28from typing import TYPE_CHECKING, Any, ContextManager, Optional 

29 

30import pandas 

31from lsst.daf.base import DateTime 

32from lsst.dax.apdb import ApdbConfig, ApdbInsertId, ApdbTableData, ApdbTables, make_apdb 

33from lsst.sphgeom import Angle, Circle, Region, UnitVector3d 

34 

35from .data_factory import makeForcedSourceCatalog, makeObjectCatalog, makeSourceCatalog, makeSSObjectCatalog 

36 

37 

38def _make_region(xyz: tuple[float, float, float] = (1.0, 1.0, -1.0)) -> Region: 

39 """Make a region to use in tests""" 

40 pointing_v = UnitVector3d(*xyz) 

41 fov = 0.05 # radians 

42 region = Circle(pointing_v, Angle(fov / 2)) 

43 return region 

44 

45 

46class ApdbTest(ABC): 

47 """Base class for Apdb tests that can be specialized for concrete 

48 implementation. 

49 

50 This can only be used as a mixin class for a unittest.TestCase and it 

51 calls various assert methods. 

52 """ 

53 

54 time_partition_tables = False 

55 visit_time = DateTime("2021-01-01T00:00:00", DateTime.TAI) 

56 

57 fsrc_requires_id_list = False 

58 """Should be set to True if getDiaForcedSources requires object IDs""" 

59 

60 use_insert_id: bool = False 

61 """Set to true when support for Insert IDs is configured""" 

62 

63 # number of columns as defined in tests/config/schema.yaml 

64 table_column_count = { 

65 ApdbTables.DiaObject: 8, 

66 ApdbTables.DiaObjectLast: 5, 

67 ApdbTables.DiaSource: 10, 

68 ApdbTables.DiaForcedSource: 4, 

69 ApdbTables.SSObject: 3, 

70 } 

71 

72 @abstractmethod 

73 def make_config(self, **kwargs: Any) -> ApdbConfig: 

74 """Make config class instance used in all tests.""" 

75 raise NotImplementedError() 

76 

77 @abstractmethod 

78 def getDiaObjects_table(self) -> ApdbTables: 

79 """Return type of table returned from getDiaObjects method.""" 

80 raise NotImplementedError() 

81 

82 def assert_catalog(self, catalog: Any, rows: int, table: ApdbTables) -> None: 

83 """Validate catalog type and size 

84 

85 Parameters 

86 ---------- 

87 catalog : `object` 

88 Expected type of this is ``pandas.DataFrame``. 

89 rows : `int` 

90 Expected number of rows in a catalog. 

91 table : `ApdbTables` 

92 APDB table type. 

93 """ 

94 self.assertIsInstance(catalog, pandas.DataFrame) 

95 self.assertEqual(catalog.shape[0], rows) 

96 self.assertEqual(catalog.shape[1], self.table_column_count[table]) 

97 

98 def assert_table_data(self, catalog: Any, rows: int, table: ApdbTables) -> None: 

99 """Validate catalog type and size 

100 

101 Parameters 

102 ---------- 

103 catalog : `object` 

104 Expected type of this is `ApdbTableData`. 

105 rows : `int` 

106 Expected number of rows in a catalog. 

107 table : `ApdbTables` 

108 APDB table type. 

109 extra_columns : `int` 

110 Count of additional columns expected in ``catalog``. 

111 """ 

112 self.assertIsInstance(catalog, ApdbTableData) 

113 n_rows = sum(1 for row in catalog.rows()) 

114 self.assertEqual(n_rows, rows) 

115 # One extra column for insert_id 

116 self.assertEqual(len(catalog.column_names()), self.table_column_count[table] + 1) 

117 

118 def test_makeSchema(self) -> None: 

119 """Test for makeing APDB schema.""" 

120 config = self.make_config() 

121 apdb = make_apdb(config) 

122 

123 apdb.makeSchema() 

124 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject)) 

125 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObjectLast)) 

126 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaSource)) 

127 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaForcedSource)) 

128 

129 def test_empty_gets(self) -> None: 

130 """Test for getting data from empty database. 

131 

132 All get() methods should return empty results, only useful for 

133 checking that code is not broken. 

134 """ 

135 # use non-zero months for Forced/Source fetching 

136 config = self.make_config() 

137 apdb = make_apdb(config) 

138 apdb.makeSchema() 

139 

140 region = _make_region() 

141 visit_time = self.visit_time 

142 

143 res: Optional[pandas.DataFrame] 

144 

145 # get objects by region 

146 res = apdb.getDiaObjects(region) 

147 self.assert_catalog(res, 0, self.getDiaObjects_table()) 

148 

149 # get sources by region 

150 res = apdb.getDiaSources(region, None, visit_time) 

151 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

152 

153 res = apdb.getDiaSources(region, [], visit_time) 

154 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

155 

156 # get sources by object ID, non-empty object list 

157 res = apdb.getDiaSources(region, [1, 2, 3], visit_time) 

158 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

159 

160 # get forced sources by object ID, empty object list 

161 res = apdb.getDiaForcedSources(region, [], visit_time) 

162 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

163 

164 # get sources by object ID, non-empty object list 

165 res = apdb.getDiaForcedSources(region, [1, 2, 3], visit_time) 

166 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

167 

168 # get sources by region 

169 if self.fsrc_requires_id_list: 

170 with self.assertRaises(NotImplementedError): 

171 apdb.getDiaForcedSources(region, None, visit_time) 

172 else: 

173 apdb.getDiaForcedSources(region, None, visit_time) 

174 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

175 

176 def test_empty_gets_0months(self) -> None: 

177 """Test for getting data from empty database. 

178 

179 All get() methods should return empty DataFrame or None. 

180 """ 

181 # set read_sources_months to 0 so that Forced/Sources are None 

182 config = self.make_config(read_sources_months=0, read_forced_sources_months=0) 

183 apdb = make_apdb(config) 

184 apdb.makeSchema() 

185 

186 region = _make_region() 

187 visit_time = self.visit_time 

188 

189 res: Optional[pandas.DataFrame] 

190 

191 # get objects by region 

192 res = apdb.getDiaObjects(region) 

193 self.assert_catalog(res, 0, self.getDiaObjects_table()) 

194 

195 # get sources by region 

196 res = apdb.getDiaSources(region, None, visit_time) 

197 self.assertIs(res, None) 

198 

199 # get sources by object ID, empty object list 

200 res = apdb.getDiaSources(region, [], visit_time) 

201 self.assertIs(res, None) 

202 

203 # get forced sources by object ID, empty object list 

204 res = apdb.getDiaForcedSources(region, [], visit_time) 

205 self.assertIs(res, None) 

206 

207 def test_storeObjects(self) -> None: 

208 """Store and retrieve DiaObjects.""" 

209 # don't care about sources. 

210 config = self.make_config() 

211 apdb = make_apdb(config) 

212 apdb.makeSchema() 

213 

214 region = _make_region() 

215 visit_time = self.visit_time 

216 

217 # make catalog with Objects 

218 catalog = makeObjectCatalog(region, 100, visit_time) 

219 

220 # store catalog 

221 apdb.store(visit_time, catalog) 

222 

223 # read it back and check sizes 

224 res = apdb.getDiaObjects(region) 

225 self.assert_catalog(res, len(catalog), self.getDiaObjects_table()) 

226 

227 def test_storeSources(self) -> None: 

228 """Store and retrieve DiaSources.""" 

229 config = self.make_config() 

230 apdb = make_apdb(config) 

231 apdb.makeSchema() 

232 

233 region = _make_region() 

234 visit_time = self.visit_time 

235 

236 # have to store Objects first 

237 objects = makeObjectCatalog(region, 100, visit_time) 

238 oids = list(objects["diaObjectId"]) 

239 sources = makeSourceCatalog(objects, visit_time) 

240 

241 # save the objects and sources 

242 apdb.store(visit_time, objects, sources) 

243 

244 # read it back, no ID filtering 

245 res = apdb.getDiaSources(region, None, visit_time) 

246 self.assert_catalog(res, len(sources), ApdbTables.DiaSource) 

247 

248 # read it back and filter by ID 

249 res = apdb.getDiaSources(region, oids, visit_time) 

250 self.assert_catalog(res, len(sources), ApdbTables.DiaSource) 

251 

252 # read it back to get schema 

253 res = apdb.getDiaSources(region, [], visit_time) 

254 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

255 

256 def test_storeForcedSources(self) -> None: 

257 """Store and retrieve DiaForcedSources.""" 

258 config = self.make_config() 

259 apdb = make_apdb(config) 

260 apdb.makeSchema() 

261 

262 region = _make_region() 

263 visit_time = self.visit_time 

264 

265 # have to store Objects first 

266 objects = makeObjectCatalog(region, 100, visit_time) 

267 oids = list(objects["diaObjectId"]) 

268 catalog = makeForcedSourceCatalog(objects, visit_time) 

269 

270 apdb.store(visit_time, objects, forced_sources=catalog) 

271 

272 # read it back and check sizes 

273 res = apdb.getDiaForcedSources(region, oids, visit_time) 

274 self.assert_catalog(res, len(catalog), ApdbTables.DiaForcedSource) 

275 

276 # read it back to get schema 

277 res = apdb.getDiaForcedSources(region, [], visit_time) 

278 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

279 

280 def test_getHistory(self) -> None: 

281 """Store and retrieve catalog history.""" 

282 # don't care about sources. 

283 config = self.make_config() 

284 apdb = make_apdb(config) 

285 apdb.makeSchema() 

286 visit_time = self.visit_time 

287 

288 region1 = _make_region((1.0, 1.0, -1.0)) 

289 region2 = _make_region((-1.0, -1.0, -1.0)) 

290 nobj = 100 

291 objects1 = makeObjectCatalog(region1, nobj, visit_time) 

292 objects2 = makeObjectCatalog(region2, nobj, visit_time, start_id=nobj * 2) 

293 

294 visits = [ 

295 (DateTime("2021-01-01T00:01:00", DateTime.TAI), objects1), 

296 (DateTime("2021-01-01T00:02:00", DateTime.TAI), objects2), 

297 (DateTime("2021-01-01T00:03:00", DateTime.TAI), objects1), 

298 (DateTime("2021-01-01T00:04:00", DateTime.TAI), objects2), 

299 (DateTime("2021-01-01T00:05:00", DateTime.TAI), objects1), 

300 (DateTime("2021-01-01T00:06:00", DateTime.TAI), objects2), 

301 (DateTime("2021-03-01T00:01:00", DateTime.TAI), objects1), 

302 (DateTime("2021-03-01T00:02:00", DateTime.TAI), objects2), 

303 ] 

304 

305 start_id = 0 

306 for visit_time, objects in visits: 

307 sources = makeSourceCatalog(objects, visit_time, start_id=start_id) 

308 fsources = makeForcedSourceCatalog(objects, visit_time, ccdVisitId=start_id) 

309 apdb.store(visit_time, objects, sources, fsources) 

310 start_id += nobj 

311 

312 insert_ids = apdb.getInsertIds() 

313 if not self.use_insert_id: 

314 self.assertIsNone(insert_ids) 

315 

316 with self.assertRaisesRegex(ValueError, "APDB is not configured for history retrieval"): 

317 apdb.getDiaObjectsHistory([]) 

318 

319 else: 

320 assert insert_ids is not None 

321 self.assertEqual(len(insert_ids), 8) 

322 

323 def _check_history(insert_ids: list[ApdbInsertId]) -> None: 

324 n_records = len(insert_ids) * nobj 

325 res = apdb.getDiaObjectsHistory(insert_ids) 

326 self.assert_table_data(res, n_records, ApdbTables.DiaObject) 

327 res = apdb.getDiaSourcesHistory(insert_ids) 

328 self.assert_table_data(res, n_records, ApdbTables.DiaSource) 

329 res = apdb.getDiaForcedSourcesHistory(insert_ids) 

330 self.assert_table_data(res, n_records, ApdbTables.DiaForcedSource) 

331 

332 # read it back and check sizes 

333 _check_history(insert_ids) 

334 _check_history(insert_ids[1:]) 

335 _check_history(insert_ids[1:-1]) 

336 _check_history(insert_ids[3:4]) 

337 _check_history([]) 

338 

339 # try to remove some of those 

340 apdb.deleteInsertIds(insert_ids[:2]) 

341 insert_ids = apdb.getInsertIds() 

342 assert insert_ids is not None 

343 self.assertEqual(len(insert_ids), 6) 

344 

345 _check_history(insert_ids) 

346 

347 def test_storeSSObjects(self) -> None: 

348 """Store and retrieve SSObjects.""" 

349 # don't care about sources. 

350 config = self.make_config() 

351 apdb = make_apdb(config) 

352 apdb.makeSchema() 

353 

354 # make catalog with SSObjects 

355 catalog = makeSSObjectCatalog(100, flags=1) 

356 

357 # store catalog 

358 apdb.storeSSObjects(catalog) 

359 

360 # read it back and check sizes 

361 res = apdb.getSSObjects() 

362 self.assert_catalog(res, len(catalog), ApdbTables.SSObject) 

363 

364 # check that override works, make catalog with SSObjects, ID = 51-150 

365 catalog = makeSSObjectCatalog(100, 51, flags=2) 

366 apdb.storeSSObjects(catalog) 

367 res = apdb.getSSObjects() 

368 self.assert_catalog(res, 150, ApdbTables.SSObject) 

369 self.assertEqual(len(res[res["flags"] == 1]), 50) 

370 self.assertEqual(len(res[res["flags"] == 2]), 100) 

371 

372 def test_reassignObjects(self) -> None: 

373 """Reassign DiaObjects.""" 

374 # don't care about sources. 

375 config = self.make_config() 

376 apdb = make_apdb(config) 

377 apdb.makeSchema() 

378 

379 region = _make_region() 

380 visit_time = self.visit_time 

381 objects = makeObjectCatalog(region, 100, visit_time) 

382 oids = list(objects["diaObjectId"]) 

383 sources = makeSourceCatalog(objects, visit_time) 

384 apdb.store(visit_time, objects, sources) 

385 

386 catalog = makeSSObjectCatalog(100) 

387 apdb.storeSSObjects(catalog) 

388 

389 # read it back and filter by ID 

390 res = apdb.getDiaSources(region, oids, visit_time) 

391 self.assert_catalog(res, len(sources), ApdbTables.DiaSource) 

392 

393 apdb.reassignDiaSources({1: 1, 2: 2, 5: 5}) 

394 res = apdb.getDiaSources(region, oids, visit_time) 

395 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource) 

396 

397 with self.assertRaisesRegex(ValueError, r"do not exist.*\D1000"): 

398 apdb.reassignDiaSources( 

399 { 

400 1000: 1, 

401 7: 3, 

402 } 

403 ) 

404 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource) 

405 

406 def test_midpointMjdTai_src(self) -> None: 

407 """Test for time filtering of DiaSources.""" 

408 config = self.make_config() 

409 apdb = make_apdb(config) 

410 apdb.makeSchema() 

411 

412 region = _make_region() 

413 # 2021-01-01 plus 360 days is 2021-12-27 

414 src_time1 = DateTime("2021-01-01T00:00:00", DateTime.TAI) 

415 src_time2 = DateTime("2021-01-01T00:00:02", DateTime.TAI) 

416 visit_time0 = DateTime("2021-12-26T23:59:59", DateTime.TAI) 

417 visit_time1 = DateTime("2021-12-27T00:00:01", DateTime.TAI) 

418 visit_time2 = DateTime("2021-12-27T00:00:03", DateTime.TAI) 

419 

420 objects = makeObjectCatalog(region, 100, visit_time0) 

421 oids = list(objects["diaObjectId"]) 

422 sources = makeSourceCatalog(objects, src_time1, 0) 

423 apdb.store(src_time1, objects, sources) 

424 

425 sources = makeSourceCatalog(objects, src_time2, 100) 

426 apdb.store(src_time2, objects, sources) 

427 

428 # reading at time of last save should read all 

429 res = apdb.getDiaSources(region, oids, src_time2) 

430 self.assert_catalog(res, 200, ApdbTables.DiaSource) 

431 

432 # one second before 12 months 

433 res = apdb.getDiaSources(region, oids, visit_time0) 

434 self.assert_catalog(res, 200, ApdbTables.DiaSource) 

435 

436 # reading at later time of last save should only read a subset 

437 res = apdb.getDiaSources(region, oids, visit_time1) 

438 self.assert_catalog(res, 100, ApdbTables.DiaSource) 

439 

440 # reading at later time of last save should only read a subset 

441 res = apdb.getDiaSources(region, oids, visit_time2) 

442 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

443 

444 def test_midpointMjdTai_fsrc(self) -> None: 

445 """Test for time filtering of DiaForcedSources.""" 

446 config = self.make_config() 

447 apdb = make_apdb(config) 

448 apdb.makeSchema() 

449 

450 region = _make_region() 

451 src_time1 = DateTime("2021-01-01T00:00:00", DateTime.TAI) 

452 src_time2 = DateTime("2021-01-01T00:00:02", DateTime.TAI) 

453 visit_time0 = DateTime("2021-12-26T23:59:59", DateTime.TAI) 

454 visit_time1 = DateTime("2021-12-27T00:00:01", DateTime.TAI) 

455 visit_time2 = DateTime("2021-12-27T00:00:03", DateTime.TAI) 

456 

457 objects = makeObjectCatalog(region, 100, visit_time0) 

458 oids = list(objects["diaObjectId"]) 

459 sources = makeForcedSourceCatalog(objects, src_time1, 1) 

460 apdb.store(src_time1, objects, forced_sources=sources) 

461 

462 sources = makeForcedSourceCatalog(objects, src_time2, 2) 

463 apdb.store(src_time2, objects, forced_sources=sources) 

464 

465 # reading at time of last save should read all 

466 res = apdb.getDiaForcedSources(region, oids, src_time2) 

467 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource) 

468 

469 # one second before 12 months 

470 res = apdb.getDiaForcedSources(region, oids, visit_time0) 

471 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource) 

472 

473 # reading at later time of last save should only read a subset 

474 res = apdb.getDiaForcedSources(region, oids, visit_time1) 

475 self.assert_catalog(res, 100, ApdbTables.DiaForcedSource) 

476 

477 # reading at later time of last save should only read a subset 

478 res = apdb.getDiaForcedSources(region, oids, visit_time2) 

479 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

480 

481 if TYPE_CHECKING: 481 ↛ 484line 481 didn't jump to line 484, because the condition on line 481 was never true

482 # This is a mixin class, some methods from unittest.TestCase declared 

483 # here to silence mypy. 

484 assertEqual: Callable[[Any, Any], None] 

485 assertIs: Callable[[Any, Any], None] 

486 assertIsInstance: Callable[[Any, Any], None] 

487 assertIsNone: Callable[[Any], None] 

488 assertIsNotNone: Callable[[Any], None] 

489 assertRaises: Callable[[Any], ContextManager] 

490 assertRaisesRegex: Callable[[Any, Any], ContextManager] 

491 

492 

493class ApdbSchemaUpdateTest(ABC): 

494 """Base class for unit tests that verify how schema changes work.""" 

495 

496 visit_time = DateTime("2021-01-01T00:00:00", DateTime.TAI) 

497 

498 @abstractmethod 

499 def make_config(self, **kwargs: Any) -> ApdbConfig: 

500 """Make config class instance used in all tests. 

501 

502 This method should return configuration that point to the identical 

503 database instance on each call (i.e. ``db_url`` must be the same, 

504 which also means for sqlite it has to use on-disk storage). 

505 """ 

506 raise NotImplementedError() 

507 

508 def test_schema_add_history(self) -> None: 

509 """Check that new code can work with old schema without history 

510 tables. 

511 """ 

512 # Make schema without history tables. 

513 config = self.make_config(use_insert_id=False) 

514 apdb = make_apdb(config) 

515 apdb.makeSchema() 

516 

517 # Make APDB instance configured for history tables. 

518 config = self.make_config(use_insert_id=True) 

519 apdb = make_apdb(config) 

520 

521 # Try to insert something, should work OK. 

522 region = _make_region() 

523 visit_time = self.visit_time 

524 

525 # have to store Objects first 

526 objects = makeObjectCatalog(region, 100, visit_time) 

527 sources = makeSourceCatalog(objects, visit_time) 

528 fsources = makeForcedSourceCatalog(objects, visit_time) 

529 apdb.store(visit_time, objects, sources, fsources) 

530 

531 # There should be no history. 

532 insert_ids = apdb.getInsertIds() 

533 self.assertIsNone(insert_ids) 

534 

535 if TYPE_CHECKING: 535 ↛ 538line 535 didn't jump to line 538, because the condition on line 535 was never true

536 # This is a mixin class, some methods from unittest.TestCase declared 

537 # here to silence mypy. 

538 assertIsNone: Callable[[Any], None]