Coverage for python/lsst/dax/apdb/tests/_apdb.py: 14%

395 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-05-01 10:45 +0000

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ApdbSchemaUpdateTest", "ApdbTest", "update_schema_yaml"] 

25 

26import contextlib 

27import os 

28import tempfile 

29import unittest 

30from abc import ABC, abstractmethod 

31from collections.abc import Iterator 

32from tempfile import TemporaryDirectory 

33from typing import TYPE_CHECKING, Any 

34 

35import astropy.time 

36import pandas 

37import yaml 

38from lsst.dax.apdb import ( 

39 Apdb, 

40 ApdbConfig, 

41 ApdbReplica, 

42 ApdbTableData, 

43 ApdbTables, 

44 IncompatibleVersionError, 

45 ReplicaChunk, 

46 VersionTuple, 

47) 

48from lsst.dax.apdb.sql import ApdbSql 

49from lsst.sphgeom import Angle, Circle, Region, UnitVector3d 

50 

51from .data_factory import makeForcedSourceCatalog, makeObjectCatalog, makeSourceCatalog, makeSSObjectCatalog 

52 

53if TYPE_CHECKING: 

54 

55 class TestCaseMixin(unittest.TestCase): 

56 """Base class for mixin test classes that use TestCase methods.""" 

57 

58else: 

59 

60 class TestCaseMixin: 

61 """Do-nothing definition of mixin base class for regular execution.""" 

62 

63 

64def _make_region(xyz: tuple[float, float, float] = (1.0, 1.0, -1.0)) -> Region: 

65 """Make a region to use in tests""" 

66 pointing_v = UnitVector3d(*xyz) 

67 fov = 0.05 # radians 

68 region = Circle(pointing_v, Angle(fov / 2)) 

69 return region 

70 

71 

72@contextlib.contextmanager 

73def update_schema_yaml( 

74 schema_file: str, 

75 drop_metadata: bool = False, 

76 version: str | None = None, 

77) -> Iterator[str]: 

78 """Update schema definition and return name of the new schema file. 

79 

80 Parameters 

81 ---------- 

82 schema_file : `str` 

83 Path for the existing YAML file with APDB schema. 

84 drop_metadata : `bool` 

85 If `True` then remove metadata table from the list of tables. 

86 version : `str` or `None` 

87 If non-empty string then set schema version to this string, if empty 

88 string then remove schema version from config, if `None` - don't change 

89 the version in config. 

90 

91 Yields 

92 ------ 

93 Path for the updated configuration file. 

94 """ 

95 with open(schema_file) as yaml_stream: 

96 schemas_list = list(yaml.load_all(yaml_stream, Loader=yaml.SafeLoader)) 

97 # Edit YAML contents. 

98 for schema in schemas_list: 

99 # Optionally drop metadata table. 

100 if drop_metadata: 

101 schema["tables"] = [table for table in schema["tables"] if table["name"] != "metadata"] 

102 if version is not None: 

103 if version == "": 

104 del schema["version"] 

105 else: 

106 schema["version"] = version 

107 

108 with TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir: 

109 output_path = os.path.join(tmpdir, "schema.yaml") 

110 with open(output_path, "w") as yaml_stream: 

111 yaml.dump_all(schemas_list, stream=yaml_stream) 

112 yield output_path 

113 

114 

115class ApdbTest(TestCaseMixin, ABC): 

116 """Base class for Apdb tests that can be specialized for concrete 

117 implementation. 

118 

119 This can only be used as a mixin class for a unittest.TestCase and it 

120 calls various assert methods. 

121 """ 

122 

123 time_partition_tables = False 

124 visit_time = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai") 

125 

126 fsrc_requires_id_list = False 

127 """Should be set to True if getDiaForcedSources requires object IDs""" 

128 

129 enable_replica: bool = False 

130 """Set to true when support for replication is configured""" 

131 

132 allow_visit_query: bool = True 

133 """Set to true when contains is implemented""" 

134 

135 schema_path: str 

136 """Location of the Felis schema file.""" 

137 

138 # number of columns as defined in tests/config/schema.yaml 

139 table_column_count = { 

140 ApdbTables.DiaObject: 8, 

141 ApdbTables.DiaObjectLast: 5, 

142 ApdbTables.DiaSource: 10, 

143 ApdbTables.DiaForcedSource: 4, 

144 ApdbTables.SSObject: 3, 

145 } 

146 

147 @abstractmethod 

148 def make_instance(self, **kwargs: Any) -> ApdbConfig: 

149 """Make database instance and return configuration for it.""" 

150 raise NotImplementedError() 

151 

152 @abstractmethod 

153 def getDiaObjects_table(self) -> ApdbTables: 

154 """Return type of table returned from getDiaObjects method.""" 

155 raise NotImplementedError() 

156 

157 def assert_catalog(self, catalog: Any, rows: int, table: ApdbTables) -> None: 

158 """Validate catalog type and size 

159 

160 Parameters 

161 ---------- 

162 catalog : `object` 

163 Expected type of this is ``pandas.DataFrame``. 

164 rows : `int` 

165 Expected number of rows in a catalog. 

166 table : `ApdbTables` 

167 APDB table type. 

168 """ 

169 self.assertIsInstance(catalog, pandas.DataFrame) 

170 self.assertEqual(catalog.shape[0], rows) 

171 self.assertEqual(catalog.shape[1], self.table_column_count[table]) 

172 

173 def assert_table_data(self, catalog: Any, rows: int, table: ApdbTables) -> None: 

174 """Validate catalog type and size 

175 

176 Parameters 

177 ---------- 

178 catalog : `object` 

179 Expected type of this is `ApdbTableData`. 

180 rows : `int` 

181 Expected number of rows in a catalog. 

182 table : `ApdbTables` 

183 APDB table type. 

184 extra_columns : `int` 

185 Count of additional columns expected in ``catalog``. 

186 """ 

187 self.assertIsInstance(catalog, ApdbTableData) 

188 n_rows = sum(1 for row in catalog.rows()) 

189 self.assertEqual(n_rows, rows) 

190 # One extra column for replica chunk id 

191 self.assertEqual(len(catalog.column_names()), self.table_column_count[table] + 1) 

192 

193 def test_makeSchema(self) -> None: 

194 """Test for making APDB schema.""" 

195 config = self.make_instance() 

196 apdb = Apdb.from_config(config) 

197 

198 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject)) 

199 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObjectLast)) 

200 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaSource)) 

201 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaForcedSource)) 

202 self.assertIsNotNone(apdb.tableDef(ApdbTables.metadata)) 

203 

204 # Test from_uri factory method with the same config. 

205 with tempfile.NamedTemporaryFile() as tmpfile: 

206 config.save(tmpfile.name) 

207 apdb = Apdb.from_uri(tmpfile.name) 

208 

209 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject)) 

210 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObjectLast)) 

211 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaSource)) 

212 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaForcedSource)) 

213 self.assertIsNotNone(apdb.tableDef(ApdbTables.metadata)) 

214 

215 def test_empty_gets(self) -> None: 

216 """Test for getting data from empty database. 

217 

218 All get() methods should return empty results, only useful for 

219 checking that code is not broken. 

220 """ 

221 # use non-zero months for Forced/Source fetching 

222 config = self.make_instance() 

223 apdb = Apdb.from_config(config) 

224 

225 region = _make_region() 

226 visit_time = self.visit_time 

227 

228 res: pandas.DataFrame | None 

229 

230 # get objects by region 

231 res = apdb.getDiaObjects(region) 

232 self.assert_catalog(res, 0, self.getDiaObjects_table()) 

233 

234 # get sources by region 

235 res = apdb.getDiaSources(region, None, visit_time) 

236 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

237 

238 res = apdb.getDiaSources(region, [], visit_time) 

239 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

240 

241 # get sources by object ID, non-empty object list 

242 res = apdb.getDiaSources(region, [1, 2, 3], visit_time) 

243 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

244 

245 # get forced sources by object ID, empty object list 

246 res = apdb.getDiaForcedSources(region, [], visit_time) 

247 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

248 

249 # get sources by object ID, non-empty object list 

250 res = apdb.getDiaForcedSources(region, [1, 2, 3], visit_time) 

251 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

252 

253 # test if a visit has objects/sources 

254 if self.allow_visit_query: 

255 res = apdb.containsVisitDetector(visit=0, detector=0) 

256 self.assertFalse(res) 

257 else: 

258 with self.assertRaises(NotImplementedError): 

259 apdb.containsVisitDetector(visit=0, detector=0) 

260 

261 # alternative method not part of the Apdb API 

262 if isinstance(apdb, ApdbSql): 

263 res = apdb.containsCcdVisit(1) 

264 self.assertFalse(res) 

265 

266 # get sources by region 

267 if self.fsrc_requires_id_list: 

268 with self.assertRaises(NotImplementedError): 

269 apdb.getDiaForcedSources(region, None, visit_time) 

270 else: 

271 apdb.getDiaForcedSources(region, None, visit_time) 

272 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

273 

274 def test_empty_gets_0months(self) -> None: 

275 """Test for getting data from empty database. 

276 

277 All get() methods should return empty DataFrame or None. 

278 """ 

279 # set read_sources_months to 0 so that Forced/Sources are None 

280 config = self.make_instance(read_sources_months=0, read_forced_sources_months=0) 

281 apdb = Apdb.from_config(config) 

282 

283 region = _make_region() 

284 visit_time = self.visit_time 

285 

286 res: pandas.DataFrame | None 

287 

288 # get objects by region 

289 res = apdb.getDiaObjects(region) 

290 self.assert_catalog(res, 0, self.getDiaObjects_table()) 

291 

292 # get sources by region 

293 res = apdb.getDiaSources(region, None, visit_time) 

294 self.assertIs(res, None) 

295 

296 # get sources by object ID, empty object list 

297 res = apdb.getDiaSources(region, [], visit_time) 

298 self.assertIs(res, None) 

299 

300 # get forced sources by object ID, empty object list 

301 res = apdb.getDiaForcedSources(region, [], visit_time) 

302 self.assertIs(res, None) 

303 

304 # test if a visit has objects/sources 

305 if self.allow_visit_query: 

306 res = apdb.containsVisitDetector(visit=0, detector=0) 

307 self.assertFalse(res) 

308 else: 

309 with self.assertRaises(NotImplementedError): 

310 apdb.containsVisitDetector(visit=0, detector=0) 

311 

312 # alternative method not part of the Apdb API 

313 if isinstance(apdb, ApdbSql): 

314 res = apdb.containsCcdVisit(1) 

315 self.assertFalse(res) 

316 

317 def test_storeObjects(self) -> None: 

318 """Store and retrieve DiaObjects.""" 

319 # don't care about sources. 

320 config = self.make_instance() 

321 apdb = Apdb.from_config(config) 

322 

323 region = _make_region() 

324 visit_time = self.visit_time 

325 

326 # make catalog with Objects 

327 catalog = makeObjectCatalog(region, 100, visit_time) 

328 

329 # store catalog 

330 apdb.store(visit_time, catalog) 

331 

332 # read it back and check sizes 

333 res = apdb.getDiaObjects(region) 

334 self.assert_catalog(res, len(catalog), self.getDiaObjects_table()) 

335 

336 # TODO: test apdb.contains with generic implementation from DM-41671 

337 

338 def test_storeObjects_empty(self) -> None: 

339 """Test calling storeObject when there are no objects: see DM-43270.""" 

340 config = self.make_instance() 

341 apdb = Apdb.from_config(config) 

342 region = _make_region() 

343 visit_time = self.visit_time 

344 # make catalog with no Objects 

345 catalog = makeObjectCatalog(region, 0, visit_time) 

346 

347 with self.assertLogs("lsst.dax.apdb", level="DEBUG") as cm: 

348 apdb.store(visit_time, catalog) 

349 self.assertIn("No objects", "\n".join(cm.output)) 

350 

351 def test_storeSources(self) -> None: 

352 """Store and retrieve DiaSources.""" 

353 config = self.make_instance() 

354 apdb = Apdb.from_config(config) 

355 

356 region = _make_region() 

357 visit_time = self.visit_time 

358 

359 # have to store Objects first 

360 objects = makeObjectCatalog(region, 100, visit_time) 

361 oids = list(objects["diaObjectId"]) 

362 sources = makeSourceCatalog(objects, visit_time) 

363 

364 # save the objects and sources 

365 apdb.store(visit_time, objects, sources) 

366 

367 # read it back, no ID filtering 

368 res = apdb.getDiaSources(region, None, visit_time) 

369 self.assert_catalog(res, len(sources), ApdbTables.DiaSource) 

370 

371 # read it back and filter by ID 

372 res = apdb.getDiaSources(region, oids, visit_time) 

373 self.assert_catalog(res, len(sources), ApdbTables.DiaSource) 

374 

375 # read it back to get schema 

376 res = apdb.getDiaSources(region, [], visit_time) 

377 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

378 

379 # test if a visit is present 

380 # data_factory's ccdVisitId generation corresponds to (0, 0) 

381 if self.allow_visit_query: 

382 res = apdb.containsVisitDetector(visit=0, detector=0) 

383 self.assertTrue(res) 

384 else: 

385 with self.assertRaises(NotImplementedError): 

386 apdb.containsVisitDetector(visit=0, detector=0) 

387 

388 # alternative method not part of the Apdb API 

389 if isinstance(apdb, ApdbSql): 

390 res = apdb.containsCcdVisit(1) 

391 self.assertTrue(res) 

392 res = apdb.containsCcdVisit(42) 

393 self.assertFalse(res) 

394 

395 def test_storeForcedSources(self) -> None: 

396 """Store and retrieve DiaForcedSources.""" 

397 config = self.make_instance() 

398 apdb = Apdb.from_config(config) 

399 

400 region = _make_region() 

401 visit_time = self.visit_time 

402 

403 # have to store Objects first 

404 objects = makeObjectCatalog(region, 100, visit_time) 

405 oids = list(objects["diaObjectId"]) 

406 catalog = makeForcedSourceCatalog(objects, visit_time) 

407 

408 apdb.store(visit_time, objects, forced_sources=catalog) 

409 

410 # read it back and check sizes 

411 res = apdb.getDiaForcedSources(region, oids, visit_time) 

412 self.assert_catalog(res, len(catalog), ApdbTables.DiaForcedSource) 

413 

414 # read it back to get schema 

415 res = apdb.getDiaForcedSources(region, [], visit_time) 

416 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

417 

418 # TODO: test apdb.contains with generic implementation from DM-41671 

419 

420 # alternative method not part of the Apdb API 

421 if isinstance(apdb, ApdbSql): 

422 res = apdb.containsCcdVisit(1) 

423 self.assertTrue(res) 

424 res = apdb.containsCcdVisit(42) 

425 self.assertFalse(res) 

426 

427 def test_getChunks(self) -> None: 

428 """Store and retrieve replica chunks.""" 

429 # don't care about sources. 

430 config = self.make_instance() 

431 apdb = Apdb.from_config(config) 

432 apdb_replica = ApdbReplica.from_config(config) 

433 visit_time = self.visit_time 

434 

435 region1 = _make_region((1.0, 1.0, -1.0)) 

436 region2 = _make_region((-1.0, -1.0, -1.0)) 

437 nobj = 100 

438 objects1 = makeObjectCatalog(region1, nobj, visit_time) 

439 objects2 = makeObjectCatalog(region2, nobj, visit_time, start_id=nobj * 2) 

440 

441 # With the default 10 minutes replica chunk window we should have 4 

442 # records. 

443 visits = [ 

444 (astropy.time.Time("2021-01-01T00:01:00", format="isot", scale="tai"), objects1), 

445 (astropy.time.Time("2021-01-01T00:02:00", format="isot", scale="tai"), objects2), 

446 (astropy.time.Time("2021-01-01T00:11:00", format="isot", scale="tai"), objects1), 

447 (astropy.time.Time("2021-01-01T00:12:00", format="isot", scale="tai"), objects2), 

448 (astropy.time.Time("2021-01-01T00:45:00", format="isot", scale="tai"), objects1), 

449 (astropy.time.Time("2021-01-01T00:46:00", format="isot", scale="tai"), objects2), 

450 (astropy.time.Time("2021-03-01T00:01:00", format="isot", scale="tai"), objects1), 

451 (astropy.time.Time("2021-03-01T00:02:00", format="isot", scale="tai"), objects2), 

452 ] 

453 

454 start_id = 0 

455 for visit_time, objects in visits: 

456 sources = makeSourceCatalog(objects, visit_time, start_id=start_id) 

457 fsources = makeForcedSourceCatalog(objects, visit_time, ccdVisitId=start_id) 

458 apdb.store(visit_time, objects, sources, fsources) 

459 start_id += nobj 

460 

461 replica_chunks = apdb_replica.getReplicaChunks() 

462 if not self.enable_replica: 

463 self.assertIsNone(replica_chunks) 

464 

465 with self.assertRaisesRegex(ValueError, "APDB is not configured for replication"): 

466 apdb_replica.getDiaObjectsChunks([]) 

467 

468 else: 

469 assert replica_chunks is not None 

470 self.assertEqual(len(replica_chunks), 4) 

471 

472 def _check_chunks(replica_chunks: list[ReplicaChunk], n_records: int | None = None) -> None: 

473 if n_records is None: 

474 n_records = len(replica_chunks) * nobj 

475 res = apdb_replica.getDiaObjectsChunks(chunk.id for chunk in replica_chunks) 

476 self.assert_table_data(res, n_records, ApdbTables.DiaObject) 

477 res = apdb_replica.getDiaSourcesChunks(chunk.id for chunk in replica_chunks) 

478 self.assert_table_data(res, n_records, ApdbTables.DiaSource) 

479 res = apdb_replica.getDiaForcedSourcesChunks(chunk.id for chunk in replica_chunks) 

480 self.assert_table_data(res, n_records, ApdbTables.DiaForcedSource) 

481 

482 # read it back and check sizes 

483 _check_chunks(replica_chunks, 800) 

484 _check_chunks(replica_chunks[1:], 600) 

485 _check_chunks(replica_chunks[1:-1], 400) 

486 _check_chunks(replica_chunks[2:3], 200) 

487 _check_chunks([]) 

488 

489 # try to remove some of those 

490 deleted_chunks = replica_chunks[:1] 

491 apdb_replica.deleteReplicaChunks(chunk.id for chunk in deleted_chunks) 

492 

493 # All queries on deleted ids should return empty set. 

494 _check_chunks(deleted_chunks, 0) 

495 

496 replica_chunks = apdb_replica.getReplicaChunks() 

497 assert replica_chunks is not None 

498 self.assertEqual(len(replica_chunks), 3) 

499 

500 _check_chunks(replica_chunks, 600) 

501 

502 def test_storeSSObjects(self) -> None: 

503 """Store and retrieve SSObjects.""" 

504 # don't care about sources. 

505 config = self.make_instance() 

506 apdb = Apdb.from_config(config) 

507 

508 # make catalog with SSObjects 

509 catalog = makeSSObjectCatalog(100, flags=1) 

510 

511 # store catalog 

512 apdb.storeSSObjects(catalog) 

513 

514 # read it back and check sizes 

515 res = apdb.getSSObjects() 

516 self.assert_catalog(res, len(catalog), ApdbTables.SSObject) 

517 

518 # check that override works, make catalog with SSObjects, ID = 51-150 

519 catalog = makeSSObjectCatalog(100, 51, flags=2) 

520 apdb.storeSSObjects(catalog) 

521 res = apdb.getSSObjects() 

522 self.assert_catalog(res, 150, ApdbTables.SSObject) 

523 self.assertEqual(len(res[res["flags"] == 1]), 50) 

524 self.assertEqual(len(res[res["flags"] == 2]), 100) 

525 

526 def test_reassignObjects(self) -> None: 

527 """Reassign DiaObjects.""" 

528 # don't care about sources. 

529 config = self.make_instance() 

530 apdb = Apdb.from_config(config) 

531 

532 region = _make_region() 

533 visit_time = self.visit_time 

534 objects = makeObjectCatalog(region, 100, visit_time) 

535 oids = list(objects["diaObjectId"]) 

536 sources = makeSourceCatalog(objects, visit_time) 

537 apdb.store(visit_time, objects, sources) 

538 

539 catalog = makeSSObjectCatalog(100) 

540 apdb.storeSSObjects(catalog) 

541 

542 # read it back and filter by ID 

543 res = apdb.getDiaSources(region, oids, visit_time) 

544 self.assert_catalog(res, len(sources), ApdbTables.DiaSource) 

545 

546 apdb.reassignDiaSources({1: 1, 2: 2, 5: 5}) 

547 res = apdb.getDiaSources(region, oids, visit_time) 

548 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource) 

549 

550 with self.assertRaisesRegex(ValueError, r"do not exist.*\D1000"): 

551 apdb.reassignDiaSources( 

552 { 

553 1000: 1, 

554 7: 3, 

555 } 

556 ) 

557 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource) 

558 

559 def test_midpointMjdTai_src(self) -> None: 

560 """Test for time filtering of DiaSources.""" 

561 config = self.make_instance() 

562 apdb = Apdb.from_config(config) 

563 

564 region = _make_region() 

565 # 2021-01-01 plus 360 days is 2021-12-27 

566 src_time1 = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai") 

567 src_time2 = astropy.time.Time("2021-01-01T00:00:02", format="isot", scale="tai") 

568 visit_time0 = astropy.time.Time("2021-12-26T23:59:59", format="isot", scale="tai") 

569 visit_time1 = astropy.time.Time("2021-12-27T00:00:01", format="isot", scale="tai") 

570 visit_time2 = astropy.time.Time("2021-12-27T00:00:03", format="isot", scale="tai") 

571 

572 objects = makeObjectCatalog(region, 100, visit_time0) 

573 oids = list(objects["diaObjectId"]) 

574 sources = makeSourceCatalog(objects, src_time1, 0) 

575 apdb.store(src_time1, objects, sources) 

576 

577 sources = makeSourceCatalog(objects, src_time2, 100) 

578 apdb.store(src_time2, objects, sources) 

579 

580 # reading at time of last save should read all 

581 res = apdb.getDiaSources(region, oids, src_time2) 

582 self.assert_catalog(res, 200, ApdbTables.DiaSource) 

583 

584 # one second before 12 months 

585 res = apdb.getDiaSources(region, oids, visit_time0) 

586 self.assert_catalog(res, 200, ApdbTables.DiaSource) 

587 

588 # reading at later time of last save should only read a subset 

589 res = apdb.getDiaSources(region, oids, visit_time1) 

590 self.assert_catalog(res, 100, ApdbTables.DiaSource) 

591 

592 # reading at later time of last save should only read a subset 

593 res = apdb.getDiaSources(region, oids, visit_time2) 

594 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

595 

596 def test_midpointMjdTai_fsrc(self) -> None: 

597 """Test for time filtering of DiaForcedSources.""" 

598 config = self.make_instance() 

599 apdb = Apdb.from_config(config) 

600 

601 region = _make_region() 

602 src_time1 = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai") 

603 src_time2 = astropy.time.Time("2021-01-01T00:00:02", format="isot", scale="tai") 

604 visit_time0 = astropy.time.Time("2021-12-26T23:59:59", format="isot", scale="tai") 

605 visit_time1 = astropy.time.Time("2021-12-27T00:00:01", format="isot", scale="tai") 

606 visit_time2 = astropy.time.Time("2021-12-27T00:00:03", format="isot", scale="tai") 

607 

608 objects = makeObjectCatalog(region, 100, visit_time0) 

609 oids = list(objects["diaObjectId"]) 

610 sources = makeForcedSourceCatalog(objects, src_time1, 1) 

611 apdb.store(src_time1, objects, forced_sources=sources) 

612 

613 sources = makeForcedSourceCatalog(objects, src_time2, 2) 

614 apdb.store(src_time2, objects, forced_sources=sources) 

615 

616 # reading at time of last save should read all 

617 res = apdb.getDiaForcedSources(region, oids, src_time2) 

618 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource) 

619 

620 # one second before 12 months 

621 res = apdb.getDiaForcedSources(region, oids, visit_time0) 

622 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource) 

623 

624 # reading at later time of last save should only read a subset 

625 res = apdb.getDiaForcedSources(region, oids, visit_time1) 

626 self.assert_catalog(res, 100, ApdbTables.DiaForcedSource) 

627 

628 # reading at later time of last save should only read a subset 

629 res = apdb.getDiaForcedSources(region, oids, visit_time2) 

630 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

631 

632 def test_metadata(self) -> None: 

633 """Simple test for writing/reading metadata table""" 

634 config = self.make_instance() 

635 apdb = Apdb.from_config(config) 

636 metadata = apdb.metadata 

637 

638 # APDB should write two or three metadata items with version numbers 

639 # and a frozen JSON config. 

640 self.assertFalse(metadata.empty()) 

641 expected_rows = 4 if self.enable_replica else 3 

642 self.assertEqual(len(list(metadata.items())), expected_rows) 

643 

644 metadata.set("meta", "data") 

645 metadata.set("data", "meta") 

646 

647 self.assertFalse(metadata.empty()) 

648 self.assertTrue(set(metadata.items()) >= {("meta", "data"), ("data", "meta")}) 

649 

650 with self.assertRaisesRegex(KeyError, "Metadata key 'meta' already exists"): 

651 metadata.set("meta", "data1") 

652 

653 metadata.set("meta", "data2", force=True) 

654 self.assertTrue(set(metadata.items()) >= {("meta", "data2"), ("data", "meta")}) 

655 

656 self.assertTrue(metadata.delete("meta")) 

657 self.assertIsNone(metadata.get("meta")) 

658 self.assertFalse(metadata.delete("meta")) 

659 

660 self.assertEqual(metadata.get("data"), "meta") 

661 self.assertEqual(metadata.get("meta", "meta"), "meta") 

662 

663 def test_nometadata(self) -> None: 

664 """Test case for when metadata table is missing""" 

665 # We expect that schema includes metadata table, drop it. 

666 with update_schema_yaml(self.schema_path, drop_metadata=True) as schema_file: 

667 config = self.make_instance(schema_file=schema_file) 

668 apdb = Apdb.from_config(config) 

669 metadata = apdb.metadata 

670 

671 self.assertTrue(metadata.empty()) 

672 self.assertEqual(list(metadata.items()), []) 

673 with self.assertRaisesRegex(RuntimeError, "Metadata table does not exist"): 

674 metadata.set("meta", "data") 

675 

676 self.assertTrue(metadata.empty()) 

677 self.assertIsNone(metadata.get("meta")) 

678 

679 # Also check what happens when configured schema has metadata, but 

680 # database is missing it. Database was initialized inside above context 

681 # without metadata table, here we use schema config which includes 

682 # metadata table. 

683 config.schema_file = self.schema_path 

684 apdb = Apdb.from_config(config) 

685 metadata = apdb.metadata 

686 self.assertTrue(metadata.empty()) 

687 

688 def test_schemaVersionFromYaml(self) -> None: 

689 """Check version number handling for reading schema from YAML.""" 

690 config = self.make_instance() 

691 default_schema = config.schema_file 

692 apdb = Apdb.from_config(config) 

693 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(0, 1, 1)) 

694 

695 with update_schema_yaml(default_schema, version="") as schema_file: 

696 config = self.make_instance(schema_file=schema_file) 

697 apdb = Apdb.from_config(config) 

698 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(0, 1, 0)) 

699 

700 with update_schema_yaml(default_schema, version="99.0.0") as schema_file: 

701 config = self.make_instance(schema_file=schema_file) 

702 apdb = Apdb.from_config(config) 

703 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(99, 0, 0)) 

704 

705 def test_config_freeze(self) -> None: 

706 """Test that some config fields are correctly frozen in database.""" 

707 config = self.make_instance() 

708 

709 # `use_insert_id` is the only parameter that is frozen in all 

710 # implementations. 

711 config.use_insert_id = not self.enable_replica 

712 apdb = Apdb.from_config(config) 

713 frozen_config = apdb.config # type: ignore[attr-defined] 

714 self.assertEqual(frozen_config.use_insert_id, self.enable_replica) 

715 

716 

717class ApdbSchemaUpdateTest(TestCaseMixin, ABC): 

718 """Base class for unit tests that verify how schema changes work.""" 

719 

720 visit_time = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai") 

721 

722 @abstractmethod 

723 def make_instance(self, **kwargs: Any) -> ApdbConfig: 

724 """Make config class instance used in all tests. 

725 

726 This method should return configuration that point to the identical 

727 database instance on each call (i.e. ``db_url`` must be the same, 

728 which also means for sqlite it has to use on-disk storage). 

729 """ 

730 raise NotImplementedError() 

731 

732 def test_schema_add_replica(self) -> None: 

733 """Check that new code can work with old schema without replica 

734 tables. 

735 """ 

736 # Make schema without replica tables. 

737 config = self.make_instance(use_insert_id=False) 

738 apdb = Apdb.from_config(config) 

739 apdb_replica = ApdbReplica.from_config(config) 

740 

741 # Make APDB instance configured for replication. 

742 config.use_insert_id = True 

743 apdb = Apdb.from_config(config) 

744 

745 # Try to insert something, should work OK. 

746 region = _make_region() 

747 visit_time = self.visit_time 

748 

749 # have to store Objects first 

750 objects = makeObjectCatalog(region, 100, visit_time) 

751 sources = makeSourceCatalog(objects, visit_time) 

752 fsources = makeForcedSourceCatalog(objects, visit_time) 

753 apdb.store(visit_time, objects, sources, fsources) 

754 

755 # There should be no replica chunks. 

756 replica_chunks = apdb_replica.getReplicaChunks() 

757 self.assertIsNone(replica_chunks) 

758 

759 def test_schemaVersionCheck(self) -> None: 

760 """Check version number compatibility.""" 

761 config = self.make_instance() 

762 apdb = Apdb.from_config(config) 

763 

764 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(0, 1, 1)) 

765 

766 # Claim that schema version is now 99.0.0, must raise an exception. 

767 with update_schema_yaml(config.schema_file, version="99.0.0") as schema_file: 

768 config.schema_file = schema_file 

769 with self.assertRaises(IncompatibleVersionError): 

770 apdb = Apdb.from_config(config)