Coverage for python/lsst/dax/apdb/tests/_apdb.py: 15%

373 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-16 03:20 -0700

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ApdbSchemaUpdateTest", "ApdbTest", "update_schema_yaml"] 

25 

26import contextlib 

27import os 

28import tempfile 

29import unittest 

30from abc import ABC, abstractmethod 

31from collections.abc import Iterator 

32from tempfile import TemporaryDirectory 

33from typing import TYPE_CHECKING, Any 

34 

35import astropy.time 

36import pandas 

37import yaml 

38from lsst.dax.apdb import ( 

39 Apdb, 

40 ApdbConfig, 

41 ApdbReplica, 

42 ApdbTableData, 

43 ApdbTables, 

44 IncompatibleVersionError, 

45 ReplicaChunk, 

46 VersionTuple, 

47) 

48from lsst.sphgeom import Angle, Circle, Region, UnitVector3d 

49 

50from .data_factory import makeForcedSourceCatalog, makeObjectCatalog, makeSourceCatalog, makeSSObjectCatalog 

51 

52if TYPE_CHECKING: 

53 

54 class TestCaseMixin(unittest.TestCase): 

55 """Base class for mixin test classes that use TestCase methods.""" 

56 

57else: 

58 

59 class TestCaseMixin: 

60 """Do-nothing definition of mixin base class for regular execution.""" 

61 

62 

63def _make_region(xyz: tuple[float, float, float] = (1.0, 1.0, -1.0)) -> Region: 

64 """Make a region to use in tests""" 

65 pointing_v = UnitVector3d(*xyz) 

66 fov = 0.05 # radians 

67 region = Circle(pointing_v, Angle(fov / 2)) 

68 return region 

69 

70 

71@contextlib.contextmanager 

72def update_schema_yaml( 

73 schema_file: str, 

74 drop_metadata: bool = False, 

75 version: str | None = None, 

76) -> Iterator[str]: 

77 """Update schema definition and return name of the new schema file. 

78 

79 Parameters 

80 ---------- 

81 schema_file : `str` 

82 Path for the existing YAML file with APDB schema. 

83 drop_metadata : `bool` 

84 If `True` then remove metadata table from the list of tables. 

85 version : `str` or `None` 

86 If non-empty string then set schema version to this string, if empty 

87 string then remove schema version from config, if `None` - don't change 

88 the version in config. 

89 

90 Yields 

91 ------ 

92 Path for the updated configuration file. 

93 """ 

94 with open(schema_file) as yaml_stream: 

95 schemas_list = list(yaml.load_all(yaml_stream, Loader=yaml.SafeLoader)) 

96 # Edit YAML contents. 

97 for schema in schemas_list: 

98 # Optionally drop metadata table. 

99 if drop_metadata: 

100 schema["tables"] = [table for table in schema["tables"] if table["name"] != "metadata"] 

101 if version is not None: 

102 if version == "": 

103 del schema["version"] 

104 else: 

105 schema["version"] = version 

106 

107 with TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir: 

108 output_path = os.path.join(tmpdir, "schema.yaml") 

109 with open(output_path, "w") as yaml_stream: 

110 yaml.dump_all(schemas_list, stream=yaml_stream) 

111 yield output_path 

112 

113 

114class ApdbTest(TestCaseMixin, ABC): 

115 """Base class for Apdb tests that can be specialized for concrete 

116 implementation. 

117 

118 This can only be used as a mixin class for a unittest.TestCase and it 

119 calls various assert methods. 

120 """ 

121 

122 time_partition_tables = False 

123 visit_time = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai") 

124 

125 fsrc_requires_id_list = False 

126 """Should be set to True if getDiaForcedSources requires object IDs""" 

127 

128 enable_replica: bool = False 

129 """Set to true when support for replication is configured""" 

130 

131 schema_path: str 

132 """Location of the Felis schema file.""" 

133 

134 # number of columns as defined in tests/config/schema.yaml 

135 table_column_count = { 

136 ApdbTables.DiaObject: 8, 

137 ApdbTables.DiaObjectLast: 5, 

138 ApdbTables.DiaSource: 11, 

139 ApdbTables.DiaForcedSource: 5, 

140 ApdbTables.SSObject: 3, 

141 } 

142 

143 @abstractmethod 

144 def make_instance(self, **kwargs: Any) -> ApdbConfig: 

145 """Make database instance and return configuration for it.""" 

146 raise NotImplementedError() 

147 

148 @abstractmethod 

149 def getDiaObjects_table(self) -> ApdbTables: 

150 """Return type of table returned from getDiaObjects method.""" 

151 raise NotImplementedError() 

152 

153 def assert_catalog(self, catalog: Any, rows: int, table: ApdbTables) -> None: 

154 """Validate catalog type and size 

155 

156 Parameters 

157 ---------- 

158 catalog : `object` 

159 Expected type of this is ``pandas.DataFrame``. 

160 rows : `int` 

161 Expected number of rows in a catalog. 

162 table : `ApdbTables` 

163 APDB table type. 

164 """ 

165 self.assertIsInstance(catalog, pandas.DataFrame) 

166 self.assertEqual(catalog.shape[0], rows) 

167 self.assertEqual(catalog.shape[1], self.table_column_count[table]) 

168 

169 def assert_table_data(self, catalog: Any, rows: int, table: ApdbTables) -> None: 

170 """Validate catalog type and size 

171 

172 Parameters 

173 ---------- 

174 catalog : `object` 

175 Expected type of this is `ApdbTableData`. 

176 rows : `int` 

177 Expected number of rows in a catalog. 

178 table : `ApdbTables` 

179 APDB table type. 

180 extra_columns : `int` 

181 Count of additional columns expected in ``catalog``. 

182 """ 

183 self.assertIsInstance(catalog, ApdbTableData) 

184 n_rows = sum(1 for row in catalog.rows()) 

185 self.assertEqual(n_rows, rows) 

186 # One extra column for replica chunk id 

187 self.assertEqual(len(catalog.column_names()), self.table_column_count[table] + 1) 

188 

189 def test_makeSchema(self) -> None: 

190 """Test for making APDB schema.""" 

191 config = self.make_instance() 

192 apdb = Apdb.from_config(config) 

193 

194 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject)) 

195 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObjectLast)) 

196 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaSource)) 

197 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaForcedSource)) 

198 self.assertIsNotNone(apdb.tableDef(ApdbTables.metadata)) 

199 

200 # Test from_uri factory method with the same config. 

201 with tempfile.NamedTemporaryFile() as tmpfile: 

202 config.save(tmpfile.name) 

203 apdb = Apdb.from_uri(tmpfile.name) 

204 

205 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject)) 

206 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObjectLast)) 

207 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaSource)) 

208 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaForcedSource)) 

209 self.assertIsNotNone(apdb.tableDef(ApdbTables.metadata)) 

210 

211 def test_empty_gets(self) -> None: 

212 """Test for getting data from empty database. 

213 

214 All get() methods should return empty results, only useful for 

215 checking that code is not broken. 

216 """ 

217 # use non-zero months for Forced/Source fetching 

218 config = self.make_instance() 

219 apdb = Apdb.from_config(config) 

220 

221 region = _make_region() 

222 visit_time = self.visit_time 

223 

224 res: pandas.DataFrame | None 

225 

226 # get objects by region 

227 res = apdb.getDiaObjects(region) 

228 self.assert_catalog(res, 0, self.getDiaObjects_table()) 

229 

230 # get sources by region 

231 res = apdb.getDiaSources(region, None, visit_time) 

232 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

233 

234 res = apdb.getDiaSources(region, [], visit_time) 

235 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

236 

237 # get sources by object ID, non-empty object list 

238 res = apdb.getDiaSources(region, [1, 2, 3], visit_time) 

239 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

240 

241 # get forced sources by object ID, empty object list 

242 res = apdb.getDiaForcedSources(region, [], visit_time) 

243 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

244 

245 # get sources by object ID, non-empty object list 

246 res = apdb.getDiaForcedSources(region, [1, 2, 3], visit_time) 

247 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

248 

249 # data_factory's ccdVisitId generation corresponds to (1, 1) 

250 res = apdb.containsVisitDetector(visit=1, detector=1) 

251 self.assertFalse(res) 

252 

253 # get sources by region 

254 if self.fsrc_requires_id_list: 

255 with self.assertRaises(NotImplementedError): 

256 apdb.getDiaForcedSources(region, None, visit_time) 

257 else: 

258 res = apdb.getDiaForcedSources(region, None, visit_time) 

259 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

260 

261 def test_empty_gets_0months(self) -> None: 

262 """Test for getting data from empty database. 

263 

264 All get() methods should return empty DataFrame or None. 

265 """ 

266 # set read_sources_months to 0 so that Forced/Sources are None 

267 config = self.make_instance(read_sources_months=0, read_forced_sources_months=0) 

268 apdb = Apdb.from_config(config) 

269 

270 region = _make_region() 

271 visit_time = self.visit_time 

272 

273 res: pandas.DataFrame | None 

274 

275 # get objects by region 

276 res = apdb.getDiaObjects(region) 

277 self.assert_catalog(res, 0, self.getDiaObjects_table()) 

278 

279 # get sources by region 

280 res = apdb.getDiaSources(region, None, visit_time) 

281 self.assertIs(res, None) 

282 

283 # get sources by object ID, empty object list 

284 res = apdb.getDiaSources(region, [], visit_time) 

285 self.assertIs(res, None) 

286 

287 # get forced sources by object ID, empty object list 

288 res = apdb.getDiaForcedSources(region, [], visit_time) 

289 self.assertIs(res, None) 

290 

291 # Database is empty, no images exist. 

292 res = apdb.containsVisitDetector(visit=1, detector=1) 

293 self.assertFalse(res) 

294 

295 def test_storeObjects(self) -> None: 

296 """Store and retrieve DiaObjects.""" 

297 # don't care about sources. 

298 config = self.make_instance() 

299 apdb = Apdb.from_config(config) 

300 

301 region = _make_region() 

302 visit_time = self.visit_time 

303 

304 # make catalog with Objects 

305 catalog = makeObjectCatalog(region, 100, visit_time) 

306 

307 # store catalog 

308 apdb.store(visit_time, catalog) 

309 

310 # read it back and check sizes 

311 res = apdb.getDiaObjects(region) 

312 self.assert_catalog(res, len(catalog), self.getDiaObjects_table()) 

313 

314 # TODO: test apdb.contains with generic implementation from DM-41671 

315 

316 def test_storeObjects_empty(self) -> None: 

317 """Test calling storeObject when there are no objects: see DM-43270.""" 

318 config = self.make_instance() 

319 apdb = Apdb.from_config(config) 

320 region = _make_region() 

321 visit_time = self.visit_time 

322 # make catalog with no Objects 

323 catalog = makeObjectCatalog(region, 0, visit_time) 

324 

325 with self.assertLogs("lsst.dax.apdb", level="DEBUG") as cm: 

326 apdb.store(visit_time, catalog) 

327 self.assertIn("No objects", "\n".join(cm.output)) 

328 

329 def test_storeSources(self) -> None: 

330 """Store and retrieve DiaSources.""" 

331 config = self.make_instance() 

332 apdb = Apdb.from_config(config) 

333 

334 region = _make_region() 

335 visit_time = self.visit_time 

336 

337 # have to store Objects first 

338 objects = makeObjectCatalog(region, 100, visit_time) 

339 oids = list(objects["diaObjectId"]) 

340 sources = makeSourceCatalog(objects, visit_time) 

341 

342 # save the objects and sources 

343 apdb.store(visit_time, objects, sources) 

344 

345 # read it back, no ID filtering 

346 res = apdb.getDiaSources(region, None, visit_time) 

347 self.assert_catalog(res, len(sources), ApdbTables.DiaSource) 

348 

349 # read it back and filter by ID 

350 res = apdb.getDiaSources(region, oids, visit_time) 

351 self.assert_catalog(res, len(sources), ApdbTables.DiaSource) 

352 

353 # read it back to get schema 

354 res = apdb.getDiaSources(region, [], visit_time) 

355 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

356 

357 # test if a visit is present 

358 # data_factory's ccdVisitId generation corresponds to (1, 1) 

359 res = apdb.containsVisitDetector(visit=1, detector=1) 

360 self.assertTrue(res) 

361 # non-existent image 

362 res = apdb.containsVisitDetector(visit=2, detector=42) 

363 self.assertFalse(res) 

364 

365 def test_storeForcedSources(self) -> None: 

366 """Store and retrieve DiaForcedSources.""" 

367 config = self.make_instance() 

368 apdb = Apdb.from_config(config) 

369 

370 region = _make_region() 

371 visit_time = self.visit_time 

372 

373 # have to store Objects first 

374 objects = makeObjectCatalog(region, 100, visit_time) 

375 oids = list(objects["diaObjectId"]) 

376 catalog = makeForcedSourceCatalog(objects, visit_time) 

377 

378 apdb.store(visit_time, objects, forced_sources=catalog) 

379 

380 # read it back and check sizes 

381 res = apdb.getDiaForcedSources(region, oids, visit_time) 

382 self.assert_catalog(res, len(catalog), ApdbTables.DiaForcedSource) 

383 

384 # read it back to get schema 

385 res = apdb.getDiaForcedSources(region, [], visit_time) 

386 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

387 

388 # data_factory's ccdVisitId generation corresponds to (1, 1) 

389 res = apdb.containsVisitDetector(visit=1, detector=1) 

390 self.assertTrue(res) 

391 # non-existent image 

392 res = apdb.containsVisitDetector(visit=2, detector=42) 

393 self.assertFalse(res) 

394 

395 def test_getChunks(self) -> None: 

396 """Store and retrieve replica chunks.""" 

397 # don't care about sources. 

398 config = self.make_instance() 

399 apdb = Apdb.from_config(config) 

400 apdb_replica = ApdbReplica.from_config(config) 

401 visit_time = self.visit_time 

402 

403 region1 = _make_region((1.0, 1.0, -1.0)) 

404 region2 = _make_region((-1.0, -1.0, -1.0)) 

405 nobj = 100 

406 objects1 = makeObjectCatalog(region1, nobj, visit_time) 

407 objects2 = makeObjectCatalog(region2, nobj, visit_time, start_id=nobj * 2) 

408 

409 # With the default 10 minutes replica chunk window we should have 4 

410 # records. 

411 visits = [ 

412 (astropy.time.Time("2021-01-01T00:01:00", format="isot", scale="tai"), objects1), 

413 (astropy.time.Time("2021-01-01T00:02:00", format="isot", scale="tai"), objects2), 

414 (astropy.time.Time("2021-01-01T00:11:00", format="isot", scale="tai"), objects1), 

415 (astropy.time.Time("2021-01-01T00:12:00", format="isot", scale="tai"), objects2), 

416 (astropy.time.Time("2021-01-01T00:45:00", format="isot", scale="tai"), objects1), 

417 (astropy.time.Time("2021-01-01T00:46:00", format="isot", scale="tai"), objects2), 

418 (astropy.time.Time("2021-03-01T00:01:00", format="isot", scale="tai"), objects1), 

419 (astropy.time.Time("2021-03-01T00:02:00", format="isot", scale="tai"), objects2), 

420 ] 

421 

422 start_id = 0 

423 for visit_time, objects in visits: 

424 sources = makeSourceCatalog(objects, visit_time, start_id=start_id) 

425 fsources = makeForcedSourceCatalog(objects, visit_time, visit=start_id) 

426 apdb.store(visit_time, objects, sources, fsources) 

427 start_id += nobj 

428 

429 replica_chunks = apdb_replica.getReplicaChunks() 

430 if not self.enable_replica: 

431 self.assertIsNone(replica_chunks) 

432 

433 with self.assertRaisesRegex(ValueError, "APDB is not configured for replication"): 

434 apdb_replica.getDiaObjectsChunks([]) 

435 

436 else: 

437 assert replica_chunks is not None 

438 self.assertEqual(len(replica_chunks), 4) 

439 

440 def _check_chunks(replica_chunks: list[ReplicaChunk], n_records: int | None = None) -> None: 

441 if n_records is None: 

442 n_records = len(replica_chunks) * nobj 

443 res = apdb_replica.getDiaObjectsChunks(chunk.id for chunk in replica_chunks) 

444 self.assert_table_data(res, n_records, ApdbTables.DiaObject) 

445 res = apdb_replica.getDiaSourcesChunks(chunk.id for chunk in replica_chunks) 

446 self.assert_table_data(res, n_records, ApdbTables.DiaSource) 

447 res = apdb_replica.getDiaForcedSourcesChunks(chunk.id for chunk in replica_chunks) 

448 self.assert_table_data(res, n_records, ApdbTables.DiaForcedSource) 

449 

450 # read it back and check sizes 

451 _check_chunks(replica_chunks, 800) 

452 _check_chunks(replica_chunks[1:], 600) 

453 _check_chunks(replica_chunks[1:-1], 400) 

454 _check_chunks(replica_chunks[2:3], 200) 

455 _check_chunks([]) 

456 

457 # try to remove some of those 

458 deleted_chunks = replica_chunks[:1] 

459 apdb_replica.deleteReplicaChunks(chunk.id for chunk in deleted_chunks) 

460 

461 # All queries on deleted ids should return empty set. 

462 _check_chunks(deleted_chunks, 0) 

463 

464 replica_chunks = apdb_replica.getReplicaChunks() 

465 assert replica_chunks is not None 

466 self.assertEqual(len(replica_chunks), 3) 

467 

468 _check_chunks(replica_chunks, 600) 

469 

470 def test_storeSSObjects(self) -> None: 

471 """Store and retrieve SSObjects.""" 

472 # don't care about sources. 

473 config = self.make_instance() 

474 apdb = Apdb.from_config(config) 

475 

476 # make catalog with SSObjects 

477 catalog = makeSSObjectCatalog(100, flags=1) 

478 

479 # store catalog 

480 apdb.storeSSObjects(catalog) 

481 

482 # read it back and check sizes 

483 res = apdb.getSSObjects() 

484 self.assert_catalog(res, len(catalog), ApdbTables.SSObject) 

485 

486 # check that override works, make catalog with SSObjects, ID = 51-150 

487 catalog = makeSSObjectCatalog(100, 51, flags=2) 

488 apdb.storeSSObjects(catalog) 

489 res = apdb.getSSObjects() 

490 self.assert_catalog(res, 150, ApdbTables.SSObject) 

491 self.assertEqual(len(res[res["flags"] == 1]), 50) 

492 self.assertEqual(len(res[res["flags"] == 2]), 100) 

493 

494 def test_reassignObjects(self) -> None: 

495 """Reassign DiaObjects.""" 

496 # don't care about sources. 

497 config = self.make_instance() 

498 apdb = Apdb.from_config(config) 

499 

500 region = _make_region() 

501 visit_time = self.visit_time 

502 objects = makeObjectCatalog(region, 100, visit_time) 

503 oids = list(objects["diaObjectId"]) 

504 sources = makeSourceCatalog(objects, visit_time) 

505 apdb.store(visit_time, objects, sources) 

506 

507 catalog = makeSSObjectCatalog(100) 

508 apdb.storeSSObjects(catalog) 

509 

510 # read it back and filter by ID 

511 res = apdb.getDiaSources(region, oids, visit_time) 

512 self.assert_catalog(res, len(sources), ApdbTables.DiaSource) 

513 

514 apdb.reassignDiaSources({1: 1, 2: 2, 5: 5}) 

515 res = apdb.getDiaSources(region, oids, visit_time) 

516 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource) 

517 

518 with self.assertRaisesRegex(ValueError, r"do not exist.*\D1000"): 

519 apdb.reassignDiaSources( 

520 { 

521 1000: 1, 

522 7: 3, 

523 } 

524 ) 

525 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource) 

526 

527 def test_midpointMjdTai_src(self) -> None: 

528 """Test for time filtering of DiaSources.""" 

529 config = self.make_instance() 

530 apdb = Apdb.from_config(config) 

531 

532 region = _make_region() 

533 # 2021-01-01 plus 360 days is 2021-12-27 

534 src_time1 = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai") 

535 src_time2 = astropy.time.Time("2021-01-01T00:00:02", format="isot", scale="tai") 

536 visit_time0 = astropy.time.Time("2021-12-26T23:59:59", format="isot", scale="tai") 

537 visit_time1 = astropy.time.Time("2021-12-27T00:00:01", format="isot", scale="tai") 

538 visit_time2 = astropy.time.Time("2021-12-27T00:00:03", format="isot", scale="tai") 

539 

540 objects = makeObjectCatalog(region, 100, visit_time0) 

541 oids = list(objects["diaObjectId"]) 

542 sources = makeSourceCatalog(objects, src_time1, 0) 

543 apdb.store(src_time1, objects, sources) 

544 

545 sources = makeSourceCatalog(objects, src_time2, 100) 

546 apdb.store(src_time2, objects, sources) 

547 

548 # reading at time of last save should read all 

549 res = apdb.getDiaSources(region, oids, src_time2) 

550 self.assert_catalog(res, 200, ApdbTables.DiaSource) 

551 

552 # one second before 12 months 

553 res = apdb.getDiaSources(region, oids, visit_time0) 

554 self.assert_catalog(res, 200, ApdbTables.DiaSource) 

555 

556 # reading at later time of last save should only read a subset 

557 res = apdb.getDiaSources(region, oids, visit_time1) 

558 self.assert_catalog(res, 100, ApdbTables.DiaSource) 

559 

560 # reading at later time of last save should only read a subset 

561 res = apdb.getDiaSources(region, oids, visit_time2) 

562 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

563 

564 def test_midpointMjdTai_fsrc(self) -> None: 

565 """Test for time filtering of DiaForcedSources.""" 

566 config = self.make_instance() 

567 apdb = Apdb.from_config(config) 

568 

569 region = _make_region() 

570 src_time1 = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai") 

571 src_time2 = astropy.time.Time("2021-01-01T00:00:02", format="isot", scale="tai") 

572 visit_time0 = astropy.time.Time("2021-12-26T23:59:59", format="isot", scale="tai") 

573 visit_time1 = astropy.time.Time("2021-12-27T00:00:01", format="isot", scale="tai") 

574 visit_time2 = astropy.time.Time("2021-12-27T00:00:03", format="isot", scale="tai") 

575 

576 objects = makeObjectCatalog(region, 100, visit_time0) 

577 oids = list(objects["diaObjectId"]) 

578 sources = makeForcedSourceCatalog(objects, src_time1, 1) 

579 apdb.store(src_time1, objects, forced_sources=sources) 

580 

581 sources = makeForcedSourceCatalog(objects, src_time2, 2) 

582 apdb.store(src_time2, objects, forced_sources=sources) 

583 

584 # reading at time of last save should read all 

585 res = apdb.getDiaForcedSources(region, oids, src_time2) 

586 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource) 

587 

588 # one second before 12 months 

589 res = apdb.getDiaForcedSources(region, oids, visit_time0) 

590 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource) 

591 

592 # reading at later time of last save should only read a subset 

593 res = apdb.getDiaForcedSources(region, oids, visit_time1) 

594 self.assert_catalog(res, 100, ApdbTables.DiaForcedSource) 

595 

596 # reading at later time of last save should only read a subset 

597 res = apdb.getDiaForcedSources(region, oids, visit_time2) 

598 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

599 

600 def test_metadata(self) -> None: 

601 """Simple test for writing/reading metadata table""" 

602 config = self.make_instance() 

603 apdb = Apdb.from_config(config) 

604 metadata = apdb.metadata 

605 

606 # APDB should write two or three metadata items with version numbers 

607 # and a frozen JSON config. 

608 self.assertFalse(metadata.empty()) 

609 expected_rows = 4 if self.enable_replica else 3 

610 self.assertEqual(len(list(metadata.items())), expected_rows) 

611 

612 metadata.set("meta", "data") 

613 metadata.set("data", "meta") 

614 

615 self.assertFalse(metadata.empty()) 

616 self.assertTrue(set(metadata.items()) >= {("meta", "data"), ("data", "meta")}) 

617 

618 with self.assertRaisesRegex(KeyError, "Metadata key 'meta' already exists"): 

619 metadata.set("meta", "data1") 

620 

621 metadata.set("meta", "data2", force=True) 

622 self.assertTrue(set(metadata.items()) >= {("meta", "data2"), ("data", "meta")}) 

623 

624 self.assertTrue(metadata.delete("meta")) 

625 self.assertIsNone(metadata.get("meta")) 

626 self.assertFalse(metadata.delete("meta")) 

627 

628 self.assertEqual(metadata.get("data"), "meta") 

629 self.assertEqual(metadata.get("meta", "meta"), "meta") 

630 

631 def test_nometadata(self) -> None: 

632 """Test case for when metadata table is missing""" 

633 # We expect that schema includes metadata table, drop it. 

634 with update_schema_yaml(self.schema_path, drop_metadata=True) as schema_file: 

635 config = self.make_instance(schema_file=schema_file) 

636 apdb = Apdb.from_config(config) 

637 metadata = apdb.metadata 

638 

639 self.assertTrue(metadata.empty()) 

640 self.assertEqual(list(metadata.items()), []) 

641 with self.assertRaisesRegex(RuntimeError, "Metadata table does not exist"): 

642 metadata.set("meta", "data") 

643 

644 self.assertTrue(metadata.empty()) 

645 self.assertIsNone(metadata.get("meta")) 

646 

647 # Also check what happens when configured schema has metadata, but 

648 # database is missing it. Database was initialized inside above context 

649 # without metadata table, here we use schema config which includes 

650 # metadata table. 

651 config.schema_file = self.schema_path 

652 apdb = Apdb.from_config(config) 

653 metadata = apdb.metadata 

654 self.assertTrue(metadata.empty()) 

655 

656 def test_schemaVersionFromYaml(self) -> None: 

657 """Check version number handling for reading schema from YAML.""" 

658 config = self.make_instance() 

659 default_schema = config.schema_file 

660 apdb = Apdb.from_config(config) 

661 self.assertEqual(apdb._schema.schemaVersion(), VersionTuple(0, 1, 1)) # type: ignore[attr-defined] 

662 

663 with update_schema_yaml(default_schema, version="") as schema_file: 

664 config = self.make_instance(schema_file=schema_file) 

665 apdb = Apdb.from_config(config) 

666 self.assertEqual( 

667 apdb._schema.schemaVersion(), VersionTuple(0, 1, 0) # type: ignore[attr-defined] 

668 ) 

669 

670 with update_schema_yaml(default_schema, version="99.0.0") as schema_file: 

671 config = self.make_instance(schema_file=schema_file) 

672 apdb = Apdb.from_config(config) 

673 self.assertEqual( 

674 apdb._schema.schemaVersion(), VersionTuple(99, 0, 0) # type: ignore[attr-defined] 

675 ) 

676 

677 def test_config_freeze(self) -> None: 

678 """Test that some config fields are correctly frozen in database.""" 

679 config = self.make_instance() 

680 

681 # `use_insert_id` is the only parameter that is frozen in all 

682 # implementations. 

683 config.use_insert_id = not self.enable_replica 

684 apdb = Apdb.from_config(config) 

685 frozen_config = apdb.config # type: ignore[attr-defined] 

686 self.assertEqual(frozen_config.use_insert_id, self.enable_replica) 

687 

688 

689class ApdbSchemaUpdateTest(TestCaseMixin, ABC): 

690 """Base class for unit tests that verify how schema changes work.""" 

691 

692 visit_time = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai") 

693 

694 @abstractmethod 

695 def make_instance(self, **kwargs: Any) -> ApdbConfig: 

696 """Make config class instance used in all tests. 

697 

698 This method should return configuration that point to the identical 

699 database instance on each call (i.e. ``db_url`` must be the same, 

700 which also means for sqlite it has to use on-disk storage). 

701 """ 

702 raise NotImplementedError() 

703 

704 def test_schema_add_replica(self) -> None: 

705 """Check that new code can work with old schema without replica 

706 tables. 

707 """ 

708 # Make schema without replica tables. 

709 config = self.make_instance(use_insert_id=False) 

710 apdb = Apdb.from_config(config) 

711 apdb_replica = ApdbReplica.from_config(config) 

712 

713 # Make APDB instance configured for replication. 

714 config.use_insert_id = True 

715 apdb = Apdb.from_config(config) 

716 

717 # Try to insert something, should work OK. 

718 region = _make_region() 

719 visit_time = self.visit_time 

720 

721 # have to store Objects first 

722 objects = makeObjectCatalog(region, 100, visit_time) 

723 sources = makeSourceCatalog(objects, visit_time) 

724 fsources = makeForcedSourceCatalog(objects, visit_time) 

725 apdb.store(visit_time, objects, sources, fsources) 

726 

727 # There should be no replica chunks. 

728 replica_chunks = apdb_replica.getReplicaChunks() 

729 self.assertIsNone(replica_chunks) 

730 

731 def test_schemaVersionCheck(self) -> None: 

732 """Check version number compatibility.""" 

733 config = self.make_instance() 

734 apdb = Apdb.from_config(config) 

735 

736 self.assertEqual(apdb._schema.schemaVersion(), VersionTuple(0, 1, 1)) # type: ignore[attr-defined] 

737 

738 # Claim that schema version is now 99.0.0, must raise an exception. 

739 with update_schema_yaml(config.schema_file, version="99.0.0") as schema_file: 

740 config.schema_file = schema_file 

741 with self.assertRaises(IncompatibleVersionError): 

742 apdb = Apdb.from_config(config)