Coverage for python/lsst/dax/apdb/tests/_apdb.py: 14%

392 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-02 11:13 +0000

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ApdbSchemaUpdateTest", "ApdbTest", "update_schema_yaml"] 

25 

26import contextlib 

27import os 

28import unittest 

29from abc import ABC, abstractmethod 

30from collections.abc import Iterator 

31from tempfile import TemporaryDirectory 

32from typing import TYPE_CHECKING, Any 

33 

34import pandas 

35import yaml 

36from lsst.daf.base import DateTime 

37from lsst.dax.apdb import ( 

38 Apdb, 

39 ApdbConfig, 

40 ApdbInsertId, 

41 ApdbSql, 

42 ApdbTableData, 

43 ApdbTables, 

44 IncompatibleVersionError, 

45 VersionTuple, 

46 make_apdb, 

47) 

48from lsst.sphgeom import Angle, Circle, Region, UnitVector3d 

49 

50from .data_factory import makeForcedSourceCatalog, makeObjectCatalog, makeSourceCatalog, makeSSObjectCatalog 

51 

52if TYPE_CHECKING: 52 ↛ 54line 52 didn't jump to line 54, because the condition on line 52 was never true

53 

54 class TestCaseMixin(unittest.TestCase): 

55 """Base class for mixin test classes that use TestCase methods.""" 

56 

57else: 

58 

59 class TestCaseMixin: 

60 """Do-nothing definition of mixin base class for regular execution.""" 

61 

62 

63def _make_region(xyz: tuple[float, float, float] = (1.0, 1.0, -1.0)) -> Region: 

64 """Make a region to use in tests""" 

65 pointing_v = UnitVector3d(*xyz) 

66 fov = 0.05 # radians 

67 region = Circle(pointing_v, Angle(fov / 2)) 

68 return region 

69 

70 

71@contextlib.contextmanager 

72def update_schema_yaml( 

73 schema_file: str, 

74 drop_metadata: bool = False, 

75 version: str | None = None, 

76) -> Iterator[str]: 

77 """Update schema definition and return name of the new schema file. 

78 

79 Parameters 

80 ---------- 

81 schema_file : `str` 

82 Path for the existing YAML file with APDB schema. 

83 drop_metadata : `bool` 

84 If `True` then remove metadata table from the list of tables. 

85 version : `str` or `None` 

86 If non-empty string then set schema version to this string, if empty 

87 string then remove schema version from config, if `None` - don't change 

88 the version in config. 

89 

90 Yields 

91 ------ 

92 Path for the updated configuration file. 

93 """ 

94 with open(schema_file) as yaml_stream: 

95 schemas_list = list(yaml.load_all(yaml_stream, Loader=yaml.SafeLoader)) 

96 # Edit YAML contents. 

97 for schema in schemas_list: 

98 # Optionally drop metadata table. 

99 if drop_metadata: 

100 schema["tables"] = [table for table in schema["tables"] if table["name"] != "metadata"] 

101 if version is not None: 

102 if version == "": 

103 del schema["version"] 

104 else: 

105 schema["version"] = version 

106 

107 with TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir: 

108 output_path = os.path.join(tmpdir, "schema.yaml") 

109 with open(output_path, "w") as yaml_stream: 

110 yaml.dump_all(schemas_list, stream=yaml_stream) 

111 yield output_path 

112 

113 

114class ApdbTest(TestCaseMixin, ABC): 

115 """Base class for Apdb tests that can be specialized for concrete 

116 implementation. 

117 

118 This can only be used as a mixin class for a unittest.TestCase and it 

119 calls various assert methods. 

120 """ 

121 

122 time_partition_tables = False 

123 visit_time = DateTime("2021-01-01T00:00:00", DateTime.TAI) 

124 

125 fsrc_requires_id_list = False 

126 """Should be set to True if getDiaForcedSources requires object IDs""" 

127 

128 use_insert_id: bool = False 

129 """Set to true when support for Insert IDs is configured""" 

130 

131 allow_visit_query: bool = True 

132 """Set to true when contains is implemented""" 

133 

134 # number of columns as defined in tests/config/schema.yaml 

135 table_column_count = { 

136 ApdbTables.DiaObject: 8, 

137 ApdbTables.DiaObjectLast: 5, 

138 ApdbTables.DiaSource: 10, 

139 ApdbTables.DiaForcedSource: 4, 

140 ApdbTables.SSObject: 3, 

141 } 

142 

143 @abstractmethod 

144 def make_config(self, **kwargs: Any) -> ApdbConfig: 

145 """Make config class instance used in all tests.""" 

146 raise NotImplementedError() 

147 

148 @abstractmethod 

149 def getDiaObjects_table(self) -> ApdbTables: 

150 """Return type of table returned from getDiaObjects method.""" 

151 raise NotImplementedError() 

152 

153 def assert_catalog(self, catalog: Any, rows: int, table: ApdbTables) -> None: 

154 """Validate catalog type and size 

155 

156 Parameters 

157 ---------- 

158 catalog : `object` 

159 Expected type of this is ``pandas.DataFrame``. 

160 rows : `int` 

161 Expected number of rows in a catalog. 

162 table : `ApdbTables` 

163 APDB table type. 

164 """ 

165 self.assertIsInstance(catalog, pandas.DataFrame) 

166 self.assertEqual(catalog.shape[0], rows) 

167 self.assertEqual(catalog.shape[1], self.table_column_count[table]) 

168 

169 def assert_table_data(self, catalog: Any, rows: int, table: ApdbTables) -> None: 

170 """Validate catalog type and size 

171 

172 Parameters 

173 ---------- 

174 catalog : `object` 

175 Expected type of this is `ApdbTableData`. 

176 rows : `int` 

177 Expected number of rows in a catalog. 

178 table : `ApdbTables` 

179 APDB table type. 

180 extra_columns : `int` 

181 Count of additional columns expected in ``catalog``. 

182 """ 

183 self.assertIsInstance(catalog, ApdbTableData) 

184 n_rows = sum(1 for row in catalog.rows()) 

185 self.assertEqual(n_rows, rows) 

186 # One extra column for insert_id 

187 self.assertEqual(len(catalog.column_names()), self.table_column_count[table] + 1) 

188 

189 def test_makeSchema(self) -> None: 

190 """Test for making APDB schema.""" 

191 config = self.make_config() 

192 Apdb.makeSchema(config) 

193 apdb = make_apdb(config) 

194 

195 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject)) 

196 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObjectLast)) 

197 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaSource)) 

198 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaForcedSource)) 

199 self.assertIsNotNone(apdb.tableDef(ApdbTables.metadata)) 

200 

201 def test_empty_gets(self) -> None: 

202 """Test for getting data from empty database. 

203 

204 All get() methods should return empty results, only useful for 

205 checking that code is not broken. 

206 """ 

207 # use non-zero months for Forced/Source fetching 

208 config = self.make_config() 

209 Apdb.makeSchema(config) 

210 apdb = make_apdb(config) 

211 

212 region = _make_region() 

213 visit_time = self.visit_time 

214 

215 res: pandas.DataFrame | None 

216 

217 # get objects by region 

218 res = apdb.getDiaObjects(region) 

219 self.assert_catalog(res, 0, self.getDiaObjects_table()) 

220 

221 # get sources by region 

222 res = apdb.getDiaSources(region, None, visit_time) 

223 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

224 

225 res = apdb.getDiaSources(region, [], visit_time) 

226 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

227 

228 # get sources by object ID, non-empty object list 

229 res = apdb.getDiaSources(region, [1, 2, 3], visit_time) 

230 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

231 

232 # get forced sources by object ID, empty object list 

233 res = apdb.getDiaForcedSources(region, [], visit_time) 

234 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

235 

236 # get sources by object ID, non-empty object list 

237 res = apdb.getDiaForcedSources(region, [1, 2, 3], visit_time) 

238 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

239 

240 # test if a visit has objects/sources 

241 if self.allow_visit_query: 

242 res = apdb.containsVisitDetector(visit=0, detector=0) 

243 self.assertFalse(res) 

244 else: 

245 with self.assertRaises(NotImplementedError): 

246 apdb.containsVisitDetector(visit=0, detector=0) 

247 

248 # alternative method not part of the Apdb API 

249 if isinstance(apdb, ApdbSql): 

250 res = apdb.containsCcdVisit(1) 

251 self.assertFalse(res) 

252 

253 # get sources by region 

254 if self.fsrc_requires_id_list: 

255 with self.assertRaises(NotImplementedError): 

256 apdb.getDiaForcedSources(region, None, visit_time) 

257 else: 

258 apdb.getDiaForcedSources(region, None, visit_time) 

259 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

260 

261 def test_empty_gets_0months(self) -> None: 

262 """Test for getting data from empty database. 

263 

264 All get() methods should return empty DataFrame or None. 

265 """ 

266 # set read_sources_months to 0 so that Forced/Sources are None 

267 config = self.make_config(read_sources_months=0, read_forced_sources_months=0) 

268 Apdb.makeSchema(config) 

269 apdb = make_apdb(config) 

270 

271 region = _make_region() 

272 visit_time = self.visit_time 

273 

274 res: pandas.DataFrame | None 

275 

276 # get objects by region 

277 res = apdb.getDiaObjects(region) 

278 self.assert_catalog(res, 0, self.getDiaObjects_table()) 

279 

280 # get sources by region 

281 res = apdb.getDiaSources(region, None, visit_time) 

282 self.assertIs(res, None) 

283 

284 # get sources by object ID, empty object list 

285 res = apdb.getDiaSources(region, [], visit_time) 

286 self.assertIs(res, None) 

287 

288 # get forced sources by object ID, empty object list 

289 res = apdb.getDiaForcedSources(region, [], visit_time) 

290 self.assertIs(res, None) 

291 

292 # test if a visit has objects/sources 

293 if self.allow_visit_query: 

294 res = apdb.containsVisitDetector(visit=0, detector=0) 

295 self.assertFalse(res) 

296 else: 

297 with self.assertRaises(NotImplementedError): 

298 apdb.containsVisitDetector(visit=0, detector=0) 

299 

300 # alternative method not part of the Apdb API 

301 if isinstance(apdb, ApdbSql): 

302 res = apdb.containsCcdVisit(1) 

303 self.assertFalse(res) 

304 

305 def test_storeObjects(self) -> None: 

306 """Store and retrieve DiaObjects.""" 

307 # don't care about sources. 

308 config = self.make_config() 

309 Apdb.makeSchema(config) 

310 apdb = make_apdb(config) 

311 

312 region = _make_region() 

313 visit_time = self.visit_time 

314 

315 # make catalog with Objects 

316 catalog = makeObjectCatalog(region, 100, visit_time) 

317 

318 # store catalog 

319 apdb.store(visit_time, catalog) 

320 

321 # read it back and check sizes 

322 res = apdb.getDiaObjects(region) 

323 self.assert_catalog(res, len(catalog), self.getDiaObjects_table()) 

324 

325 # TODO: test apdb.contains with generic implementation from DM-41671 

326 

327 def test_storeSources(self) -> None: 

328 """Store and retrieve DiaSources.""" 

329 config = self.make_config() 

330 Apdb.makeSchema(config) 

331 apdb = make_apdb(config) 

332 

333 region = _make_region() 

334 visit_time = self.visit_time 

335 

336 # have to store Objects first 

337 objects = makeObjectCatalog(region, 100, visit_time) 

338 oids = list(objects["diaObjectId"]) 

339 sources = makeSourceCatalog(objects, visit_time) 

340 

341 # save the objects and sources 

342 apdb.store(visit_time, objects, sources) 

343 

344 # read it back, no ID filtering 

345 res = apdb.getDiaSources(region, None, visit_time) 

346 self.assert_catalog(res, len(sources), ApdbTables.DiaSource) 

347 

348 # read it back and filter by ID 

349 res = apdb.getDiaSources(region, oids, visit_time) 

350 self.assert_catalog(res, len(sources), ApdbTables.DiaSource) 

351 

352 # read it back to get schema 

353 res = apdb.getDiaSources(region, [], visit_time) 

354 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

355 

356 # test if a visit is present 

357 # data_factory's ccdVisitId generation corresponds to (0, 0) 

358 if self.allow_visit_query: 

359 res = apdb.containsVisitDetector(visit=0, detector=0) 

360 self.assertTrue(res) 

361 else: 

362 with self.assertRaises(NotImplementedError): 

363 apdb.containsVisitDetector(visit=0, detector=0) 

364 

365 # alternative method not part of the Apdb API 

366 if isinstance(apdb, ApdbSql): 

367 res = apdb.containsCcdVisit(1) 

368 self.assertTrue(res) 

369 res = apdb.containsCcdVisit(42) 

370 self.assertFalse(res) 

371 

372 def test_storeForcedSources(self) -> None: 

373 """Store and retrieve DiaForcedSources.""" 

374 config = self.make_config() 

375 Apdb.makeSchema(config) 

376 apdb = make_apdb(config) 

377 

378 region = _make_region() 

379 visit_time = self.visit_time 

380 

381 # have to store Objects first 

382 objects = makeObjectCatalog(region, 100, visit_time) 

383 oids = list(objects["diaObjectId"]) 

384 catalog = makeForcedSourceCatalog(objects, visit_time) 

385 

386 apdb.store(visit_time, objects, forced_sources=catalog) 

387 

388 # read it back and check sizes 

389 res = apdb.getDiaForcedSources(region, oids, visit_time) 

390 self.assert_catalog(res, len(catalog), ApdbTables.DiaForcedSource) 

391 

392 # read it back to get schema 

393 res = apdb.getDiaForcedSources(region, [], visit_time) 

394 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

395 

396 # TODO: test apdb.contains with generic implementation from DM-41671 

397 

398 # alternative method not part of the Apdb API 

399 if isinstance(apdb, ApdbSql): 

400 res = apdb.containsCcdVisit(1) 

401 self.assertTrue(res) 

402 res = apdb.containsCcdVisit(42) 

403 self.assertFalse(res) 

404 

405 def test_getHistory(self) -> None: 

406 """Store and retrieve catalog history.""" 

407 # don't care about sources. 

408 config = self.make_config() 

409 Apdb.makeSchema(config) 

410 apdb = make_apdb(config) 

411 visit_time = self.visit_time 

412 

413 region1 = _make_region((1.0, 1.0, -1.0)) 

414 region2 = _make_region((-1.0, -1.0, -1.0)) 

415 nobj = 100 

416 objects1 = makeObjectCatalog(region1, nobj, visit_time) 

417 objects2 = makeObjectCatalog(region2, nobj, visit_time, start_id=nobj * 2) 

418 

419 visits = [ 

420 (DateTime("2021-01-01T00:01:00", DateTime.TAI), objects1), 

421 (DateTime("2021-01-01T00:02:00", DateTime.TAI), objects2), 

422 (DateTime("2021-01-01T00:03:00", DateTime.TAI), objects1), 

423 (DateTime("2021-01-01T00:04:00", DateTime.TAI), objects2), 

424 (DateTime("2021-01-01T00:05:00", DateTime.TAI), objects1), 

425 (DateTime("2021-01-01T00:06:00", DateTime.TAI), objects2), 

426 (DateTime("2021-03-01T00:01:00", DateTime.TAI), objects1), 

427 (DateTime("2021-03-01T00:02:00", DateTime.TAI), objects2), 

428 ] 

429 

430 start_id = 0 

431 for visit_time, objects in visits: 

432 sources = makeSourceCatalog(objects, visit_time, start_id=start_id) 

433 fsources = makeForcedSourceCatalog(objects, visit_time, ccdVisitId=start_id) 

434 apdb.store(visit_time, objects, sources, fsources) 

435 start_id += nobj 

436 

437 insert_ids = apdb.getInsertIds() 

438 if not self.use_insert_id: 

439 self.assertIsNone(insert_ids) 

440 

441 with self.assertRaisesRegex(ValueError, "APDB is not configured for history retrieval"): 

442 apdb.getDiaObjectsHistory([]) 

443 

444 else: 

445 assert insert_ids is not None 

446 self.assertEqual(len(insert_ids), 8) 

447 

448 def _check_history(insert_ids: list[ApdbInsertId], n_records: int | None = None) -> None: 

449 if n_records is None: 

450 n_records = len(insert_ids) * nobj 

451 res = apdb.getDiaObjectsHistory(insert_ids) 

452 self.assert_table_data(res, n_records, ApdbTables.DiaObject) 

453 res = apdb.getDiaSourcesHistory(insert_ids) 

454 self.assert_table_data(res, n_records, ApdbTables.DiaSource) 

455 res = apdb.getDiaForcedSourcesHistory(insert_ids) 

456 self.assert_table_data(res, n_records, ApdbTables.DiaForcedSource) 

457 

458 # read it back and check sizes 

459 _check_history(insert_ids) 

460 _check_history(insert_ids[1:]) 

461 _check_history(insert_ids[1:-1]) 

462 _check_history(insert_ids[3:4]) 

463 _check_history([]) 

464 

465 # try to remove some of those 

466 deleted_ids = insert_ids[:2] 

467 apdb.deleteInsertIds(deleted_ids) 

468 

469 # All queries on deleted ids should return empty set. 

470 _check_history(deleted_ids, 0) 

471 

472 insert_ids = apdb.getInsertIds() 

473 assert insert_ids is not None 

474 self.assertEqual(len(insert_ids), 6) 

475 

476 _check_history(insert_ids) 

477 

478 def test_storeSSObjects(self) -> None: 

479 """Store and retrieve SSObjects.""" 

480 # don't care about sources. 

481 config = self.make_config() 

482 Apdb.makeSchema(config) 

483 apdb = make_apdb(config) 

484 

485 # make catalog with SSObjects 

486 catalog = makeSSObjectCatalog(100, flags=1) 

487 

488 # store catalog 

489 apdb.storeSSObjects(catalog) 

490 

491 # read it back and check sizes 

492 res = apdb.getSSObjects() 

493 self.assert_catalog(res, len(catalog), ApdbTables.SSObject) 

494 

495 # check that override works, make catalog with SSObjects, ID = 51-150 

496 catalog = makeSSObjectCatalog(100, 51, flags=2) 

497 apdb.storeSSObjects(catalog) 

498 res = apdb.getSSObjects() 

499 self.assert_catalog(res, 150, ApdbTables.SSObject) 

500 self.assertEqual(len(res[res["flags"] == 1]), 50) 

501 self.assertEqual(len(res[res["flags"] == 2]), 100) 

502 

503 def test_reassignObjects(self) -> None: 

504 """Reassign DiaObjects.""" 

505 # don't care about sources. 

506 config = self.make_config() 

507 Apdb.makeSchema(config) 

508 apdb = make_apdb(config) 

509 

510 region = _make_region() 

511 visit_time = self.visit_time 

512 objects = makeObjectCatalog(region, 100, visit_time) 

513 oids = list(objects["diaObjectId"]) 

514 sources = makeSourceCatalog(objects, visit_time) 

515 apdb.store(visit_time, objects, sources) 

516 

517 catalog = makeSSObjectCatalog(100) 

518 apdb.storeSSObjects(catalog) 

519 

520 # read it back and filter by ID 

521 res = apdb.getDiaSources(region, oids, visit_time) 

522 self.assert_catalog(res, len(sources), ApdbTables.DiaSource) 

523 

524 apdb.reassignDiaSources({1: 1, 2: 2, 5: 5}) 

525 res = apdb.getDiaSources(region, oids, visit_time) 

526 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource) 

527 

528 with self.assertRaisesRegex(ValueError, r"do not exist.*\D1000"): 

529 apdb.reassignDiaSources( 

530 { 

531 1000: 1, 

532 7: 3, 

533 } 

534 ) 

535 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource) 

536 

537 def test_midpointMjdTai_src(self) -> None: 

538 """Test for time filtering of DiaSources.""" 

539 config = self.make_config() 

540 Apdb.makeSchema(config) 

541 apdb = make_apdb(config) 

542 

543 region = _make_region() 

544 # 2021-01-01 plus 360 days is 2021-12-27 

545 src_time1 = DateTime("2021-01-01T00:00:00", DateTime.TAI) 

546 src_time2 = DateTime("2021-01-01T00:00:02", DateTime.TAI) 

547 visit_time0 = DateTime("2021-12-26T23:59:59", DateTime.TAI) 

548 visit_time1 = DateTime("2021-12-27T00:00:01", DateTime.TAI) 

549 visit_time2 = DateTime("2021-12-27T00:00:03", DateTime.TAI) 

550 

551 objects = makeObjectCatalog(region, 100, visit_time0) 

552 oids = list(objects["diaObjectId"]) 

553 sources = makeSourceCatalog(objects, src_time1, 0) 

554 apdb.store(src_time1, objects, sources) 

555 

556 sources = makeSourceCatalog(objects, src_time2, 100) 

557 apdb.store(src_time2, objects, sources) 

558 

559 # reading at time of last save should read all 

560 res = apdb.getDiaSources(region, oids, src_time2) 

561 self.assert_catalog(res, 200, ApdbTables.DiaSource) 

562 

563 # one second before 12 months 

564 res = apdb.getDiaSources(region, oids, visit_time0) 

565 self.assert_catalog(res, 200, ApdbTables.DiaSource) 

566 

567 # reading at later time of last save should only read a subset 

568 res = apdb.getDiaSources(region, oids, visit_time1) 

569 self.assert_catalog(res, 100, ApdbTables.DiaSource) 

570 

571 # reading at later time of last save should only read a subset 

572 res = apdb.getDiaSources(region, oids, visit_time2) 

573 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

574 

575 def test_midpointMjdTai_fsrc(self) -> None: 

576 """Test for time filtering of DiaForcedSources.""" 

577 config = self.make_config() 

578 Apdb.makeSchema(config) 

579 apdb = make_apdb(config) 

580 

581 region = _make_region() 

582 src_time1 = DateTime("2021-01-01T00:00:00", DateTime.TAI) 

583 src_time2 = DateTime("2021-01-01T00:00:02", DateTime.TAI) 

584 visit_time0 = DateTime("2021-12-26T23:59:59", DateTime.TAI) 

585 visit_time1 = DateTime("2021-12-27T00:00:01", DateTime.TAI) 

586 visit_time2 = DateTime("2021-12-27T00:00:03", DateTime.TAI) 

587 

588 objects = makeObjectCatalog(region, 100, visit_time0) 

589 oids = list(objects["diaObjectId"]) 

590 sources = makeForcedSourceCatalog(objects, src_time1, 1) 

591 apdb.store(src_time1, objects, forced_sources=sources) 

592 

593 sources = makeForcedSourceCatalog(objects, src_time2, 2) 

594 apdb.store(src_time2, objects, forced_sources=sources) 

595 

596 # reading at time of last save should read all 

597 res = apdb.getDiaForcedSources(region, oids, src_time2) 

598 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource) 

599 

600 # one second before 12 months 

601 res = apdb.getDiaForcedSources(region, oids, visit_time0) 

602 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource) 

603 

604 # reading at later time of last save should only read a subset 

605 res = apdb.getDiaForcedSources(region, oids, visit_time1) 

606 self.assert_catalog(res, 100, ApdbTables.DiaForcedSource) 

607 

608 # reading at later time of last save should only read a subset 

609 res = apdb.getDiaForcedSources(region, oids, visit_time2) 

610 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

611 

612 def test_metadata(self) -> None: 

613 """Simple test for writing/reading metadata table""" 

614 config = self.make_config() 

615 Apdb.makeSchema(config) 

616 apdb = make_apdb(config) 

617 metadata = apdb.metadata 

618 

619 # APDB should write two metadata items with version numbers and a 

620 # frozen JSON config. 

621 self.assertFalse(metadata.empty()) 

622 self.assertEqual(len(list(metadata.items())), 3) 

623 

624 metadata.set("meta", "data") 

625 metadata.set("data", "meta") 

626 

627 self.assertFalse(metadata.empty()) 

628 self.assertTrue(set(metadata.items()) >= {("meta", "data"), ("data", "meta")}) 

629 

630 with self.assertRaisesRegex(KeyError, "Metadata key 'meta' already exists"): 

631 metadata.set("meta", "data1") 

632 

633 metadata.set("meta", "data2", force=True) 

634 self.assertTrue(set(metadata.items()) >= {("meta", "data2"), ("data", "meta")}) 

635 

636 self.assertTrue(metadata.delete("meta")) 

637 self.assertIsNone(metadata.get("meta")) 

638 self.assertFalse(metadata.delete("meta")) 

639 

640 self.assertEqual(metadata.get("data"), "meta") 

641 self.assertEqual(metadata.get("meta", "meta"), "meta") 

642 

643 def test_nometadata(self) -> None: 

644 """Test case for when metadata table is missing""" 

645 config = self.make_config() 

646 # We expect that schema includes metadata table, drop it. 

647 with update_schema_yaml(config.schema_file, drop_metadata=True) as schema_file: 

648 config_nometa = self.make_config(schema_file=schema_file) 

649 Apdb.makeSchema(config_nometa) 

650 apdb = make_apdb(config_nometa) 

651 metadata = apdb.metadata 

652 

653 self.assertTrue(metadata.empty()) 

654 self.assertEqual(list(metadata.items()), []) 

655 with self.assertRaisesRegex(RuntimeError, "Metadata table does not exist"): 

656 metadata.set("meta", "data") 

657 

658 self.assertTrue(metadata.empty()) 

659 self.assertIsNone(metadata.get("meta")) 

660 

661 # Also check what happens when configured schema has metadata, but 

662 # database is missing it. Database was initialized inside above context 

663 # without metadata table, here we use schema config which includes 

664 # metadata table. 

665 apdb = make_apdb(config) 

666 metadata = apdb.metadata 

667 self.assertTrue(metadata.empty()) 

668 

669 def test_schemaVersionFromYaml(self) -> None: 

670 """Check version number handling for reading schema from YAML.""" 

671 config = self.make_config() 

672 default_schema = config.schema_file 

673 apdb = make_apdb(config) 

674 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(0, 1, 1)) 

675 

676 with update_schema_yaml(default_schema, version="") as schema_file: 

677 config = self.make_config(schema_file=schema_file) 

678 apdb = make_apdb(config) 

679 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(0, 1, 0)) 

680 

681 with update_schema_yaml(default_schema, version="99.0.0") as schema_file: 

682 config = self.make_config(schema_file=schema_file) 

683 apdb = make_apdb(config) 

684 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(99, 0, 0)) 

685 

686 def test_config_freeze(self) -> None: 

687 """Test that some config fields are correctly frozen in database.""" 

688 config = self.make_config() 

689 Apdb.makeSchema(config) 

690 

691 # `use_insert_id` is the only parameter that is frozen in all 

692 # implementations. 

693 config.use_insert_id = not self.use_insert_id 

694 apdb = make_apdb(config) 

695 frozen_config = apdb.config # type: ignore[attr-defined] 

696 self.assertEqual(frozen_config.use_insert_id, self.use_insert_id) 

697 

698 

699class ApdbSchemaUpdateTest(TestCaseMixin, ABC): 

700 """Base class for unit tests that verify how schema changes work.""" 

701 

702 visit_time = DateTime("2021-01-01T00:00:00", DateTime.TAI) 

703 

704 @abstractmethod 

705 def make_config(self, **kwargs: Any) -> ApdbConfig: 

706 """Make config class instance used in all tests. 

707 

708 This method should return configuration that point to the identical 

709 database instance on each call (i.e. ``db_url`` must be the same, 

710 which also means for sqlite it has to use on-disk storage). 

711 """ 

712 raise NotImplementedError() 

713 

714 def test_schema_add_history(self) -> None: 

715 """Check that new code can work with old schema without history 

716 tables. 

717 """ 

718 # Make schema without history tables. 

719 config = self.make_config(use_insert_id=False) 

720 Apdb.makeSchema(config) 

721 apdb = make_apdb(config) 

722 

723 # Make APDB instance configured for history tables. 

724 config = self.make_config(use_insert_id=True) 

725 apdb = make_apdb(config) 

726 

727 # Try to insert something, should work OK. 

728 region = _make_region() 

729 visit_time = self.visit_time 

730 

731 # have to store Objects first 

732 objects = makeObjectCatalog(region, 100, visit_time) 

733 sources = makeSourceCatalog(objects, visit_time) 

734 fsources = makeForcedSourceCatalog(objects, visit_time) 

735 apdb.store(visit_time, objects, sources, fsources) 

736 

737 # There should be no history. 

738 insert_ids = apdb.getInsertIds() 

739 self.assertIsNone(insert_ids) 

740 

741 def test_schemaVersionCheck(self) -> None: 

742 """Check version number compatibility.""" 

743 config = self.make_config() 

744 Apdb.makeSchema(config) 

745 apdb = make_apdb(config) 

746 

747 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(0, 1, 1)) 

748 

749 # Claim that schema version is now 99.0.0, must raise an exception. 

750 with update_schema_yaml(config.schema_file, version="99.0.0") as schema_file: 

751 config = self.make_config(schema_file=schema_file) 

752 with self.assertRaises(IncompatibleVersionError): 

753 apdb = make_apdb(config)