Coverage for python/lsst/dax/apdb/tests/_apdb.py: 14%

387 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-08 02:52 -0700

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ApdbSchemaUpdateTest", "ApdbTest", "update_schema_yaml"] 

25 

26import contextlib 

27import os 

28import tempfile 

29import unittest 

30from abc import ABC, abstractmethod 

31from collections.abc import Iterator 

32from tempfile import TemporaryDirectory 

33from typing import TYPE_CHECKING, Any 

34 

35import astropy.time 

36import pandas 

37import yaml 

38from lsst.dax.apdb import ( 

39 Apdb, 

40 ApdbConfig, 

41 ApdbReplica, 

42 ApdbTableData, 

43 ApdbTables, 

44 IncompatibleVersionError, 

45 ReplicaChunk, 

46 VersionTuple, 

47) 

48from lsst.sphgeom import Angle, Circle, Region, UnitVector3d 

49 

50from .data_factory import makeForcedSourceCatalog, makeObjectCatalog, makeSourceCatalog, makeSSObjectCatalog 

51 

52if TYPE_CHECKING: 

53 

54 class TestCaseMixin(unittest.TestCase): 

55 """Base class for mixin test classes that use TestCase methods.""" 

56 

57else: 

58 

59 class TestCaseMixin: 

60 """Do-nothing definition of mixin base class for regular execution.""" 

61 

62 

63def _make_region(xyz: tuple[float, float, float] = (1.0, 1.0, -1.0)) -> Region: 

64 """Make a region to use in tests""" 

65 pointing_v = UnitVector3d(*xyz) 

66 fov = 0.05 # radians 

67 region = Circle(pointing_v, Angle(fov / 2)) 

68 return region 

69 

70 

71@contextlib.contextmanager 

72def update_schema_yaml( 

73 schema_file: str, 

74 drop_metadata: bool = False, 

75 version: str | None = None, 

76) -> Iterator[str]: 

77 """Update schema definition and return name of the new schema file. 

78 

79 Parameters 

80 ---------- 

81 schema_file : `str` 

82 Path for the existing YAML file with APDB schema. 

83 drop_metadata : `bool` 

84 If `True` then remove metadata table from the list of tables. 

85 version : `str` or `None` 

86 If non-empty string then set schema version to this string, if empty 

87 string then remove schema version from config, if `None` - don't change 

88 the version in config. 

89 

90 Yields 

91 ------ 

92 Path for the updated configuration file. 

93 """ 

94 with open(schema_file) as yaml_stream: 

95 schemas_list = list(yaml.load_all(yaml_stream, Loader=yaml.SafeLoader)) 

96 # Edit YAML contents. 

97 for schema in schemas_list: 

98 # Optionally drop metadata table. 

99 if drop_metadata: 

100 schema["tables"] = [table for table in schema["tables"] if table["name"] != "metadata"] 

101 if version is not None: 

102 if version == "": 

103 del schema["version"] 

104 else: 

105 schema["version"] = version 

106 

107 with TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir: 

108 output_path = os.path.join(tmpdir, "schema.yaml") 

109 with open(output_path, "w") as yaml_stream: 

110 yaml.dump_all(schemas_list, stream=yaml_stream) 

111 yield output_path 

112 

113 

114class ApdbTest(TestCaseMixin, ABC): 

115 """Base class for Apdb tests that can be specialized for concrete 

116 implementation. 

117 

118 This can only be used as a mixin class for a unittest.TestCase and it 

119 calls various assert methods. 

120 """ 

121 

122 time_partition_tables = False 

123 visit_time = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai") 

124 

125 fsrc_requires_id_list = False 

126 """Should be set to True if getDiaForcedSources requires object IDs""" 

127 

128 enable_replica: bool = False 

129 """Set to true when support for replication is configured""" 

130 

131 allow_visit_query: bool = True 

132 """Set to true when contains is implemented""" 

133 

134 schema_path: str 

135 """Location of the Felis schema file.""" 

136 

137 # number of columns as defined in tests/config/schema.yaml 

138 table_column_count = { 

139 ApdbTables.DiaObject: 8, 

140 ApdbTables.DiaObjectLast: 5, 

141 ApdbTables.DiaSource: 11, 

142 ApdbTables.DiaForcedSource: 5, 

143 ApdbTables.SSObject: 3, 

144 } 

145 

146 @abstractmethod 

147 def make_instance(self, **kwargs: Any) -> ApdbConfig: 

148 """Make database instance and return configuration for it.""" 

149 raise NotImplementedError() 

150 

151 @abstractmethod 

152 def getDiaObjects_table(self) -> ApdbTables: 

153 """Return type of table returned from getDiaObjects method.""" 

154 raise NotImplementedError() 

155 

156 def assert_catalog(self, catalog: Any, rows: int, table: ApdbTables) -> None: 

157 """Validate catalog type and size 

158 

159 Parameters 

160 ---------- 

161 catalog : `object` 

162 Expected type of this is ``pandas.DataFrame``. 

163 rows : `int` 

164 Expected number of rows in a catalog. 

165 table : `ApdbTables` 

166 APDB table type. 

167 """ 

168 self.assertIsInstance(catalog, pandas.DataFrame) 

169 self.assertEqual(catalog.shape[0], rows) 

170 self.assertEqual(catalog.shape[1], self.table_column_count[table]) 

171 

172 def assert_table_data(self, catalog: Any, rows: int, table: ApdbTables) -> None: 

173 """Validate catalog type and size 

174 

175 Parameters 

176 ---------- 

177 catalog : `object` 

178 Expected type of this is `ApdbTableData`. 

179 rows : `int` 

180 Expected number of rows in a catalog. 

181 table : `ApdbTables` 

182 APDB table type. 

183 extra_columns : `int` 

184 Count of additional columns expected in ``catalog``. 

185 """ 

186 self.assertIsInstance(catalog, ApdbTableData) 

187 n_rows = sum(1 for row in catalog.rows()) 

188 self.assertEqual(n_rows, rows) 

189 # One extra column for replica chunk id 

190 self.assertEqual(len(catalog.column_names()), self.table_column_count[table] + 1) 

191 

192 def test_makeSchema(self) -> None: 

193 """Test for making APDB schema.""" 

194 config = self.make_instance() 

195 apdb = Apdb.from_config(config) 

196 

197 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject)) 

198 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObjectLast)) 

199 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaSource)) 

200 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaForcedSource)) 

201 self.assertIsNotNone(apdb.tableDef(ApdbTables.metadata)) 

202 

203 # Test from_uri factory method with the same config. 

204 with tempfile.NamedTemporaryFile() as tmpfile: 

205 config.save(tmpfile.name) 

206 apdb = Apdb.from_uri(tmpfile.name) 

207 

208 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject)) 

209 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObjectLast)) 

210 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaSource)) 

211 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaForcedSource)) 

212 self.assertIsNotNone(apdb.tableDef(ApdbTables.metadata)) 

213 

214 def test_empty_gets(self) -> None: 

215 """Test for getting data from empty database. 

216 

217 All get() methods should return empty results, only useful for 

218 checking that code is not broken. 

219 """ 

220 # use non-zero months for Forced/Source fetching 

221 config = self.make_instance() 

222 apdb = Apdb.from_config(config) 

223 

224 region = _make_region() 

225 visit_time = self.visit_time 

226 

227 res: pandas.DataFrame | None 

228 

229 # get objects by region 

230 res = apdb.getDiaObjects(region) 

231 self.assert_catalog(res, 0, self.getDiaObjects_table()) 

232 

233 # get sources by region 

234 res = apdb.getDiaSources(region, None, visit_time) 

235 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

236 

237 res = apdb.getDiaSources(region, [], visit_time) 

238 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

239 

240 # get sources by object ID, non-empty object list 

241 res = apdb.getDiaSources(region, [1, 2, 3], visit_time) 

242 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

243 

244 # get forced sources by object ID, empty object list 

245 res = apdb.getDiaForcedSources(region, [], visit_time) 

246 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

247 

248 # get sources by object ID, non-empty object list 

249 res = apdb.getDiaForcedSources(region, [1, 2, 3], visit_time) 

250 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

251 

252 # data_factory's ccdVisitId generation corresponds to (1, 1) 

253 if self.allow_visit_query: 

254 res = apdb.containsVisitDetector(visit=1, detector=1) 

255 self.assertFalse(res) 

256 else: 

257 with self.assertRaises(NotImplementedError): 

258 apdb.containsVisitDetector(visit=0, detector=0) 

259 

260 # get sources by region 

261 if self.fsrc_requires_id_list: 

262 with self.assertRaises(NotImplementedError): 

263 apdb.getDiaForcedSources(region, None, visit_time) 

264 else: 

265 apdb.getDiaForcedSources(region, None, visit_time) 

266 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

267 

268 def test_empty_gets_0months(self) -> None: 

269 """Test for getting data from empty database. 

270 

271 All get() methods should return empty DataFrame or None. 

272 """ 

273 # set read_sources_months to 0 so that Forced/Sources are None 

274 config = self.make_instance(read_sources_months=0, read_forced_sources_months=0) 

275 apdb = Apdb.from_config(config) 

276 

277 region = _make_region() 

278 visit_time = self.visit_time 

279 

280 res: pandas.DataFrame | None 

281 

282 # get objects by region 

283 res = apdb.getDiaObjects(region) 

284 self.assert_catalog(res, 0, self.getDiaObjects_table()) 

285 

286 # get sources by region 

287 res = apdb.getDiaSources(region, None, visit_time) 

288 self.assertIs(res, None) 

289 

290 # get sources by object ID, empty object list 

291 res = apdb.getDiaSources(region, [], visit_time) 

292 self.assertIs(res, None) 

293 

294 # get forced sources by object ID, empty object list 

295 res = apdb.getDiaForcedSources(region, [], visit_time) 

296 self.assertIs(res, None) 

297 

298 # test if a visit has objects/sources 

299 if self.allow_visit_query: 

300 # Database is empty, no images exist. 

301 res = apdb.containsVisitDetector(visit=1, detector=1) 

302 self.assertFalse(res) 

303 else: 

304 with self.assertRaises(NotImplementedError): 

305 apdb.containsVisitDetector(visit=0, detector=0) 

306 

307 def test_storeObjects(self) -> None: 

308 """Store and retrieve DiaObjects.""" 

309 # don't care about sources. 

310 config = self.make_instance() 

311 apdb = Apdb.from_config(config) 

312 

313 region = _make_region() 

314 visit_time = self.visit_time 

315 

316 # make catalog with Objects 

317 catalog = makeObjectCatalog(region, 100, visit_time) 

318 

319 # store catalog 

320 apdb.store(visit_time, catalog) 

321 

322 # read it back and check sizes 

323 res = apdb.getDiaObjects(region) 

324 self.assert_catalog(res, len(catalog), self.getDiaObjects_table()) 

325 

326 # TODO: test apdb.contains with generic implementation from DM-41671 

327 

328 def test_storeObjects_empty(self) -> None: 

329 """Test calling storeObject when there are no objects: see DM-43270.""" 

330 config = self.make_instance() 

331 apdb = Apdb.from_config(config) 

332 region = _make_region() 

333 visit_time = self.visit_time 

334 # make catalog with no Objects 

335 catalog = makeObjectCatalog(region, 0, visit_time) 

336 

337 with self.assertLogs("lsst.dax.apdb", level="DEBUG") as cm: 

338 apdb.store(visit_time, catalog) 

339 self.assertIn("No objects", "\n".join(cm.output)) 

340 

341 def test_storeSources(self) -> None: 

342 """Store and retrieve DiaSources.""" 

343 config = self.make_instance() 

344 apdb = Apdb.from_config(config) 

345 

346 region = _make_region() 

347 visit_time = self.visit_time 

348 

349 # have to store Objects first 

350 objects = makeObjectCatalog(region, 100, visit_time) 

351 oids = list(objects["diaObjectId"]) 

352 sources = makeSourceCatalog(objects, visit_time) 

353 

354 # save the objects and sources 

355 apdb.store(visit_time, objects, sources) 

356 

357 # read it back, no ID filtering 

358 res = apdb.getDiaSources(region, None, visit_time) 

359 self.assert_catalog(res, len(sources), ApdbTables.DiaSource) 

360 

361 # read it back and filter by ID 

362 res = apdb.getDiaSources(region, oids, visit_time) 

363 self.assert_catalog(res, len(sources), ApdbTables.DiaSource) 

364 

365 # read it back to get schema 

366 res = apdb.getDiaSources(region, [], visit_time) 

367 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

368 

369 # test if a visit is present 

370 # data_factory's ccdVisitId generation corresponds to (1, 1) 

371 if self.allow_visit_query: 

372 res = apdb.containsVisitDetector(visit=1, detector=1) 

373 self.assertTrue(res) 

374 # non-existent image 

375 res = apdb.containsVisitDetector(visit=2, detector=42) 

376 self.assertFalse(res) 

377 else: 

378 with self.assertRaises(NotImplementedError): 

379 apdb.containsVisitDetector(visit=0, detector=0) 

380 

381 def test_storeForcedSources(self) -> None: 

382 """Store and retrieve DiaForcedSources.""" 

383 config = self.make_instance() 

384 apdb = Apdb.from_config(config) 

385 

386 region = _make_region() 

387 visit_time = self.visit_time 

388 

389 # have to store Objects first 

390 objects = makeObjectCatalog(region, 100, visit_time) 

391 oids = list(objects["diaObjectId"]) 

392 catalog = makeForcedSourceCatalog(objects, visit_time) 

393 

394 apdb.store(visit_time, objects, forced_sources=catalog) 

395 

396 # read it back and check sizes 

397 res = apdb.getDiaForcedSources(region, oids, visit_time) 

398 self.assert_catalog(res, len(catalog), ApdbTables.DiaForcedSource) 

399 

400 # read it back to get schema 

401 res = apdb.getDiaForcedSources(region, [], visit_time) 

402 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

403 

404 # data_factory's ccdVisitId generation corresponds to (1, 1) 

405 if self.allow_visit_query: 

406 res = apdb.containsVisitDetector(visit=1, detector=1) 

407 self.assertTrue(res) 

408 # non-existent image 

409 res = apdb.containsVisitDetector(visit=2, detector=42) 

410 self.assertFalse(res) 

411 else: 

412 with self.assertRaises(NotImplementedError): 

413 apdb.containsVisitDetector(visit=0, detector=0) 

414 

415 def test_getChunks(self) -> None: 

416 """Store and retrieve replica chunks.""" 

417 # don't care about sources. 

418 config = self.make_instance() 

419 apdb = Apdb.from_config(config) 

420 apdb_replica = ApdbReplica.from_config(config) 

421 visit_time = self.visit_time 

422 

423 region1 = _make_region((1.0, 1.0, -1.0)) 

424 region2 = _make_region((-1.0, -1.0, -1.0)) 

425 nobj = 100 

426 objects1 = makeObjectCatalog(region1, nobj, visit_time) 

427 objects2 = makeObjectCatalog(region2, nobj, visit_time, start_id=nobj * 2) 

428 

429 # With the default 10 minutes replica chunk window we should have 4 

430 # records. 

431 visits = [ 

432 (astropy.time.Time("2021-01-01T00:01:00", format="isot", scale="tai"), objects1), 

433 (astropy.time.Time("2021-01-01T00:02:00", format="isot", scale="tai"), objects2), 

434 (astropy.time.Time("2021-01-01T00:11:00", format="isot", scale="tai"), objects1), 

435 (astropy.time.Time("2021-01-01T00:12:00", format="isot", scale="tai"), objects2), 

436 (astropy.time.Time("2021-01-01T00:45:00", format="isot", scale="tai"), objects1), 

437 (astropy.time.Time("2021-01-01T00:46:00", format="isot", scale="tai"), objects2), 

438 (astropy.time.Time("2021-03-01T00:01:00", format="isot", scale="tai"), objects1), 

439 (astropy.time.Time("2021-03-01T00:02:00", format="isot", scale="tai"), objects2), 

440 ] 

441 

442 start_id = 0 

443 for visit_time, objects in visits: 

444 sources = makeSourceCatalog(objects, visit_time, start_id=start_id) 

445 fsources = makeForcedSourceCatalog(objects, visit_time, visit=start_id) 

446 apdb.store(visit_time, objects, sources, fsources) 

447 start_id += nobj 

448 

449 replica_chunks = apdb_replica.getReplicaChunks() 

450 if not self.enable_replica: 

451 self.assertIsNone(replica_chunks) 

452 

453 with self.assertRaisesRegex(ValueError, "APDB is not configured for replication"): 

454 apdb_replica.getDiaObjectsChunks([]) 

455 

456 else: 

457 assert replica_chunks is not None 

458 self.assertEqual(len(replica_chunks), 4) 

459 

460 def _check_chunks(replica_chunks: list[ReplicaChunk], n_records: int | None = None) -> None: 

461 if n_records is None: 

462 n_records = len(replica_chunks) * nobj 

463 res = apdb_replica.getDiaObjectsChunks(chunk.id for chunk in replica_chunks) 

464 self.assert_table_data(res, n_records, ApdbTables.DiaObject) 

465 res = apdb_replica.getDiaSourcesChunks(chunk.id for chunk in replica_chunks) 

466 self.assert_table_data(res, n_records, ApdbTables.DiaSource) 

467 res = apdb_replica.getDiaForcedSourcesChunks(chunk.id for chunk in replica_chunks) 

468 self.assert_table_data(res, n_records, ApdbTables.DiaForcedSource) 

469 

470 # read it back and check sizes 

471 _check_chunks(replica_chunks, 800) 

472 _check_chunks(replica_chunks[1:], 600) 

473 _check_chunks(replica_chunks[1:-1], 400) 

474 _check_chunks(replica_chunks[2:3], 200) 

475 _check_chunks([]) 

476 

477 # try to remove some of those 

478 deleted_chunks = replica_chunks[:1] 

479 apdb_replica.deleteReplicaChunks(chunk.id for chunk in deleted_chunks) 

480 

481 # All queries on deleted ids should return empty set. 

482 _check_chunks(deleted_chunks, 0) 

483 

484 replica_chunks = apdb_replica.getReplicaChunks() 

485 assert replica_chunks is not None 

486 self.assertEqual(len(replica_chunks), 3) 

487 

488 _check_chunks(replica_chunks, 600) 

489 

490 def test_storeSSObjects(self) -> None: 

491 """Store and retrieve SSObjects.""" 

492 # don't care about sources. 

493 config = self.make_instance() 

494 apdb = Apdb.from_config(config) 

495 

496 # make catalog with SSObjects 

497 catalog = makeSSObjectCatalog(100, flags=1) 

498 

499 # store catalog 

500 apdb.storeSSObjects(catalog) 

501 

502 # read it back and check sizes 

503 res = apdb.getSSObjects() 

504 self.assert_catalog(res, len(catalog), ApdbTables.SSObject) 

505 

506 # check that override works, make catalog with SSObjects, ID = 51-150 

507 catalog = makeSSObjectCatalog(100, 51, flags=2) 

508 apdb.storeSSObjects(catalog) 

509 res = apdb.getSSObjects() 

510 self.assert_catalog(res, 150, ApdbTables.SSObject) 

511 self.assertEqual(len(res[res["flags"] == 1]), 50) 

512 self.assertEqual(len(res[res["flags"] == 2]), 100) 

513 

514 def test_reassignObjects(self) -> None: 

515 """Reassign DiaObjects.""" 

516 # don't care about sources. 

517 config = self.make_instance() 

518 apdb = Apdb.from_config(config) 

519 

520 region = _make_region() 

521 visit_time = self.visit_time 

522 objects = makeObjectCatalog(region, 100, visit_time) 

523 oids = list(objects["diaObjectId"]) 

524 sources = makeSourceCatalog(objects, visit_time) 

525 apdb.store(visit_time, objects, sources) 

526 

527 catalog = makeSSObjectCatalog(100) 

528 apdb.storeSSObjects(catalog) 

529 

530 # read it back and filter by ID 

531 res = apdb.getDiaSources(region, oids, visit_time) 

532 self.assert_catalog(res, len(sources), ApdbTables.DiaSource) 

533 

534 apdb.reassignDiaSources({1: 1, 2: 2, 5: 5}) 

535 res = apdb.getDiaSources(region, oids, visit_time) 

536 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource) 

537 

538 with self.assertRaisesRegex(ValueError, r"do not exist.*\D1000"): 

539 apdb.reassignDiaSources( 

540 { 

541 1000: 1, 

542 7: 3, 

543 } 

544 ) 

545 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource) 

546 

547 def test_midpointMjdTai_src(self) -> None: 

548 """Test for time filtering of DiaSources.""" 

549 config = self.make_instance() 

550 apdb = Apdb.from_config(config) 

551 

552 region = _make_region() 

553 # 2021-01-01 plus 360 days is 2021-12-27 

554 src_time1 = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai") 

555 src_time2 = astropy.time.Time("2021-01-01T00:00:02", format="isot", scale="tai") 

556 visit_time0 = astropy.time.Time("2021-12-26T23:59:59", format="isot", scale="tai") 

557 visit_time1 = astropy.time.Time("2021-12-27T00:00:01", format="isot", scale="tai") 

558 visit_time2 = astropy.time.Time("2021-12-27T00:00:03", format="isot", scale="tai") 

559 

560 objects = makeObjectCatalog(region, 100, visit_time0) 

561 oids = list(objects["diaObjectId"]) 

562 sources = makeSourceCatalog(objects, src_time1, 0) 

563 apdb.store(src_time1, objects, sources) 

564 

565 sources = makeSourceCatalog(objects, src_time2, 100) 

566 apdb.store(src_time2, objects, sources) 

567 

568 # reading at time of last save should read all 

569 res = apdb.getDiaSources(region, oids, src_time2) 

570 self.assert_catalog(res, 200, ApdbTables.DiaSource) 

571 

572 # one second before 12 months 

573 res = apdb.getDiaSources(region, oids, visit_time0) 

574 self.assert_catalog(res, 200, ApdbTables.DiaSource) 

575 

576 # reading at later time of last save should only read a subset 

577 res = apdb.getDiaSources(region, oids, visit_time1) 

578 self.assert_catalog(res, 100, ApdbTables.DiaSource) 

579 

580 # reading at later time of last save should only read a subset 

581 res = apdb.getDiaSources(region, oids, visit_time2) 

582 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

583 

584 def test_midpointMjdTai_fsrc(self) -> None: 

585 """Test for time filtering of DiaForcedSources.""" 

586 config = self.make_instance() 

587 apdb = Apdb.from_config(config) 

588 

589 region = _make_region() 

590 src_time1 = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai") 

591 src_time2 = astropy.time.Time("2021-01-01T00:00:02", format="isot", scale="tai") 

592 visit_time0 = astropy.time.Time("2021-12-26T23:59:59", format="isot", scale="tai") 

593 visit_time1 = astropy.time.Time("2021-12-27T00:00:01", format="isot", scale="tai") 

594 visit_time2 = astropy.time.Time("2021-12-27T00:00:03", format="isot", scale="tai") 

595 

596 objects = makeObjectCatalog(region, 100, visit_time0) 

597 oids = list(objects["diaObjectId"]) 

598 sources = makeForcedSourceCatalog(objects, src_time1, 1) 

599 apdb.store(src_time1, objects, forced_sources=sources) 

600 

601 sources = makeForcedSourceCatalog(objects, src_time2, 2) 

602 apdb.store(src_time2, objects, forced_sources=sources) 

603 

604 # reading at time of last save should read all 

605 res = apdb.getDiaForcedSources(region, oids, src_time2) 

606 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource) 

607 

608 # one second before 12 months 

609 res = apdb.getDiaForcedSources(region, oids, visit_time0) 

610 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource) 

611 

612 # reading at later time of last save should only read a subset 

613 res = apdb.getDiaForcedSources(region, oids, visit_time1) 

614 self.assert_catalog(res, 100, ApdbTables.DiaForcedSource) 

615 

616 # reading at later time of last save should only read a subset 

617 res = apdb.getDiaForcedSources(region, oids, visit_time2) 

618 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

619 

620 def test_metadata(self) -> None: 

621 """Simple test for writing/reading metadata table""" 

622 config = self.make_instance() 

623 apdb = Apdb.from_config(config) 

624 metadata = apdb.metadata 

625 

626 # APDB should write two or three metadata items with version numbers 

627 # and a frozen JSON config. 

628 self.assertFalse(metadata.empty()) 

629 expected_rows = 4 if self.enable_replica else 3 

630 self.assertEqual(len(list(metadata.items())), expected_rows) 

631 

632 metadata.set("meta", "data") 

633 metadata.set("data", "meta") 

634 

635 self.assertFalse(metadata.empty()) 

636 self.assertTrue(set(metadata.items()) >= {("meta", "data"), ("data", "meta")}) 

637 

638 with self.assertRaisesRegex(KeyError, "Metadata key 'meta' already exists"): 

639 metadata.set("meta", "data1") 

640 

641 metadata.set("meta", "data2", force=True) 

642 self.assertTrue(set(metadata.items()) >= {("meta", "data2"), ("data", "meta")}) 

643 

644 self.assertTrue(metadata.delete("meta")) 

645 self.assertIsNone(metadata.get("meta")) 

646 self.assertFalse(metadata.delete("meta")) 

647 

648 self.assertEqual(metadata.get("data"), "meta") 

649 self.assertEqual(metadata.get("meta", "meta"), "meta") 

650 

651 def test_nometadata(self) -> None: 

652 """Test case for when metadata table is missing""" 

653 # We expect that schema includes metadata table, drop it. 

654 with update_schema_yaml(self.schema_path, drop_metadata=True) as schema_file: 

655 config = self.make_instance(schema_file=schema_file) 

656 apdb = Apdb.from_config(config) 

657 metadata = apdb.metadata 

658 

659 self.assertTrue(metadata.empty()) 

660 self.assertEqual(list(metadata.items()), []) 

661 with self.assertRaisesRegex(RuntimeError, "Metadata table does not exist"): 

662 metadata.set("meta", "data") 

663 

664 self.assertTrue(metadata.empty()) 

665 self.assertIsNone(metadata.get("meta")) 

666 

667 # Also check what happens when configured schema has metadata, but 

668 # database is missing it. Database was initialized inside above context 

669 # without metadata table, here we use schema config which includes 

670 # metadata table. 

671 config.schema_file = self.schema_path 

672 apdb = Apdb.from_config(config) 

673 metadata = apdb.metadata 

674 self.assertTrue(metadata.empty()) 

675 

676 def test_schemaVersionFromYaml(self) -> None: 

677 """Check version number handling for reading schema from YAML.""" 

678 config = self.make_instance() 

679 default_schema = config.schema_file 

680 apdb = Apdb.from_config(config) 

681 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(0, 1, 1)) 

682 

683 with update_schema_yaml(default_schema, version="") as schema_file: 

684 config = self.make_instance(schema_file=schema_file) 

685 apdb = Apdb.from_config(config) 

686 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(0, 1, 0)) 

687 

688 with update_schema_yaml(default_schema, version="99.0.0") as schema_file: 

689 config = self.make_instance(schema_file=schema_file) 

690 apdb = Apdb.from_config(config) 

691 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(99, 0, 0)) 

692 

693 def test_config_freeze(self) -> None: 

694 """Test that some config fields are correctly frozen in database.""" 

695 config = self.make_instance() 

696 

697 # `use_insert_id` is the only parameter that is frozen in all 

698 # implementations. 

699 config.use_insert_id = not self.enable_replica 

700 apdb = Apdb.from_config(config) 

701 frozen_config = apdb.config # type: ignore[attr-defined] 

702 self.assertEqual(frozen_config.use_insert_id, self.enable_replica) 

703 

704 

705class ApdbSchemaUpdateTest(TestCaseMixin, ABC): 

706 """Base class for unit tests that verify how schema changes work.""" 

707 

708 visit_time = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai") 

709 

710 @abstractmethod 

711 def make_instance(self, **kwargs: Any) -> ApdbConfig: 

712 """Make config class instance used in all tests. 

713 

714 This method should return configuration that point to the identical 

715 database instance on each call (i.e. ``db_url`` must be the same, 

716 which also means for sqlite it has to use on-disk storage). 

717 """ 

718 raise NotImplementedError() 

719 

720 def test_schema_add_replica(self) -> None: 

721 """Check that new code can work with old schema without replica 

722 tables. 

723 """ 

724 # Make schema without replica tables. 

725 config = self.make_instance(use_insert_id=False) 

726 apdb = Apdb.from_config(config) 

727 apdb_replica = ApdbReplica.from_config(config) 

728 

729 # Make APDB instance configured for replication. 

730 config.use_insert_id = True 

731 apdb = Apdb.from_config(config) 

732 

733 # Try to insert something, should work OK. 

734 region = _make_region() 

735 visit_time = self.visit_time 

736 

737 # have to store Objects first 

738 objects = makeObjectCatalog(region, 100, visit_time) 

739 sources = makeSourceCatalog(objects, visit_time) 

740 fsources = makeForcedSourceCatalog(objects, visit_time) 

741 apdb.store(visit_time, objects, sources, fsources) 

742 

743 # There should be no replica chunks. 

744 replica_chunks = apdb_replica.getReplicaChunks() 

745 self.assertIsNone(replica_chunks) 

746 

747 def test_schemaVersionCheck(self) -> None: 

748 """Check version number compatibility.""" 

749 config = self.make_instance() 

750 apdb = Apdb.from_config(config) 

751 

752 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(0, 1, 1)) 

753 

754 # Claim that schema version is now 99.0.0, must raise an exception. 

755 with update_schema_yaml(config.schema_file, version="99.0.0") as schema_file: 

756 config.schema_file = schema_file 

757 with self.assertRaises(IncompatibleVersionError): 

758 apdb = Apdb.from_config(config)