Coverage for python / lsst / dax / apdb / tests / _apdb.py: 12%

618 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-17 08:58 +0000

1# This file is part of dax_apdb. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["ApdbSchemaUpdateTest", "ApdbTest", "update_schema_yaml"] 

25 

26import contextlib 

27import logging.config 

28import os 

29import tempfile 

30from abc import ABC, abstractmethod 

31from collections.abc import Iterator 

32from tempfile import TemporaryDirectory 

33from typing import TYPE_CHECKING, Any 

34 

35import astropy.time 

36import felis.datamodel 

37import pandas 

38import yaml 

39 

40from lsst.sphgeom import Angle, Circle, LonLat, Region, UnitVector3d 

41 

42from .. import ( 

43 Apdb, 

44 ApdbConfig, 

45 ApdbReassignDiaSourceToSSObjectRecord, 

46 ApdbReplica, 

47 ApdbTableData, 

48 ApdbTables, 

49 ApdbUpdateRecord, 

50 ApdbWithdrawDiaSourceRecord, 

51 DiaObjectId, 

52 DiaSourceId, 

53 IncompatibleVersionError, 

54 ReplicaChunk, 

55 VersionTuple, 

56) 

57from .data_factory import ( 

58 makeForcedSourceCatalog, 

59 makeObjectCatalog, 

60 makeSourceCatalog, 

61 makeTimestamp, 

62 makeTimestampColumn, 

63) 

64from .utils import TestCaseMixin 

65 

66if TYPE_CHECKING: 

67 from ..pixelization import Pixelization 

68 

69 

70# Optionally configure logging from a config file. 

71if log_config := os.environ.get("DAX_APDB_TEST_LOG_CONFIG"): 71 ↛ 72line 71 didn't jump to line 72 because the condition on line 71 was never true

72 logging.config.fileConfig(log_config) 

73 

74 

75def _make_region(xyz: tuple[float, float, float] = (1.0, 1.0, -1.0)) -> Region: 

76 """Make a region to use in tests""" 

77 pointing_v = UnitVector3d(*xyz) 

78 fov = 0.0013 # radians 

79 region = Circle(pointing_v, Angle(fov / 2)) 

80 return region 

81 

82 

83@contextlib.contextmanager 

84def update_schema_yaml( 

85 schema_file: str, 

86 drop_metadata: bool = False, 

87 version: str | None = None, 

88) -> Iterator[str]: 

89 """Update schema definition and return name of the new schema file. 

90 

91 Parameters 

92 ---------- 

93 schema_file : `str` 

94 Path for the existing YAML file with APDB schema. 

95 drop_metadata : `bool` 

96 If `True` then remove metadata table from the list of tables. 

97 version : `str` or `None` 

98 If non-empty string then set schema version to this string, if empty 

99 string then remove schema version from config, if `None` - don't change 

100 the version in config. 

101 

102 Yields 

103 ------ 

104 Path for the updated configuration file. 

105 """ 

106 with open(schema_file) as yaml_stream: 

107 schemas_list = list(yaml.load_all(yaml_stream, Loader=yaml.SafeLoader)) 

108 # Edit YAML contents. 

109 for schema in schemas_list: 

110 # Optionally drop metadata table. 

111 if drop_metadata: 

112 schema["tables"] = [table for table in schema["tables"] if table["name"] != "metadata"] 

113 if version is not None: 

114 if version == "": 

115 del schema["version"] 

116 else: 

117 schema["version"] = version 

118 

119 with TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir: 

120 output_path = os.path.join(tmpdir, "schema.yaml") 

121 with open(output_path, "w") as yaml_stream: 

122 yaml.dump_all(schemas_list, stream=yaml_stream) 

123 yield output_path 

124 

125 

126class ApdbTest(TestCaseMixin, ABC): 

127 """Base class for Apdb tests that can be specialized for concrete 

128 implementation. 

129 

130 This can only be used as a mixin class for a unittest.TestCase and it 

131 calls various assert methods. 

132 """ 

133 

134 visit_time = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai") 

135 

136 processing_time = astropy.time.Time("2021-01-01T12:00:00", format="isot", scale="tai") 

137 

138 fsrc_requires_id_list = False 

139 """Should be set to True if getDiaForcedSources requires object IDs""" 

140 

141 enable_replica: bool = False 

142 """Set to true when support for replication is configured""" 

143 

144 use_mjd: bool = True 

145 """If True then timestamp columns are MJD TAI.""" 

146 

147 extra_chunk_columns = 1 

148 """Number of additional columns in chunk tables.""" 

149 

150 meta_row_count = 3 

151 """Initial row count in metadata table.""" 

152 

153 # number of columns as defined in tests/config/schema.yaml 

154 table_column_count = { 

155 ApdbTables.DiaObject: 8, 

156 ApdbTables.DiaObjectLast: 6, 

157 ApdbTables.DiaSource: 12, 

158 ApdbTables.DiaForcedSource: 8, 

159 ApdbTables.SSObject: 3, 

160 } 

161 

162 @abstractmethod 

163 def make_instance(self, **kwargs: Any) -> ApdbConfig: 

164 """Make database instance and return configuration for it.""" 

165 raise NotImplementedError() 

166 

167 @abstractmethod 

168 def getDiaObjects_table(self) -> ApdbTables: 

169 """Return type of table returned from getDiaObjects method.""" 

170 raise NotImplementedError() 

171 

172 @abstractmethod 

173 def pixelization(self, config: ApdbConfig) -> Pixelization: 

174 """Return pixelization used by implementation.""" 

175 raise NotImplementedError() 

176 

177 def assert_catalog(self, catalog: Any, rows: int, table: ApdbTables) -> None: 

178 """Validate catalog type and size 

179 

180 Parameters 

181 ---------- 

182 catalog : `object` 

183 Expected type of this is ``pandas.DataFrame``. 

184 rows : `int` 

185 Expected number of rows in a catalog. 

186 table : `ApdbTables` 

187 APDB table type. 

188 """ 

189 self.assertIsInstance(catalog, pandas.DataFrame) 

190 self.assertEqual(catalog.shape[0], rows) 

191 self.assertEqual(catalog.shape[1], self.table_column_count[table]) 

192 

193 def assert_table_data(self, catalog: Any, rows: int, table: ApdbTables) -> None: 

194 """Validate catalog type and size 

195 

196 Parameters 

197 ---------- 

198 catalog : `object` 

199 Expected type of this is `ApdbTableData`. 

200 rows : `int` 

201 Expected number of rows in a catalog. 

202 table : `ApdbTables` 

203 APDB table type. 

204 extra_columns : `int` 

205 Count of additional columns expected in ``catalog``. 

206 """ 

207 self.assertIsInstance(catalog, ApdbTableData) 

208 n_rows = sum(1 for row in catalog.rows()) 

209 self.assertEqual(n_rows, rows) 

210 # One extra column for replica chunk id 

211 self.assertEqual( 

212 len(catalog.column_names()), self.table_column_count[table] + self.extra_chunk_columns 

213 ) 

214 

215 def assert_column_types(self, catalog: Any, types: dict[str, felis.datamodel.DataType]) -> None: 

216 column_defs = dict(catalog.column_defs()) 

217 for column, datatype in types.items(): 

218 self.assertEqual(column_defs[column], datatype) 

219 

220 def make_region(self, xyz: tuple[float, float, float] = (1.0, 1.0, -1.0)) -> Region: 

221 """Make a region to use in tests""" 

222 return _make_region(xyz) 

223 

224 def test_makeSchema(self) -> None: 

225 """Test for making APDB schema.""" 

226 config = self.make_instance() 

227 apdb = Apdb.from_config(config) 

228 

229 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject)) 

230 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObjectLast)) 

231 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaSource)) 

232 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaForcedSource)) 

233 self.assertIsNotNone(apdb.tableDef(ApdbTables.metadata)) 

234 self.assertIsNotNone(apdb.tableDef(ApdbTables.SSObject)) 

235 self.assertIsNotNone(apdb.tableDef(ApdbTables.SSSource)) 

236 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject_To_Object_Match)) 

237 

238 # Test from_uri factory method with the same config. 

239 with tempfile.NamedTemporaryFile() as tmpfile: 

240 config.save(tmpfile.name) 

241 apdb = Apdb.from_uri(tmpfile.name) 

242 

243 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject)) 

244 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObjectLast)) 

245 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaSource)) 

246 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaForcedSource)) 

247 self.assertIsNotNone(apdb.tableDef(ApdbTables.metadata)) 

248 self.assertIsNotNone(apdb.tableDef(ApdbTables.SSObject)) 

249 self.assertIsNotNone(apdb.tableDef(ApdbTables.SSSource)) 

250 self.assertIsNotNone(apdb.tableDef(ApdbTables.DiaObject_To_Object_Match)) 

251 

252 def test_empty_gets(self) -> None: 

253 """Test for getting data from empty database. 

254 

255 All get() methods should return empty results, only useful for 

256 checking that code is not broken. 

257 """ 

258 # use non-zero months for Forced/Source fetching 

259 config = self.make_instance() 

260 apdb = Apdb.from_config(config) 

261 

262 region = self.make_region() 

263 visit_time = self.visit_time 

264 

265 res: pandas.DataFrame | None 

266 

267 # get objects by region 

268 res = apdb.getDiaObjects(region) 

269 self.assert_catalog(res, 0, self.getDiaObjects_table()) 

270 

271 # get sources by region 

272 res = apdb.getDiaSources(region, None, visit_time) 

273 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

274 

275 res = apdb.getDiaSources(region, [], visit_time) 

276 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

277 

278 # get sources by object ID, non-empty object list 

279 res = apdb.getDiaSources(region, [1, 2, 3], visit_time) 

280 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

281 

282 # get forced sources by object ID, empty object list 

283 res = apdb.getDiaForcedSources(region, [], visit_time) 

284 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

285 

286 # get sources by object ID, non-empty object list 

287 res = apdb.getDiaForcedSources(region, [1, 2, 3], visit_time) 

288 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

289 

290 # data_factory's ccdVisitId generation corresponds to (1, 1) 

291 res = apdb.containsVisitDetector(visit=1, detector=1, region=region, visit_time=visit_time) 

292 self.assertFalse(res) 

293 

294 # get sources by region 

295 if self.fsrc_requires_id_list: 

296 with self.assertRaises(NotImplementedError): 

297 apdb.getDiaForcedSources(region, None, visit_time) 

298 else: 

299 res = apdb.getDiaForcedSources(region, None, visit_time) 

300 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

301 

302 def test_empty_gets_0months(self) -> None: 

303 """Test for getting data from empty database. 

304 

305 All get() methods should return empty DataFrame or None. 

306 """ 

307 # set read_sources_months to 0 so that Forced/Sources are None 

308 config = self.make_instance(read_sources_months=0, read_forced_sources_months=0) 

309 apdb = Apdb.from_config(config) 

310 

311 region = self.make_region() 

312 visit_time = self.visit_time 

313 

314 res: pandas.DataFrame | None 

315 

316 # get objects by region 

317 res = apdb.getDiaObjects(region) 

318 self.assert_catalog(res, 0, self.getDiaObjects_table()) 

319 

320 # get sources by region 

321 res = apdb.getDiaSources(region, None, visit_time) 

322 self.assertIs(res, None) 

323 

324 # get sources by object ID, empty object list 

325 res = apdb.getDiaSources(region, [], visit_time) 

326 self.assertIs(res, None) 

327 

328 # get forced sources by object ID, empty object list 

329 res = apdb.getDiaForcedSources(region, [], visit_time) 

330 self.assertIs(res, None) 

331 

332 # Database is empty, no images exist. 

333 res = apdb.containsVisitDetector(visit=1, detector=1, region=region, visit_time=visit_time) 

334 self.assertFalse(res) 

335 

336 def test_storeObjects(self) -> None: 

337 """Store and retrieve DiaObjects.""" 

338 # don't care about sources. 

339 config = self.make_instance() 

340 apdb = Apdb.from_config(config) 

341 

342 region = self.make_region() 

343 visit_time = self.visit_time 

344 

345 # make catalog with Objects 

346 catalog = makeObjectCatalog(region, 100) 

347 

348 # store catalog 

349 apdb.store(visit_time, catalog) 

350 

351 # read it back and check sizes 

352 res = apdb.getDiaObjects(region) 

353 self.assert_catalog(res, len(catalog), self.getDiaObjects_table()) 

354 

355 # TODO: test apdb.contains with generic implementation from DM-41671 

356 

357 def test_storeObjects_empty(self) -> None: 

358 """Test calling storeObject when there are no objects: see DM-43270.""" 

359 config = self.make_instance() 

360 apdb = Apdb.from_config(config) 

361 region = self.make_region() 

362 visit_time = self.visit_time 

363 # make catalog with no Objects 

364 catalog = makeObjectCatalog(region, 0) 

365 

366 with self.assertLogs("lsst.dax.apdb", level="DEBUG") as cm: 

367 apdb.store(visit_time, catalog) 

368 self.assertIn("No objects", "\n".join(cm.output)) 

369 

370 def test_storeMovingObject(self) -> None: 

371 """Store and retrieve DiaObject which changes its position.""" 

372 # don't care about sources. 

373 config = self.make_instance() 

374 apdb = Apdb.from_config(config) 

375 pixelization = self.pixelization(config) 

376 

377 lon_deg, lat_deg = 0.0, 0.0 

378 lonlat1 = LonLat.fromDegrees(lon_deg - 1.0, lat_deg) 

379 lonlat2 = LonLat.fromDegrees(lon_deg + 1.0, lat_deg) 

380 uv1 = UnitVector3d(lonlat1) 

381 uv2 = UnitVector3d(lonlat2) 

382 

383 # Check that they fall into different pixels. 

384 self.assertNotEqual(pixelization.pixel(uv1), pixelization.pixel(uv2)) 

385 

386 # Store one object at two different positions. 

387 visit_time1 = self.visit_time 

388 catalog1 = makeObjectCatalog(lonlat1, 1) 

389 apdb.store(visit_time1, catalog1) 

390 

391 visit_time2 = visit_time1 + astropy.time.TimeDelta(120.0, format="sec") 

392 catalog1 = makeObjectCatalog(lonlat2, 1) 

393 apdb.store(visit_time2, catalog1) 

394 

395 # Make region covering both points. 

396 region = Circle(UnitVector3d(LonLat.fromDegrees(lon_deg, lat_deg)), Angle.fromDegrees(1.1)) 

397 self.assertTrue(region.contains(uv1)) 

398 self.assertTrue(region.contains(uv2)) 

399 

400 # Read it back, must return the latest one. 

401 res = apdb.getDiaObjects(region) 

402 self.assert_catalog(res, 1, self.getDiaObjects_table()) 

403 

404 def test_storeSources(self) -> None: 

405 """Store and retrieve DiaSources.""" 

406 config = self.make_instance() 

407 apdb = Apdb.from_config(config) 

408 

409 region = self.make_region() 

410 visit_time = self.visit_time 

411 

412 # have to store Objects first 

413 objects = makeObjectCatalog(region, 100) 

414 oids = list(objects["diaObjectId"]) 

415 sources = makeSourceCatalog(objects, visit_time, use_mjd=self.use_mjd) 

416 

417 # save the objects and sources 

418 apdb.store(visit_time, objects, sources) 

419 

420 # read it back, no ID filtering 

421 res = apdb.getDiaSources(region, None, visit_time) 

422 self.assert_catalog(res, len(sources), ApdbTables.DiaSource) 

423 

424 # read it back and filter by ID 

425 res = apdb.getDiaSources(region, oids, visit_time) 

426 self.assert_catalog(res, len(sources), ApdbTables.DiaSource) 

427 

428 # read it back to get schema 

429 res = apdb.getDiaSources(region, [], visit_time) 

430 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

431 

432 # test if a visit is present 

433 # data_factory's ccdVisitId generation corresponds to (1, 1) 

434 res = apdb.containsVisitDetector(visit=1, detector=1, region=region, visit_time=visit_time) 

435 self.assertTrue(res) 

436 # non-existent image 

437 res = apdb.containsVisitDetector(visit=2, detector=42, region=region, visit_time=visit_time) 

438 self.assertFalse(res) 

439 

440 def test_storeForcedSources(self) -> None: 

441 """Store and retrieve DiaForcedSources.""" 

442 config = self.make_instance() 

443 apdb = Apdb.from_config(config) 

444 

445 region = self.make_region() 

446 visit_time = self.visit_time 

447 

448 # have to store Objects first 

449 objects = makeObjectCatalog(region, 100) 

450 oids = list(objects["diaObjectId"]) 

451 catalog = makeForcedSourceCatalog(objects, visit_time, use_mjd=self.use_mjd) 

452 

453 apdb.store(visit_time, objects, forced_sources=catalog) 

454 

455 # read it back and check sizes 

456 res = apdb.getDiaForcedSources(region, oids, visit_time) 

457 self.assert_catalog(res, len(catalog), ApdbTables.DiaForcedSource) 

458 

459 # read it back to get schema 

460 res = apdb.getDiaForcedSources(region, [], visit_time) 

461 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

462 

463 # data_factory's ccdVisitId generation corresponds to (1, 1) 

464 res = apdb.containsVisitDetector(visit=1, detector=1, region=region, visit_time=visit_time) 

465 self.assertTrue(res) 

466 # non-existent image 

467 res = apdb.containsVisitDetector(visit=2, detector=42, region=region, visit_time=visit_time) 

468 self.assertFalse(res) 

469 

470 def test_null_integer_type(self) -> None: 

471 """Test that integer column with NULLs correct type on select.""" 

472 config = self.make_instance() 

473 apdb = Apdb.from_config(config) 

474 

475 region = self.make_region() 

476 visit_time = self.visit_time 

477 

478 # have to store Objects first 

479 objects = makeObjectCatalog(region, 100) 

480 sources = makeSourceCatalog(objects, visit_time, use_mjd=self.use_mjd) 

481 # Reset some diaObjectIds to NULL. 

482 sources.loc[0:10, "diaObjectId"] = None 

483 

484 # save the objects and sources 

485 apdb.store(visit_time, objects, sources) 

486 

487 # read it back, no ID filtering 

488 res = apdb.getDiaSources(region, None, visit_time) 

489 self.assert_catalog(res, len(sources), ApdbTables.DiaSource) 

490 assert res is not None, "Expecting catalog, not None" 

491 self.assertEqual(res.dtypes["diaObjectId"], pandas.Int64Dtype()) 

492 

493 def test_timestamps(self) -> None: 

494 """Check that timestamp return type is as expected.""" 

495 config = self.make_instance() 

496 apdb = Apdb.from_config(config) 

497 

498 region = self.make_region() 

499 visit_time = self.visit_time 

500 

501 # Cassandra has a millisecond precision, so subtract 1ms to allow for 

502 # truncated returned values. 

503 time_before = makeTimestamp(self.processing_time, self.use_mjd, -1) 

504 objects = makeObjectCatalog(region, 100) 

505 oids = list(objects["diaObjectId"]) 

506 catalog = makeForcedSourceCatalog( 

507 objects, visit_time, processing_time=self.processing_time, use_mjd=self.use_mjd 

508 ) 

509 time_after = makeTimestamp(self.processing_time, self.use_mjd) 

510 

511 apdb.store(visit_time, objects, forced_sources=catalog) 

512 

513 # read it back and check sizes 

514 res = apdb.getDiaForcedSources(region, oids, visit_time) 

515 assert res is not None 

516 self.assert_catalog(res, len(catalog), ApdbTables.DiaForcedSource) 

517 

518 time_processed_column = makeTimestampColumn("time_processed", self.use_mjd) 

519 self.assertIn(time_processed_column, res.dtypes) 

520 dtype = res.dtypes[time_processed_column] 

521 timestamp_type_names = ( 

522 ("float64",) if self.use_mjd else ("datetime64[ms]", "datetime64[us]", "datetime64[ns]") 

523 ) 

524 self.assertIn(dtype.name, timestamp_type_names) 

525 # Verify that returned time is sensible. 

526 self.assertTrue(all(time_before <= dt <= time_after for dt in res[time_processed_column])) 

527 

528 def test_getDiaObjectsForDedup(self) -> None: 

529 """Test getDiaObjectsForDedup() method.""" 

530 config = self.make_instance() 

531 apdb = Apdb.from_config(config) 

532 

533 region1 = self.make_region((1.0, 1.0, -1.0)) 

534 region2 = self.make_region((-1.0, 1.0, -1.0)) 

535 region3 = self.make_region((-1.0, -1.0, -1.0)) 

536 nobj = 100 

537 objects1 = makeObjectCatalog(region1, nobj) 

538 objects2 = makeObjectCatalog(region2, nobj, start_id=nobj * 2) 

539 objects3 = makeObjectCatalog(region3, nobj, start_id=nobj * 4) 

540 

541 visits = [ 

542 (astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai"), objects1), 

543 (astropy.time.Time("2021-01-01T00:10:00", format="isot", scale="tai"), objects2), 

544 (astropy.time.Time("2021-01-01T00:20:00", format="isot", scale="tai"), objects3), 

545 ] 

546 

547 for visit_time, objects in visits: 

548 apdb.store(visit_time, objects) 

549 

550 catalog = apdb.getDiaObjectsForDedup() 

551 self.assertEqual(len(catalog), 300) 

552 

553 catalog = apdb.getDiaObjectsForDedup(visits[0][0]) 

554 self.assertEqual(len(catalog), 300) 

555 

556 catalog = apdb.getDiaObjectsForDedup(visits[1][0]) 

557 self.assertEqual(len(catalog), 200) 

558 

559 catalog = apdb.getDiaObjectsForDedup(visits[2][0]) 

560 self.assertEqual(len(catalog), 100) 

561 

562 time = astropy.time.Time("2021-01-01T00:30:00", format="isot", scale="tai") 

563 catalog = apdb.getDiaObjectsForDedup(time) 

564 self.assertEqual(len(catalog), 0) 

565 

566 def test_getDiaSourcesForDiaObjects(self) -> None: 

567 """Test getDiaSourcesForDiaObjects() method.""" 

568 config = self.make_instance() 

569 apdb = Apdb.from_config(config) 

570 # Monkey-patch APDB instance to set current time. 

571 apdb._current_time = lambda: self.processing_time # type: ignore[method-assign] 

572 

573 region1 = self.make_region((1.0, 1.0, -1.0)) 

574 region2 = self.make_region((-1.0, 1.0, -1.0)) 

575 region3 = self.make_region((-1.0, -1.0, -1.0)) 

576 nobj = 100 

577 objects1 = makeObjectCatalog(region1, nobj) 

578 objects2 = makeObjectCatalog(region2, nobj, start_id=nobj * 2) 

579 objects3 = makeObjectCatalog(region3, nobj, start_id=nobj * 4) 

580 

581 visits = [ 

582 (astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai"), objects1), 

583 (astropy.time.Time("2021-01-01T00:10:00", format="isot", scale="tai"), objects2), 

584 (astropy.time.Time("2021-01-01T00:20:00", format="isot", scale="tai"), objects3), 

585 ] 

586 

587 start_id = 1_000_000 

588 for visit_time, objects in visits: 

589 sources = makeSourceCatalog(objects, visit_time, start_id=start_id, use_mjd=self.use_mjd) 

590 apdb.store(visit_time, objects, sources) 

591 start_id += 1_000_000 

592 

593 # Take a small number of objects from different regions. 

594 object_ids = [ 

595 DiaObjectId.from_named_tuple(next(objects1.itertuples())), 

596 DiaObjectId.from_named_tuple(next(objects2.itertuples())), 

597 DiaObjectId.from_named_tuple(next(objects3.itertuples())), 

598 ] 

599 

600 catalog = apdb.getDiaSourcesForDiaObjects(object_ids, visits[0][0]) 

601 self.assertEqual(len(catalog), 3) 

602 self.assertEqual(set(catalog["diaObjectId"]), {1, 200, 400}) 

603 self.assertEqual(set(catalog["diaSourceId"]), {1_000_000, 2_000_000, 3_000_000}) 

604 

605 catalog = apdb.getDiaSourcesForDiaObjects(object_ids, visits[2][0]) 

606 self.assertEqual(len(catalog), 1) 

607 self.assertEqual(set(catalog["diaObjectId"]), {400}) 

608 self.assertEqual(set(catalog["diaSourceId"]), {3_000_000}) 

609 

610 def test_reassignDiaSourcesToDiaObjects(self) -> None: 

611 """Test reassignDiaSourcesToDiaObjects() method.""" 

612 config = self.make_instance() 

613 apdb = Apdb.from_config(config) 

614 apdb._current_time = lambda: self.processing_time # type: ignore[method-assign] 

615 apdb_replica = ApdbReplica.from_config(config) 

616 

617 visit_time = self.visit_time 

618 lonlat1 = LonLat.fromDegrees(0.0, 0.0) 

619 lonlat2 = LonLat.fromDegrees(180.0, 0.0) 

620 # regons around lonlat1/2 

621 region1 = self.make_region(xyz=(1.0, 0.0, 0.0)) 

622 region2 = self.make_region(xyz=(-1.0, 0.0, 0.0)) 

623 

624 # Store 3 objects and sources at the same position in each region. 

625 objects = makeObjectCatalog(lonlat1, 3, start_id=100) 

626 sources = makeSourceCatalog(objects, visit_time, start_id=1000, use_mjd=self.use_mjd) 

627 apdb.store(visit_time, objects, sources) 

628 

629 objects = makeObjectCatalog(lonlat2, 3, start_id=200) 

630 sources = makeSourceCatalog(objects, visit_time, start_id=2000, use_mjd=self.use_mjd) 

631 apdb.store(visit_time, objects, sources) 

632 

633 # check that everything as we think it is. 

634 objects = apdb.getDiaObjects(region1) 

635 self.assertEqual(set(objects["diaObjectId"]), {100, 101, 102}) 

636 self.assertEqual(list(objects["nDiaSources"]), [1, 1, 1]) 

637 sources = apdb.getDiaSources(region1, [100, 101, 102], visit_time) 

638 assert sources is not None 

639 self.assertEqual(set(sources["diaSourceId"]), {1000, 1001, 1002}) 

640 self.assertEqual(set(sources["diaObjectId"]), {100, 101, 102}) 

641 

642 dia_source_ids = [DiaSourceId.from_named_tuple(row) for row in sources.itertuples()] 

643 

644 # Reassign sources in region1 and increment/decrement nDiaSources. 

645 reassign = { 

646 dia_source_id: 100 

647 for dia_source_id in dia_source_ids 

648 if dia_source_id.diaSourceId in (1001, 1002) 

649 } 

650 apdb.reassignDiaSourcesToDiaObjects(reassign) 

651 

652 objects = apdb.getDiaObjects(region1) 

653 self.assertEqual(set(objects["nDiaSources"]), {0, 3}) 

654 sources = apdb.getDiaSources(region1, [100], visit_time) 

655 assert sources is not None 

656 self.assertEqual(set(sources["diaSourceId"]), {1000, 1001, 1002}) 

657 self.assertEqual(set(sources["diaObjectId"]), {100}) 

658 

659 sources = apdb.getDiaSources(region2, [201, 202], visit_time) 

660 assert sources is not None 

661 self.assertEqual(set(sources["diaSourceId"]), {2001, 2002}) 

662 dia_source_ids = [DiaSourceId.from_named_tuple(row) for row in sources.itertuples()] 

663 

664 # Reassign but do not increment/decrement nDiaSources. 

665 reassign = { 

666 dia_source_id: 200 

667 for dia_source_id in dia_source_ids 

668 if dia_source_id.diaSourceId in (2001, 2002) 

669 } 

670 apdb.reassignDiaSourcesToDiaObjects( 

671 reassign, increment_nDiaSources=False, decrement_nDiaSources=False 

672 ) 

673 

674 objects = apdb.getDiaObjects(region2) 

675 self.assertEqual(set(objects["nDiaSources"]), {1}) 

676 sources = apdb.getDiaSources(region2, [200], visit_time) 

677 assert sources is not None 

678 self.assertEqual(set(sources["diaSourceId"]), {2000, 2001, 2002}) 

679 self.assertEqual(set(sources["diaObjectId"]), {200}) 

680 

681 replica_chunks = apdb_replica.getReplicaChunks() 

682 if not self.enable_replica: 

683 self.assertIsNone(replica_chunks) 

684 else: 

685 assert replica_chunks is not None 

686 

687 # There could be one or two chunks. 

688 self.assertTrue(1 <= len(replica_chunks) <= 2) 

689 

690 update_records = apdb_replica.getUpdateRecordChunks([chunk.id for chunk in replica_chunks]) 

691 # Two reassignments for region1, three increments/decrements for 

692 # that region, plus two reassignments for region2 without 

693 # increments/decrements. 

694 self.assertEqual(len(update_records), 2 + 3 + 2) 

695 

696 def test_setValidityEnd(self) -> None: 

697 """Store DiaObjects and truncate validity for some.""" 

698 # don't care about sources. 

699 config = self.make_instance() 

700 apdb = Apdb.from_config(config) 

701 apdb._current_time = lambda: self.processing_time # type: ignore[method-assign] 

702 apdb_replica = ApdbReplica.from_config(config) 

703 

704 region = self.make_region() 

705 visit_time = self.visit_time 

706 

707 # make catalog with Objects 

708 catalog = makeObjectCatalog(region, 100) 

709 

710 # store catalog 

711 apdb.store(visit_time, catalog) 

712 

713 # read it back and check sizes 

714 res = apdb.getDiaObjects(region) 

715 self.assert_catalog(res, 100, self.getDiaObjects_table()) 

716 

717 # Select first 10 objects. 

718 object_ids = [DiaObjectId.from_named_tuple(row) for row in catalog.iloc[:10].itertuples()] 

719 count = apdb.setValidityEnd(object_ids, self.processing_time) 

720 self.assertEqual(count, 10) 

721 

722 res = apdb.getDiaObjects(region) 

723 self.assert_catalog(res, 90, self.getDiaObjects_table()) 

724 

725 replica_chunks = apdb_replica.getReplicaChunks() 

726 if not self.enable_replica: 

727 self.assertIsNone(replica_chunks) 

728 else: 

729 # Check that there are 10 update records in replica tables. 

730 assert replica_chunks is not None 

731 

732 # There could be one or two chunks. 

733 self.assertTrue(1 <= len(replica_chunks) <= 2) 

734 

735 update_records = apdb_replica.getUpdateRecordChunks([chunk.id for chunk in replica_chunks]) 

736 self.assertEqual(len(update_records), 10) 

737 

738 # Check that empty list works. 

739 count = apdb.setValidityEnd(object_ids, self.processing_time) 

740 self.assertEqual(count, 0) 

741 

742 # Try with non-existing object. 

743 object_ids = [DiaObjectId.from_named_tuple(row) for row in catalog.iloc[10:12].itertuples()] 

744 object_ids += [DiaObjectId(diaObjectId=1_000_000, ra=0.0, dec=0.0)] 

745 with self.assertRaises(LookupError): 

746 apdb.setValidityEnd(object_ids, self.processing_time, raise_on_missing_id=True) 

747 

748 count = apdb.setValidityEnd(object_ids, self.processing_time) 

749 self.assertEqual(count, 2) 

750 

751 def test_resetDedup(self) -> None: 

752 """Test resetDedup method.""" 

753 # don't care about sources. 

754 config = self.make_instance() 

755 apdb = Apdb.from_config(config) 

756 

757 region = self.make_region() 

758 

759 # make catalog with Objects 

760 objects = makeObjectCatalog(region, 100) 

761 

762 visit_time1 = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai") 

763 dedup_time1 = astropy.time.Time("2021-01-01T12:00:00", format="isot", scale="tai") 

764 visit_time2 = astropy.time.Time("2021-01-02T00:00:00", format="isot", scale="tai") 

765 dedup_time2 = astropy.time.Time("2021-01-02T12:00:00", format="isot", scale="tai") 

766 

767 # store catalog 

768 apdb.store(visit_time1, objects) 

769 

770 catalog = apdb.getDiaObjectsForDedup() 

771 self.assertEqual(len(catalog), 100) 

772 

773 catalog = apdb.getDiaObjectsForDedup(visit_time1) 

774 self.assertEqual(len(catalog), 100) 

775 

776 apdb.resetDedup(dedup_time1) 

777 

778 catalog = apdb.getDiaObjectsForDedup(visit_time1) 

779 self.assertEqual(len(catalog), self._count_after_reset_dedup(100)) 

780 

781 apdb.store(visit_time2, objects) 

782 

783 catalog = apdb.getDiaObjectsForDedup() 

784 self.assertEqual(len(catalog), 100) 

785 

786 catalog = apdb.getDiaObjectsForDedup(dedup_time1) 

787 self.assertEqual(len(catalog), 100) 

788 

789 apdb.resetDedup(dedup_time2) 

790 

791 catalog = apdb.getDiaObjectsForDedup(dedup_time1) 

792 self.assertEqual(len(catalog), self._count_after_reset_dedup(100)) 

793 

794 catalog = apdb.getDiaObjectsForDedup() 

795 self.assertEqual(len(catalog), 0) 

796 

797 def _count_after_reset_dedup(self, count_before: int) -> int: 

798 """Return the number of rows that will be returned by 

799 getDiaObjectsForDedup() after resetDedup() was called. For SQL backend 

800 deduplication data comes from a regular table, and it is not removed 

801 by resetDedup(). 

802 """ 

803 raise NotImplementedError() 

804 

805 def test_getChunks(self) -> None: 

806 """Store and retrieve replica chunks.""" 

807 # don't care about sources. 

808 config = self.make_instance() 

809 apdb = Apdb.from_config(config) 

810 apdb_replica = ApdbReplica.from_config(config) 

811 visit_time = self.visit_time 

812 

813 region1 = self.make_region((1.0, 1.0, -1.0)) 

814 region2 = self.make_region((-1.0, -1.0, -1.0)) 

815 nobj = 100 

816 objects1 = makeObjectCatalog(region1, nobj) 

817 objects2 = makeObjectCatalog(region2, nobj, start_id=nobj * 2) 

818 

819 # With the default 10 minutes replica chunk window we should have 4 

820 # records. 

821 visits = [ 

822 (astropy.time.Time("2021-01-01T00:01:00", format="isot", scale="tai"), objects1), 

823 (astropy.time.Time("2021-01-01T00:02:00", format="isot", scale="tai"), objects2), 

824 (astropy.time.Time("2021-01-01T00:11:00", format="isot", scale="tai"), objects1), 

825 (astropy.time.Time("2021-01-01T00:12:00", format="isot", scale="tai"), objects2), 

826 (astropy.time.Time("2021-01-01T00:45:00", format="isot", scale="tai"), objects1), 

827 (astropy.time.Time("2021-01-01T00:46:00", format="isot", scale="tai"), objects2), 

828 (astropy.time.Time("2021-03-01T00:01:00", format="isot", scale="tai"), objects1), 

829 (astropy.time.Time("2021-03-01T00:02:00", format="isot", scale="tai"), objects2), 

830 ] 

831 

832 start_id = 0 

833 for visit_time, objects in visits: 

834 sources = makeSourceCatalog(objects, visit_time, start_id=start_id, use_mjd=self.use_mjd) 

835 fsources = makeForcedSourceCatalog(objects, visit_time, visit=start_id, use_mjd=self.use_mjd) 

836 apdb.store(visit_time, objects, sources, fsources) 

837 start_id += nobj 

838 

839 replica_chunks = apdb_replica.getReplicaChunks() 

840 if not self.enable_replica: 

841 self.assertIsNone(replica_chunks) 

842 

843 with self.assertRaisesRegex(ValueError, "APDB is not configured for replication"): 

844 apdb_replica.getTableDataChunks(ApdbTables.DiaObject, []) 

845 

846 else: 

847 assert replica_chunks is not None 

848 self.assertEqual(len(replica_chunks), 4) 

849 

850 with self.assertRaisesRegex(ValueError, "does not support replica chunks"): 

851 apdb_replica.getTableDataChunks(ApdbTables.SSObject, []) 

852 

853 def _check_chunks(replica_chunks: list[ReplicaChunk], n_records: int | None = None) -> None: 

854 if n_records is None: 

855 n_records = len(replica_chunks) * nobj 

856 res = apdb_replica.getTableDataChunks( 

857 ApdbTables.DiaObject, (chunk.id for chunk in replica_chunks) 

858 ) 

859 self.assert_table_data(res, n_records, ApdbTables.DiaObject) 

860 validityStartColumn = "validityStartMjdTai" if self.use_mjd else "validityStart" 

861 validityStartType = ( 

862 felis.datamodel.DataType.double if self.use_mjd else felis.datamodel.DataType.timestamp 

863 ) 

864 self.assert_column_types( 

865 res, 

866 { 

867 "apdb_replica_chunk": felis.datamodel.DataType.long, 

868 "diaObjectId": felis.datamodel.DataType.long, 

869 validityStartColumn: validityStartType, 

870 "ra": felis.datamodel.DataType.double, 

871 "dec": felis.datamodel.DataType.double, 

872 "parallax": felis.datamodel.DataType.float, 

873 "nDiaSources": felis.datamodel.DataType.int, 

874 }, 

875 ) 

876 

877 res = apdb_replica.getTableDataChunks( 

878 ApdbTables.DiaSource, (chunk.id for chunk in replica_chunks) 

879 ) 

880 self.assert_table_data(res, n_records, ApdbTables.DiaSource) 

881 self.assert_column_types( 

882 res, 

883 { 

884 "apdb_replica_chunk": felis.datamodel.DataType.long, 

885 "diaSourceId": felis.datamodel.DataType.long, 

886 "visit": felis.datamodel.DataType.long, 

887 "detector": felis.datamodel.DataType.short, 

888 }, 

889 ) 

890 

891 res = apdb_replica.getTableDataChunks( 

892 ApdbTables.DiaForcedSource, (chunk.id for chunk in replica_chunks) 

893 ) 

894 self.assert_table_data(res, n_records, ApdbTables.DiaForcedSource) 

895 self.assert_column_types( 

896 res, 

897 { 

898 "apdb_replica_chunk": felis.datamodel.DataType.long, 

899 "diaObjectId": felis.datamodel.DataType.long, 

900 "visit": felis.datamodel.DataType.long, 

901 "detector": felis.datamodel.DataType.short, 

902 }, 

903 ) 

904 

905 # read it back and check sizes 

906 _check_chunks(replica_chunks, 800) 

907 _check_chunks(replica_chunks[1:], 600) 

908 _check_chunks(replica_chunks[1:-1], 400) 

909 _check_chunks(replica_chunks[2:3], 200) 

910 _check_chunks([]) 

911 

912 # try to remove some of those 

913 deleted_chunks = replica_chunks[:1] 

914 apdb_replica.deleteReplicaChunks(chunk.id for chunk in deleted_chunks) 

915 

916 # All queries on deleted ids should return empty set. 

917 _check_chunks(deleted_chunks, 0) 

918 

919 replica_chunks = apdb_replica.getReplicaChunks() 

920 assert replica_chunks is not None 

921 self.assertEqual(len(replica_chunks), 3) 

922 

923 _check_chunks(replica_chunks, 600) 

924 

925 def test_reassignObjects(self) -> None: 

926 """Reassign DiaObjects.""" 

927 # don't care about sources. 

928 config = self.make_instance() 

929 apdb = Apdb.from_config(config) 

930 

931 region = self.make_region() 

932 visit_time = self.visit_time 

933 objects = makeObjectCatalog(region, 100) 

934 oids = list(objects["diaObjectId"]) 

935 sources = makeSourceCatalog(objects, visit_time, use_mjd=self.use_mjd) 

936 apdb.store(visit_time, objects, sources) 

937 

938 # read it back and filter by ID 

939 res = apdb.getDiaSources(region, oids, visit_time) 

940 self.assert_catalog(res, len(sources), ApdbTables.DiaSource) 

941 

942 apdb.reassignDiaSources({1: 1, 2: 2, 5: 5}) 

943 res = apdb.getDiaSources(region, oids, visit_time) 

944 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource) 

945 

946 with self.assertRaisesRegex(ValueError, r"do not exist.*\D1000"): 

947 apdb.reassignDiaSources( 

948 { 

949 1000: 1, 

950 7: 3, 

951 } 

952 ) 

953 self.assert_catalog(res, len(sources) - 3, ApdbTables.DiaSource) 

954 

955 def test_storeUpdateRecord(self) -> None: 

956 """Test _storeUpdateRecord() method.""" 

957 config = self.make_instance() 

958 apdb = Apdb.from_config(config) 

959 

960 # Times are totally arbitrary. 

961 update_time_ns1 = 2_000_000_000_000_000_000 

962 update_time_ns2 = 2_000_000_001_000_000_000 

963 records = [ 

964 ApdbReassignDiaSourceToSSObjectRecord( 

965 update_time_ns=update_time_ns1, 

966 update_order=0, 

967 diaSourceId=1, 

968 ssObjectId=1, 

969 ssObjectReassocTimeMjdTai=60000.0, 

970 ra=45.0, 

971 dec=-45.0, 

972 midpointMjdTai=60000.0, 

973 ), 

974 ApdbWithdrawDiaSourceRecord( 

975 update_time_ns=update_time_ns1, 

976 update_order=1, 

977 diaSourceId=123456, 

978 timeWithdrawnMjdTai=61000.0, 

979 ra=45.0, 

980 dec=-45.0, 

981 midpointMjdTai=60000.0, 

982 ), 

983 ApdbReassignDiaSourceToSSObjectRecord( 

984 update_time_ns=update_time_ns1, 

985 update_order=3, 

986 diaSourceId=2, 

987 ssObjectId=3, 

988 ssObjectReassocTimeMjdTai=60000.0, 

989 ra=45.0, 

990 dec=-45.0, 

991 midpointMjdTai=60000.0, 

992 ), 

993 ApdbWithdrawDiaSourceRecord( 

994 update_time_ns=update_time_ns2, 

995 update_order=0, 

996 diaSourceId=123456, 

997 timeWithdrawnMjdTai=61000.0, 

998 ra=45.0, 

999 dec=-45.0, 

1000 midpointMjdTai=60000.0, 

1001 ), 

1002 ] 

1003 

1004 update_time = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai") 

1005 chunk = ReplicaChunk.make_replica_chunk(update_time, 600) 

1006 

1007 if not self.enable_replica: 

1008 with self.assertRaises(TypeError): 

1009 self.store_update_records(apdb, records, chunk) 

1010 else: 

1011 self.store_update_records(apdb, records, chunk) 

1012 

1013 apdb_replica = ApdbReplica.from_config(config) 

1014 records_returned = apdb_replica.getUpdateRecordChunks([chunk.id]) 

1015 

1016 # Input records are ordered, output will be ordered too. 

1017 self.assertEqual(records_returned, records) 

1018 

1019 @abstractmethod 

1020 def store_update_records(self, apdb: Apdb, records: list[ApdbUpdateRecord], chunk: ReplicaChunk) -> None: 

1021 """Store update records in database, must be overriden in subclass.""" 

1022 raise NotImplementedError() 

1023 

1024 def test_midpointMjdTai_src(self) -> None: 

1025 """Test for time filtering of DiaSources.""" 

1026 config = self.make_instance() 

1027 apdb = Apdb.from_config(config) 

1028 

1029 region = self.make_region() 

1030 # 2021-01-01 plus 360 days is 2021-12-27 

1031 src_time1 = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai") 

1032 src_time2 = astropy.time.Time("2021-01-01T00:00:02", format="isot", scale="tai") 

1033 visit_time0 = astropy.time.Time("2021-12-26T23:59:59", format="isot", scale="tai") 

1034 visit_time1 = astropy.time.Time("2021-12-27T00:00:01", format="isot", scale="tai") 

1035 visit_time2 = astropy.time.Time("2021-12-27T00:00:03", format="isot", scale="tai") 

1036 one_sec = astropy.time.TimeDelta(1.0, format="sec") 

1037 

1038 objects = makeObjectCatalog(region, 100) 

1039 oids = list(objects["diaObjectId"]) 

1040 sources = makeSourceCatalog(objects, src_time1, 0, use_mjd=self.use_mjd) 

1041 apdb.store(src_time1, objects, sources) 

1042 

1043 sources = makeSourceCatalog(objects, src_time2, 100, use_mjd=self.use_mjd) 

1044 apdb.store(src_time2, objects, sources) 

1045 

1046 # reading at time of last save should read all 

1047 res = apdb.getDiaSources(region, oids, src_time2) 

1048 self.assert_catalog(res, 200, ApdbTables.DiaSource) 

1049 

1050 # one second before 12 months 

1051 res = apdb.getDiaSources(region, oids, visit_time0) 

1052 self.assert_catalog(res, 200, ApdbTables.DiaSource) 

1053 

1054 # reading at later time of last save should only read a subset 

1055 res = apdb.getDiaSources(region, oids, visit_time1) 

1056 self.assert_catalog(res, 100, ApdbTables.DiaSource) 

1057 

1058 # reading at later time of last save should only read a subset 

1059 res = apdb.getDiaSources(region, oids, visit_time2) 

1060 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

1061 

1062 # Use explicit start time argument instead of 12 month window, visit 

1063 # time does not matter in this case, set it to before all data. 

1064 res = apdb.getDiaSources(region, oids, src_time1 - one_sec, src_time1 - one_sec) 

1065 self.assert_catalog(res, 200, ApdbTables.DiaSource) 

1066 

1067 res = apdb.getDiaSources(region, oids, src_time1 - one_sec, src_time2 - one_sec) 

1068 self.assert_catalog(res, 100, ApdbTables.DiaSource) 

1069 

1070 res = apdb.getDiaSources(region, oids, src_time1 - one_sec, src_time2 + one_sec) 

1071 self.assert_catalog(res, 0, ApdbTables.DiaSource) 

1072 

1073 def test_midpointMjdTai_fsrc(self) -> None: 

1074 """Test for time filtering of DiaForcedSources.""" 

1075 config = self.make_instance() 

1076 apdb = Apdb.from_config(config) 

1077 

1078 region = self.make_region() 

1079 src_time1 = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai") 

1080 src_time2 = astropy.time.Time("2021-01-01T00:00:02", format="isot", scale="tai") 

1081 visit_time0 = astropy.time.Time("2021-12-26T23:59:59", format="isot", scale="tai") 

1082 visit_time1 = astropy.time.Time("2021-12-27T00:00:01", format="isot", scale="tai") 

1083 visit_time2 = astropy.time.Time("2021-12-27T00:00:03", format="isot", scale="tai") 

1084 one_sec = astropy.time.TimeDelta(1.0, format="sec") 

1085 

1086 objects = makeObjectCatalog(region, 100) 

1087 oids = list(objects["diaObjectId"]) 

1088 sources = makeForcedSourceCatalog(objects, src_time1, 1, use_mjd=self.use_mjd) 

1089 apdb.store(src_time1, objects, forced_sources=sources) 

1090 

1091 sources = makeForcedSourceCatalog(objects, src_time2, 2, use_mjd=self.use_mjd) 

1092 apdb.store(src_time2, objects, forced_sources=sources) 

1093 

1094 # reading at time of last save should read all 

1095 res = apdb.getDiaForcedSources(region, oids, src_time2) 

1096 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource) 

1097 

1098 # one second before 12 months 

1099 res = apdb.getDiaForcedSources(region, oids, visit_time0) 

1100 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource) 

1101 

1102 # reading at later time of last save should only read a subset 

1103 res = apdb.getDiaForcedSources(region, oids, visit_time1) 

1104 self.assert_catalog(res, 100, ApdbTables.DiaForcedSource) 

1105 

1106 # reading at later time of last save should only read a subset 

1107 res = apdb.getDiaForcedSources(region, oids, visit_time2) 

1108 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

1109 

1110 # Use explicit start time argument instead of 12 month window, visit 

1111 # time does not matter in this case, set it to before all data. 

1112 res = apdb.getDiaForcedSources(region, oids, src_time1 - one_sec, src_time1 - one_sec) 

1113 self.assert_catalog(res, 200, ApdbTables.DiaForcedSource) 

1114 

1115 res = apdb.getDiaForcedSources(region, oids, src_time1 - one_sec, src_time2 - one_sec) 

1116 self.assert_catalog(res, 100, ApdbTables.DiaForcedSource) 

1117 

1118 res = apdb.getDiaForcedSources(region, oids, src_time1 - one_sec, src_time2 + one_sec) 

1119 self.assert_catalog(res, 0, ApdbTables.DiaForcedSource) 

1120 

1121 def test_metadata(self) -> None: 

1122 """Simple test for writing/reading metadata table""" 

1123 config = self.make_instance() 

1124 apdb = Apdb.from_config(config) 

1125 metadata = apdb.metadata 

1126 

1127 # APDB should write two or three metadata items with version numbers 

1128 # and a frozen JSON config. 

1129 self.assertFalse(metadata.empty()) 

1130 self.assertEqual(len(list(metadata.items())), self.meta_row_count) 

1131 

1132 metadata.set("meta", "data") 

1133 metadata.set("data", "meta") 

1134 

1135 self.assertFalse(metadata.empty()) 

1136 self.assertTrue(set(metadata.items()) >= {("meta", "data"), ("data", "meta")}) 

1137 

1138 with self.assertRaisesRegex(KeyError, "Metadata key 'meta' already exists"): 

1139 metadata.set("meta", "data1") 

1140 

1141 metadata.set("meta", "data2", force=True) 

1142 self.assertTrue(set(metadata.items()) >= {("meta", "data2"), ("data", "meta")}) 

1143 

1144 self.assertTrue(metadata.delete("meta")) 

1145 self.assertIsNone(metadata.get("meta")) 

1146 self.assertFalse(metadata.delete("meta")) 

1147 

1148 self.assertEqual(metadata.get("data"), "meta") 

1149 self.assertEqual(metadata.get("meta", "meta"), "meta") 

1150 

1151 def test_schemaVersionFromYaml(self) -> None: 

1152 """Check version number handling for reading schema from YAML.""" 

1153 config = self.make_instance() 

1154 default_schema = config.schema_file 

1155 apdb = Apdb.from_config(config) 

1156 self.assertEqual(apdb.schema.schemaVersion(), VersionTuple(0, 1, 1)) 

1157 

1158 with update_schema_yaml(default_schema, version="") as schema_file: 

1159 config = self.make_instance(schema_file=schema_file) 

1160 apdb = Apdb.from_config(config) 

1161 self.assertEqual( 

1162 apdb.schema.schemaVersion(), 

1163 VersionTuple(0, 1, 0), 

1164 ) 

1165 

1166 with update_schema_yaml(default_schema, version="99.0.0") as schema_file: 

1167 config = self.make_instance(schema_file=schema_file) 

1168 apdb = Apdb.from_config(config) 

1169 self.assertEqual( 

1170 apdb.schema.schemaVersion(), 

1171 VersionTuple(99, 0, 0), 

1172 ) 

1173 

1174 def test_config_freeze(self) -> None: 

1175 """Test that some config fields are correctly frozen in database.""" 

1176 config = self.make_instance() 

1177 

1178 # `enable_replica` is the only parameter that is frozen in all 

1179 # implementations. 

1180 config.enable_replica = not self.enable_replica 

1181 apdb = Apdb.from_config(config) 

1182 frozen_config = apdb.getConfig() 

1183 self.assertEqual(frozen_config.enable_replica, self.enable_replica) 

1184 

1185 

1186class ApdbSchemaUpdateTest(TestCaseMixin, ABC): 

1187 """Base class for unit tests that verify how schema changes work.""" 

1188 

1189 visit_time = astropy.time.Time("2021-01-01T00:00:00", format="isot", scale="tai") 

1190 

1191 @abstractmethod 

1192 def make_instance(self, **kwargs: Any) -> ApdbConfig: 

1193 """Make config class instance used in all tests. 

1194 

1195 This method should return configuration that point to the identical 

1196 database instance on each call (i.e. ``db_url`` must be the same, 

1197 which also means for sqlite it has to use on-disk storage). 

1198 """ 

1199 raise NotImplementedError() 

1200 

1201 def make_region(self, xyz: tuple[float, float, float] = (1.0, 1.0, -1.0)) -> Region: 

1202 """Make a region to use in tests""" 

1203 return _make_region(xyz) 

1204 

1205 def test_schema_add_replica(self) -> None: 

1206 """Check that new code can work with old schema without replica 

1207 tables. 

1208 """ 

1209 # Make schema without replica tables. 

1210 config = self.make_instance(enable_replica=False) 

1211 apdb = Apdb.from_config(config) 

1212 apdb_replica = ApdbReplica.from_config(config) 

1213 

1214 # Make APDB instance configured for replication. 

1215 config.enable_replica = True 

1216 apdb = Apdb.from_config(config) 

1217 

1218 # Try to insert something, should work OK. 

1219 region = self.make_region() 

1220 visit_time = self.visit_time 

1221 

1222 # have to store Objects first 

1223 objects = makeObjectCatalog(region, 100) 

1224 sources = makeSourceCatalog(objects, visit_time) 

1225 fsources = makeForcedSourceCatalog(objects, visit_time) 

1226 apdb.store(visit_time, objects, sources, fsources) 

1227 

1228 # There should be no replica chunks. 

1229 replica_chunks = apdb_replica.getReplicaChunks() 

1230 self.assertIsNone(replica_chunks) 

1231 

1232 def test_schemaVersionCheck(self) -> None: 

1233 """Check version number compatibility.""" 

1234 config = self.make_instance() 

1235 apdb = Apdb.from_config(config) 

1236 

1237 self.assertEqual(apdb.schema.schemaVersion(), VersionTuple(0, 1, 1)) 

1238 

1239 # Claim that schema version is now 99.0.0, must raise an exception. 

1240 with update_schema_yaml(config.schema_file, version="99.0.0") as schema_file: 

1241 config.schema_file = schema_file 

1242 with self.assertRaises(IncompatibleVersionError): 

1243 apdb = Apdb.from_config(config) 

1244 # Version is checked only when we try to do connect. 

1245 apdb.metadata.items()