Coverage for python/lsst/dax/apdb/tests/_apdb.py: 14%

402 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-03-28 10:11 +0000

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ApdbSchemaUpdateTest", "ApdbTest", "update_schema_yaml"] 

25 

26import contextlib 

27import os 

28import unittest 

29from abc import ABC, abstractmethod 

30from collections.abc import Iterator 

31from tempfile import TemporaryDirectory 

32from typing import TYPE_CHECKING, Any 

33 

34import astropy.time 

35import pandas 

36import yaml 

37from lsst.dax.apdb import ( 

38 Apdb, 

39 ApdbConfig, 

40 ApdbInsertId, 

41 ApdbSql, 

42 ApdbTableData, 

43 ApdbTables, 

44 IncompatibleVersionError, 

45 VersionTuple, 

46 make_apdb, 

47) 

48from lsst.sphgeom import Angle, Circle, Region, UnitVector3d 

49 

50from .data_factory import makeForcedSourceCatalog, makeObjectCatalog, makeSourceCatalog, makeSSObjectCatalog 

51 

52if TYPE_CHECKING: 52 ↛ 54line 52 didn't jump to line 54, because the condition on line 52 was never true

53 

54 class TestCaseMixin(unittest.TestCase): 

55 """Base class for mixin test classes that use TestCase methods.""" 

56 

57else: 

58 

59 class TestCaseMixin: 

60 """Do-nothing definition of mixin base class for regular execution.""" 

61 

62 

63def _make_region(xyz: tuple[float, float, float] = (1.0, 1.0, -1.0)) -> Region: 

64 """Make a region to use in tests""" 

65 pointing_v = UnitVector3d(*xyz) 

66 fov = 0.05 # radians 

67 region = Circle(pointing_v, Angle(fov / 2)) 

68 return region 

69 

70 

71@contextlib.contextmanager 

72def update_schema_yaml( 

73 schema_file: str, 

74 drop_metadata: bool = False, 

75 version: str | None = None, 

76) -> Iterator[str]: 

77 """Update schema definition and return name of the new schema file. 

78 

79 Parameters 

80 ---------- 

81 schema_file : `str` 

82 Path for the existing YAML file with APDB schema. 

83 drop_metadata : `bool` 

84 If `True` then remove metadata table from the list of tables. 

85 version : `str` or `None` 

86 If non-empty string then set schema version to this string, if empty 

87 string then remove schema version from config, if `None` - don't change 

88 the version in config. 

89 

90 Yields 

91 ------ 

92 Path for the updated configuration file. 

93 """ 

94 with open(schema_file) as yaml_stream: 

95 schemas_list = list(yaml.load_all(yaml_stream, Loader=yaml.SafeLoader)) 

96 # Edit YAML contents. 

97 for schema in schemas_list: 

98 # Optionally drop metadata table. 

99 if drop_metadata: 

100 schema["tables"] = [table for table in schema["tables"] if table["name"] != "metadata"] 

101 if version is not None: 

102 if version == "": 

103 del schema["version"] 

104 else: 

105 schema["version"] = version 

106 

107 with TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir: 

108 output_path = os.path.join(tmpdir, "schema.yaml") 

109 with open(output_path, "w") as yaml_stream: 

110 yaml.dump_all(schemas_list, stream=yaml_stream) 

111 yield output_path 

112 

113 

114class ApdbTest(TestCaseMixin, ABC): 

115 """Base class for Apdb tests that can be specialized for concrete 

116 implementation. 

117 

118 This can only be used as a mixin class for a unittest.TestCase and it 

119 calls various assert methods. 

120 """ 

121 

122 time_partition_tables = False 

123 visit_time = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai") 

124 

125 fsrc_requires_id_list = False 

126 """Should be set to True if getDiaForcedSources requires object IDs""" 

127 

128 use_insert_id: bool = False 

129 """Set to true when support for Insert IDs is configured""" 

130 

131 allow_visit_query: bool = True 

132 """Set to true when contains is implemented""" 

133 

134 # number of columns as defined in tests/config/schema.yaml 

135 table_column_count = { 

136 ApdbTables.DiaObject: 8, 

137 ApdbTables.DiaObjectLast: 5, 

138 ApdbTables.DiaSource: 10, 

139 ApdbTables.DiaForcedSource: 4, 

140 ApdbTables.SSObject: 3, 

141 } 

142 

143 @abstractmethod 

144 def make_config(self, **kwargs: Any) -> ApdbConfig: 

145 """Make config class instance used in all tests.""" 

146 raise NotImplementedError() 

147 

148 @abstractmethod 

149 def getDiaObjects_table(self) -> ApdbTables: 

150 """Return type of table returned from getDiaObjects method.""" 

151 raise NotImplementedError() 

152 

153 def assert_catalog(self, catalog: Any, rows: int, table: ApdbTables) -> None: 

154 """Validate catalog type and size 

155 

156 Parameters 

157 ---------- 

158 catalog : `object` 

159 Expected type of this is ``pandas.DataFrame``. 

160 rows : `int` 

161 Expected number of rows in a catalog. 

162 table : `ApdbTables` 

163 APDB table type. 

164 """ 

165 self.assertIsInstance(catalog, pandas.DataFrame) 

166 self.assertEqual(catalog.shape[0], rows) 

167 self.assertEqual(catalog.shape[1], self.table_column_count[table]) 

168 

169 def assert_table_data(self, catalog: Any, rows: int, table: ApdbTables) -> None: 

170 """Validate catalog type and size 

171 

172 Parameters 

173 ---------- 

174 catalog : `object` 

175 Expected type of this is `ApdbTableData`. 

176 rows : `int` 

177 Expected number of rows in a catalog. 

178 table : `ApdbTables` 

179 APDB table type. 

180 extra_columns : `int` 

181 Count of additional columns expected in ``catalog``. 

182 """ 

183 self.assertIsInstance(catalog, ApdbTableData) 

184 n_rows = sum(1 for row in catalog.rows()) 

185 self.assertEqual(n_rows, rows) 

186 # One extra column for insert_id 

187 self.assertEqual(len(catalog.column_names()), self.table_column_count[table] + 1) 

188 

189 def test_makeSchema(self) -> None: 

190 """Test for making APDB schema.""" 

191 config = self.make_config() 

192 Apdb.makeSchema(config) 

193 apdb = make_apdb(config) 

194 

195 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject)) 

196 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObjectLast)) 

197 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaSource)) 

198 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaForcedSource)) 

199 self.assertIsNotNone(apdb.tableDef(ApdbTables.metadata)) 

200 

201 def test_empty_gets(self) -> None: 

202 """Test for getting data from empty database. 

203 

204 All get() methods should return empty results, only useful for 

205 checking that code is not broken. 

206 """ 

207 # use non-zero months for Forced/Source fetching 

208 config = self.make_config() 

209 Apdb.makeSchema(config) 

210 apdb = make_apdb(config) 

211 

212 region = _make_region() 

213 visit_time = self.visit_time 

214 

215 res: pandas.DataFrame | None 

216 

217 # get objects by region 

218 res = apdb.getDiaObjects(region) 

219 self.assert_catalog(res, 0, self.getDiaObjects_table()) 

220 

221 # get sources by region 

222 res = apdb.getDiaSources(region, None, visit_time) 

223 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

224 

225 res = apdb.getDiaSources(region, [], visit_time) 

226 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

227 

228 # get sources by object ID, non-empty object list 

229 res = apdb.getDiaSources(region, [1, 2, 3], visit_time) 

230 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

231 

232 # get forced sources by object ID, empty object list 

233 res = apdb.getDiaForcedSources(region, [], visit_time) 

234 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

235 

236 # get sources by object ID, non-empty object list 

237 res = apdb.getDiaForcedSources(region, [1, 2, 3], visit_time) 

238 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

239 

240 # test if a visit has objects/sources 

241 if self.allow_visit_query: 

242 res = apdb.containsVisitDetector(visit=0, detector=0) 

243 self.assertFalse(res) 

244 else: 

245 with self.assertRaises(NotImplementedError): 

246 apdb.containsVisitDetector(visit=0, detector=0) 

247 

248 # alternative method not part of the Apdb API 

249 if isinstance(apdb, ApdbSql): 

250 res = apdb.containsCcdVisit(1) 

251 self.assertFalse(res) 

252 

253 # get sources by region 

254 if self.fsrc_requires_id_list: 

255 with self.assertRaises(NotImplementedError): 

256 apdb.getDiaForcedSources(region, None, visit_time) 

257 else: 

258 apdb.getDiaForcedSources(region, None, visit_time) 

259 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

260 

261 def test_empty_gets_0months(self) -> None: 

262 """Test for getting data from empty database. 

263 

264 All get() methods should return empty DataFrame or None. 

265 """ 

266 # set read_sources_months to 0 so that Forced/Sources are None 

267 config = self.make_config(read_sources_months=0, read_forced_sources_months=0) 

268 Apdb.makeSchema(config) 

269 apdb = make_apdb(config) 

270 

271 region = _make_region() 

272 visit_time = self.visit_time 

273 

274 res: pandas.DataFrame | None 

275 

276 # get objects by region 

277 res = apdb.getDiaObjects(region) 

278 self.assert_catalog(res, 0, self.getDiaObjects_table()) 

279 

280 # get sources by region 

281 res = apdb.getDiaSources(region, None, visit_time) 

282 self.assertIs(res, None) 

283 

284 # get sources by object ID, empty object list 

285 res = apdb.getDiaSources(region, [], visit_time) 

286 self.assertIs(res, None) 

287 

288 # get forced sources by object ID, empty object list 

289 res = apdb.getDiaForcedSources(region, [], visit_time) 

290 self.assertIs(res, None) 

291 

292 # test if a visit has objects/sources 

293 if self.allow_visit_query: 

294 res = apdb.containsVisitDetector(visit=0, detector=0) 

295 self.assertFalse(res) 

296 else: 

297 with self.assertRaises(NotImplementedError): 

298 apdb.containsVisitDetector(visit=0, detector=0) 

299 

300 # alternative method not part of the Apdb API 

301 if isinstance(apdb, ApdbSql): 

302 res = apdb.containsCcdVisit(1) 

303 self.assertFalse(res) 

304 

305 def test_storeObjects(self) -> None: 

306 """Store and retrieve DiaObjects.""" 

307 # don't care about sources. 

308 config = self.make_config() 

309 Apdb.makeSchema(config) 

310 apdb = make_apdb(config) 

311 

312 region = _make_region() 

313 visit_time = self.visit_time 

314 

315 # make catalog with Objects 

316 catalog = makeObjectCatalog(region, 100, visit_time) 

317 

318 # store catalog 

319 apdb.store(visit_time, catalog) 

320 

321 # read it back and check sizes 

322 res = apdb.getDiaObjects(region) 

323 self.assert_catalog(res, len(catalog), self.getDiaObjects_table()) 

324 

325 # TODO: test apdb.contains with generic implementation from DM-41671 

326 

327 def test_storeObjects_empty(self) -> None: 

328 """Test calling storeObject when there are no objects: see DM-43270.""" 

329 config = self.make_config() 

330 Apdb.makeSchema(config) 

331 apdb = make_apdb(config) 

332 region = _make_region() 

333 visit_time = self.visit_time 

334 # make catalog with no Objects 

335 catalog = makeObjectCatalog(region, 0, visit_time) 

336 

337 with self.assertLogs("lsst.dax.apdb.apdbSql", level="DEBUG") as cm: 

338 apdb.store(visit_time, catalog) 

339 self.assertIn("No objects", "\n".join(cm.output)) 

340 

341 def test_storeSources(self) -> None: 

342 """Store and retrieve DiaSources.""" 

343 config = self.make_config() 

344 Apdb.makeSchema(config) 

345 apdb = make_apdb(config) 

346 

347 region = _make_region() 

348 visit_time = self.visit_time 

349 

350 # have to store Objects first 

351 objects = makeObjectCatalog(region, 100, visit_time) 

352 oids = list(objects["diaObjectId"]) 

353 sources = makeSourceCatalog(objects, visit_time) 

354 

355 # save the objects and sources 

356 apdb.store(visit_time, objects, sources) 

357 

358 # read it back, no ID filtering 

359 res = apdb.getDiaSources(region, None, visit_time) 

360 self.assert_catalog(res, len(sources), ApdbTables.DiaSource) 

361 

362 # read it back and filter by ID 

363 res = apdb.getDiaSources(region, oids, visit_time) 

364 self.assert_catalog(res, len(sources), ApdbTables.DiaSource) 

365 

366 # read it back to get schema 

367 res = apdb.getDiaSources(region, [], visit_time) 

368 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

369 

370 # test if a visit is present 

371 # data_factory's ccdVisitId generation corresponds to (0, 0) 

372 if self.allow_visit_query: 

373 res = apdb.containsVisitDetector(visit=0, detector=0) 

374 self.assertTrue(res) 

375 else: 

376 with self.assertRaises(NotImplementedError): 

377 apdb.containsVisitDetector(visit=0, detector=0) 

378 

379 # alternative method not part of the Apdb API 

380 if isinstance(apdb, ApdbSql): 

381 res = apdb.containsCcdVisit(1) 

382 self.assertTrue(res) 

383 res = apdb.containsCcdVisit(42) 

384 self.assertFalse(res) 

385 

386 def test_storeForcedSources(self) -> None: 

387 """Store and retrieve DiaForcedSources.""" 

388 config = self.make_config() 

389 Apdb.makeSchema(config) 

390 apdb = make_apdb(config) 

391 

392 region = _make_region() 

393 visit_time = self.visit_time 

394 

395 # have to store Objects first 

396 objects = makeObjectCatalog(region, 100, visit_time) 

397 oids = list(objects["diaObjectId"]) 

398 catalog = makeForcedSourceCatalog(objects, visit_time) 

399 

400 apdb.store(visit_time, objects, forced_sources=catalog) 

401 

402 # read it back and check sizes 

403 res = apdb.getDiaForcedSources(region, oids, visit_time) 

404 self.assert_catalog(res, len(catalog), ApdbTables.DiaForcedSource) 

405 

406 # read it back to get schema 

407 res = apdb.getDiaForcedSources(region, [], visit_time) 

408 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

409 

410 # TODO: test apdb.contains with generic implementation from DM-41671 

411 

412 # alternative method not part of the Apdb API 

413 if isinstance(apdb, ApdbSql): 

414 res = apdb.containsCcdVisit(1) 

415 self.assertTrue(res) 

416 res = apdb.containsCcdVisit(42) 

417 self.assertFalse(res) 

418 

419 def test_getHistory(self) -> None: 

420 """Store and retrieve catalog history.""" 

421 # don't care about sources. 

422 config = self.make_config() 

423 Apdb.makeSchema(config) 

424 apdb = make_apdb(config) 

425 visit_time = self.visit_time 

426 

427 region1 = _make_region((1.0, 1.0, -1.0)) 

428 region2 = _make_region((-1.0, -1.0, -1.0)) 

429 nobj = 100 

430 objects1 = makeObjectCatalog(region1, nobj, visit_time) 

431 objects2 = makeObjectCatalog(region2, nobj, visit_time, start_id=nobj * 2) 

432 

433 visits = [ 

434 (astropy.time.Time("2021-01-01T00:01:00", format="isot", scale="tai"), objects1), 

435 (astropy.time.Time("2021-01-01T00:02:00", format="isot", scale="tai"), objects2), 

436 (astropy.time.Time("2021-01-01T00:03:00", format="isot", scale="tai"), objects1), 

437 (astropy.time.Time("2021-01-01T00:04:00", format="isot", scale="tai"), objects2), 

438 (astropy.time.Time("2021-01-01T00:05:00", format="isot", scale="tai"), objects1), 

439 (astropy.time.Time("2021-01-01T00:06:00", format="isot", scale="tai"), objects2), 

440 (astropy.time.Time("2021-03-01T00:01:00", format="isot", scale="tai"), objects1), 

441 (astropy.time.Time("2021-03-01T00:02:00", format="isot", scale="tai"), objects2), 

442 ] 

443 

444 start_id = 0 

445 for visit_time, objects in visits: 

446 sources = makeSourceCatalog(objects, visit_time, start_id=start_id) 

447 fsources = makeForcedSourceCatalog(objects, visit_time, ccdVisitId=start_id) 

448 apdb.store(visit_time, objects, sources, fsources) 

449 start_id += nobj 

450 

451 insert_ids = apdb.getInsertIds() 

452 if not self.use_insert_id: 

453 self.assertIsNone(insert_ids) 

454 

455 with self.assertRaisesRegex(ValueError, "APDB is not configured for history retrieval"): 

456 apdb.getDiaObjectsHistory([]) 

457 

458 else: 

459 assert insert_ids is not None 

460 self.assertEqual(len(insert_ids), 8) 

461 

462 def _check_history(insert_ids: list[ApdbInsertId], n_records: int | None = None) -> None: 

463 if n_records is None: 

464 n_records = len(insert_ids) * nobj 

465 res = apdb.getDiaObjectsHistory(insert_ids) 

466 self.assert_table_data(res, n_records, ApdbTables.DiaObject) 

467 res = apdb.getDiaSourcesHistory(insert_ids) 

468 self.assert_table_data(res, n_records, ApdbTables.DiaSource) 

469 res = apdb.getDiaForcedSourcesHistory(insert_ids) 

470 self.assert_table_data(res, n_records, ApdbTables.DiaForcedSource) 

471 

472 # read it back and check sizes 

473 _check_history(insert_ids) 

474 _check_history(insert_ids[1:]) 

475 _check_history(insert_ids[1:-1]) 

476 _check_history(insert_ids[3:4]) 

477 _check_history([]) 

478 

479 # try to remove some of those 

480 deleted_ids = insert_ids[:2] 

481 apdb.deleteInsertIds(deleted_ids) 

482 

483 # All queries on deleted ids should return empty set. 

484 _check_history(deleted_ids, 0) 

485 

486 insert_ids = apdb.getInsertIds() 

487 assert insert_ids is not None 

488 self.assertEqual(len(insert_ids), 6) 

489 

490 _check_history(insert_ids) 

491 

492 def test_storeSSObjects(self) -> None: 

493 """Store and retrieve SSObjects.""" 

494 # don't care about sources. 

495 config = self.make_config() 

496 Apdb.makeSchema(config) 

497 apdb = make_apdb(config) 

498 

499 # make catalog with SSObjects 

500 catalog = makeSSObjectCatalog(100, flags=1) 

501 

502 # store catalog 

503 apdb.storeSSObjects(catalog) 

504 

505 # read it back and check sizes 

506 res = apdb.getSSObjects() 

507 self.assert_catalog(res, len(catalog), ApdbTables.SSObject) 

508 

509 # check that override works, make catalog with SSObjects, ID = 51-150 

510 catalog = makeSSObjectCatalog(100, 51, flags=2) 

511 apdb.storeSSObjects(catalog) 

512 res = apdb.getSSObjects() 

513 self.assert_catalog(res, 150, ApdbTables.SSObject) 

514 self.assertEqual(len(res[res["flags"] == 1]), 50) 

515 self.assertEqual(len(res[res["flags"] == 2]), 100) 

516 

517 def test_reassignObjects(self) -> None: 

518 """Reassign DiaObjects.""" 

519 # don't care about sources. 

520 config = self.make_config() 

521 Apdb.makeSchema(config) 

522 apdb = make_apdb(config) 

523 

524 region = _make_region() 

525 visit_time = self.visit_time 

526 objects = makeObjectCatalog(region, 100, visit_time) 

527 oids = list(objects["diaObjectId"]) 

528 sources = makeSourceCatalog(objects, visit_time) 

529 apdb.store(visit_time, objects, sources) 

530 

531 catalog = makeSSObjectCatalog(100) 

532 apdb.storeSSObjects(catalog) 

533 

534 # read it back and filter by ID 

535 res = apdb.getDiaSources(region, oids, visit_time) 

536 self.assert_catalog(res, len(sources), ApdbTables.DiaSource) 

537 

538 apdb.reassignDiaSources({1: 1, 2: 2, 5: 5}) 

539 res = apdb.getDiaSources(region, oids, visit_time) 

540 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource) 

541 

542 with self.assertRaisesRegex(ValueError, r"do not exist.*\D1000"): 

543 apdb.reassignDiaSources( 

544 { 

545 1000: 1, 

546 7: 3, 

547 } 

548 ) 

549 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource) 

550 

551 def test_midpointMjdTai_src(self) -> None: 

552 """Test for time filtering of DiaSources.""" 

553 config = self.make_config() 

554 Apdb.makeSchema(config) 

555 apdb = make_apdb(config) 

556 

557 region = _make_region() 

558 # 2021-01-01 plus 360 days is 2021-12-27 

559 src_time1 = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai") 

560 src_time2 = astropy.time.Time("2021-01-01T00:00:02", format="isot", scale="tai") 

561 visit_time0 = astropy.time.Time("2021-12-26T23:59:59", format="isot", scale="tai") 

562 visit_time1 = astropy.time.Time("2021-12-27T00:00:01", format="isot", scale="tai") 

563 visit_time2 = astropy.time.Time("2021-12-27T00:00:03", format="isot", scale="tai") 

564 

565 objects = makeObjectCatalog(region, 100, visit_time0) 

566 oids = list(objects["diaObjectId"]) 

567 sources = makeSourceCatalog(objects, src_time1, 0) 

568 apdb.store(src_time1, objects, sources) 

569 

570 sources = makeSourceCatalog(objects, src_time2, 100) 

571 apdb.store(src_time2, objects, sources) 

572 

573 # reading at time of last save should read all 

574 res = apdb.getDiaSources(region, oids, src_time2) 

575 self.assert_catalog(res, 200, ApdbTables.DiaSource) 

576 

577 # one second before 12 months 

578 res = apdb.getDiaSources(region, oids, visit_time0) 

579 self.assert_catalog(res, 200, ApdbTables.DiaSource) 

580 

581 # reading at later time of last save should only read a subset 

582 res = apdb.getDiaSources(region, oids, visit_time1) 

583 self.assert_catalog(res, 100, ApdbTables.DiaSource) 

584 

585 # reading at later time of last save should only read a subset 

586 res = apdb.getDiaSources(region, oids, visit_time2) 

587 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

588 

589 def test_midpointMjdTai_fsrc(self) -> None: 

590 """Test for time filtering of DiaForcedSources.""" 

591 config = self.make_config() 

592 Apdb.makeSchema(config) 

593 apdb = make_apdb(config) 

594 

595 region = _make_region() 

596 src_time1 = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai") 

597 src_time2 = astropy.time.Time("2021-01-01T00:00:02", format="isot", scale="tai") 

598 visit_time0 = astropy.time.Time("2021-12-26T23:59:59", format="isot", scale="tai") 

599 visit_time1 = astropy.time.Time("2021-12-27T00:00:01", format="isot", scale="tai") 

600 visit_time2 = astropy.time.Time("2021-12-27T00:00:03", format="isot", scale="tai") 

601 

602 objects = makeObjectCatalog(region, 100, visit_time0) 

603 oids = list(objects["diaObjectId"]) 

604 sources = makeForcedSourceCatalog(objects, src_time1, 1) 

605 apdb.store(src_time1, objects, forced_sources=sources) 

606 

607 sources = makeForcedSourceCatalog(objects, src_time2, 2) 

608 apdb.store(src_time2, objects, forced_sources=sources) 

609 

610 # reading at time of last save should read all 

611 res = apdb.getDiaForcedSources(region, oids, src_time2) 

612 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource) 

613 

614 # one second before 12 months 

615 res = apdb.getDiaForcedSources(region, oids, visit_time0) 

616 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource) 

617 

618 # reading at later time of last save should only read a subset 

619 res = apdb.getDiaForcedSources(region, oids, visit_time1) 

620 self.assert_catalog(res, 100, ApdbTables.DiaForcedSource) 

621 

622 # reading at later time of last save should only read a subset 

623 res = apdb.getDiaForcedSources(region, oids, visit_time2) 

624 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

625 

626 def test_metadata(self) -> None: 

627 """Simple test for writing/reading metadata table""" 

628 config = self.make_config() 

629 Apdb.makeSchema(config) 

630 apdb = make_apdb(config) 

631 metadata = apdb.metadata 

632 

633 # APDB should write two metadata items with version numbers and a 

634 # frozen JSON config. 

635 self.assertFalse(metadata.empty()) 

636 self.assertEqual(len(list(metadata.items())), 3) 

637 

638 metadata.set("meta", "data") 

639 metadata.set("data", "meta") 

640 

641 self.assertFalse(metadata.empty()) 

642 self.assertTrue(set(metadata.items()) >= {("meta", "data"), ("data", "meta")}) 

643 

644 with self.assertRaisesRegex(KeyError, "Metadata key 'meta' already exists"): 

645 metadata.set("meta", "data1") 

646 

647 metadata.set("meta", "data2", force=True) 

648 self.assertTrue(set(metadata.items()) >= {("meta", "data2"), ("data", "meta")}) 

649 

650 self.assertTrue(metadata.delete("meta")) 

651 self.assertIsNone(metadata.get("meta")) 

652 self.assertFalse(metadata.delete("meta")) 

653 

654 self.assertEqual(metadata.get("data"), "meta") 

655 self.assertEqual(metadata.get("meta", "meta"), "meta") 

656 

657 def test_nometadata(self) -> None: 

658 """Test case for when metadata table is missing""" 

659 config = self.make_config() 

660 # We expect that schema includes metadata table, drop it. 

661 with update_schema_yaml(config.schema_file, drop_metadata=True) as schema_file: 

662 config_nometa = self.make_config(schema_file=schema_file) 

663 Apdb.makeSchema(config_nometa) 

664 apdb = make_apdb(config_nometa) 

665 metadata = apdb.metadata 

666 

667 self.assertTrue(metadata.empty()) 

668 self.assertEqual(list(metadata.items()), []) 

669 with self.assertRaisesRegex(RuntimeError, "Metadata table does not exist"): 

670 metadata.set("meta", "data") 

671 

672 self.assertTrue(metadata.empty()) 

673 self.assertIsNone(metadata.get("meta")) 

674 

675 # Also check what happens when configured schema has metadata, but 

676 # database is missing it. Database was initialized inside above context 

677 # without metadata table, here we use schema config which includes 

678 # metadata table. 

679 apdb = make_apdb(config) 

680 metadata = apdb.metadata 

681 self.assertTrue(metadata.empty()) 

682 

683 def test_schemaVersionFromYaml(self) -> None: 

684 """Check version number handling for reading schema from YAML.""" 

685 config = self.make_config() 

686 default_schema = config.schema_file 

687 apdb = make_apdb(config) 

688 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(0, 1, 1)) 

689 

690 with update_schema_yaml(default_schema, version="") as schema_file: 

691 config = self.make_config(schema_file=schema_file) 

692 apdb = make_apdb(config) 

693 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(0, 1, 0)) 

694 

695 with update_schema_yaml(default_schema, version="99.0.0") as schema_file: 

696 config = self.make_config(schema_file=schema_file) 

697 apdb = make_apdb(config) 

698 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(99, 0, 0)) 

699 

700 def test_config_freeze(self) -> None: 

701 """Test that some config fields are correctly frozen in database.""" 

702 config = self.make_config() 

703 Apdb.makeSchema(config) 

704 

705 # `use_insert_id` is the only parameter that is frozen in all 

706 # implementations. 

707 config.use_insert_id = not self.use_insert_id 

708 apdb = make_apdb(config) 

709 frozen_config = apdb.config # type: ignore[attr-defined] 

710 self.assertEqual(frozen_config.use_insert_id, self.use_insert_id) 

711 

712 

713class ApdbSchemaUpdateTest(TestCaseMixin, ABC): 

714 """Base class for unit tests that verify how schema changes work.""" 

715 

716 visit_time = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai") 

717 

718 @abstractmethod 

719 def make_config(self, **kwargs: Any) -> ApdbConfig: 

720 """Make config class instance used in all tests. 

721 

722 This method should return configuration that point to the identical 

723 database instance on each call (i.e. ``db_url`` must be the same, 

724 which also means for sqlite it has to use on-disk storage). 

725 """ 

726 raise NotImplementedError() 

727 

728 def test_schema_add_history(self) -> None: 

729 """Check that new code can work with old schema without history 

730 tables. 

731 """ 

732 # Make schema without history tables. 

733 config = self.make_config(use_insert_id=False) 

734 Apdb.makeSchema(config) 

735 apdb = make_apdb(config) 

736 

737 # Make APDB instance configured for history tables. 

738 config = self.make_config(use_insert_id=True) 

739 apdb = make_apdb(config) 

740 

741 # Try to insert something, should work OK. 

742 region = _make_region() 

743 visit_time = self.visit_time 

744 

745 # have to store Objects first 

746 objects = makeObjectCatalog(region, 100, visit_time) 

747 sources = makeSourceCatalog(objects, visit_time) 

748 fsources = makeForcedSourceCatalog(objects, visit_time) 

749 apdb.store(visit_time, objects, sources, fsources) 

750 

751 # There should be no history. 

752 insert_ids = apdb.getInsertIds() 

753 self.assertIsNone(insert_ids) 

754 

755 def test_schemaVersionCheck(self) -> None: 

756 """Check version number compatibility.""" 

757 config = self.make_config() 

758 Apdb.makeSchema(config) 

759 apdb = make_apdb(config) 

760 

761 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(0, 1, 1)) 

762 

763 # Claim that schema version is now 99.0.0, must raise an exception. 

764 with update_schema_yaml(config.schema_file, version="99.0.0") as schema_file: 

765 config = self.make_config(schema_file=schema_file) 

766 with self.assertRaises(IncompatibleVersionError): 

767 apdb = make_apdb(config)