Coverage for tests / test_deblend.py: 12%

423 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-17 09:14 +0000

1# This file is part of meas_extensions_scarlet. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import os 

23import unittest 

24import tempfile 

25 

26import lsst.afw.image as afwImage 

27import lsst.meas.extensions.scarlet as mes 

28import lsst.scarlet.lite as scl 

29import lsst.utils.tests 

30import numpy as np 

31from lsst.afw.detection import Footprint, GaussianPsf, InvalidPsfError, PeakTable, Psf 

32from lsst.afw.geom import SpanSet 

33from lsst.afw.table import SourceCatalog, SourceTable, SchemaMapper 

34import lsst.daf.butler 

35from lsst.daf.butler import Butler, Config, DatasetType, StorageClass, FileDataset, DatasetRef 

36from lsst.daf.butler.tests import makeTestRepo, makeTestCollection 

37from lsst.geom import Extent2I, Point2D, Point2I 

38from lsst.meas.algorithms import SourceDetectionTask 

39from lsst.meas.extensions.scarlet.scarletDeblendTask import ScarletDeblendTask 

40from lsst.meas.extensions.scarlet.deconvolveExposureTask import DeconvolveExposureTask 

41from lsst.pipe.base import NoWorkFound, Struct 

42from utils import initData, SersicModel, PsfModel 

43 

44TESTDIR = os.path.abspath(os.path.dirname(__file__)) 

45 

46 

47class TestDeblend(lsst.utils.tests.TestCase): 

48 def setUp(self): 

49 self.modelPsf = scl.utils.integrated_circular_gaussian(sigma=0.8).astype(np.float32) 

50 psfRadius = 20 

51 psfShape = (2 * psfRadius + 1, 2 * psfRadius + 1) 

52 self.psfs = [ 

53 GaussianPsf(psfShape[1], psfShape[0], 1.0), 

54 GaussianPsf(psfShape[1], psfShape[0], 1.2), 

55 GaussianPsf(psfShape[1], psfShape[0], 1.4), 

56 ] 

57 self.imagePsf = np.asarray( 

58 [psf.computeImage(psf.getAveragePosition()).array for psf in self.psfs] 

59 ).astype(np.float32) 

60 self.imagePsf /= self.imagePsf.sum(axis=(1, 2))[:, None, None] 

61 self.bands = tuple("gri") 

62 

63 self.models = [ 

64 # Isolated source 

65 PsfModel( 

66 center=(30, 15), 

67 spectrum=np.array([8, 2, 1]), 

68 bands=self.bands, 

69 ), 

70 # Two source blend 

71 SersicModel( 

72 center=(40, 20), 

73 major=5, 

74 minor=2, 

75 radius=15, 

76 theta=-np.pi/4, 

77 n=1, 

78 spectrum=np.array([2, 4, 8]), 

79 bands=self.bands, 

80 ), 

81 PsfModel( 

82 center=(12, 20), 

83 spectrum=np.array([1, 2, 8]), 

84 bands=self.bands, 

85 ), 

86 # 3 source blend 

87 SersicModel( 

88 center=(25, 70), 

89 major=5, 

90 minor=2, 

91 radius=20, 

92 theta=np.pi/48, 

93 n=1, 

94 spectrum=np.array([2, 5, 8]), 

95 bands=self.bands, 

96 ), 

97 PsfModel( 

98 center=(32, 60), 

99 spectrum=np.array([1, 2, 8]), 

100 bands=self.bands, 

101 ), 

102 PsfModel( 

103 center=(16, 80), 

104 spectrum=np.array([8, 2, 1]), 

105 bands=self.bands, 

106 ), 

107 # Large blend 

108 SersicModel( 

109 center=(70, 70), 

110 major=5, 

111 minor=2, 

112 radius=25, 

113 theta=0, 

114 n=1, 

115 spectrum=np.array([2, 10, 18]), 

116 bands=self.bands, 

117 ), 

118 SersicModel( 

119 center=(85, 85), 

120 major=5, 

121 minor=2, 

122 radius=25, 

123 theta=np.pi/2, 

124 n=1, 

125 spectrum=np.array([5, 10, 20]), 

126 bands=self.bands, 

127 ), 

128 ] 

129 

130 def scarlet_image_to_exposure( 

131 self, 

132 image: scl.Image, 

133 noise: np.ndarray, 

134 ) -> afwImage.MultibandExposure: 

135 masked_image = afwImage.MultibandMaskedImage.fromArrays( 

136 image.bands, image.data, None, noise**2 

137 ) 

138 coadds = [ 

139 afwImage.Exposure(img, dtype=img.image.array.dtype) for img in masked_image 

140 ] 

141 mCoadd = afwImage.MultibandExposure.fromExposures(image.bands, coadds) 

142 for b, coadd in enumerate(mCoadd): 

143 coadd.setPsf(self.psfs[b]) 

144 return mCoadd 

145 

146 def initialize_data( 

147 self, 

148 models, 

149 deconvolveConfig=None, 

150 deblendConfig=None, 

151 doDetect: bool = True, 

152 ): 

153 if deconvolveConfig is None: 

154 deconvolveConfig = DeconvolveExposureTask.ConfigClass() 

155 if deblendConfig is None: 

156 deblendConfig = ScarletDeblendTask.ConfigClass() 

157 # Generate the data for the test 

158 deconvolved, convolved = initData(models, self.modelPsf, self.imagePsf) 

159 # Set the random seed so that the noise field is unaffected 

160 # and add noise to the image 

161 np.random.seed(0) 

162 noise = 0.05 * (np.random.rand(*convolved.shape).astype(np.float32) - 0.5) 

163 noisyImage = convolved.copy() 

164 noisyImage._data += noise 

165 # Create the multiband coadd 

166 mCoadd = self.scarlet_image_to_exposure(noisyImage, noise) 

167 # Initialze tasks 

168 inputSchema = SourceTable.makeMinimalSchema() 

169 table = SourceTable.make(inputSchema) 

170 detectionTask = SourceDetectionTask(schema=inputSchema) 

171 schemaMapper = SchemaMapper(inputSchema) 

172 schemaMapper.addMinimalSchema(inputSchema) 

173 schema = schemaMapper.getOutputSchema() 

174 deconvolveTask = DeconvolveExposureTask(config=deconvolveConfig) 

175 deblendTask = ScarletDeblendTask(schema=schema, config=deblendConfig) 

176 

177 result = Struct( 

178 deconvolved=deconvolved, 

179 convolved=convolved, 

180 noise=noise, 

181 noisyImage=noisyImage, 

182 mCoadd=mCoadd, 

183 detectionTask=detectionTask, 

184 deconvolveTask=deconvolveTask, 

185 deblendTask=deblendTask, 

186 ) 

187 

188 if doDetect: 

189 # Generate a detection catalog 

190 detectionResult = detectionTask.run(table, mCoadd["r"]) 

191 table = SourceCatalog.Table.make(schema) 

192 catalog = SourceCatalog(table) 

193 catalog.extend(detectionResult.sources, schemaMapper) 

194 result.catalog = catalog 

195 

196 return result 

197 

198 def deconvolve(self, data: Struct): 

199 deconvolvedCoadds = [] 

200 deconvolveTask = data.deconvolveTask 

201 if deconvolveTask.config.useFootprints: 

202 catalog = data.catalog 

203 else: 

204 catalog = None 

205 for coadd in data.mCoadd: 

206 deconvolvedCoadd = deconvolveTask.run(coadd, catalog).deconvolved 

207 deconvolvedCoadds.append(deconvolvedCoadd) 

208 mDeconvolved = afwImage.MultibandExposure.fromExposures(self.bands, deconvolvedCoadds) 

209 return mDeconvolved 

210 

211 def test_default_deconvolve(self): 

212 data = self.initialize_data(self.models) 

213 deconvolved = self.deconvolve(data) 

214 

215 diff = data.deconvolved.data - deconvolved.image.array 

216 # Due to peakiness of Sersic models the center has a sharp peak, 

217 # so we ignore a 3x3 region around each source center 

218 for model in self.models: 

219 yc, xc = model.center 

220 for x in (-1, 0, 1): 

221 for y in (-1, 0, 1): 

222 diff[:, yc+y, xc+x] = 0 

223 self.assertTrue(np.max(diff[:2]) < 10*np.std(data.noise)) 

224 self.assertTrue(np.max(diff[2]) < 20*np.std(data.noise)) 

225 

226 context = mes.scarletDeblendTask.ScarletDeblendContext.build( 

227 data.mCoadd, 

228 deconvolved, 

229 data.catalog, 

230 data.deblendTask.ConfigClass() 

231 ) 

232 

233 self.assertEqual(len(context.footprints), 4) 

234 

235 def test_catalog_free_deconvolve(self): 

236 config = DeconvolveExposureTask.ConfigClass() 

237 config.useFootprints = False 

238 data = self.initialize_data(self.models, deconvolveConfig=config) 

239 deconvolved = self.deconvolve(data) 

240 

241 diff = data.deconvolved.data - deconvolved.image.array 

242 # Due to peakiness of Sersic models the center has a sharp peak, 

243 # so we ignore a 3x3 region around each source center 

244 for model in self.models: 

245 yc, xc = model.center 

246 for x in (-1, 0, 1): 

247 for y in (-1, 0, 1): 

248 diff[:, yc+y, xc+x] = 0 

249 self.assertTrue(np.max(diff[:2]) < 10*np.std(data.noise)) 

250 self.assertTrue(np.max(diff[2]) < 20*np.std(data.noise)) 

251 

252 def test_footprints(self): 

253 data = self.initialize_data(self.models) 

254 mDeconvolved = self.deconvolve(data) 

255 result = data.deblendTask.run(data.mCoadd, mDeconvolved, data.catalog) 

256 

257 catalog = result.deblendedCatalog 

258 objectParents = result.objectParents 

259 modelData = result.scarletModelData 

260 observedPsf = modelData.metadata["psf"] 

261 modelPsf = modelData.metadata["model_psf"] 

262 

263 # Check that isolated sources are handled correctly 

264 isolated = catalog[(catalog["parent"] == 0)] 

265 self.assertEqual(len(isolated), len(modelData.isolated)) 

266 for sid, source in modelData.isolated.items(): 

267 catalog_footprint = catalog.find(sid).getFootprint() 

268 isolated_array = catalog_footprint.spans.asArray() 

269 np.testing.assert_array_equal(source.span_array, isolated_array) 

270 

271 # Check that the origin is correct 

272 self.assertTupleEqual(source.origin[::-1], tuple(catalog_footprint.getBBox().getMin())) 

273 

274 # Verify that the isolated parent flag is being set 

275 isolatedParents = objectParents[ 

276 (objectParents["parent"] == 0) 

277 & (objectParents["deblend_nPeaks"] == 1) 

278 ] 

279 self.assertEqual(np.sum(objectParents["deblend_skipped_isolatedParent"]), len(isolatedParents)) 

280 

281 # Attach the footprints in each band and compare to the full 

282 # data model. This is done in each band, both with and without 

283 # flux re-distribution to test all of the different possible 

284 # options of loading catalog footprints. 

285 for useFlux in [False, True]: 

286 for band in self.bands: 

287 bandIndex = self.bands.index(band) 

288 coadd = data.mCoadd[band] 

289 

290 if useFlux: 

291 imageForRedistribution = coadd 

292 else: 

293 imageForRedistribution = None 

294 

295 mes.io.updateCatalogFootprints( 

296 modelData, 

297 catalog, 

298 band=band, 

299 imageForRedistribution=imageForRedistribution, 

300 removeScarletData=False, 

301 updateFluxColumns=True, 

302 ) 

303 

304 # Check that the number of deblended children is consistent 

305 parents = objectParents[ 

306 (objectParents["parent"] == 0) & (objectParents["deblend_nPeaks"] > 1)] 

307 self.assertEqual( 

308 np.sum(parents["deblend_nChild"]), len(catalog) - len(isolated) 

309 ) 

310 

311 for parent in parents: 

312 children = catalog[catalog["parent"] == parent.get("id")] 

313 

314 # Extract the parent blend data 

315 parentBlendData = modelData.blends[parent.getId()] 

316 parentFootprint = parent.getFootprint() 

317 x0, y0 = parentFootprint.getBBox().getMin() 

318 width, height = parentFootprint.getBBox().getDimensions() 

319 yx0 = (y0, x0) 

320 

321 for child in children: 

322 fp = child.getFootprint() 

323 img = fp.extractImage(fill=0.0) 

324 # Check that the flux at the center is correct. 

325 # Note: this only works in this test image because the 

326 # detected peak is in the same location as the 

327 # scarlet peak. 

328 # If the peak is shifted, 

329 # the flux value will be correct but 

330 # deblend_peak_center is not the correct location. 

331 px = child.get("deblend_peak_center_x") 

332 py = child.get("deblend_peak_center_y") 

333 flux = img[Point2I(px, py)] 

334 self.assertEqual(flux, child.get("deblend_peak_instFlux")) 

335 

336 # Check that the peak positions match the catalog entry 

337 peaks = fp.getPeaks() 

338 self.assertEqual(px, peaks[0].getIx()) 

339 self.assertEqual(py, peaks[0].getIy()) 

340 

341 # Load the data to check against the HeavyFootprint 

342 blendData = parentBlendData.children[child["deblend_blendId"]] 

343 # We need to set an observation in order to convolve 

344 # the model. 

345 modelBox = scl.Box((height, width), origin=(y0, x0)) 

346 observation = scl.Observation.empty( 

347 bands=("dummy",), 

348 psfs=observedPsf[bandIndex][None, :, :], 

349 model_psf=modelPsf[None, :, :], 

350 bbox=modelBox, 

351 dtype=np.float32, 

352 ) 

353 blend = mes.io.monochromaticDataToScarlet( 

354 blendData=blendData, 

355 bandIndex=bandIndex, 

356 observation=observation, 

357 ) 

358 

359 # Get the scarlet model for the source 

360 source = next( 

361 src for src in blend.sources if src.metadata["id"] == child.getId() 

362 ) 

363 self.assertEqual(source.center[1], px) 

364 self.assertEqual(source.center[0], py) 

365 

366 if useFlux: 

367 assert imageForRedistribution is not None 

368 # Get the flux re-weighted model and test against 

369 # the HeavyFootprint. 

370 # The HeavyFootprint needs to be projected onto 

371 # the image of the flux-redistributed model, 

372 # since the HeavyFootprint 

373 # may trim rows or columns. 

374 _images = imageForRedistribution[ 

375 parentFootprint.getBBox() 

376 ].image.array 

377 blend.observation.images = scl.Image( 

378 _images[None, :, :], 

379 yx0=yx0, 

380 bands=("dummy",), 

381 ) 

382 blend.observation.weights = scl.Image( 

383 parentFootprint.spans.asArray()[None, :, :], 

384 yx0=yx0, 

385 bands=("dummy",), 

386 ) 

387 blend.conserve_flux() 

388 model = source.flux_weighted_image.data[0] 

389 my0, mx0 = source.flux_weighted_image.yx0 

390 image = afwImage.ImageF(model, xy0=Point2I(mx0, my0)) 

391 fp.insert(image) 

392 np.testing.assert_almost_equal(image.array, model) 

393 else: 

394 # Get the model for the source and test 

395 # against the HeavyFootprint 

396 bbox = fp.getBBox() 

397 bbox = mes.utils.bboxToScarletBox(bbox) 

398 model = blend.observation.convolve( 

399 source.get_model().project(bbox=bbox), mode="real" 

400 ).data[0] 

401 np.testing.assert_almost_equal(img.array, model) 

402 

403 # Check that all sources have the correct number of peaks 

404 maxId = np.max(objectParents["id"]) 

405 for src in catalog: 

406 fp = src.getFootprint() 

407 self.assertEqual(len(fp.peaks), 1) 

408 if src["parent"] > 0: 

409 # Check that source IDs are greater than the max parent ID 

410 self.assertGreater(src["id"], maxId) 

411 

412 # Ensure that sources are sorted by parent ID 

413 np.testing.assert_array_equal(sorted(catalog["parent"]), catalog["parent"]) 

414 

415 # Check that the catalog matches the expected results 

416 self.assertEqual(len(catalog), len(self.models)) 

417 

418 def test_skipped(self): 

419 # Use tight configs to force skipping a 3 source footprint 

420 # and "large" footprint 

421 config = ScarletDeblendTask.ConfigClass() 

422 config.maxFootprintArea = 2000 

423 config.maxNumberOfPeaks = 2 

424 config.catchFailures = False 

425 

426 data = self.initialize_data(self.models, deblendConfig=config) 

427 mDeconvolved = self.deconvolve(data) 

428 result = data.deblendTask.run(data.mCoadd, mDeconvolved, data.catalog) 

429 

430 catalog = result.objectParents 

431 parents = catalog[catalog["parent"] == 0] 

432 self.assertEqual(np.sum(parents["deblend_skipped"]), 2) 

433 self.assertEqual(np.sum(parents["deblend_skipped_parentTooBig"]), 1) 

434 self.assertEqual(np.sum(parents["deblend_skipped_tooManyPeaks"]), 1) 

435 

436 def test_persistence(self): 

437 # Test that the model data is persisted correctly 

438 data = self.initialize_data(self.models) 

439 repo = self._setup_butler() 

440 mDeconvolved = self.deconvolve(data) 

441 result = data.deblendTask.run(data.mCoadd, mDeconvolved, data.catalog) 

442 modelData = result.scarletModelData 

443 bands = modelData.metadata["bands"] 

444 butler = makeTestCollection(repo, uniqueId="test_run1") 

445 butler.put(modelData, "scarlet_model_data", dataId={}) 

446 modelData2 = butler.get("scarlet_model_data", dataId={}) 

447 model_psf = modelData.metadata["model_psf"][None, :, :] 

448 model_psf2 = modelData2.metadata["model_psf"][None, :, :] 

449 np.testing.assert_almost_equal(model_psf2, model_psf) 

450 psf = modelData.metadata["psf"] 

451 psf2 = modelData2.metadata["psf"] 

452 np.testing.assert_almost_equal(psf2, psf) 

453 self.assertEqual(len(modelData2.blends), len(modelData.blends)) 

454 

455 for parentId in modelData.blends.keys(): 

456 nChildren = len(modelData.blends[parentId].children) 

457 self.assertEqual(nChildren, len(modelData2.blends[parentId].children)) 

458 for blendId in modelData.blends[parentId].children: 

459 blendData1 = modelData.blends[parentId].children[blendId] 

460 blendData2 = modelData2.blends[parentId].children[blendId] 

461 self._test_blend(blendData1, blendData2, model_psf, psf, bands) 

462 

463 for sourceId in modelData.isolated.keys(): 

464 isolatedData1 = modelData.isolated[sourceId] 

465 isolatedData2 = modelData2.isolated[sourceId] 

466 self.assertTupleEqual(isolatedData1.origin, isolatedData2.origin) 

467 np.testing.assert_array_equal( 

468 isolatedData1.span_array, 

469 isolatedData2.span_array, 

470 ) 

471 

472 # Test extracting a single blend 

473 modelData2 = butler.get("scarlet_model_data", dataId={}, parameters={"blend_id": parentId}) 

474 self.assertEqual(len(modelData2.blends), 1) 

475 

476 for blendId, blendData1 in modelData.blends[parentId].children.items(): 

477 blendData2 = modelData2.blends[parentId].children[blendId] 

478 self._test_blend(blendData1, blendData2, model_psf, psf, bands) 

479 

480 # Test extracting two blends 

481 modelData2 = butler.get("scarlet_model_data", dataId={}, parameters={"blend_id": [2, 3]}) 

482 self.assertEqual(len(modelData2.blends), 2) 

483 for parentId in [2, 3]: 

484 parentData1 = modelData.blends[parentId] 

485 parentData2 = modelData2.blends[parentId] 

486 self.assertEqual(len(parentData1.children), len(parentData2.children)) 

487 for blendId in parentData1.children.keys(): 

488 blendData1 = parentData1.children[blendId] 

489 blendData2 = parentData2.children[blendId] 

490 self._test_blend(blendData1, blendData2, model_psf, psf, bands) 

491 

492 def test_legacy_model(self): 

493 repo = self._setup_butler() 

494 storageClass = StorageClass( 

495 "LsstScarletModelData", 

496 pytype=mes.io.LsstScarletModelData, 

497 ) 

498 datasetType = DatasetType( 

499 "old_scarlet_model_data", 

500 dimensions=(), 

501 storageClass=storageClass, 

502 universe=repo.dimensions, 

503 ) 

504 ref = DatasetRef( 

505 datasetType, 

506 run="test_ingestion", 

507 dataId={}, 

508 ) 

509 dataset = FileDataset( 

510 path=os.path.join(TESTDIR, "data", "v29_models.json"), 

511 formatter="lsst.daf.butler.formatters.json.JsonFormatter", 

512 refs=[ref], 

513 ) 

514 

515 # Ingest the legacy model into the butler 

516 butler = makeTestCollection(repo, uniqueId="ingestion") 

517 repo.registry.registerDatasetType(datasetType) 

518 butler.ingest(dataset) 

519 

520 model = butler.get("old_scarlet_model_data", dataId={}) 

521 self.assertEqual(len(model.blends), 2) 

522 

523 test = butler.get("old_scarlet_model_data", dataId={}, parameters={"blend_id": 3495976385350991873}) 

524 self.assertEqual(len(test.blends), 1) 

525 

526 def test_older_legacy_model(self): 

527 repo = self._setup_butler() 

528 oldStorageClass = StorageClass( 

529 "ScarletModelData", 

530 pytype=lsst.scarlet.lite.io.ScarletModelData, 

531 ) 

532 oldDatasetType = DatasetType( 

533 "old_scarlet_model_data", 

534 dimensions=(), 

535 storageClass=oldStorageClass, 

536 universe=repo.dimensions, 

537 ) 

538 ref = DatasetRef( 

539 oldDatasetType, 

540 run="test_ingestion", 

541 dataId={}, 

542 ) 

543 dataset = FileDataset( 

544 path=os.path.join(TESTDIR, "data", "v29_models.json"), 

545 formatter="lsst.daf.butler.formatters.json.JsonFormatter", 

546 refs=[ref], 

547 ) 

548 

549 # Ingest the legacy model into the butler 

550 butler = makeTestCollection(repo, uniqueId="ingestion") 

551 repo.registry.registerDatasetType(oldDatasetType) 

552 butler.ingest(dataset) 

553 

554 # Load the base repo config from the repository 

555 base_config = Config(os.path.join(self.repo_dir, "butler.yaml")) 

556 

557 # Load the storage class override config 

558 override_path = os.path.join( 

559 os.path.dirname(lsst.daf.butler.__file__), 

560 "configs", 

561 "storageClasses.yaml" 

562 ) 

563 override_config = Config(override_path) 

564 

565 # Merge the configs (update base with override) 

566 base_config.update(override_config) 

567 

568 # Create Butler with the merged config 

569 # The config now contains both the repo info and 

570 # the storage class overrides 

571 newButler = Butler.from_config(base_config, collections=butler.collections) 

572 

573 model = newButler.get("old_scarlet_model_data", dataId={}, storageClass="LsstScarletModelData") 

574 self.assertEqual(len(model.blends), 2) 

575 self.assertEqual(len(model.isolated), 0) 

576 

577 def _test_blend(self, blendData1, blendData2, model_psf, psf, bands): 

578 # Test that two ScarletBlendData objects are equal 

579 # up to machine precision. 

580 self.assertTupleEqual(blendData1.origin, blendData2.origin) 

581 self.assertEqual(len(blendData1.sources), len(blendData2.sources)) 

582 

583 # Test that the two blends are equal up to machine precision 

584 # once converted into scarlet lite Blend objects. 

585 blend1 = blendData1.minimal_data_to_blend( 

586 model_psf, 

587 psf, 

588 bands, 

589 dtype=np.float32, 

590 ) 

591 blend2 = blendData2.minimal_data_to_blend( 

592 model_psf, 

593 psf, 

594 bands, 

595 dtype=np.float32, 

596 ) 

597 np.testing.assert_almost_equal(blend1.get_model().data, blend2.get_model().data) 

598 

599 def _setup_butler(self): 

600 # Initialize a Butler to test persistence 

601 repo_dir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True) 

602 self.repo_dir = repo_dir.name 

603 self.addCleanup(tempfile.TemporaryDirectory.cleanup, repo_dir) 

604 config = Config() 

605 config["datastore", "cls"] = "lsst.daf.butler.datastores.fileDatastore.FileDatastore" 

606 repo = makeTestRepo(repo_dir.name, config=config) 

607 storageClass = StorageClass( 

608 "LsstScarletModelData", 

609 pytype=mes.io.LsstScarletModelData, 

610 parameters=('blend_id',), 

611 delegate="lsst.meas.extensions.scarlet.io.ScarletModelDelegate", 

612 ) 

613 datasetType = DatasetType( 

614 "scarlet_model_data", 

615 dimensions=(), 

616 storageClass=storageClass, 

617 universe=repo.dimensions, 

618 ) 

619 repo.registry.registerDatasetType(datasetType) 

620 return repo 

621 

622 

623class BadPsf(Psf): 

624 def __init__(self, validPoint: Point2D, psf: GaussianPsf): 

625 self.validPoint = validPoint 

626 self.psf = psf 

627 super().__init__() 

628 

629 def computeKernelImage(self, location: Point2D): 

630 if location == self.validPoint: 

631 return self.psf.computeKernelImage(location) 

632 raise InvalidPsfError(f"Invalid PSF at location {location}") 

633 

634 

635class TestUtils(lsst.utils.tests.TestCase): 

636 def setUp(self): 

637 self.bands = tuple("gri") 

638 

639 def test_box_transforms(self): 

640 # Test scarlet Box to BBox 

641 box = scl.Box((25, 30), (17, 5)) 

642 bbox = mes.utils.scarletBoxToBBox(box) 

643 x0, y0 = bbox.getMin() 

644 width = bbox.getWidth() 

645 height = bbox.getHeight() 

646 self.assertTupleEqual((y0, x0), box.origin) 

647 self.assertTupleEqual((height, width), box.shape) 

648 

649 # Test back to scarlet Box 

650 newBox = mes.utils.bboxToScarletBox(bbox) 

651 self.assertTupleEqual(newBox.origin, box.origin) 

652 self.assertTupleEqual(newBox.shape, box.shape) 

653 

654 def test_computeNearestPsfGood(self): 

655 # Test that using a valid PSF works normally 

656 psf, psfImage = self._generateGoodPsf() 

657 coadd = self._generateCoadd(psf) 

658 

659 # Test that computing the PSF works 

660 derivedPsf, center, dist = mes.utils.computeNearestPsf(coadd, None, "g", Point2D(25, 25)) 

661 np.testing.assert_array_equal(derivedPsf.array, psfImage) 

662 self.assertEqual(center, Point2D(25, 25)) 

663 self.assertEqual(dist, 0) 

664 

665 def test_computeNearestPsfRecoverable(self): 

666 # Test that using a PSF not defined at the initial location 

667 # will fallback to a valid location. 

668 psf, psfImage = self._generateGoodPsf() 

669 coadd = self._generateCoadd(BadPsf(Point2D(1, 1), psf)) 

670 catalog = self._generateCatalog(self.bands, [[(1, 1, 10)]]) 

671 

672 # Test that computing the PSF works after finding a new location. 

673 # Since the PSF above is *only* defined at (1, 1) it will fail to 

674 # compute a PSF image at (4, 5) but should fall back to (1, 1). 

675 derivedPsf, center, dist = mes.utils.computeNearestPsf(coadd, catalog, None, Point2D(4, 5)) 

676 np.testing.assert_array_equal(derivedPsf.array, psfImage) 

677 self.assertEqual(center, Point2I(1, 1)) 

678 self.assertEqual(dist, 5) 

679 

680 def test_computeNearestPsfBad(self): 

681 # Test that a PSF that cannot find a matching location returns None 

682 psf = self._generateGoodPsf() 

683 coadd = self._generateCoadd(BadPsf(Point2D(1, 1), psf)) 

684 catalog = self._generateCatalog(self.bands) 

685 

686 # Test that computing the PSF cannot generate a PSF 

687 derivedPsf, center, dist = mes.utils.computeNearestPsf(coadd, catalog, None, Point2D(4, 5)) 

688 self.assertIsNone(derivedPsf) 

689 self.assertIsNone(center) 

690 self.assertIsNone(dist) 

691 

692 def test_computeNearestPsfMultiBandGood(self): 

693 # Test that a valid PSF in every band works normally 

694 bands = tuple("gri") 

695 psfs, psfImage = self._generateMultibandPsf([1.0, 1.2, 1.4]) 

696 mCoadd = self._generateMultibandCoadd(psfs, bands) 

697 

698 # Test that computing the PSF works 

699 psfArray, newCoadd = mes.utils.computeNearestPsfMultiBand(mCoadd, Point2D(25, 25), None) 

700 np.testing.assert_array_equal(psfArray, psfImage) 

701 self.assertTupleEqual(newCoadd.bands, bands) 

702 

703 def test_computeNearestPsfMultiBandRecoverable(self): 

704 # Test that a Psf at a different location is still recoverable 

705 bands = tuple("gri") 

706 psfs, psfImage = self._generateMultibandPsf([1.0, 1.2, 1.4]) 

707 psfs[1] = BadPsf(Point2D(1, 1), psfs[1]) 

708 mCoadd = self._generateMultibandCoadd(psfs, bands) 

709 catalog = self._generateCatalog(self.bands, [[(1, 1, 10)]]) 

710 

711 # Test that computing the PSF works because the catalog has a peak 

712 # at the location of the BadPsf. 

713 psfArray, newCoadd = mes.utils.computeNearestPsfMultiBand(mCoadd, Point2D(25, 25), catalog) 

714 np.testing.assert_array_equal(psfArray, psfImage) 

715 self.assertTupleEqual(newCoadd.bands, bands) 

716 

717 def test_computeNearestPsfMultiBandIncomplete(self): 

718 # Test that missing a PSF in one band returns a PSF and 

719 # an exposure that are missing bands. 

720 bands = tuple("gri") 

721 psfs, psfImage = self._generateMultibandPsf([1.0, 1.2, 1.4]) 

722 psfs[1] = BadPsf(Point2D(1, 1), psfs[1]) 

723 mCoadd = self._generateMultibandCoadd(psfs, bands) 

724 catalog = self._generateCatalog(self.bands) 

725 

726 # Test that computing the PSF works for the g- and i-band PSFs that 

727 # are not BadPsf. 

728 psfArray, newCoadd = mes.utils.computeNearestPsfMultiBand(mCoadd, Point2D(25, 25), catalog) 

729 np.testing.assert_array_equal(psfArray, np.delete(psfImage, 1, axis=0)) 

730 self.assertTupleEqual(newCoadd.bands, tuple("gi")) 

731 

732 def test_computeNearestPsfMultiBandBad(self): 

733 # Test that None is returned if none of the PSFs can be computed 

734 bands = tuple("gri") 

735 psfs, psfImage = self._generateMultibandPsf([1.0, 1.2, 1.4]) 

736 psfs = [BadPsf(Point2D(1, 1), psfs) for psf in psfs] 

737 mCoadd = self._generateMultibandCoadd(psfs, bands) 

738 catalog = self._generateCatalog(self.bands) 

739 

740 psfArray, newCoadd = mes.utils.computeNearestPsfMultiBand(mCoadd, Point2D(25, 25), catalog) 

741 self.assertIsNone(psfArray) 

742 self.assertIsNone(newCoadd) 

743 

744 def test_buildObservationBadPsfs(self): 

745 # Test that creating an observation with all bad PSFs 

746 # raises NoWorkFound 

747 modelPsf = scl.utils.integrated_circular_gaussian(sigma=0.8).astype(np.float32) 

748 bands = tuple("gri") 

749 psfs, psfImage = self._generateMultibandPsf([1.0, 1.2, 1.4]) 

750 psfs = [BadPsf(Point2D(1, 1), psf) for psf in psfs] 

751 mCoadd = self._generateMultibandCoadd(psfs, bands) 

752 catalog = self._generateCatalog(self.bands) 

753 

754 # Test that building the observation fails without a catalog 

755 with self.assertRaises(NoWorkFound): 

756 mes.utils.buildObservation(modelPsf, Point2I(25, 25), mCoadd) 

757 

758 # Test that building the observation fails even with a catalog 

759 with self.assertRaises(NoWorkFound): 

760 mes.utils.buildObservation(modelPsf, Point2I(25, 25), mCoadd, catalog=catalog) 

761 

762 def _generateGoodPsf(self, sigma: float = 1.0): 

763 # Generate a PSF and Image of the PSF 

764 psfRadius = 20 

765 psfShape = (2 * psfRadius + 1, 2 * psfRadius + 1) 

766 psf = GaussianPsf(psfShape[1], psfShape[0], sigma) 

767 psfImage = psf.computeImage(psf.getAveragePosition()).array 

768 return psf, psfImage 

769 

770 def _generateMultibandPsf(self, sigmas: list[float]): 

771 # Generate a multiband PSF with a BadPsf for each None value in sigmas 

772 psfs = [] 

773 psfImages = [] 

774 for sigma in sigmas: 

775 psf, psfImage = self._generateGoodPsf(sigma) 

776 psfs.append(psf) 

777 psfImages.append(psfImage) 

778 return psfs, np.asarray(psfImages) 

779 

780 def _generateCoadd(self, psf: Psf): 

781 # Create an empty exposure 

782 masked_image = afwImage.MaskedImage(Extent2I(50, 50), dtype=np.float32) 

783 coadd = afwImage.Exposure(masked_image, dtype=np.float32) 

784 coadd.setPsf(psf) 

785 return coadd 

786 

787 def _generateMultibandCoadd(self, psfs: Psf, bands: list[str]): 

788 # Create an empty multi-band exposure 

789 coadds = [] 

790 for psf in psfs: 

791 coadds.append(self._generateCoadd(psf)) 

792 return afwImage.MultibandExposure.fromExposures(bands, coadds) 

793 

794 def _generateCatalog(self, bands, footprints: list[list[tuple[int, int, int]]] | None = None): 

795 # Generate a catalog with a source for each footprint 

796 if footprints is None: 

797 footprints = [] 

798 schema = SourceTable.makeMinimalSchema() 

799 peakSchema = PeakTable.makeMinimalSchema() 

800 for band in bands: 

801 schema.addField(f"merge_footprint_{band}", type="Flag") 

802 peakSchema.addField(f"merge_peak_{band}", type="Flag") 

803 

804 table = SourceTable.make(schema) 

805 catalog = SourceCatalog(table) 

806 

807 for peaks in footprints: 

808 src = catalog.addNew() 

809 footprint = Footprint(SpanSet(), peakSchema) 

810 for peak in peaks: 

811 footprint.addPeak(*peak) 

812 src.setFootprint(footprint) 

813 

814 for band in bands: 

815 src[f"merge_footprint_{band}"] = True 

816 footprint.peaks[f"merge_peak_{band}"] = True 

817 return catalog 

818 

819 

820class MemoryTester(lsst.utils.tests.MemoryTestCase): 

821 pass 

822 

823 

824def setup_module(module): 

825 lsst.utils.tests.init() 

826 

827 

828if __name__ == "__main__": 828 ↛ 829line 828 didn't jump to line 829 because the condition on line 828 was never true

829 lsst.utils.tests.init() 

830 unittest.main()