Coverage for python/lsst/dax/apdb/tests/_apdb.py: 14%

391 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-13 09:59 +0000

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ApdbSchemaUpdateTest", "ApdbTest", "update_schema_yaml"] 

25 

26import contextlib 

27import os 

28import tempfile 

29import unittest 

30from abc import ABC, abstractmethod 

31from collections.abc import Iterator 

32from tempfile import TemporaryDirectory 

33from typing import TYPE_CHECKING, Any 

34 

35import astropy.time 

36import pandas 

37import yaml 

38from lsst.dax.apdb import ( 

39 Apdb, 

40 ApdbConfig, 

41 ApdbInsertId, 

42 ApdbSql, 

43 ApdbTableData, 

44 ApdbTables, 

45 IncompatibleVersionError, 

46 VersionTuple, 

47) 

48from lsst.sphgeom import Angle, Circle, Region, UnitVector3d 

49 

50from .data_factory import makeForcedSourceCatalog, makeObjectCatalog, makeSourceCatalog, makeSSObjectCatalog 

51 

52if TYPE_CHECKING: 

53 

54 class TestCaseMixin(unittest.TestCase): 

55 """Base class for mixin test classes that use TestCase methods.""" 

56 

57else: 

58 

59 class TestCaseMixin: 

60 """Do-nothing definition of mixin base class for regular execution.""" 

61 

62 

63def _make_region(xyz: tuple[float, float, float] = (1.0, 1.0, -1.0)) -> Region: 

64 """Make a region to use in tests""" 

65 pointing_v = UnitVector3d(*xyz) 

66 fov = 0.05 # radians 

67 region = Circle(pointing_v, Angle(fov / 2)) 

68 return region 

69 

70 

71@contextlib.contextmanager 

72def update_schema_yaml( 

73 schema_file: str, 

74 drop_metadata: bool = False, 

75 version: str | None = None, 

76) -> Iterator[str]: 

77 """Update schema definition and return name of the new schema file. 

78 

79 Parameters 

80 ---------- 

81 schema_file : `str` 

82 Path for the existing YAML file with APDB schema. 

83 drop_metadata : `bool` 

84 If `True` then remove metadata table from the list of tables. 

85 version : `str` or `None` 

86 If non-empty string then set schema version to this string, if empty 

87 string then remove schema version from config, if `None` - don't change 

88 the version in config. 

89 

90 Yields 

91 ------ 

92 Path for the updated configuration file. 

93 """ 

94 with open(schema_file) as yaml_stream: 

95 schemas_list = list(yaml.load_all(yaml_stream, Loader=yaml.SafeLoader)) 

96 # Edit YAML contents. 

97 for schema in schemas_list: 

98 # Optionally drop metadata table. 

99 if drop_metadata: 

100 schema["tables"] = [table for table in schema["tables"] if table["name"] != "metadata"] 

101 if version is not None: 

102 if version == "": 

103 del schema["version"] 

104 else: 

105 schema["version"] = version 

106 

107 with TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir: 

108 output_path = os.path.join(tmpdir, "schema.yaml") 

109 with open(output_path, "w") as yaml_stream: 

110 yaml.dump_all(schemas_list, stream=yaml_stream) 

111 yield output_path 

112 

113 

114class ApdbTest(TestCaseMixin, ABC): 

115 """Base class for Apdb tests that can be specialized for concrete 

116 implementation. 

117 

118 This can only be used as a mixin class for a unittest.TestCase and it 

119 calls various assert methods. 

120 """ 

121 

122 time_partition_tables = False 

123 visit_time = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai") 

124 

125 fsrc_requires_id_list = False 

126 """Should be set to True if getDiaForcedSources requires object IDs""" 

127 

128 use_insert_id: bool = False 

129 """Set to true when support for Insert IDs is configured""" 

130 

131 allow_visit_query: bool = True 

132 """Set to true when contains is implemented""" 

133 

134 schema_path: str 

135 """Location of the Felis schema file.""" 

136 

137 # number of columns as defined in tests/config/schema.yaml 

138 table_column_count = { 

139 ApdbTables.DiaObject: 8, 

140 ApdbTables.DiaObjectLast: 5, 

141 ApdbTables.DiaSource: 10, 

142 ApdbTables.DiaForcedSource: 4, 

143 ApdbTables.SSObject: 3, 

144 } 

145 

146 @abstractmethod 

147 def make_instance(self, **kwargs: Any) -> ApdbConfig: 

148 """Make database instance and return configuration for it.""" 

149 raise NotImplementedError() 

150 

151 @abstractmethod 

152 def getDiaObjects_table(self) -> ApdbTables: 

153 """Return type of table returned from getDiaObjects method.""" 

154 raise NotImplementedError() 

155 

156 def assert_catalog(self, catalog: Any, rows: int, table: ApdbTables) -> None: 

157 """Validate catalog type and size 

158 

159 Parameters 

160 ---------- 

161 catalog : `object` 

162 Expected type of this is ``pandas.DataFrame``. 

163 rows : `int` 

164 Expected number of rows in a catalog. 

165 table : `ApdbTables` 

166 APDB table type. 

167 """ 

168 self.assertIsInstance(catalog, pandas.DataFrame) 

169 self.assertEqual(catalog.shape[0], rows) 

170 self.assertEqual(catalog.shape[1], self.table_column_count[table]) 

171 

172 def assert_table_data(self, catalog: Any, rows: int, table: ApdbTables) -> None: 

173 """Validate catalog type and size 

174 

175 Parameters 

176 ---------- 

177 catalog : `object` 

178 Expected type of this is `ApdbTableData`. 

179 rows : `int` 

180 Expected number of rows in a catalog. 

181 table : `ApdbTables` 

182 APDB table type. 

183 extra_columns : `int` 

184 Count of additional columns expected in ``catalog``. 

185 """ 

186 self.assertIsInstance(catalog, ApdbTableData) 

187 n_rows = sum(1 for row in catalog.rows()) 

188 self.assertEqual(n_rows, rows) 

189 # One extra column for insert_id 

190 self.assertEqual(len(catalog.column_names()), self.table_column_count[table] + 1) 

191 

192 def test_makeSchema(self) -> None: 

193 """Test for making APDB schema.""" 

194 config = self.make_instance() 

195 apdb = Apdb.from_config(config) 

196 

197 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject)) 

198 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObjectLast)) 

199 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaSource)) 

200 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaForcedSource)) 

201 self.assertIsNotNone(apdb.tableDef(ApdbTables.metadata)) 

202 

203 # Test from_uri factory method with the same config. 

204 with tempfile.NamedTemporaryFile() as tmpfile: 

205 config.save(tmpfile.name) 

206 apdb = Apdb.from_uri(tmpfile.name) 

207 

208 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject)) 

209 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObjectLast)) 

210 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaSource)) 

211 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaForcedSource)) 

212 self.assertIsNotNone(apdb.tableDef(ApdbTables.metadata)) 

213 

214 def test_empty_gets(self) -> None: 

215 """Test for getting data from empty database. 

216 

217 All get() methods should return empty results, only useful for 

218 checking that code is not broken. 

219 """ 

220 # use non-zero months for Forced/Source fetching 

221 config = self.make_instance() 

222 apdb = Apdb.from_config(config) 

223 

224 region = _make_region() 

225 visit_time = self.visit_time 

226 

227 res: pandas.DataFrame | None 

228 

229 # get objects by region 

230 res = apdb.getDiaObjects(region) 

231 self.assert_catalog(res, 0, self.getDiaObjects_table()) 

232 

233 # get sources by region 

234 res = apdb.getDiaSources(region, None, visit_time) 

235 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

236 

237 res = apdb.getDiaSources(region, [], visit_time) 

238 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

239 

240 # get sources by object ID, non-empty object list 

241 res = apdb.getDiaSources(region, [1, 2, 3], visit_time) 

242 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

243 

244 # get forced sources by object ID, empty object list 

245 res = apdb.getDiaForcedSources(region, [], visit_time) 

246 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

247 

248 # get sources by object ID, non-empty object list 

249 res = apdb.getDiaForcedSources(region, [1, 2, 3], visit_time) 

250 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

251 

252 # test if a visit has objects/sources 

253 if self.allow_visit_query: 

254 res = apdb.containsVisitDetector(visit=0, detector=0) 

255 self.assertFalse(res) 

256 else: 

257 with self.assertRaises(NotImplementedError): 

258 apdb.containsVisitDetector(visit=0, detector=0) 

259 

260 # alternative method not part of the Apdb API 

261 if isinstance(apdb, ApdbSql): 

262 res = apdb.containsCcdVisit(1) 

263 self.assertFalse(res) 

264 

265 # get sources by region 

266 if self.fsrc_requires_id_list: 

267 with self.assertRaises(NotImplementedError): 

268 apdb.getDiaForcedSources(region, None, visit_time) 

269 else: 

270 apdb.getDiaForcedSources(region, None, visit_time) 

271 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

272 

273 def test_empty_gets_0months(self) -> None: 

274 """Test for getting data from empty database. 

275 

276 All get() methods should return empty DataFrame or None. 

277 """ 

278 # set read_sources_months to 0 so that Forced/Sources are None 

279 config = self.make_instance(read_sources_months=0, read_forced_sources_months=0) 

280 apdb = Apdb.from_config(config) 

281 

282 region = _make_region() 

283 visit_time = self.visit_time 

284 

285 res: pandas.DataFrame | None 

286 

287 # get objects by region 

288 res = apdb.getDiaObjects(region) 

289 self.assert_catalog(res, 0, self.getDiaObjects_table()) 

290 

291 # get sources by region 

292 res = apdb.getDiaSources(region, None, visit_time) 

293 self.assertIs(res, None) 

294 

295 # get sources by object ID, empty object list 

296 res = apdb.getDiaSources(region, [], visit_time) 

297 self.assertIs(res, None) 

298 

299 # get forced sources by object ID, empty object list 

300 res = apdb.getDiaForcedSources(region, [], visit_time) 

301 self.assertIs(res, None) 

302 

303 # test if a visit has objects/sources 

304 if self.allow_visit_query: 

305 res = apdb.containsVisitDetector(visit=0, detector=0) 

306 self.assertFalse(res) 

307 else: 

308 with self.assertRaises(NotImplementedError): 

309 apdb.containsVisitDetector(visit=0, detector=0) 

310 

311 # alternative method not part of the Apdb API 

312 if isinstance(apdb, ApdbSql): 

313 res = apdb.containsCcdVisit(1) 

314 self.assertFalse(res) 

315 

316 def test_storeObjects(self) -> None: 

317 """Store and retrieve DiaObjects.""" 

318 # don't care about sources. 

319 config = self.make_instance() 

320 apdb = Apdb.from_config(config) 

321 

322 region = _make_region() 

323 visit_time = self.visit_time 

324 

325 # make catalog with Objects 

326 catalog = makeObjectCatalog(region, 100, visit_time) 

327 

328 # store catalog 

329 apdb.store(visit_time, catalog) 

330 

331 # read it back and check sizes 

332 res = apdb.getDiaObjects(region) 

333 self.assert_catalog(res, len(catalog), self.getDiaObjects_table()) 

334 

335 # TODO: test apdb.contains with generic implementation from DM-41671 

336 

337 def test_storeObjects_empty(self) -> None: 

338 """Test calling storeObject when there are no objects: see DM-43270.""" 

339 config = self.make_instance() 

340 apdb = Apdb.from_config(config) 

341 region = _make_region() 

342 visit_time = self.visit_time 

343 # make catalog with no Objects 

344 catalog = makeObjectCatalog(region, 0, visit_time) 

345 

346 with self.assertLogs("lsst.dax.apdb", level="DEBUG") as cm: 

347 apdb.store(visit_time, catalog) 

348 self.assertIn("No objects", "\n".join(cm.output)) 

349 

350 def test_storeSources(self) -> None: 

351 """Store and retrieve DiaSources.""" 

352 config = self.make_instance() 

353 apdb = Apdb.from_config(config) 

354 

355 region = _make_region() 

356 visit_time = self.visit_time 

357 

358 # have to store Objects first 

359 objects = makeObjectCatalog(region, 100, visit_time) 

360 oids = list(objects["diaObjectId"]) 

361 sources = makeSourceCatalog(objects, visit_time) 

362 

363 # save the objects and sources 

364 apdb.store(visit_time, objects, sources) 

365 

366 # read it back, no ID filtering 

367 res = apdb.getDiaSources(region, None, visit_time) 

368 self.assert_catalog(res, len(sources), ApdbTables.DiaSource) 

369 

370 # read it back and filter by ID 

371 res = apdb.getDiaSources(region, oids, visit_time) 

372 self.assert_catalog(res, len(sources), ApdbTables.DiaSource) 

373 

374 # read it back to get schema 

375 res = apdb.getDiaSources(region, [], visit_time) 

376 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

377 

378 # test if a visit is present 

379 # data_factory's ccdVisitId generation corresponds to (0, 0) 

380 if self.allow_visit_query: 

381 res = apdb.containsVisitDetector(visit=0, detector=0) 

382 self.assertTrue(res) 

383 else: 

384 with self.assertRaises(NotImplementedError): 

385 apdb.containsVisitDetector(visit=0, detector=0) 

386 

387 # alternative method not part of the Apdb API 

388 if isinstance(apdb, ApdbSql): 

389 res = apdb.containsCcdVisit(1) 

390 self.assertTrue(res) 

391 res = apdb.containsCcdVisit(42) 

392 self.assertFalse(res) 

393 

394 def test_storeForcedSources(self) -> None: 

395 """Store and retrieve DiaForcedSources.""" 

396 config = self.make_instance() 

397 apdb = Apdb.from_config(config) 

398 

399 region = _make_region() 

400 visit_time = self.visit_time 

401 

402 # have to store Objects first 

403 objects = makeObjectCatalog(region, 100, visit_time) 

404 oids = list(objects["diaObjectId"]) 

405 catalog = makeForcedSourceCatalog(objects, visit_time) 

406 

407 apdb.store(visit_time, objects, forced_sources=catalog) 

408 

409 # read it back and check sizes 

410 res = apdb.getDiaForcedSources(region, oids, visit_time) 

411 self.assert_catalog(res, len(catalog), ApdbTables.DiaForcedSource) 

412 

413 # read it back to get schema 

414 res = apdb.getDiaForcedSources(region, [], visit_time) 

415 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

416 

417 # TODO: test apdb.contains with generic implementation from DM-41671 

418 

419 # alternative method not part of the Apdb API 

420 if isinstance(apdb, ApdbSql): 

421 res = apdb.containsCcdVisit(1) 

422 self.assertTrue(res) 

423 res = apdb.containsCcdVisit(42) 

424 self.assertFalse(res) 

425 

426 def test_getHistory(self) -> None: 

427 """Store and retrieve catalog history.""" 

428 # don't care about sources. 

429 config = self.make_instance() 

430 apdb = Apdb.from_config(config) 

431 visit_time = self.visit_time 

432 

433 region1 = _make_region((1.0, 1.0, -1.0)) 

434 region2 = _make_region((-1.0, -1.0, -1.0)) 

435 nobj = 100 

436 objects1 = makeObjectCatalog(region1, nobj, visit_time) 

437 objects2 = makeObjectCatalog(region2, nobj, visit_time, start_id=nobj * 2) 

438 

439 visits = [ 

440 (astropy.time.Time("2021-01-01T00:01:00", format="isot", scale="tai"), objects1), 

441 (astropy.time.Time("2021-01-01T00:02:00", format="isot", scale="tai"), objects2), 

442 (astropy.time.Time("2021-01-01T00:03:00", format="isot", scale="tai"), objects1), 

443 (astropy.time.Time("2021-01-01T00:04:00", format="isot", scale="tai"), objects2), 

444 (astropy.time.Time("2021-01-01T00:05:00", format="isot", scale="tai"), objects1), 

445 (astropy.time.Time("2021-01-01T00:06:00", format="isot", scale="tai"), objects2), 

446 (astropy.time.Time("2021-03-01T00:01:00", format="isot", scale="tai"), objects1), 

447 (astropy.time.Time("2021-03-01T00:02:00", format="isot", scale="tai"), objects2), 

448 ] 

449 

450 start_id = 0 

451 for visit_time, objects in visits: 

452 sources = makeSourceCatalog(objects, visit_time, start_id=start_id) 

453 fsources = makeForcedSourceCatalog(objects, visit_time, ccdVisitId=start_id) 

454 apdb.store(visit_time, objects, sources, fsources) 

455 start_id += nobj 

456 

457 insert_ids = apdb.getInsertIds() 

458 if not self.use_insert_id: 

459 self.assertIsNone(insert_ids) 

460 

461 with self.assertRaisesRegex(ValueError, "APDB is not configured for history retrieval"): 

462 apdb.getDiaObjectsHistory([]) 

463 

464 else: 

465 assert insert_ids is not None 

466 self.assertEqual(len(insert_ids), 8) 

467 

468 def _check_history(insert_ids: list[ApdbInsertId], n_records: int | None = None) -> None: 

469 if n_records is None: 

470 n_records = len(insert_ids) * nobj 

471 res = apdb.getDiaObjectsHistory(insert_ids) 

472 self.assert_table_data(res, n_records, ApdbTables.DiaObject) 

473 res = apdb.getDiaSourcesHistory(insert_ids) 

474 self.assert_table_data(res, n_records, ApdbTables.DiaSource) 

475 res = apdb.getDiaForcedSourcesHistory(insert_ids) 

476 self.assert_table_data(res, n_records, ApdbTables.DiaForcedSource) 

477 

478 # read it back and check sizes 

479 _check_history(insert_ids) 

480 _check_history(insert_ids[1:]) 

481 _check_history(insert_ids[1:-1]) 

482 _check_history(insert_ids[3:4]) 

483 _check_history([]) 

484 

485 # try to remove some of those 

486 deleted_ids = insert_ids[:2] 

487 apdb.deleteInsertIds(deleted_ids) 

488 

489 # All queries on deleted ids should return empty set. 

490 _check_history(deleted_ids, 0) 

491 

492 insert_ids = apdb.getInsertIds() 

493 assert insert_ids is not None 

494 self.assertEqual(len(insert_ids), 6) 

495 

496 _check_history(insert_ids) 

497 

498 def test_storeSSObjects(self) -> None: 

499 """Store and retrieve SSObjects.""" 

500 # don't care about sources. 

501 config = self.make_instance() 

502 apdb = Apdb.from_config(config) 

503 

504 # make catalog with SSObjects 

505 catalog = makeSSObjectCatalog(100, flags=1) 

506 

507 # store catalog 

508 apdb.storeSSObjects(catalog) 

509 

510 # read it back and check sizes 

511 res = apdb.getSSObjects() 

512 self.assert_catalog(res, len(catalog), ApdbTables.SSObject) 

513 

514 # check that override works, make catalog with SSObjects, ID = 51-150 

515 catalog = makeSSObjectCatalog(100, 51, flags=2) 

516 apdb.storeSSObjects(catalog) 

517 res = apdb.getSSObjects() 

518 self.assert_catalog(res, 150, ApdbTables.SSObject) 

519 self.assertEqual(len(res[res["flags"] == 1]), 50) 

520 self.assertEqual(len(res[res["flags"] == 2]), 100) 

521 

522 def test_reassignObjects(self) -> None: 

523 """Reassign DiaObjects.""" 

524 # don't care about sources. 

525 config = self.make_instance() 

526 apdb = Apdb.from_config(config) 

527 

528 region = _make_region() 

529 visit_time = self.visit_time 

530 objects = makeObjectCatalog(region, 100, visit_time) 

531 oids = list(objects["diaObjectId"]) 

532 sources = makeSourceCatalog(objects, visit_time) 

533 apdb.store(visit_time, objects, sources) 

534 

535 catalog = makeSSObjectCatalog(100) 

536 apdb.storeSSObjects(catalog) 

537 

538 # read it back and filter by ID 

539 res = apdb.getDiaSources(region, oids, visit_time) 

540 self.assert_catalog(res, len(sources), ApdbTables.DiaSource) 

541 

542 apdb.reassignDiaSources({1: 1, 2: 2, 5: 5}) 

543 res = apdb.getDiaSources(region, oids, visit_time) 

544 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource) 

545 

546 with self.assertRaisesRegex(ValueError, r"do not exist.*\D1000"): 

547 apdb.reassignDiaSources( 

548 { 

549 1000: 1, 

550 7: 3, 

551 } 

552 ) 

553 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource) 

554 

555 def test_midpointMjdTai_src(self) -> None: 

556 """Test for time filtering of DiaSources.""" 

557 config = self.make_instance() 

558 apdb = Apdb.from_config(config) 

559 

560 region = _make_region() 

561 # 2021-01-01 plus 360 days is 2021-12-27 

562 src_time1 = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai") 

563 src_time2 = astropy.time.Time("2021-01-01T00:00:02", format="isot", scale="tai") 

564 visit_time0 = astropy.time.Time("2021-12-26T23:59:59", format="isot", scale="tai") 

565 visit_time1 = astropy.time.Time("2021-12-27T00:00:01", format="isot", scale="tai") 

566 visit_time2 = astropy.time.Time("2021-12-27T00:00:03", format="isot", scale="tai") 

567 

568 objects = makeObjectCatalog(region, 100, visit_time0) 

569 oids = list(objects["diaObjectId"]) 

570 sources = makeSourceCatalog(objects, src_time1, 0) 

571 apdb.store(src_time1, objects, sources) 

572 

573 sources = makeSourceCatalog(objects, src_time2, 100) 

574 apdb.store(src_time2, objects, sources) 

575 

576 # reading at time of last save should read all 

577 res = apdb.getDiaSources(region, oids, src_time2) 

578 self.assert_catalog(res, 200, ApdbTables.DiaSource) 

579 

580 # one second before 12 months 

581 res = apdb.getDiaSources(region, oids, visit_time0) 

582 self.assert_catalog(res, 200, ApdbTables.DiaSource) 

583 

584 # reading at later time of last save should only read a subset 

585 res = apdb.getDiaSources(region, oids, visit_time1) 

586 self.assert_catalog(res, 100, ApdbTables.DiaSource) 

587 

588 # reading at later time of last save should only read a subset 

589 res = apdb.getDiaSources(region, oids, visit_time2) 

590 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

591 

592 def test_midpointMjdTai_fsrc(self) -> None: 

593 """Test for time filtering of DiaForcedSources.""" 

594 config = self.make_instance() 

595 apdb = Apdb.from_config(config) 

596 

597 region = _make_region() 

598 src_time1 = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai") 

599 src_time2 = astropy.time.Time("2021-01-01T00:00:02", format="isot", scale="tai") 

600 visit_time0 = astropy.time.Time("2021-12-26T23:59:59", format="isot", scale="tai") 

601 visit_time1 = astropy.time.Time("2021-12-27T00:00:01", format="isot", scale="tai") 

602 visit_time2 = astropy.time.Time("2021-12-27T00:00:03", format="isot", scale="tai") 

603 

604 objects = makeObjectCatalog(region, 100, visit_time0) 

605 oids = list(objects["diaObjectId"]) 

606 sources = makeForcedSourceCatalog(objects, src_time1, 1) 

607 apdb.store(src_time1, objects, forced_sources=sources) 

608 

609 sources = makeForcedSourceCatalog(objects, src_time2, 2) 

610 apdb.store(src_time2, objects, forced_sources=sources) 

611 

612 # reading at time of last save should read all 

613 res = apdb.getDiaForcedSources(region, oids, src_time2) 

614 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource) 

615 

616 # one second before 12 months 

617 res = apdb.getDiaForcedSources(region, oids, visit_time0) 

618 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource) 

619 

620 # reading at later time of last save should only read a subset 

621 res = apdb.getDiaForcedSources(region, oids, visit_time1) 

622 self.assert_catalog(res, 100, ApdbTables.DiaForcedSource) 

623 

624 # reading at later time of last save should only read a subset 

625 res = apdb.getDiaForcedSources(region, oids, visit_time2) 

626 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

627 

628 def test_metadata(self) -> None: 

629 """Simple test for writing/reading metadata table""" 

630 config = self.make_instance() 

631 apdb = Apdb.from_config(config) 

632 metadata = apdb.metadata 

633 

634 # APDB should write two metadata items with version numbers and a 

635 # frozen JSON config. 

636 self.assertFalse(metadata.empty()) 

637 self.assertEqual(len(list(metadata.items())), 3) 

638 

639 metadata.set("meta", "data") 

640 metadata.set("data", "meta") 

641 

642 self.assertFalse(metadata.empty()) 

643 self.assertTrue(set(metadata.items()) >= {("meta", "data"), ("data", "meta")}) 

644 

645 with self.assertRaisesRegex(KeyError, "Metadata key 'meta' already exists"): 

646 metadata.set("meta", "data1") 

647 

648 metadata.set("meta", "data2", force=True) 

649 self.assertTrue(set(metadata.items()) >= {("meta", "data2"), ("data", "meta")}) 

650 

651 self.assertTrue(metadata.delete("meta")) 

652 self.assertIsNone(metadata.get("meta")) 

653 self.assertFalse(metadata.delete("meta")) 

654 

655 self.assertEqual(metadata.get("data"), "meta") 

656 self.assertEqual(metadata.get("meta", "meta"), "meta") 

657 

658 def test_nometadata(self) -> None: 

659 """Test case for when metadata table is missing""" 

660 # We expect that schema includes metadata table, drop it. 

661 with update_schema_yaml(self.schema_path, drop_metadata=True) as schema_file: 

662 config = self.make_instance(schema_file=schema_file) 

663 apdb = Apdb.from_config(config) 

664 metadata = apdb.metadata 

665 

666 self.assertTrue(metadata.empty()) 

667 self.assertEqual(list(metadata.items()), []) 

668 with self.assertRaisesRegex(RuntimeError, "Metadata table does not exist"): 

669 metadata.set("meta", "data") 

670 

671 self.assertTrue(metadata.empty()) 

672 self.assertIsNone(metadata.get("meta")) 

673 

674 # Also check what happens when configured schema has metadata, but 

675 # database is missing it. Database was initialized inside above context 

676 # without metadata table, here we use schema config which includes 

677 # metadata table. 

678 config.schema_file = self.schema_path 

679 apdb = Apdb.from_config(config) 

680 metadata = apdb.metadata 

681 self.assertTrue(metadata.empty()) 

682 

683 def test_schemaVersionFromYaml(self) -> None: 

684 """Check version number handling for reading schema from YAML.""" 

685 config = self.make_instance() 

686 default_schema = config.schema_file 

687 apdb = Apdb.from_config(config) 

688 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(0, 1, 1)) 

689 

690 with update_schema_yaml(default_schema, version="") as schema_file: 

691 config = self.make_instance(schema_file=schema_file) 

692 apdb = Apdb.from_config(config) 

693 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(0, 1, 0)) 

694 

695 with update_schema_yaml(default_schema, version="99.0.0") as schema_file: 

696 config = self.make_instance(schema_file=schema_file) 

697 apdb = Apdb.from_config(config) 

698 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(99, 0, 0)) 

699 

700 def test_config_freeze(self) -> None: 

701 """Test that some config fields are correctly frozen in database.""" 

702 config = self.make_instance() 

703 

704 # `use_insert_id` is the only parameter that is frozen in all 

705 # implementations. 

706 config.use_insert_id = not self.use_insert_id 

707 apdb = Apdb.from_config(config) 

708 frozen_config = apdb.config # type: ignore[attr-defined] 

709 self.assertEqual(frozen_config.use_insert_id, self.use_insert_id) 

710 

711 

712class ApdbSchemaUpdateTest(TestCaseMixin, ABC): 

713 """Base class for unit tests that verify how schema changes work.""" 

714 

715 visit_time = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai") 

716 

717 @abstractmethod 

718 def make_instance(self, **kwargs: Any) -> ApdbConfig: 

719 """Make config class instance used in all tests. 

720 

721 This method should return configuration that point to the identical 

722 database instance on each call (i.e. ``db_url`` must be the same, 

723 which also means for sqlite it has to use on-disk storage). 

724 """ 

725 raise NotImplementedError() 

726 

727 def test_schema_add_history(self) -> None: 

728 """Check that new code can work with old schema without history 

729 tables. 

730 """ 

731 # Make schema without history tables. 

732 config = self.make_instance(use_insert_id=False) 

733 apdb = Apdb.from_config(config) 

734 

735 # Make APDB instance configured for history tables. 

736 config.use_insert_id = True 

737 apdb = Apdb.from_config(config) 

738 

739 # Try to insert something, should work OK. 

740 region = _make_region() 

741 visit_time = self.visit_time 

742 

743 # have to store Objects first 

744 objects = makeObjectCatalog(region, 100, visit_time) 

745 sources = makeSourceCatalog(objects, visit_time) 

746 fsources = makeForcedSourceCatalog(objects, visit_time) 

747 apdb.store(visit_time, objects, sources, fsources) 

748 

749 # There should be no history. 

750 insert_ids = apdb.getInsertIds() 

751 self.assertIsNone(insert_ids) 

752 

753 def test_schemaVersionCheck(self) -> None: 

754 """Check version number compatibility.""" 

755 config = self.make_instance() 

756 apdb = Apdb.from_config(config) 

757 

758 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(0, 1, 1)) 

759 

760 # Claim that schema version is now 99.0.0, must raise an exception. 

761 with update_schema_yaml(config.schema_file, version="99.0.0") as schema_file: 

762 config.schema_file = schema_file 

763 with self.assertRaises(IncompatibleVersionError): 

764 apdb = Apdb.from_config(config)