Coverage for python/lsst/dax/apdb/tests/_apdb.py: 17%

276 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-27 09:07 +0000

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ApdbSchemaUpdateTest", "ApdbTest"] 

25 

26from abc import ABC, abstractmethod 

27from collections.abc import Callable 

28from typing import TYPE_CHECKING, Any, ContextManager, Optional 

29 

30import pandas 

31from lsst.daf.base import DateTime 

32from lsst.dax.apdb import ApdbConfig, ApdbInsertId, ApdbTableData, ApdbTables, make_apdb 

33from lsst.sphgeom import Angle, Circle, Region, UnitVector3d 

34 

35from .data_factory import makeForcedSourceCatalog, makeObjectCatalog, makeSourceCatalog, makeSSObjectCatalog 

36 

37 

38def _make_region(xyz: tuple[float, float, float] = (1.0, 1.0, -1.0)) -> Region: 

39 """Make a region to use in tests""" 

40 pointing_v = UnitVector3d(*xyz) 

41 fov = 0.05 # radians 

42 region = Circle(pointing_v, Angle(fov / 2)) 

43 return region 

44 

45 

46class ApdbTest(ABC): 

47 """Base class for Apdb tests that can be specialized for concrete 

48 implementation. 

49 

50 This can only be used as a mixin class for a unittest.TestCase and it 

51 calls various assert methods. 

52 """ 

53 

54 time_partition_tables = False 

55 visit_time = DateTime("2021-01-01T00:00:00", DateTime.TAI) 

56 

57 fsrc_requires_id_list = False 

58 """Should be set to True if getDiaForcedSources requires object IDs""" 

59 

60 use_insert_id: bool = False 

61 """Set to true when support for Insert IDs is configured""" 

62 

63 # number of columns as defined in tests/config/schema.yaml 

64 table_column_count = { 

65 ApdbTables.DiaObject: 8, 

66 ApdbTables.DiaObjectLast: 5, 

67 ApdbTables.DiaSource: 10, 

68 ApdbTables.DiaForcedSource: 4, 

69 ApdbTables.SSObject: 3, 

70 } 

71 

72 @abstractmethod 

73 def make_config(self, **kwargs: Any) -> ApdbConfig: 

74 """Make config class instance used in all tests.""" 

75 raise NotImplementedError() 

76 

77 @abstractmethod 

78 def getDiaObjects_table(self) -> ApdbTables: 

79 """Return type of table returned from getDiaObjects method.""" 

80 raise NotImplementedError() 

81 

82 def assert_catalog(self, catalog: Any, rows: int, table: ApdbTables) -> None: 

83 """Validate catalog type and size 

84 

85 Parameters 

86 ---------- 

87 catalog : `object` 

88 Expected type of this is ``pandas.DataFrame``. 

89 rows : `int` 

90 Expected number of rows in a catalog. 

91 table : `ApdbTables` 

92 APDB table type. 

93 """ 

94 self.assertIsInstance(catalog, pandas.DataFrame) 

95 self.assertEqual(catalog.shape[0], rows) 

96 self.assertEqual(catalog.shape[1], self.table_column_count[table]) 

97 

98 def assert_table_data(self, catalog: Any, rows: int, table: ApdbTables) -> None: 

99 """Validate catalog type and size 

100 

101 Parameters 

102 ---------- 

103 catalog : `object` 

104 Expected type of this is `ApdbTableData`. 

105 rows : `int` 

106 Expected number of rows in a catalog. 

107 table : `ApdbTables` 

108 APDB table type. 

109 extra_columns : `int` 

110 Count of additional columns expected in ``catalog``. 

111 """ 

112 self.assertIsInstance(catalog, ApdbTableData) 

113 n_rows = sum(1 for row in catalog.rows()) 

114 self.assertEqual(n_rows, rows) 

115 # One extra column for insert_id 

116 self.assertEqual(len(catalog.column_names()), self.table_column_count[table] + 1) 

117 

118 def test_makeSchema(self) -> None: 

119 """Test for makeing APDB schema.""" 

120 config = self.make_config() 

121 apdb = make_apdb(config) 

122 

123 apdb.makeSchema() 

124 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject)) 

125 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObjectLast)) 

126 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaSource)) 

127 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaForcedSource)) 

128 

129 def test_empty_gets(self) -> None: 

130 """Test for getting data from empty database. 

131 

132 All get() methods should return empty results, only useful for 

133 checking that code is not broken. 

134 """ 

135 

136 # use non-zero months for Forced/Source fetching 

137 config = self.make_config() 

138 apdb = make_apdb(config) 

139 apdb.makeSchema() 

140 

141 region = _make_region() 

142 visit_time = self.visit_time 

143 

144 res: Optional[pandas.DataFrame] 

145 

146 # get objects by region 

147 res = apdb.getDiaObjects(region) 

148 self.assert_catalog(res, 0, self.getDiaObjects_table()) 

149 

150 # get sources by region 

151 res = apdb.getDiaSources(region, None, visit_time) 

152 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

153 

154 res = apdb.getDiaSources(region, [], visit_time) 

155 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

156 

157 # get sources by object ID, non-empty object list 

158 res = apdb.getDiaSources(region, [1, 2, 3], visit_time) 

159 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

160 

161 # get forced sources by object ID, empty object list 

162 res = apdb.getDiaForcedSources(region, [], visit_time) 

163 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

164 

165 # get sources by object ID, non-empty object list 

166 res = apdb.getDiaForcedSources(region, [1, 2, 3], visit_time) 

167 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

168 

169 # get sources by region 

170 if self.fsrc_requires_id_list: 

171 with self.assertRaises(NotImplementedError): 

172 apdb.getDiaForcedSources(region, None, visit_time) 

173 else: 

174 apdb.getDiaForcedSources(region, None, visit_time) 

175 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

176 

177 def test_empty_gets_0months(self) -> None: 

178 """Test for getting data from empty database. 

179 

180 All get() methods should return empty DataFrame or None. 

181 """ 

182 

183 # set read_sources_months to 0 so that Forced/Sources are None 

184 config = self.make_config(read_sources_months=0, read_forced_sources_months=0) 

185 apdb = make_apdb(config) 

186 apdb.makeSchema() 

187 

188 region = _make_region() 

189 visit_time = self.visit_time 

190 

191 res: Optional[pandas.DataFrame] 

192 

193 # get objects by region 

194 res = apdb.getDiaObjects(region) 

195 self.assert_catalog(res, 0, self.getDiaObjects_table()) 

196 

197 # get sources by region 

198 res = apdb.getDiaSources(region, None, visit_time) 

199 self.assertIs(res, None) 

200 

201 # get sources by object ID, empty object list 

202 res = apdb.getDiaSources(region, [], visit_time) 

203 self.assertIs(res, None) 

204 

205 # get forced sources by object ID, empty object list 

206 res = apdb.getDiaForcedSources(region, [], visit_time) 

207 self.assertIs(res, None) 

208 

209 def test_storeObjects(self) -> None: 

210 """Store and retrieve DiaObjects.""" 

211 

212 # don't care about sources. 

213 config = self.make_config() 

214 apdb = make_apdb(config) 

215 apdb.makeSchema() 

216 

217 region = _make_region() 

218 visit_time = self.visit_time 

219 

220 # make catalog with Objects 

221 catalog = makeObjectCatalog(region, 100, visit_time) 

222 

223 # store catalog 

224 apdb.store(visit_time, catalog) 

225 

226 # read it back and check sizes 

227 res = apdb.getDiaObjects(region) 

228 self.assert_catalog(res, len(catalog), self.getDiaObjects_table()) 

229 

230 def test_storeSources(self) -> None: 

231 """Store and retrieve DiaSources.""" 

232 config = self.make_config() 

233 apdb = make_apdb(config) 

234 apdb.makeSchema() 

235 

236 region = _make_region() 

237 visit_time = self.visit_time 

238 

239 # have to store Objects first 

240 objects = makeObjectCatalog(region, 100, visit_time) 

241 oids = list(objects["diaObjectId"]) 

242 sources = makeSourceCatalog(objects, visit_time) 

243 

244 # save the objects and sources 

245 apdb.store(visit_time, objects, sources) 

246 

247 # read it back, no ID filtering 

248 res = apdb.getDiaSources(region, None, visit_time) 

249 self.assert_catalog(res, len(sources), ApdbTables.DiaSource) 

250 

251 # read it back and filter by ID 

252 res = apdb.getDiaSources(region, oids, visit_time) 

253 self.assert_catalog(res, len(sources), ApdbTables.DiaSource) 

254 

255 # read it back to get schema 

256 res = apdb.getDiaSources(region, [], visit_time) 

257 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

258 

259 def test_storeForcedSources(self) -> None: 

260 """Store and retrieve DiaForcedSources.""" 

261 

262 config = self.make_config() 

263 apdb = make_apdb(config) 

264 apdb.makeSchema() 

265 

266 region = _make_region() 

267 visit_time = self.visit_time 

268 

269 # have to store Objects first 

270 objects = makeObjectCatalog(region, 100, visit_time) 

271 oids = list(objects["diaObjectId"]) 

272 catalog = makeForcedSourceCatalog(objects, visit_time) 

273 

274 apdb.store(visit_time, objects, forced_sources=catalog) 

275 

276 # read it back and check sizes 

277 res = apdb.getDiaForcedSources(region, oids, visit_time) 

278 self.assert_catalog(res, len(catalog), ApdbTables.DiaForcedSource) 

279 

280 # read it back to get schema 

281 res = apdb.getDiaForcedSources(region, [], visit_time) 

282 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

283 

284 def test_getHistory(self) -> None: 

285 """Store and retrieve catalog history.""" 

286 

287 # don't care about sources. 

288 config = self.make_config() 

289 apdb = make_apdb(config) 

290 apdb.makeSchema() 

291 visit_time = self.visit_time 

292 

293 region1 = _make_region((1.0, 1.0, -1.0)) 

294 region2 = _make_region((-1.0, -1.0, -1.0)) 

295 nobj = 100 

296 objects1 = makeObjectCatalog(region1, nobj, visit_time) 

297 objects2 = makeObjectCatalog(region2, nobj, visit_time, start_id=nobj * 2) 

298 

299 visits = [ 

300 (DateTime("2021-01-01T00:01:00", DateTime.TAI), objects1), 

301 (DateTime("2021-01-01T00:02:00", DateTime.TAI), objects2), 

302 (DateTime("2021-01-01T00:03:00", DateTime.TAI), objects1), 

303 (DateTime("2021-01-01T00:04:00", DateTime.TAI), objects2), 

304 (DateTime("2021-01-01T00:05:00", DateTime.TAI), objects1), 

305 (DateTime("2021-01-01T00:06:00", DateTime.TAI), objects2), 

306 (DateTime("2021-03-01T00:01:00", DateTime.TAI), objects1), 

307 (DateTime("2021-03-01T00:02:00", DateTime.TAI), objects2), 

308 ] 

309 

310 start_id = 0 

311 for visit_time, objects in visits: 

312 sources = makeSourceCatalog(objects, visit_time, start_id=start_id) 

313 fsources = makeForcedSourceCatalog(objects, visit_time, ccdVisitId=start_id) 

314 apdb.store(visit_time, objects, sources, fsources) 

315 start_id += nobj 

316 

317 insert_ids = apdb.getInsertIds() 

318 if not self.use_insert_id: 

319 self.assertIsNone(insert_ids) 

320 

321 with self.assertRaisesRegex(ValueError, "APDB is not configured for history retrieval"): 

322 apdb.getDiaObjectsHistory([]) 

323 

324 else: 

325 assert insert_ids is not None 

326 self.assertEqual(len(insert_ids), 8) 

327 

328 def _check_history(insert_ids: list[ApdbInsertId]) -> None: 

329 n_records = len(insert_ids) * nobj 

330 res = apdb.getDiaObjectsHistory(insert_ids) 

331 self.assert_table_data(res, n_records, ApdbTables.DiaObject) 

332 res = apdb.getDiaSourcesHistory(insert_ids) 

333 self.assert_table_data(res, n_records, ApdbTables.DiaSource) 

334 res = apdb.getDiaForcedSourcesHistory(insert_ids) 

335 self.assert_table_data(res, n_records, ApdbTables.DiaForcedSource) 

336 

337 # read it back and check sizes 

338 _check_history(insert_ids) 

339 _check_history(insert_ids[1:]) 

340 _check_history(insert_ids[1:-1]) 

341 _check_history(insert_ids[3:4]) 

342 _check_history([]) 

343 

344 # try to remove some of those 

345 apdb.deleteInsertIds(insert_ids[:2]) 

346 insert_ids = apdb.getInsertIds() 

347 assert insert_ids is not None 

348 self.assertEqual(len(insert_ids), 6) 

349 

350 _check_history(insert_ids) 

351 

352 def test_storeSSObjects(self) -> None: 

353 """Store and retrieve SSObjects.""" 

354 

355 # don't care about sources. 

356 config = self.make_config() 

357 apdb = make_apdb(config) 

358 apdb.makeSchema() 

359 

360 # make catalog with SSObjects 

361 catalog = makeSSObjectCatalog(100, flags=1) 

362 

363 # store catalog 

364 apdb.storeSSObjects(catalog) 

365 

366 # read it back and check sizes 

367 res = apdb.getSSObjects() 

368 self.assert_catalog(res, len(catalog), ApdbTables.SSObject) 

369 

370 # check that override works, make catalog with SSObjects, ID = 51-150 

371 catalog = makeSSObjectCatalog(100, 51, flags=2) 

372 apdb.storeSSObjects(catalog) 

373 res = apdb.getSSObjects() 

374 self.assert_catalog(res, 150, ApdbTables.SSObject) 

375 self.assertEqual(len(res[res["flags"] == 1]), 50) 

376 self.assertEqual(len(res[res["flags"] == 2]), 100) 

377 

378 def test_reassignObjects(self) -> None: 

379 """Reassign DiaObjects.""" 

380 

381 # don't care about sources. 

382 config = self.make_config() 

383 apdb = make_apdb(config) 

384 apdb.makeSchema() 

385 

386 region = _make_region() 

387 visit_time = self.visit_time 

388 objects = makeObjectCatalog(region, 100, visit_time) 

389 oids = list(objects["diaObjectId"]) 

390 sources = makeSourceCatalog(objects, visit_time) 

391 apdb.store(visit_time, objects, sources) 

392 

393 catalog = makeSSObjectCatalog(100) 

394 apdb.storeSSObjects(catalog) 

395 

396 # read it back and filter by ID 

397 res = apdb.getDiaSources(region, oids, visit_time) 

398 self.assert_catalog(res, len(sources), ApdbTables.DiaSource) 

399 

400 apdb.reassignDiaSources({1: 1, 2: 2, 5: 5}) 

401 res = apdb.getDiaSources(region, oids, visit_time) 

402 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource) 

403 

404 with self.assertRaisesRegex(ValueError, r"do not exist.*\D1000"): 

405 apdb.reassignDiaSources( 

406 { 

407 1000: 1, 

408 7: 3, 

409 } 

410 ) 

411 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource) 

412 

413 def test_midpointMjdTai_src(self) -> None: 

414 """Test for time filtering of DiaSources.""" 

415 config = self.make_config() 

416 apdb = make_apdb(config) 

417 apdb.makeSchema() 

418 

419 region = _make_region() 

420 # 2021-01-01 plus 360 days is 2021-12-27 

421 src_time1 = DateTime("2021-01-01T00:00:00", DateTime.TAI) 

422 src_time2 = DateTime("2021-01-01T00:00:02", DateTime.TAI) 

423 visit_time0 = DateTime("2021-12-26T23:59:59", DateTime.TAI) 

424 visit_time1 = DateTime("2021-12-27T00:00:01", DateTime.TAI) 

425 visit_time2 = DateTime("2021-12-27T00:00:03", DateTime.TAI) 

426 

427 objects = makeObjectCatalog(region, 100, visit_time0) 

428 oids = list(objects["diaObjectId"]) 

429 sources = makeSourceCatalog(objects, src_time1, 0) 

430 apdb.store(src_time1, objects, sources) 

431 

432 sources = makeSourceCatalog(objects, src_time2, 100) 

433 apdb.store(src_time2, objects, sources) 

434 

435 # reading at time of last save should read all 

436 res = apdb.getDiaSources(region, oids, src_time2) 

437 self.assert_catalog(res, 200, ApdbTables.DiaSource) 

438 

439 # one second before 12 months 

440 res = apdb.getDiaSources(region, oids, visit_time0) 

441 self.assert_catalog(res, 200, ApdbTables.DiaSource) 

442 

443 # reading at later time of last save should only read a subset 

444 res = apdb.getDiaSources(region, oids, visit_time1) 

445 self.assert_catalog(res, 100, ApdbTables.DiaSource) 

446 

447 # reading at later time of last save should only read a subset 

448 res = apdb.getDiaSources(region, oids, visit_time2) 

449 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

450 

451 def test_midpointMjdTai_fsrc(self) -> None: 

452 """Test for time filtering of DiaForcedSources.""" 

453 config = self.make_config() 

454 apdb = make_apdb(config) 

455 apdb.makeSchema() 

456 

457 region = _make_region() 

458 src_time1 = DateTime("2021-01-01T00:00:00", DateTime.TAI) 

459 src_time2 = DateTime("2021-01-01T00:00:02", DateTime.TAI) 

460 visit_time0 = DateTime("2021-12-26T23:59:59", DateTime.TAI) 

461 visit_time1 = DateTime("2021-12-27T00:00:01", DateTime.TAI) 

462 visit_time2 = DateTime("2021-12-27T00:00:03", DateTime.TAI) 

463 

464 objects = makeObjectCatalog(region, 100, visit_time0) 

465 oids = list(objects["diaObjectId"]) 

466 sources = makeForcedSourceCatalog(objects, src_time1, 1) 

467 apdb.store(src_time1, objects, forced_sources=sources) 

468 

469 sources = makeForcedSourceCatalog(objects, src_time2, 2) 

470 apdb.store(src_time2, objects, forced_sources=sources) 

471 

472 # reading at time of last save should read all 

473 res = apdb.getDiaForcedSources(region, oids, src_time2) 

474 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource) 

475 

476 # one second before 12 months 

477 res = apdb.getDiaForcedSources(region, oids, visit_time0) 

478 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource) 

479 

480 # reading at later time of last save should only read a subset 

481 res = apdb.getDiaForcedSources(region, oids, visit_time1) 

482 self.assert_catalog(res, 100, ApdbTables.DiaForcedSource) 

483 

484 # reading at later time of last save should only read a subset 

485 res = apdb.getDiaForcedSources(region, oids, visit_time2) 

486 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

487 

488 if TYPE_CHECKING: 488 ↛ 491line 488 didn't jump to line 491, because the condition on line 488 was never true

489 # This is a mixin class, some methods from unittest.TestCase declared 

490 # here to silence mypy. 

491 assertEqual: Callable[[Any, Any], None] 

492 assertIs: Callable[[Any, Any], None] 

493 assertIsInstance: Callable[[Any, Any], None] 

494 assertIsNone: Callable[[Any], None] 

495 assertIsNotNone: Callable[[Any], None] 

496 assertRaises: Callable[[Any], ContextManager] 

497 assertRaisesRegex: Callable[[Any, Any], ContextManager] 

498 

499 

500class ApdbSchemaUpdateTest(ABC): 

501 """Base class for unit tests that verify how schema changes work.""" 

502 

503 visit_time = DateTime("2021-01-01T00:00:00", DateTime.TAI) 

504 

505 @abstractmethod 

506 def make_config(self, **kwargs: Any) -> ApdbConfig: 

507 """Make config class instance used in all tests. 

508 

509 This method should return configuration that point to the identical 

510 database instance on each call (i.e. ``db_url`` must be the same, 

511 which also means for sqlite it has to use on-disk storage). 

512 """ 

513 raise NotImplementedError() 

514 

515 def test_schema_add_history(self) -> None: 

516 """Check that new code can work with old schema without history 

517 tables. 

518 """ 

519 

520 # Make schema without history tables. 

521 config = self.make_config(use_insert_id=False) 

522 apdb = make_apdb(config) 

523 apdb.makeSchema() 

524 

525 # Make APDB instance configured for history tables. 

526 config = self.make_config(use_insert_id=True) 

527 apdb = make_apdb(config) 

528 

529 # Try to insert something, should work OK. 

530 region = _make_region() 

531 visit_time = self.visit_time 

532 

533 # have to store Objects first 

534 objects = makeObjectCatalog(region, 100, visit_time) 

535 sources = makeSourceCatalog(objects, visit_time) 

536 fsources = makeForcedSourceCatalog(objects, visit_time) 

537 apdb.store(visit_time, objects, sources, fsources) 

538 

539 # There should be no history. 

540 insert_ids = apdb.getInsertIds() 

541 self.assertIsNone(insert_ids) 

542 

543 if TYPE_CHECKING: 543 ↛ 546line 543 didn't jump to line 546, because the condition on line 543 was never true

544 # This is a mixin class, some methods from unittest.TestCase declared 

545 # here to silence mypy. 

546 assertIsNone: Callable[[Any], None]