Coverage for python / lsst / dax / apdb / tests / _apdb.py: 15%

441 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-22 08:49 +0000

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ApdbSchemaUpdateTest", "ApdbTest", "update_schema_yaml"] 

25 

26import contextlib 

27import os 

28import tempfile 

29from abc import ABC, abstractmethod 

30from collections.abc import Iterator 

31from tempfile import TemporaryDirectory 

32from typing import TYPE_CHECKING, Any 

33 

34import astropy.time 

35import felis.datamodel 

36import pandas 

37import yaml 

38 

39from lsst.sphgeom import Angle, Circle, LonLat, Region, UnitVector3d 

40 

41from .. import ( 

42 Apdb, 

43 ApdbConfig, 

44 ApdbReassignDiaSourceRecord, 

45 ApdbReplica, 

46 ApdbTableData, 

47 ApdbTables, 

48 ApdbUpdateRecord, 

49 ApdbWithdrawDiaSourceRecord, 

50 IncompatibleVersionError, 

51 ReplicaChunk, 

52 VersionTuple, 

53) 

54from .data_factory import ( 

55 makeForcedSourceCatalog, 

56 makeObjectCatalog, 

57 makeSourceCatalog, 

58 makeTimestampNow, 

59) 

60from .utils import TestCaseMixin 

61 

62if TYPE_CHECKING: 

63 from ..pixelization import Pixelization 

64 

65 

66def _make_region(xyz: tuple[float, float, float] = (1.0, 1.0, -1.0)) -> Region: 

67 """Make a region to use in tests""" 

68 pointing_v = UnitVector3d(*xyz) 

69 fov = 0.0013 # radians 

70 region = Circle(pointing_v, Angle(fov / 2)) 

71 return region 

72 

73 

74@contextlib.contextmanager 

75def update_schema_yaml( 

76 schema_file: str, 

77 drop_metadata: bool = False, 

78 version: str | None = None, 

79) -> Iterator[str]: 

80 """Update schema definition and return name of the new schema file. 

81 

82 Parameters 

83 ---------- 

84 schema_file : `str` 

85 Path for the existing YAML file with APDB schema. 

86 drop_metadata : `bool` 

87 If `True` then remove metadata table from the list of tables. 

88 version : `str` or `None` 

89 If non-empty string then set schema version to this string, if empty 

90 string then remove schema version from config, if `None` - don't change 

91 the version in config. 

92 

93 Yields 

94 ------ 

95 Path for the updated configuration file. 

96 """ 

97 with open(schema_file) as yaml_stream: 

98 schemas_list = list(yaml.load_all(yaml_stream, Loader=yaml.SafeLoader)) 

99 # Edit YAML contents. 

100 for schema in schemas_list: 

101 # Optionally drop metadata table. 

102 if drop_metadata: 

103 schema["tables"] = [table for table in schema["tables"] if table["name"] != "metadata"] 

104 if version is not None: 

105 if version == "": 

106 del schema["version"] 

107 else: 

108 schema["version"] = version 

109 

110 with TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir: 

111 output_path = os.path.join(tmpdir, "schema.yaml") 

112 with open(output_path, "w") as yaml_stream: 

113 yaml.dump_all(schemas_list, stream=yaml_stream) 

114 yield output_path 

115 

116 

117class ApdbTest(TestCaseMixin, ABC): 

118 """Base class for Apdb tests that can be specialized for concrete 

119 implementation. 

120 

121 This can only be used as a mixin class for a unittest.TestCase and it 

122 calls various assert methods. 

123 """ 

124 

125 visit_time = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai") 

126 

127 fsrc_requires_id_list = False 

128 """Should be set to True if getDiaForcedSources requires object IDs""" 

129 

130 enable_replica: bool = False 

131 """Set to true when support for replication is configured""" 

132 

133 use_mjd: bool = True 

134 """If True then timestamp columns are MJD TAI.""" 

135 

136 extra_chunk_columns = 1 

137 """Number of additional columns in chunk tables.""" 

138 

139 meta_row_count = 3 

140 """Initial row count in metadata table.""" 

141 

142 # number of columns as defined in tests/config/schema.yaml 

143 table_column_count = { 

144 ApdbTables.DiaObject: 7, 

145 ApdbTables.DiaObjectLast: 5, 

146 ApdbTables.DiaSource: 12, 

147 ApdbTables.DiaForcedSource: 8, 

148 ApdbTables.SSObject: 3, 

149 } 

150 

151 @abstractmethod 

152 def make_instance(self, **kwargs: Any) -> ApdbConfig: 

153 """Make database instance and return configuration for it.""" 

154 raise NotImplementedError() 

155 

156 @abstractmethod 

157 def getDiaObjects_table(self) -> ApdbTables: 

158 """Return type of table returned from getDiaObjects method.""" 

159 raise NotImplementedError() 

160 

161 @abstractmethod 

162 def pixelization(self, config: ApdbConfig) -> Pixelization: 

163 """Return pixelization used by implementation.""" 

164 raise NotImplementedError() 

165 

166 def assert_catalog(self, catalog: Any, rows: int, table: ApdbTables) -> None: 

167 """Validate catalog type and size 

168 

169 Parameters 

170 ---------- 

171 catalog : `object` 

172 Expected type of this is ``pandas.DataFrame``. 

173 rows : `int` 

174 Expected number of rows in a catalog. 

175 table : `ApdbTables` 

176 APDB table type. 

177 """ 

178 self.assertIsInstance(catalog, pandas.DataFrame) 

179 self.assertEqual(catalog.shape[0], rows) 

180 self.assertEqual(catalog.shape[1], self.table_column_count[table]) 

181 

182 def assert_table_data(self, catalog: Any, rows: int, table: ApdbTables) -> None: 

183 """Validate catalog type and size 

184 

185 Parameters 

186 ---------- 

187 catalog : `object` 

188 Expected type of this is `ApdbTableData`. 

189 rows : `int` 

190 Expected number of rows in a catalog. 

191 table : `ApdbTables` 

192 APDB table type. 

193 extra_columns : `int` 

194 Count of additional columns expected in ``catalog``. 

195 """ 

196 self.assertIsInstance(catalog, ApdbTableData) 

197 n_rows = sum(1 for row in catalog.rows()) 

198 self.assertEqual(n_rows, rows) 

199 # One extra column for replica chunk id 

200 self.assertEqual( 

201 len(catalog.column_names()), self.table_column_count[table] + self.extra_chunk_columns 

202 ) 

203 

204 def assert_column_types(self, catalog: Any, types: dict[str, felis.datamodel.DataType]) -> None: 

205 column_defs = dict(catalog.column_defs()) 

206 for column, datatype in types.items(): 

207 self.assertEqual(column_defs[column], datatype) 

208 

209 def make_region(self, xyz: tuple[float, float, float] = (1.0, 1.0, -1.0)) -> Region: 

210 """Make a region to use in tests""" 

211 return _make_region(xyz) 

212 

213 def test_makeSchema(self) -> None: 

214 """Test for making APDB schema.""" 

215 config = self.make_instance() 

216 apdb = Apdb.from_config(config) 

217 

218 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject)) 

219 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObjectLast)) 

220 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaSource)) 

221 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaForcedSource)) 

222 self.assertIsNotNone(apdb.tableDef(ApdbTables.metadata)) 

223 self.assertIsNotNone(apdb.tableDef(ApdbTables.SSObject)) 

224 self.assertIsNotNone(apdb.tableDef(ApdbTables.SSSource)) 

225 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject_To_Object_Match)) 

226 

227 # Test from_uri factory method with the same config. 

228 with tempfile.NamedTemporaryFile() as tmpfile: 

229 config.save(tmpfile.name) 

230 apdb = Apdb.from_uri(tmpfile.name) 

231 

232 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject)) 

233 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObjectLast)) 

234 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaSource)) 

235 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaForcedSource)) 

236 self.assertIsNotNone(apdb.tableDef(ApdbTables.metadata)) 

237 self.assertIsNotNone(apdb.tableDef(ApdbTables.SSObject)) 

238 self.assertIsNotNone(apdb.tableDef(ApdbTables.SSSource)) 

239 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject_To_Object_Match)) 

240 

241 def test_empty_gets(self) -> None: 

242 """Test for getting data from empty database. 

243 

244 All get() methods should return empty results, only useful for 

245 checking that code is not broken. 

246 """ 

247 # use non-zero months for Forced/Source fetching 

248 config = self.make_instance() 

249 apdb = Apdb.from_config(config) 

250 

251 region = self.make_region() 

252 visit_time = self.visit_time 

253 

254 res: pandas.DataFrame | None 

255 

256 # get objects by region 

257 res = apdb.getDiaObjects(region) 

258 self.assert_catalog(res, 0, self.getDiaObjects_table()) 

259 

260 # get sources by region 

261 res = apdb.getDiaSources(region, None, visit_time) 

262 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

263 

264 res = apdb.getDiaSources(region, [], visit_time) 

265 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

266 

267 # get sources by object ID, non-empty object list 

268 res = apdb.getDiaSources(region, [1, 2, 3], visit_time) 

269 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

270 

271 # get forced sources by object ID, empty object list 

272 res = apdb.getDiaForcedSources(region, [], visit_time) 

273 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

274 

275 # get sources by object ID, non-empty object list 

276 res = apdb.getDiaForcedSources(region, [1, 2, 3], visit_time) 

277 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

278 

279 # data_factory's ccdVisitId generation corresponds to (1, 1) 

280 res = apdb.containsVisitDetector(visit=1, detector=1, region=region, visit_time=visit_time) 

281 self.assertFalse(res) 

282 

283 # get sources by region 

284 if self.fsrc_requires_id_list: 

285 with self.assertRaises(NotImplementedError): 

286 apdb.getDiaForcedSources(region, None, visit_time) 

287 else: 

288 res = apdb.getDiaForcedSources(region, None, visit_time) 

289 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

290 

291 def test_empty_gets_0months(self) -> None: 

292 """Test for getting data from empty database. 

293 

294 All get() methods should return empty DataFrame or None. 

295 """ 

296 # set read_sources_months to 0 so that Forced/Sources are None 

297 config = self.make_instance(read_sources_months=0, read_forced_sources_months=0) 

298 apdb = Apdb.from_config(config) 

299 

300 region = self.make_region() 

301 visit_time = self.visit_time 

302 

303 res: pandas.DataFrame | None 

304 

305 # get objects by region 

306 res = apdb.getDiaObjects(region) 

307 self.assert_catalog(res, 0, self.getDiaObjects_table()) 

308 

309 # get sources by region 

310 res = apdb.getDiaSources(region, None, visit_time) 

311 self.assertIs(res, None) 

312 

313 # get sources by object ID, empty object list 

314 res = apdb.getDiaSources(region, [], visit_time) 

315 self.assertIs(res, None) 

316 

317 # get forced sources by object ID, empty object list 

318 res = apdb.getDiaForcedSources(region, [], visit_time) 

319 self.assertIs(res, None) 

320 

321 # Database is empty, no images exist. 

322 res = apdb.containsVisitDetector(visit=1, detector=1, region=region, visit_time=visit_time) 

323 self.assertFalse(res) 

324 

325 def test_storeObjects(self) -> None: 

326 """Store and retrieve DiaObjects.""" 

327 # don't care about sources. 

328 config = self.make_instance() 

329 apdb = Apdb.from_config(config) 

330 

331 region = self.make_region() 

332 visit_time = self.visit_time 

333 

334 # make catalog with Objects 

335 catalog = makeObjectCatalog(region, 100, visit_time) 

336 

337 # store catalog 

338 apdb.store(visit_time, catalog) 

339 

340 # read it back and check sizes 

341 res = apdb.getDiaObjects(region) 

342 self.assert_catalog(res, len(catalog), self.getDiaObjects_table()) 

343 

344 # TODO: test apdb.contains with generic implementation from DM-41671 

345 

346 def test_storeObjects_empty(self) -> None: 

347 """Test calling storeObject when there are no objects: see DM-43270.""" 

348 config = self.make_instance() 

349 apdb = Apdb.from_config(config) 

350 region = self.make_region() 

351 visit_time = self.visit_time 

352 # make catalog with no Objects 

353 catalog = makeObjectCatalog(region, 0, visit_time) 

354 

355 with self.assertLogs("lsst.dax.apdb", level="DEBUG") as cm: 

356 apdb.store(visit_time, catalog) 

357 self.assertIn("No objects", "\n".join(cm.output)) 

358 

359 def test_storeMovingObject(self) -> None: 

360 """Store and retrieve DiaObject which changes its position.""" 

361 # don't care about sources. 

362 config = self.make_instance() 

363 apdb = Apdb.from_config(config) 

364 pixelization = self.pixelization(config) 

365 

366 lon_deg, lat_deg = 0.0, 0.0 

367 lonlat1 = LonLat.fromDegrees(lon_deg - 1.0, lat_deg) 

368 lonlat2 = LonLat.fromDegrees(lon_deg + 1.0, lat_deg) 

369 uv1 = UnitVector3d(lonlat1) 

370 uv2 = UnitVector3d(lonlat2) 

371 

372 # Check that they fall into different pixels. 

373 self.assertNotEqual(pixelization.pixel(uv1), pixelization.pixel(uv2)) 

374 

375 # Store one object at two different positions. 

376 visit_time1 = self.visit_time 

377 catalog1 = makeObjectCatalog(lonlat1, 1, visit_time1) 

378 apdb.store(visit_time1, catalog1) 

379 

380 visit_time2 = visit_time1 + astropy.time.TimeDelta(120.0, format="sec") 

381 catalog1 = makeObjectCatalog(lonlat2, 1, visit_time2) 

382 apdb.store(visit_time2, catalog1) 

383 

384 # Make region covering both points. 

385 region = Circle(UnitVector3d(LonLat.fromDegrees(lon_deg, lat_deg)), Angle.fromDegrees(1.1)) 

386 self.assertTrue(region.contains(uv1)) 

387 self.assertTrue(region.contains(uv2)) 

388 

389 # Read it back, must return the latest one. 

390 res = apdb.getDiaObjects(region) 

391 self.assert_catalog(res, 1, self.getDiaObjects_table()) 

392 

393 def test_storeSources(self) -> None: 

394 """Store and retrieve DiaSources.""" 

395 config = self.make_instance() 

396 apdb = Apdb.from_config(config) 

397 

398 region = self.make_region() 

399 visit_time = self.visit_time 

400 

401 # have to store Objects first 

402 objects = makeObjectCatalog(region, 100, visit_time) 

403 oids = list(objects["diaObjectId"]) 

404 sources = makeSourceCatalog(objects, visit_time, use_mjd=self.use_mjd) 

405 

406 # save the objects and sources 

407 apdb.store(visit_time, objects, sources) 

408 

409 # read it back, no ID filtering 

410 res = apdb.getDiaSources(region, None, visit_time) 

411 self.assert_catalog(res, len(sources), ApdbTables.DiaSource) 

412 

413 # read it back and filter by ID 

414 res = apdb.getDiaSources(region, oids, visit_time) 

415 self.assert_catalog(res, len(sources), ApdbTables.DiaSource) 

416 

417 # read it back to get schema 

418 res = apdb.getDiaSources(region, [], visit_time) 

419 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

420 

421 # test if a visit is present 

422 # data_factory's ccdVisitId generation corresponds to (1, 1) 

423 res = apdb.containsVisitDetector(visit=1, detector=1, region=region, visit_time=visit_time) 

424 self.assertTrue(res) 

425 # non-existent image 

426 res = apdb.containsVisitDetector(visit=2, detector=42, region=region, visit_time=visit_time) 

427 self.assertFalse(res) 

428 

429 def test_storeForcedSources(self) -> None: 

430 """Store and retrieve DiaForcedSources.""" 

431 config = self.make_instance() 

432 apdb = Apdb.from_config(config) 

433 

434 region = self.make_region() 

435 visit_time = self.visit_time 

436 

437 # have to store Objects first 

438 objects = makeObjectCatalog(region, 100, visit_time) 

439 oids = list(objects["diaObjectId"]) 

440 catalog = makeForcedSourceCatalog(objects, visit_time, use_mjd=self.use_mjd) 

441 

442 apdb.store(visit_time, objects, forced_sources=catalog) 

443 

444 # read it back and check sizes 

445 res = apdb.getDiaForcedSources(region, oids, visit_time) 

446 self.assert_catalog(res, len(catalog), ApdbTables.DiaForcedSource) 

447 

448 # read it back to get schema 

449 res = apdb.getDiaForcedSources(region, [], visit_time) 

450 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

451 

452 # data_factory's ccdVisitId generation corresponds to (1, 1) 

453 res = apdb.containsVisitDetector(visit=1, detector=1, region=region, visit_time=visit_time) 

454 self.assertTrue(res) 

455 # non-existent image 

456 res = apdb.containsVisitDetector(visit=2, detector=42, region=region, visit_time=visit_time) 

457 self.assertFalse(res) 

458 

459 def test_timestamps(self) -> None: 

460 """Check that timestamp return type is as expected.""" 

461 config = self.make_instance() 

462 apdb = Apdb.from_config(config) 

463 

464 region = self.make_region() 

465 visit_time = self.visit_time 

466 

467 # Cassandra has a millisecond precision, so subtract 1ms to allow for 

468 # truncated returned values. 

469 time_before = makeTimestampNow(self.use_mjd, -1) 

470 objects = makeObjectCatalog(region, 100, visit_time) 

471 oids = list(objects["diaObjectId"]) 

472 catalog = makeForcedSourceCatalog(objects, visit_time, use_mjd=self.use_mjd) 

473 time_after = makeTimestampNow(self.use_mjd) 

474 

475 apdb.store(visit_time, objects, forced_sources=catalog) 

476 

477 # read it back and check sizes 

478 res = apdb.getDiaForcedSources(region, oids, visit_time) 

479 assert res is not None 

480 self.assert_catalog(res, len(catalog), ApdbTables.DiaForcedSource) 

481 

482 time_processed_column = "timeProcessedMjdTai" if self.use_mjd else "time_processed" 

483 self.assertIn(time_processed_column, res.dtypes) 

484 dtype = res.dtypes[time_processed_column] 

485 timestamp_type_name = "float64" if self.use_mjd else "datetime64[ns]" 

486 self.assertEqual(dtype.name, timestamp_type_name) 

487 # Verify that returned time is sensible. 

488 self.assertTrue(all(time_before <= dt <= time_after for dt in res[time_processed_column])) 

489 

490 def test_getChunks(self) -> None: 

491 """Store and retrieve replica chunks.""" 

492 # don't care about sources. 

493 config = self.make_instance() 

494 apdb = Apdb.from_config(config) 

495 apdb_replica = ApdbReplica.from_config(config) 

496 visit_time = self.visit_time 

497 

498 region1 = self.make_region((1.0, 1.0, -1.0)) 

499 region2 = self.make_region((-1.0, -1.0, -1.0)) 

500 nobj = 100 

501 objects1 = makeObjectCatalog(region1, nobj, visit_time) 

502 objects2 = makeObjectCatalog(region2, nobj, visit_time, start_id=nobj * 2) 

503 

504 # With the default 10 minutes replica chunk window we should have 4 

505 # records. 

506 visits = [ 

507 (astropy.time.Time("2021-01-01T00:01:00", format="isot", scale="tai"), objects1), 

508 (astropy.time.Time("2021-01-01T00:02:00", format="isot", scale="tai"), objects2), 

509 (astropy.time.Time("2021-01-01T00:11:00", format="isot", scale="tai"), objects1), 

510 (astropy.time.Time("2021-01-01T00:12:00", format="isot", scale="tai"), objects2), 

511 (astropy.time.Time("2021-01-01T00:45:00", format="isot", scale="tai"), objects1), 

512 (astropy.time.Time("2021-01-01T00:46:00", format="isot", scale="tai"), objects2), 

513 (astropy.time.Time("2021-03-01T00:01:00", format="isot", scale="tai"), objects1), 

514 (astropy.time.Time("2021-03-01T00:02:00", format="isot", scale="tai"), objects2), 

515 ] 

516 

517 start_id = 0 

518 for visit_time, objects in visits: 

519 sources = makeSourceCatalog(objects, visit_time, start_id=start_id, use_mjd=self.use_mjd) 

520 fsources = makeForcedSourceCatalog(objects, visit_time, visit=start_id, use_mjd=self.use_mjd) 

521 apdb.store(visit_time, objects, sources, fsources) 

522 start_id += nobj 

523 

524 replica_chunks = apdb_replica.getReplicaChunks() 

525 if not self.enable_replica: 

526 self.assertIsNone(replica_chunks) 

527 

528 with self.assertRaisesRegex(ValueError, "APDB is not configured for replication"): 

529 apdb_replica.getTableDataChunks(ApdbTables.DiaObject, []) 

530 

531 else: 

532 assert replica_chunks is not None 

533 self.assertEqual(len(replica_chunks), 4) 

534 

535 with self.assertRaisesRegex(ValueError, "does not support replica chunks"): 

536 apdb_replica.getTableDataChunks(ApdbTables.SSObject, []) 

537 

538 def _check_chunks(replica_chunks: list[ReplicaChunk], n_records: int | None = None) -> None: 

539 if n_records is None: 

540 n_records = len(replica_chunks) * nobj 

541 res = apdb_replica.getTableDataChunks( 

542 ApdbTables.DiaObject, (chunk.id for chunk in replica_chunks) 

543 ) 

544 self.assert_table_data(res, n_records, ApdbTables.DiaObject) 

545 validityStartColumn = "validityStartMjdTai" if self.use_mjd else "validityStart" 

546 validityStartType = ( 

547 felis.datamodel.DataType.double if self.use_mjd else felis.datamodel.DataType.timestamp 

548 ) 

549 self.assert_column_types( 

550 res, 

551 { 

552 "apdb_replica_chunk": felis.datamodel.DataType.long, 

553 "diaObjectId": felis.datamodel.DataType.long, 

554 validityStartColumn: validityStartType, 

555 "ra": felis.datamodel.DataType.double, 

556 "dec": felis.datamodel.DataType.double, 

557 "parallax": felis.datamodel.DataType.float, 

558 "nDiaSources": felis.datamodel.DataType.int, 

559 }, 

560 ) 

561 

562 res = apdb_replica.getTableDataChunks( 

563 ApdbTables.DiaSource, (chunk.id for chunk in replica_chunks) 

564 ) 

565 self.assert_table_data(res, n_records, ApdbTables.DiaSource) 

566 self.assert_column_types( 

567 res, 

568 { 

569 "apdb_replica_chunk": felis.datamodel.DataType.long, 

570 "diaSourceId": felis.datamodel.DataType.long, 

571 "visit": felis.datamodel.DataType.long, 

572 "detector": felis.datamodel.DataType.short, 

573 }, 

574 ) 

575 

576 res = apdb_replica.getTableDataChunks( 

577 ApdbTables.DiaForcedSource, (chunk.id for chunk in replica_chunks) 

578 ) 

579 self.assert_table_data(res, n_records, ApdbTables.DiaForcedSource) 

580 self.assert_column_types( 

581 res, 

582 { 

583 "apdb_replica_chunk": felis.datamodel.DataType.long, 

584 "diaObjectId": felis.datamodel.DataType.long, 

585 "visit": felis.datamodel.DataType.long, 

586 "detector": felis.datamodel.DataType.short, 

587 }, 

588 ) 

589 

590 # read it back and check sizes 

591 _check_chunks(replica_chunks, 800) 

592 _check_chunks(replica_chunks[1:], 600) 

593 _check_chunks(replica_chunks[1:-1], 400) 

594 _check_chunks(replica_chunks[2:3], 200) 

595 _check_chunks([]) 

596 

597 # try to remove some of those 

598 deleted_chunks = replica_chunks[:1] 

599 apdb_replica.deleteReplicaChunks(chunk.id for chunk in deleted_chunks) 

600 

601 # All queries on deleted ids should return empty set. 

602 _check_chunks(deleted_chunks, 0) 

603 

604 replica_chunks = apdb_replica.getReplicaChunks() 

605 assert replica_chunks is not None 

606 self.assertEqual(len(replica_chunks), 3) 

607 

608 _check_chunks(replica_chunks, 600) 

609 

610 def test_reassignObjects(self) -> None: 

611 """Reassign DiaObjects.""" 

612 # don't care about sources. 

613 config = self.make_instance() 

614 apdb = Apdb.from_config(config) 

615 

616 region = self.make_region() 

617 visit_time = self.visit_time 

618 objects = makeObjectCatalog(region, 100, visit_time) 

619 oids = list(objects["diaObjectId"]) 

620 sources = makeSourceCatalog(objects, visit_time, use_mjd=self.use_mjd) 

621 apdb.store(visit_time, objects, sources) 

622 

623 # read it back and filter by ID 

624 res = apdb.getDiaSources(region, oids, visit_time) 

625 self.assert_catalog(res, len(sources), ApdbTables.DiaSource) 

626 

627 apdb.reassignDiaSources({1: 1, 2: 2, 5: 5}) 

628 res = apdb.getDiaSources(region, oids, visit_time) 

629 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource) 

630 

631 with self.assertRaisesRegex(ValueError, r"do not exist.*\D1000"): 

632 apdb.reassignDiaSources( 

633 { 

634 1000: 1, 

635 7: 3, 

636 } 

637 ) 

638 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource) 

639 

640 def test_storeUpdateRecord(self) -> None: 

641 """Test _storeUpdateRecord() method.""" 

642 config = self.make_instance() 

643 apdb = Apdb.from_config(config) 

644 

645 # Times are totally arbitrary. 

646 update_time_ns1 = 2_000_000_000_000_000_000 

647 update_time_ns2 = 2_000_000_001_000_000_000 

648 records = [ 

649 ApdbReassignDiaSourceRecord( 

650 update_time_ns=update_time_ns1, 

651 update_order=0, 

652 diaSourceId=1, 

653 diaObjectId=321, 

654 ssObjectId=1, 

655 ssObjectReassocTimeMjdTai=60000.0, 

656 ra=45.0, 

657 dec=-45.0, 

658 ), 

659 ApdbWithdrawDiaSourceRecord( 

660 update_time_ns=update_time_ns1, 

661 update_order=1, 

662 diaSourceId=123456, 

663 diaObjectId=321, 

664 timeWithdrawnMjdTai=61000.0, 

665 ra=45.0, 

666 dec=-45.0, 

667 ), 

668 ApdbReassignDiaSourceRecord( 

669 update_time_ns=update_time_ns1, 

670 update_order=3, 

671 diaSourceId=2, 

672 diaObjectId=3, 

673 ssObjectId=3, 

674 ssObjectReassocTimeMjdTai=60000.0, 

675 ra=45.0, 

676 dec=-45.0, 

677 ), 

678 ApdbWithdrawDiaSourceRecord( 

679 update_time_ns=update_time_ns2, 

680 update_order=0, 

681 diaSourceId=123456, 

682 diaObjectId=321, 

683 timeWithdrawnMjdTai=61000.0, 

684 ra=45.0, 

685 dec=-45.0, 

686 ), 

687 ] 

688 

689 update_time = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai") 

690 chunk = ReplicaChunk.make_replica_chunk(update_time, 600) 

691 

692 if not self.enable_replica: 

693 with self.assertRaises(TypeError): 

694 self.store_update_records(apdb, records, chunk) 

695 else: 

696 self.store_update_records(apdb, records, chunk) 

697 

698 apdb_replica = ApdbReplica.from_config(config) 

699 records_returned = apdb_replica.getUpdateRecordChunks([chunk.id]) 

700 

701 # Input records are ordered, output will be ordered too. 

702 self.assertEqual(records_returned, records) 

703 

704 @abstractmethod 

705 def store_update_records(self, apdb: Apdb, records: list[ApdbUpdateRecord], chunk: ReplicaChunk) -> None: 

706 """Store update records in database, must be overriden in subclass.""" 

707 raise NotImplementedError() 

708 

709 def test_midpointMjdTai_src(self) -> None: 

710 """Test for time filtering of DiaSources.""" 

711 config = self.make_instance() 

712 apdb = Apdb.from_config(config) 

713 

714 region = self.make_region() 

715 # 2021-01-01 plus 360 days is 2021-12-27 

716 src_time1 = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai") 

717 src_time2 = astropy.time.Time("2021-01-01T00:00:02", format="isot", scale="tai") 

718 visit_time0 = astropy.time.Time("2021-12-26T23:59:59", format="isot", scale="tai") 

719 visit_time1 = astropy.time.Time("2021-12-27T00:00:01", format="isot", scale="tai") 

720 visit_time2 = astropy.time.Time("2021-12-27T00:00:03", format="isot", scale="tai") 

721 one_sec = astropy.time.TimeDelta(1.0, format="sec") 

722 

723 objects = makeObjectCatalog(region, 100, visit_time0) 

724 oids = list(objects["diaObjectId"]) 

725 sources = makeSourceCatalog(objects, src_time1, 0, use_mjd=self.use_mjd) 

726 apdb.store(src_time1, objects, sources) 

727 

728 sources = makeSourceCatalog(objects, src_time2, 100, use_mjd=self.use_mjd) 

729 apdb.store(src_time2, objects, sources) 

730 

731 # reading at time of last save should read all 

732 res = apdb.getDiaSources(region, oids, src_time2) 

733 self.assert_catalog(res, 200, ApdbTables.DiaSource) 

734 

735 # one second before 12 months 

736 res = apdb.getDiaSources(region, oids, visit_time0) 

737 self.assert_catalog(res, 200, ApdbTables.DiaSource) 

738 

739 # reading at later time of last save should only read a subset 

740 res = apdb.getDiaSources(region, oids, visit_time1) 

741 self.assert_catalog(res, 100, ApdbTables.DiaSource) 

742 

743 # reading at later time of last save should only read a subset 

744 res = apdb.getDiaSources(region, oids, visit_time2) 

745 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

746 

747 # Use explicit start time argument instead of 12 month window, visit 

748 # time does not matter in this case, set it to before all data. 

749 res = apdb.getDiaSources(region, oids, src_time1 - one_sec, src_time1 - one_sec) 

750 self.assert_catalog(res, 200, ApdbTables.DiaSource) 

751 

752 res = apdb.getDiaSources(region, oids, src_time1 - one_sec, src_time2 - one_sec) 

753 self.assert_catalog(res, 100, ApdbTables.DiaSource) 

754 

755 res = apdb.getDiaSources(region, oids, src_time1 - one_sec, src_time2 + one_sec) 

756 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

757 

758 def test_midpointMjdTai_fsrc(self) -> None: 

759 """Test for time filtering of DiaForcedSources.""" 

760 config = self.make_instance() 

761 apdb = Apdb.from_config(config) 

762 

763 region = self.make_region() 

764 src_time1 = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai") 

765 src_time2 = astropy.time.Time("2021-01-01T00:00:02", format="isot", scale="tai") 

766 visit_time0 = astropy.time.Time("2021-12-26T23:59:59", format="isot", scale="tai") 

767 visit_time1 = astropy.time.Time("2021-12-27T00:00:01", format="isot", scale="tai") 

768 visit_time2 = astropy.time.Time("2021-12-27T00:00:03", format="isot", scale="tai") 

769 one_sec = astropy.time.TimeDelta(1.0, format="sec") 

770 

771 objects = makeObjectCatalog(region, 100, visit_time0) 

772 oids = list(objects["diaObjectId"]) 

773 sources = makeForcedSourceCatalog(objects, src_time1, 1, use_mjd=self.use_mjd) 

774 apdb.store(src_time1, objects, forced_sources=sources) 

775 

776 sources = makeForcedSourceCatalog(objects, src_time2, 2, use_mjd=self.use_mjd) 

777 apdb.store(src_time2, objects, forced_sources=sources) 

778 

779 # reading at time of last save should read all 

780 res = apdb.getDiaForcedSources(region, oids, src_time2) 

781 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource) 

782 

783 # one second before 12 months 

784 res = apdb.getDiaForcedSources(region, oids, visit_time0) 

785 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource) 

786 

787 # reading at later time of last save should only read a subset 

788 res = apdb.getDiaForcedSources(region, oids, visit_time1) 

789 self.assert_catalog(res, 100, ApdbTables.DiaForcedSource) 

790 

791 # reading at later time of last save should only read a subset 

792 res = apdb.getDiaForcedSources(region, oids, visit_time2) 

793 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

794 

795 # Use explicit start time argument instead of 12 month window, visit 

796 # time does not matter in this case, set it to before all data. 

797 res = apdb.getDiaForcedSources(region, oids, src_time1 - one_sec, src_time1 - one_sec) 

798 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource) 

799 

800 res = apdb.getDiaForcedSources(region, oids, src_time1 - one_sec, src_time2 - one_sec) 

801 self.assert_catalog(res, 100, ApdbTables.DiaForcedSource) 

802 

803 res = apdb.getDiaForcedSources(region, oids, src_time1 - one_sec, src_time2 + one_sec) 

804 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

805 

806 def test_metadata(self) -> None: 

807 """Simple test for writing/reading metadata table""" 

808 config = self.make_instance() 

809 apdb = Apdb.from_config(config) 

810 metadata = apdb.metadata 

811 

812 # APDB should write two or three metadata items with version numbers 

813 # and a frozen JSON config. 

814 self.assertFalse(metadata.empty()) 

815 self.assertEqual(len(list(metadata.items())), self.meta_row_count) 

816 

817 metadata.set("meta", "data") 

818 metadata.set("data", "meta") 

819 

820 self.assertFalse(metadata.empty()) 

821 self.assertTrue(set(metadata.items()) >= {("meta", "data"), ("data", "meta")}) 

822 

823 with self.assertRaisesRegex(KeyError, "Metadata key 'meta' already exists"): 

824 metadata.set("meta", "data1") 

825 

826 metadata.set("meta", "data2", force=True) 

827 self.assertTrue(set(metadata.items()) >= {("meta", "data2"), ("data", "meta")}) 

828 

829 self.assertTrue(metadata.delete("meta")) 

830 self.assertIsNone(metadata.get("meta")) 

831 self.assertFalse(metadata.delete("meta")) 

832 

833 self.assertEqual(metadata.get("data"), "meta") 

834 self.assertEqual(metadata.get("meta", "meta"), "meta") 

835 

836 def test_schemaVersionFromYaml(self) -> None: 

837 """Check version number handling for reading schema from YAML.""" 

838 config = self.make_instance() 

839 default_schema = config.schema_file 

840 apdb = Apdb.from_config(config) 

841 self.assertEqual(apdb._schema.schemaVersion(), VersionTuple(0, 1, 1)) # type: ignore[attr-defined] 

842 

843 with update_schema_yaml(default_schema, version="") as schema_file: 

844 config = self.make_instance(schema_file=schema_file) 

845 apdb = Apdb.from_config(config) 

846 self.assertEqual( 

847 apdb._schema.schemaVersion(), # type: ignore[attr-defined] 

848 VersionTuple(0, 1, 0), 

849 ) 

850 

851 with update_schema_yaml(default_schema, version="99.0.0") as schema_file: 

852 config = self.make_instance(schema_file=schema_file) 

853 apdb = Apdb.from_config(config) 

854 self.assertEqual( 

855 apdb._schema.schemaVersion(), # type: ignore[attr-defined] 

856 VersionTuple(99, 0, 0), 

857 ) 

858 

859 def test_config_freeze(self) -> None: 

860 """Test that some config fields are correctly frozen in database.""" 

861 config = self.make_instance() 

862 

863 # `enable_replica` is the only parameter that is frozen in all 

864 # implementations. 

865 config.enable_replica = not self.enable_replica 

866 apdb = Apdb.from_config(config) 

867 frozen_config = apdb.getConfig() 

868 self.assertEqual(frozen_config.enable_replica, self.enable_replica) 

869 

870 

871class ApdbSchemaUpdateTest(TestCaseMixin, ABC): 

872 """Base class for unit tests that verify how schema changes work.""" 

873 

874 visit_time = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai") 

875 

876 @abstractmethod 

877 def make_instance(self, **kwargs: Any) -> ApdbConfig: 

878 """Make config class instance used in all tests. 

879 

880 This method should return configuration that point to the identical 

881 database instance on each call (i.e. ``db_url`` must be the same, 

882 which also means for sqlite it has to use on-disk storage). 

883 """ 

884 raise NotImplementedError() 

885 

886 def make_region(self, xyz: tuple[float, float, float] = (1.0, 1.0, -1.0)) -> Region: 

887 """Make a region to use in tests""" 

888 return _make_region(xyz) 

889 

890 def test_schema_add_replica(self) -> None: 

891 """Check that new code can work with old schema without replica 

892 tables. 

893 """ 

894 # Make schema without replica tables. 

895 config = self.make_instance(enable_replica=False) 

896 apdb = Apdb.from_config(config) 

897 apdb_replica = ApdbReplica.from_config(config) 

898 

899 # Make APDB instance configured for replication. 

900 config.enable_replica = True 

901 apdb = Apdb.from_config(config) 

902 

903 # Try to insert something, should work OK. 

904 region = self.make_region() 

905 visit_time = self.visit_time 

906 

907 # have to store Objects first 

908 objects = makeObjectCatalog(region, 100, visit_time) 

909 sources = makeSourceCatalog(objects, visit_time) 

910 fsources = makeForcedSourceCatalog(objects, visit_time) 

911 apdb.store(visit_time, objects, sources, fsources) 

912 

913 # There should be no replica chunks. 

914 replica_chunks = apdb_replica.getReplicaChunks() 

915 self.assertIsNone(replica_chunks) 

916 

917 def test_schemaVersionCheck(self) -> None: 

918 """Check version number compatibility.""" 

919 config = self.make_instance() 

920 apdb = Apdb.from_config(config) 

921 

922 self.assertEqual(apdb._schema.schemaVersion(), VersionTuple(0, 1, 1)) # type: ignore[attr-defined] 

923 

924 # Claim that schema version is now 99.0.0, must raise an exception. 

925 with update_schema_yaml(config.schema_file, version="99.0.0") as schema_file: 

926 config.schema_file = schema_file 

927 with self.assertRaises(IncompatibleVersionError): 

928 apdb = Apdb.from_config(config) 

929 # Version is checked only when we try to do connect. 

930 apdb.metadata.items()