Coverage for python/lsst/dax/apdb/tests/_apdb.py: 14%

382 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-02-20 11:02 +0000

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ApdbSchemaUpdateTest", "ApdbTest", "update_schema_yaml"] 

25 

26import contextlib 

27import os 

28import unittest 

29from abc import ABC, abstractmethod 

30from collections.abc import Iterator 

31from tempfile import TemporaryDirectory 

32from typing import TYPE_CHECKING, Any 

33 

34import pandas 

35import yaml 

36from lsst.daf.base import DateTime 

37from lsst.dax.apdb import ( 

38 ApdbConfig, 

39 ApdbInsertId, 

40 ApdbSql, 

41 ApdbTableData, 

42 ApdbTables, 

43 IncompatibleVersionError, 

44 VersionTuple, 

45 make_apdb, 

46) 

47from lsst.sphgeom import Angle, Circle, Region, UnitVector3d 

48 

49from .data_factory import makeForcedSourceCatalog, makeObjectCatalog, makeSourceCatalog, makeSSObjectCatalog 

50 

51if TYPE_CHECKING: 51 ↛ 53line 51 didn't jump to line 53, because the condition on line 51 was never true

52 

53 class TestCaseMixin(unittest.TestCase): 

54 """Base class for mixin test classes that use TestCase methods.""" 

55 

56else: 

57 

58 class TestCaseMixin: 

59 """Do-nothing definition of mixin base class for regular execution.""" 

60 

61 

62def _make_region(xyz: tuple[float, float, float] = (1.0, 1.0, -1.0)) -> Region: 

63 """Make a region to use in tests""" 

64 pointing_v = UnitVector3d(*xyz) 

65 fov = 0.05 # radians 

66 region = Circle(pointing_v, Angle(fov / 2)) 

67 return region 

68 

69 

70@contextlib.contextmanager 

71def update_schema_yaml( 

72 schema_file: str, 

73 drop_metadata: bool = False, 

74 version: str | None = None, 

75) -> Iterator[str]: 

76 """Update schema definition and return name of the new schema file. 

77 

78 Parameters 

79 ---------- 

80 schema_file : `str` 

81 Path for the existing YAML file with APDB schema. 

82 drop_metadata : `bool` 

83 If `True` then remove metadata table from the list of tables. 

84 version : `str` or `None` 

85 If non-empty string then set schema version to this string, if empty 

86 string then remove schema version from config, if `None` - don't change 

87 the version in config. 

88 

89 Yields 

90 ------ 

91 Path for the updated configuration file. 

92 """ 

93 with open(schema_file) as yaml_stream: 

94 schemas_list = list(yaml.load_all(yaml_stream, Loader=yaml.SafeLoader)) 

95 # Edit YAML contents. 

96 for schema in schemas_list: 

97 # Optionally drop metadata table. 

98 if drop_metadata: 

99 schema["tables"] = [table for table in schema["tables"] if table["name"] != "metadata"] 

100 if version is not None: 

101 if version == "": 

102 del schema["version"] 

103 else: 

104 schema["version"] = version 

105 

106 with TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir: 

107 output_path = os.path.join(tmpdir, "schema.yaml") 

108 with open(output_path, "w") as yaml_stream: 

109 yaml.dump_all(schemas_list, stream=yaml_stream) 

110 yield output_path 

111 

112 

113class ApdbTest(TestCaseMixin, ABC): 

114 """Base class for Apdb tests that can be specialized for concrete 

115 implementation. 

116 

117 This can only be used as a mixin class for a unittest.TestCase and it 

118 calls various assert methods. 

119 """ 

120 

121 time_partition_tables = False 

122 visit_time = DateTime("2021-01-01T00:00:00", DateTime.TAI) 

123 

124 fsrc_requires_id_list = False 

125 """Should be set to True if getDiaForcedSources requires object IDs""" 

126 

127 use_insert_id: bool = False 

128 """Set to true when support for Insert IDs is configured""" 

129 

130 allow_visit_query: bool = True 

131 """Set to true when contains is implemented""" 

132 

133 # number of columns as defined in tests/config/schema.yaml 

134 table_column_count = { 

135 ApdbTables.DiaObject: 8, 

136 ApdbTables.DiaObjectLast: 5, 

137 ApdbTables.DiaSource: 10, 

138 ApdbTables.DiaForcedSource: 4, 

139 ApdbTables.SSObject: 3, 

140 } 

141 

142 @abstractmethod 

143 def make_config(self, **kwargs: Any) -> ApdbConfig: 

144 """Make config class instance used in all tests.""" 

145 raise NotImplementedError() 

146 

147 @abstractmethod 

148 def getDiaObjects_table(self) -> ApdbTables: 

149 """Return type of table returned from getDiaObjects method.""" 

150 raise NotImplementedError() 

151 

152 def assert_catalog(self, catalog: Any, rows: int, table: ApdbTables) -> None: 

153 """Validate catalog type and size 

154 

155 Parameters 

156 ---------- 

157 catalog : `object` 

158 Expected type of this is ``pandas.DataFrame``. 

159 rows : `int` 

160 Expected number of rows in a catalog. 

161 table : `ApdbTables` 

162 APDB table type. 

163 """ 

164 self.assertIsInstance(catalog, pandas.DataFrame) 

165 self.assertEqual(catalog.shape[0], rows) 

166 self.assertEqual(catalog.shape[1], self.table_column_count[table]) 

167 

168 def assert_table_data(self, catalog: Any, rows: int, table: ApdbTables) -> None: 

169 """Validate catalog type and size 

170 

171 Parameters 

172 ---------- 

173 catalog : `object` 

174 Expected type of this is `ApdbTableData`. 

175 rows : `int` 

176 Expected number of rows in a catalog. 

177 table : `ApdbTables` 

178 APDB table type. 

179 extra_columns : `int` 

180 Count of additional columns expected in ``catalog``. 

181 """ 

182 self.assertIsInstance(catalog, ApdbTableData) 

183 n_rows = sum(1 for row in catalog.rows()) 

184 self.assertEqual(n_rows, rows) 

185 # One extra column for insert_id 

186 self.assertEqual(len(catalog.column_names()), self.table_column_count[table] + 1) 

187 

188 def test_makeSchema(self) -> None: 

189 """Test for making APDB schema.""" 

190 config = self.make_config() 

191 apdb = make_apdb(config) 

192 

193 apdb.makeSchema() 

194 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject)) 

195 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObjectLast)) 

196 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaSource)) 

197 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaForcedSource)) 

198 self.assertIsNotNone(apdb.tableDef(ApdbTables.metadata)) 

199 

200 def test_empty_gets(self) -> None: 

201 """Test for getting data from empty database. 

202 

203 All get() methods should return empty results, only useful for 

204 checking that code is not broken. 

205 """ 

206 # use non-zero months for Forced/Source fetching 

207 config = self.make_config() 

208 apdb = make_apdb(config) 

209 apdb.makeSchema() 

210 

211 region = _make_region() 

212 visit_time = self.visit_time 

213 

214 res: pandas.DataFrame | None 

215 

216 # get objects by region 

217 res = apdb.getDiaObjects(region) 

218 self.assert_catalog(res, 0, self.getDiaObjects_table()) 

219 

220 # get sources by region 

221 res = apdb.getDiaSources(region, None, visit_time) 

222 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

223 

224 res = apdb.getDiaSources(region, [], visit_time) 

225 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

226 

227 # get sources by object ID, non-empty object list 

228 res = apdb.getDiaSources(region, [1, 2, 3], visit_time) 

229 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

230 

231 # get forced sources by object ID, empty object list 

232 res = apdb.getDiaForcedSources(region, [], visit_time) 

233 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

234 

235 # get sources by object ID, non-empty object list 

236 res = apdb.getDiaForcedSources(region, [1, 2, 3], visit_time) 

237 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

238 

239 # test if a visit has objects/sources 

240 if self.allow_visit_query: 

241 res = apdb.containsVisitDetector(visit=0, detector=0) 

242 self.assertFalse(res) 

243 else: 

244 with self.assertRaises(NotImplementedError): 

245 apdb.containsVisitDetector(visit=0, detector=0) 

246 

247 # alternative method not part of the Apdb API 

248 if isinstance(apdb, ApdbSql): 

249 res = apdb.containsCcdVisit(1) 

250 self.assertFalse(res) 

251 

252 # get sources by region 

253 if self.fsrc_requires_id_list: 

254 with self.assertRaises(NotImplementedError): 

255 apdb.getDiaForcedSources(region, None, visit_time) 

256 else: 

257 apdb.getDiaForcedSources(region, None, visit_time) 

258 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

259 

260 def test_empty_gets_0months(self) -> None: 

261 """Test for getting data from empty database. 

262 

263 All get() methods should return empty DataFrame or None. 

264 """ 

265 # set read_sources_months to 0 so that Forced/Sources are None 

266 config = self.make_config(read_sources_months=0, read_forced_sources_months=0) 

267 apdb = make_apdb(config) 

268 apdb.makeSchema() 

269 

270 region = _make_region() 

271 visit_time = self.visit_time 

272 

273 res: pandas.DataFrame | None 

274 

275 # get objects by region 

276 res = apdb.getDiaObjects(region) 

277 self.assert_catalog(res, 0, self.getDiaObjects_table()) 

278 

279 # get sources by region 

280 res = apdb.getDiaSources(region, None, visit_time) 

281 self.assertIs(res, None) 

282 

283 # get sources by object ID, empty object list 

284 res = apdb.getDiaSources(region, [], visit_time) 

285 self.assertIs(res, None) 

286 

287 # get forced sources by object ID, empty object list 

288 res = apdb.getDiaForcedSources(region, [], visit_time) 

289 self.assertIs(res, None) 

290 

291 # test if a visit has objects/sources 

292 if self.allow_visit_query: 

293 res = apdb.containsVisitDetector(visit=0, detector=0) 

294 self.assertFalse(res) 

295 else: 

296 with self.assertRaises(NotImplementedError): 

297 apdb.containsVisitDetector(visit=0, detector=0) 

298 

299 # alternative method not part of the Apdb API 

300 if isinstance(apdb, ApdbSql): 

301 res = apdb.containsCcdVisit(1) 

302 self.assertFalse(res) 

303 

304 def test_storeObjects(self) -> None: 

305 """Store and retrieve DiaObjects.""" 

306 # don't care about sources. 

307 config = self.make_config() 

308 apdb = make_apdb(config) 

309 apdb.makeSchema() 

310 

311 region = _make_region() 

312 visit_time = self.visit_time 

313 

314 # make catalog with Objects 

315 catalog = makeObjectCatalog(region, 100, visit_time) 

316 

317 # store catalog 

318 apdb.store(visit_time, catalog) 

319 

320 # read it back and check sizes 

321 res = apdb.getDiaObjects(region) 

322 self.assert_catalog(res, len(catalog), self.getDiaObjects_table()) 

323 

324 # TODO: test apdb.contains with generic implementation from DM-41671 

325 

326 def test_storeSources(self) -> None: 

327 """Store and retrieve DiaSources.""" 

328 config = self.make_config() 

329 apdb = make_apdb(config) 

330 apdb.makeSchema() 

331 

332 region = _make_region() 

333 visit_time = self.visit_time 

334 

335 # have to store Objects first 

336 objects = makeObjectCatalog(region, 100, visit_time) 

337 oids = list(objects["diaObjectId"]) 

338 sources = makeSourceCatalog(objects, visit_time) 

339 

340 # save the objects and sources 

341 apdb.store(visit_time, objects, sources) 

342 

343 # read it back, no ID filtering 

344 res = apdb.getDiaSources(region, None, visit_time) 

345 self.assert_catalog(res, len(sources), ApdbTables.DiaSource) 

346 

347 # read it back and filter by ID 

348 res = apdb.getDiaSources(region, oids, visit_time) 

349 self.assert_catalog(res, len(sources), ApdbTables.DiaSource) 

350 

351 # read it back to get schema 

352 res = apdb.getDiaSources(region, [], visit_time) 

353 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

354 

355 # test if a visit is present 

356 # data_factory's ccdVisitId generation corresponds to (0, 0) 

357 if self.allow_visit_query: 

358 res = apdb.containsVisitDetector(visit=0, detector=0) 

359 self.assertTrue(res) 

360 else: 

361 with self.assertRaises(NotImplementedError): 

362 apdb.containsVisitDetector(visit=0, detector=0) 

363 

364 # alternative method not part of the Apdb API 

365 if isinstance(apdb, ApdbSql): 

366 res = apdb.containsCcdVisit(1) 

367 self.assertTrue(res) 

368 res = apdb.containsCcdVisit(42) 

369 self.assertFalse(res) 

370 

371 def test_storeForcedSources(self) -> None: 

372 """Store and retrieve DiaForcedSources.""" 

373 config = self.make_config() 

374 apdb = make_apdb(config) 

375 apdb.makeSchema() 

376 

377 region = _make_region() 

378 visit_time = self.visit_time 

379 

380 # have to store Objects first 

381 objects = makeObjectCatalog(region, 100, visit_time) 

382 oids = list(objects["diaObjectId"]) 

383 catalog = makeForcedSourceCatalog(objects, visit_time) 

384 

385 apdb.store(visit_time, objects, forced_sources=catalog) 

386 

387 # read it back and check sizes 

388 res = apdb.getDiaForcedSources(region, oids, visit_time) 

389 self.assert_catalog(res, len(catalog), ApdbTables.DiaForcedSource) 

390 

391 # read it back to get schema 

392 res = apdb.getDiaForcedSources(region, [], visit_time) 

393 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

394 

395 # TODO: test apdb.contains with generic implementation from DM-41671 

396 

397 # alternative method not part of the Apdb API 

398 if isinstance(apdb, ApdbSql): 

399 res = apdb.containsCcdVisit(1) 

400 self.assertTrue(res) 

401 res = apdb.containsCcdVisit(42) 

402 self.assertFalse(res) 

403 

404 def test_getHistory(self) -> None: 

405 """Store and retrieve catalog history.""" 

406 # don't care about sources. 

407 config = self.make_config() 

408 apdb = make_apdb(config) 

409 apdb.makeSchema() 

410 visit_time = self.visit_time 

411 

412 region1 = _make_region((1.0, 1.0, -1.0)) 

413 region2 = _make_region((-1.0, -1.0, -1.0)) 

414 nobj = 100 

415 objects1 = makeObjectCatalog(region1, nobj, visit_time) 

416 objects2 = makeObjectCatalog(region2, nobj, visit_time, start_id=nobj * 2) 

417 

418 visits = [ 

419 (DateTime("2021-01-01T00:01:00", DateTime.TAI), objects1), 

420 (DateTime("2021-01-01T00:02:00", DateTime.TAI), objects2), 

421 (DateTime("2021-01-01T00:03:00", DateTime.TAI), objects1), 

422 (DateTime("2021-01-01T00:04:00", DateTime.TAI), objects2), 

423 (DateTime("2021-01-01T00:05:00", DateTime.TAI), objects1), 

424 (DateTime("2021-01-01T00:06:00", DateTime.TAI), objects2), 

425 (DateTime("2021-03-01T00:01:00", DateTime.TAI), objects1), 

426 (DateTime("2021-03-01T00:02:00", DateTime.TAI), objects2), 

427 ] 

428 

429 start_id = 0 

430 for visit_time, objects in visits: 

431 sources = makeSourceCatalog(objects, visit_time, start_id=start_id) 

432 fsources = makeForcedSourceCatalog(objects, visit_time, ccdVisitId=start_id) 

433 apdb.store(visit_time, objects, sources, fsources) 

434 start_id += nobj 

435 

436 insert_ids = apdb.getInsertIds() 

437 if not self.use_insert_id: 

438 self.assertIsNone(insert_ids) 

439 

440 with self.assertRaisesRegex(ValueError, "APDB is not configured for history retrieval"): 

441 apdb.getDiaObjectsHistory([]) 

442 

443 else: 

444 assert insert_ids is not None 

445 self.assertEqual(len(insert_ids), 8) 

446 

447 def _check_history(insert_ids: list[ApdbInsertId], n_records: int | None = None) -> None: 

448 if n_records is None: 

449 n_records = len(insert_ids) * nobj 

450 res = apdb.getDiaObjectsHistory(insert_ids) 

451 self.assert_table_data(res, n_records, ApdbTables.DiaObject) 

452 res = apdb.getDiaSourcesHistory(insert_ids) 

453 self.assert_table_data(res, n_records, ApdbTables.DiaSource) 

454 res = apdb.getDiaForcedSourcesHistory(insert_ids) 

455 self.assert_table_data(res, n_records, ApdbTables.DiaForcedSource) 

456 

457 # read it back and check sizes 

458 _check_history(insert_ids) 

459 _check_history(insert_ids[1:]) 

460 _check_history(insert_ids[1:-1]) 

461 _check_history(insert_ids[3:4]) 

462 _check_history([]) 

463 

464 # try to remove some of those 

465 deleted_ids = insert_ids[:2] 

466 apdb.deleteInsertIds(deleted_ids) 

467 

468 # All queries on deleted ids should return empty set. 

469 _check_history(deleted_ids, 0) 

470 

471 insert_ids = apdb.getInsertIds() 

472 assert insert_ids is not None 

473 self.assertEqual(len(insert_ids), 6) 

474 

475 _check_history(insert_ids) 

476 

477 def test_storeSSObjects(self) -> None: 

478 """Store and retrieve SSObjects.""" 

479 # don't care about sources. 

480 config = self.make_config() 

481 apdb = make_apdb(config) 

482 apdb.makeSchema() 

483 

484 # make catalog with SSObjects 

485 catalog = makeSSObjectCatalog(100, flags=1) 

486 

487 # store catalog 

488 apdb.storeSSObjects(catalog) 

489 

490 # read it back and check sizes 

491 res = apdb.getSSObjects() 

492 self.assert_catalog(res, len(catalog), ApdbTables.SSObject) 

493 

494 # check that override works, make catalog with SSObjects, ID = 51-150 

495 catalog = makeSSObjectCatalog(100, 51, flags=2) 

496 apdb.storeSSObjects(catalog) 

497 res = apdb.getSSObjects() 

498 self.assert_catalog(res, 150, ApdbTables.SSObject) 

499 self.assertEqual(len(res[res["flags"] == 1]), 50) 

500 self.assertEqual(len(res[res["flags"] == 2]), 100) 

501 

502 def test_reassignObjects(self) -> None: 

503 """Reassign DiaObjects.""" 

504 # don't care about sources. 

505 config = self.make_config() 

506 apdb = make_apdb(config) 

507 apdb.makeSchema() 

508 

509 region = _make_region() 

510 visit_time = self.visit_time 

511 objects = makeObjectCatalog(region, 100, visit_time) 

512 oids = list(objects["diaObjectId"]) 

513 sources = makeSourceCatalog(objects, visit_time) 

514 apdb.store(visit_time, objects, sources) 

515 

516 catalog = makeSSObjectCatalog(100) 

517 apdb.storeSSObjects(catalog) 

518 

519 # read it back and filter by ID 

520 res = apdb.getDiaSources(region, oids, visit_time) 

521 self.assert_catalog(res, len(sources), ApdbTables.DiaSource) 

522 

523 apdb.reassignDiaSources({1: 1, 2: 2, 5: 5}) 

524 res = apdb.getDiaSources(region, oids, visit_time) 

525 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource) 

526 

527 with self.assertRaisesRegex(ValueError, r"do not exist.*\D1000"): 

528 apdb.reassignDiaSources( 

529 { 

530 1000: 1, 

531 7: 3, 

532 } 

533 ) 

534 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource) 

535 

536 def test_midpointMjdTai_src(self) -> None: 

537 """Test for time filtering of DiaSources.""" 

538 config = self.make_config() 

539 apdb = make_apdb(config) 

540 apdb.makeSchema() 

541 

542 region = _make_region() 

543 # 2021-01-01 plus 360 days is 2021-12-27 

544 src_time1 = DateTime("2021-01-01T00:00:00", DateTime.TAI) 

545 src_time2 = DateTime("2021-01-01T00:00:02", DateTime.TAI) 

546 visit_time0 = DateTime("2021-12-26T23:59:59", DateTime.TAI) 

547 visit_time1 = DateTime("2021-12-27T00:00:01", DateTime.TAI) 

548 visit_time2 = DateTime("2021-12-27T00:00:03", DateTime.TAI) 

549 

550 objects = makeObjectCatalog(region, 100, visit_time0) 

551 oids = list(objects["diaObjectId"]) 

552 sources = makeSourceCatalog(objects, src_time1, 0) 

553 apdb.store(src_time1, objects, sources) 

554 

555 sources = makeSourceCatalog(objects, src_time2, 100) 

556 apdb.store(src_time2, objects, sources) 

557 

558 # reading at time of last save should read all 

559 res = apdb.getDiaSources(region, oids, src_time2) 

560 self.assert_catalog(res, 200, ApdbTables.DiaSource) 

561 

562 # one second before 12 months 

563 res = apdb.getDiaSources(region, oids, visit_time0) 

564 self.assert_catalog(res, 200, ApdbTables.DiaSource) 

565 

566 # reading at later time of last save should only read a subset 

567 res = apdb.getDiaSources(region, oids, visit_time1) 

568 self.assert_catalog(res, 100, ApdbTables.DiaSource) 

569 

570 # reading at later time of last save should only read a subset 

571 res = apdb.getDiaSources(region, oids, visit_time2) 

572 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

573 

574 def test_midpointMjdTai_fsrc(self) -> None: 

575 """Test for time filtering of DiaForcedSources.""" 

576 config = self.make_config() 

577 apdb = make_apdb(config) 

578 apdb.makeSchema() 

579 

580 region = _make_region() 

581 src_time1 = DateTime("2021-01-01T00:00:00", DateTime.TAI) 

582 src_time2 = DateTime("2021-01-01T00:00:02", DateTime.TAI) 

583 visit_time0 = DateTime("2021-12-26T23:59:59", DateTime.TAI) 

584 visit_time1 = DateTime("2021-12-27T00:00:01", DateTime.TAI) 

585 visit_time2 = DateTime("2021-12-27T00:00:03", DateTime.TAI) 

586 

587 objects = makeObjectCatalog(region, 100, visit_time0) 

588 oids = list(objects["diaObjectId"]) 

589 sources = makeForcedSourceCatalog(objects, src_time1, 1) 

590 apdb.store(src_time1, objects, forced_sources=sources) 

591 

592 sources = makeForcedSourceCatalog(objects, src_time2, 2) 

593 apdb.store(src_time2, objects, forced_sources=sources) 

594 

595 # reading at time of last save should read all 

596 res = apdb.getDiaForcedSources(region, oids, src_time2) 

597 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource) 

598 

599 # one second before 12 months 

600 res = apdb.getDiaForcedSources(region, oids, visit_time0) 

601 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource) 

602 

603 # reading at later time of last save should only read a subset 

604 res = apdb.getDiaForcedSources(region, oids, visit_time1) 

605 self.assert_catalog(res, 100, ApdbTables.DiaForcedSource) 

606 

607 # reading at later time of last save should only read a subset 

608 res = apdb.getDiaForcedSources(region, oids, visit_time2) 

609 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

610 

611 def test_metadata(self) -> None: 

612 """Simple test for writing/reading metadata table""" 

613 config = self.make_config() 

614 apdb = make_apdb(config) 

615 apdb.makeSchema() 

616 metadata = apdb.metadata 

617 

618 # APDB should write two metadata items with version numbers. 

619 self.assertFalse(metadata.empty()) 

620 self.assertEqual(len(list(metadata.items())), 2) 

621 

622 metadata.set("meta", "data") 

623 metadata.set("data", "meta") 

624 

625 self.assertFalse(metadata.empty()) 

626 self.assertTrue(set(metadata.items()) >= {("meta", "data"), ("data", "meta")}) 

627 

628 with self.assertRaisesRegex(KeyError, "Metadata key 'meta' already exists"): 

629 metadata.set("meta", "data1") 

630 

631 metadata.set("meta", "data2", force=True) 

632 self.assertTrue(set(metadata.items()) >= {("meta", "data2"), ("data", "meta")}) 

633 

634 self.assertTrue(metadata.delete("meta")) 

635 self.assertIsNone(metadata.get("meta")) 

636 self.assertFalse(metadata.delete("meta")) 

637 

638 self.assertEqual(metadata.get("data"), "meta") 

639 self.assertEqual(metadata.get("meta", "meta"), "meta") 

640 

641 def test_nometadata(self) -> None: 

642 """Test case for when metadata table is missing""" 

643 config = self.make_config() 

644 # We expect that schema includes metadata table, drop it. 

645 with update_schema_yaml(config.schema_file, drop_metadata=True) as schema_file: 

646 config = self.make_config(schema_file=schema_file) 

647 apdb = make_apdb(config) 

648 apdb.makeSchema() 

649 metadata = apdb.metadata 

650 

651 self.assertTrue(metadata.empty()) 

652 self.assertEqual(list(metadata.items()), []) 

653 with self.assertRaisesRegex(RuntimeError, "Metadata table does not exist"): 

654 metadata.set("meta", "data") 

655 

656 self.assertTrue(metadata.empty()) 

657 self.assertIsNone(metadata.get("meta")) 

658 

659 def test_schemaVersionFromYaml(self) -> None: 

660 """Check version number handling for reading schema from YAML.""" 

661 config = self.make_config() 

662 default_schema = config.schema_file 

663 apdb = make_apdb(config) 

664 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(0, 1, 1)) 

665 

666 with update_schema_yaml(default_schema, version="") as schema_file: 

667 config = self.make_config(schema_file=schema_file) 

668 apdb = make_apdb(config) 

669 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(0, 1, 0)) 

670 

671 with update_schema_yaml(default_schema, version="99.0.0") as schema_file: 

672 config = self.make_config(schema_file=schema_file) 

673 apdb = make_apdb(config) 

674 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(99, 0, 0)) 

675 

676 

677class ApdbSchemaUpdateTest(TestCaseMixin, ABC): 

678 """Base class for unit tests that verify how schema changes work.""" 

679 

680 visit_time = DateTime("2021-01-01T00:00:00", DateTime.TAI) 

681 

682 @abstractmethod 

683 def make_config(self, **kwargs: Any) -> ApdbConfig: 

684 """Make config class instance used in all tests. 

685 

686 This method should return configuration that point to the identical 

687 database instance on each call (i.e. ``db_url`` must be the same, 

688 which also means for sqlite it has to use on-disk storage). 

689 """ 

690 raise NotImplementedError() 

691 

692 def test_schema_add_history(self) -> None: 

693 """Check that new code can work with old schema without history 

694 tables. 

695 """ 

696 # Make schema without history tables. 

697 config = self.make_config(use_insert_id=False) 

698 apdb = make_apdb(config) 

699 apdb.makeSchema() 

700 

701 # Make APDB instance configured for history tables. 

702 config = self.make_config(use_insert_id=True) 

703 apdb = make_apdb(config) 

704 

705 # Try to insert something, should work OK. 

706 region = _make_region() 

707 visit_time = self.visit_time 

708 

709 # have to store Objects first 

710 objects = makeObjectCatalog(region, 100, visit_time) 

711 sources = makeSourceCatalog(objects, visit_time) 

712 fsources = makeForcedSourceCatalog(objects, visit_time) 

713 apdb.store(visit_time, objects, sources, fsources) 

714 

715 # There should be no history. 

716 insert_ids = apdb.getInsertIds() 

717 self.assertIsNone(insert_ids) 

718 

719 def test_schemaVersionCheck(self) -> None: 

720 """Check version number compatibility.""" 

721 config = self.make_config() 

722 apdb = make_apdb(config) 

723 

724 self.assertEqual(apdb.apdbSchemaVersion(), VersionTuple(0, 1, 1)) 

725 apdb.makeSchema() 

726 

727 # Claim that schema version is now 99.0.0, must raise an exception. 

728 with update_schema_yaml(config.schema_file, version="99.0.0") as schema_file: 

729 config = self.make_config(schema_file=schema_file) 

730 with self.assertRaises(IncompatibleVersionError): 

731 apdb = make_apdb(config)