Coverage for python / lsst / meas / extensions / multiprofit / rebuild_coadd_multiband.py: 0%

178 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-05 08:46 +0000

1# This file is part of meas_extensions_multiprofit. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ["ModelRebuilder", "PatchModelMatches", "PatchCoaddRebuilder"] 

23 

24from functools import cached_property 

25from typing import Iterable 

26 

27import astropy.table 

28import astropy.units as u 

29import lsst.afw.table as afwTable 

30import lsst.daf.butler as dafButler 

31import lsst.gauss2d.fit as g2f 

32import lsst.geom as geom 

33from lsst.meas.extensions.scarlet.io import updateCatalogFootprints 

34from lsst.pipe.base import QuantumContext, QuantumGraph 

35from lsst.pipe.tasks.fit_coadd_multiband import ( 

36 CoaddMultibandFitBaseTemplates, 

37 CoaddMultibandFitInputConnections, 

38 CoaddMultibandFitTask, 

39) 

40from lsst.skymap import BaseSkyMap, TractInfo 

41import numpy as np 

42import pydantic 

43 

44from .fit_coadd_multiband import ( 

45 CatalogExposurePsfs, 

46 CatalogSourceFitterConfigData, 

47 MultiProFitSourceConfig, 

48 MultiProFitSourceTask, 

49) 

50 

51astropy_to_geom_units = { 

52 u.arcmin: geom.arcminutes, 

53 u.arcsec: geom.arcseconds, 

54 u.mas: geom.milliarcseconds, 

55 u.deg: geom.degrees, 

56 u.rad: geom.radians, 

57} 

58 

59 

60def astropy_unit_to_geom(unit: u.Unit, default=None) -> geom.AngleUnit: 

61 """Convert an astropy unit to an lsst.geom unit. 

62 

63 Parameters 

64 ---------- 

65 unit 

66 The astropy unit to convert. 

67 default 

68 The default value to return if no known conversion is found. 

69 

70 Returns 

71 ------- 

72 unit_geom 

73 The equivalent unit, if found. 

74 

75 Raises 

76 ------ 

77 ValueError 

78 Raised if no equivalent unit is found. 

79 """ 

80 unit_geom = astropy_to_geom_units.get(unit, default) 

81 if unit_geom is None: 

82 raise ValueError(f"{unit=} not found in {astropy_to_geom_units=}") 

83 return unit_geom 

84 

85 

86def find_patches(tract_info: TractInfo, ra_array, dec_array, unit: geom.AngleUnit) -> list[int]: 

87 """Find the patches containing a list of ra/dec values within a tract. 

88 

89 Parameters 

90 ---------- 

91 tract_info 

92 The TractInfo object for the tract. 

93 ra_array 

94 The array of right ascension values. 

95 dec_array 

96 The array of declination values (must be same length as ra_array). 

97 unit 

98 The unit of the RA/dec values. 

99 

100 Returns 

101 ------- 

102 patches 

103 A list of patches containing the specified RA/dec values. 

104 """ 

105 radec = [geom.SpherePoint(ra, dec, units=unit) for ra, dec in zip(ra_array, dec_array, strict=True)] 

106 points = np.array([geom.Point2I(tract_info.wcs.skyToPixel(coords)) for coords in radec]) 

107 x_list, y_list = (points[:, idx] // tract_info.patch_inner_dimensions[idx] for idx in range(2)) 

108 patches = [tract_info.getSequentialPatchIndexFromPair((x, y)) for x, y in zip(x_list, y_list)] 

109 return patches 

110 

111 

112def get_radec_unit(table: astropy.table.Table, coord_ra: str, coord_dec: str, default=None): 

113 """Get the RA/dec units for columns in a table. 

114 

115 Parameters 

116 ---------- 

117 table 

118 The table to determine units for. 

119 coord_ra 

120 The key of the right ascension column. 

121 coord_dec 

122 The key of the declination column. 

123 default 

124 The default value to return if no unit is found. 

125 

126 Returns 

127 ------- 

128 unit 

129 The unit of the RA/dec columns or None if none is found. 

130 

131 Raises 

132 ------ 

133 ValueError 

134 Raised if the units are inconsistent. 

135 """ 

136 unit_ra, unit_dec = ( 

137 astropy_unit_to_geom(table[coord].unit, default=default) for coord in (coord_ra, coord_dec) 

138 ) 

139 if unit_ra != unit_dec: 

140 units = {coord: table[coord].unit for coord in (coord_ra, coord_dec)} 

141 raise ValueError(f"Reference table has inconsistent {units=}") 

142 return unit_ra 

143 

144 

145class DataLoader(pydantic.BaseModel): 

146 """A collection of data that can be used to rebuild models.""" 

147 

148 model_config = pydantic.ConfigDict(arbitrary_types_allowed=True, frozen=True) 

149 

150 catexps: list[CatalogExposurePsfs] = pydantic.Field( 

151 doc="List of MultiProFit catalog-exposure-psf objects used to fit PSF-convolved models", 

152 ) 

153 catalog_multi: afwTable.SourceCatalog = pydantic.Field( 

154 doc="Patch-level multiband reference catalog (deepCoadd_ref)", 

155 ) 

156 

157 @cached_property 

158 def channels(self) -> tuple[g2f.Channel]: 

159 channels = tuple(g2f.Channel.get(catexp.band) for catexp in self.catexps) 

160 return channels 

161 

162 @classmethod 

163 def from_butler( 

164 cls, butler: dafButler.Butler, data_id: dict[str], bands: Iterable[str], name_coadd=None, **kwargs 

165 ): 

166 """Construct a DataLoader from a Butler and dataId. 

167 

168 Parameters 

169 ---------- 

170 butler 

171 The butler to load from. 

172 data_id 

173 Key-value pairs for the {name_coadd}Coadd_* dataId. 

174 bands 

175 The list of bands to load. 

176 name_coadd 

177 The prefix of the Coadd datasettype name. 

178 **kwargs 

179 Additional keyword arguments to pass to the init method for 

180 `CoaddMultibandFitInputConnections`. 

181 

182 Returns 

183 ------- 

184 data_loader 

185 An initialized DataLoader. 

186 """ 

187 bands = tuple(bands) 

188 if len(set(bands)) != len(bands): 

189 raise ValueError(f"{bands=} is not a set") 

190 if name_coadd is None: 

191 name_coadd = CoaddMultibandFitBaseTemplates["name_coadd"] 

192 

193 catalog_multi = butler.get( 

194 CoaddMultibandFitInputConnections.cat_ref.name.format(name_coadd=name_coadd), **data_id, **kwargs 

195 ) 

196 

197 catexps = {} 

198 for band in bands: 

199 data_id["band"] = band 

200 catalog = butler.get( 

201 CoaddMultibandFitInputConnections.cats_meas.name.format(name_coadd=name_coadd), 

202 **data_id, 

203 **kwargs, 

204 ) 

205 exposure = butler.get( 

206 CoaddMultibandFitInputConnections.coadds.name.format(name_coadd=name_coadd), 

207 **data_id, 

208 **kwargs, 

209 ) 

210 models_scarlet = butler.get( 

211 CoaddMultibandFitInputConnections.models_scarlet.name.format(name_coadd=name_coadd), 

212 **data_id, 

213 **kwargs, 

214 ) 

215 updateCatalogFootprints( 

216 modelData=models_scarlet, 

217 catalog=catalog, 

218 band=data_id["band"], 

219 imageForRedistribution=exposure, 

220 removeScarletData=True, 

221 updateFluxColumns=False, 

222 ) 

223 # The config and table are harmless dummies 

224 catexps[band] = CatalogExposurePsfs( 

225 catalog=catalog, 

226 exposure=exposure, 

227 table_psf_fits=astropy.table.Table(), 

228 dataId=data_id, 

229 id_tract_patch=data_id["patch"], 

230 channel=g2f.Channel.get(band), 

231 config_fit=MultiProFitSourceConfig(), 

232 ) 

233 return cls( 

234 catalog_multi=catalog_multi, 

235 catexps=list(catexps.values()), 

236 ) 

237 

238 def load_deblended_object( 

239 self, 

240 idx_row: int, 

241 ) -> list[g2f.ObservationD]: 

242 """Load a deblended object from catexps. 

243 

244 Parameters 

245 ---------- 

246 idx_row 

247 The index of the object to load. 

248 

249 Returns 

250 ------- 

251 observations 

252 The observations of the object (deblended if it is a child). 

253 """ 

254 observations = [] 

255 for catexp in self.catexps: 

256 observations.append(catexp.get_source_observation(catexp.get_catalog()[idx_row])) 

257 return observations 

258 

259 

260class ModelRebuilder(DataLoader): 

261 """A rebuilder of MultiProFit models from their inputs and best-fit 

262 parameter values. 

263 """ 

264 

265 fit_results: astropy.table.Table = pydantic.Field(doc="Multiprofit model fit results") 

266 task_fit: MultiProFitSourceTask = pydantic.Field(doc="The task") 

267 

268 @cached_property 

269 def config_data(self) -> CatalogSourceFitterConfigData: 

270 config_data = self.make_config_data() 

271 return config_data 

272 

273 @classmethod 

274 def from_quantumGraph( 

275 cls, 

276 butler: dafButler.Butler, 

277 quantumgraph: QuantumGraph, 

278 dataId: dict = None, 

279 ): 

280 """Make a rebuilder from a butler and quantumgraph. 

281 

282 Parameters 

283 ---------- 

284 butler 

285 The butler that the quantumgraph was built for. 

286 quantumgraph 

287 The quantum graph file from a CoaddMultibandFitTask using the 

288 MultiProFitSourceTask. 

289 dataId 

290 The dataId for the fit, including skymap, tract and patch. 

291 

292 Returns 

293 ------- 

294 rebuilder 

295 A ModelRebuilder instance initialized with the necessary kwargs. 

296 """ 

297 if dataId is None: 

298 quantum = next(iter(quantumgraph.outputQuanta)).quantum 

299 else: 

300 quantum = None 

301 for node in quantumgraph.outputQuanta: 

302 if node.quantum.dataId.to_simple().dataId == dataId: 

303 quantum = node.quantum 

304 break 

305 if quantum is None: 

306 raise ValueError( 

307 f"{dataId=} not found in {[x.quantum.dataId for x in quantumgraph.outputQuanta]=}" 

308 ) 

309 taskDef = next(iter(quantumgraph.iterTaskGraph())) 

310 butlerQC = QuantumContext(butler, quantum) 

311 config = butler.get(f"{taskDef.label}_config") 

312 # I have no idea what to put for initInputs. 

313 # quantum.initInputs looks wrong - the values can be lists 

314 # quantumgraph.initInputRefs(taskDef) returns a list of DatasetRefs... 

315 # ... but I'm not sure how to map that to connection names? 

316 task: CoaddMultibandFitTask = taskDef.taskClass(config=config, initInputs={}) 

317 if not isinstance(task, CoaddMultibandFitTask): 

318 raise ValueError(f"{task=} type={type(task)} !isinstance of {CoaddMultibandFitTask=}") 

319 task_fit: MultiProFitSourceTask = task.fit_coadd_multiband 

320 if not isinstance(task_fit, MultiProFitSourceTask): 

321 raise ValueError(f"{task_fit=} type={type(task_fit)} !isinstance of {MultiProFitSourceTask=}") 

322 inputRefs, outputRefs = taskDef.connections.buildDatasetRefs(quantum) 

323 inputs = butlerQC.get(inputRefs) 

324 catexps = task.build_catexps(butlerQC, inputRefs, inputs) 

325 catexps = [task_fit.make_CatalogExposurePsfs(catexp) for catexp in catexps] 

326 cat_output: astropy.table.Table = butler.get(outputRefs.cat_output, storageClass="ArrowAstropy") 

327 return cls( 

328 catexps=catexps, 

329 task_fit=task_fit, 

330 catalog_multi=inputs["cat_ref"], 

331 fit_results=cat_output, 

332 ) 

333 

334 def make_config_data(self): 

335 """Make a ConfigData object out of self's channels and fit task 

336 config. 

337 """ 

338 config_data = CatalogSourceFitterConfigData(channels=self.channels, config=self.task_fit.config) 

339 return config_data 

340 

341 def make_model( 

342 self, 

343 idx_row: int, 

344 config_data: CatalogSourceFitterConfigData = None, 

345 init: bool = True, 

346 ) -> g2f.ModelD: 

347 """Make a ModelD for a single row from the originally fitted catalog. 

348 

349 Parameters 

350 ---------- 

351 idx_row 

352 The index of the row to make a model for. 

353 config_data 

354 The model configuration data object. 

355 init 

356 Whether to initialize the model parameters as they would have been 

357 prior to fitting. 

358 

359 Returns 

360 ------- 

361 model 

362 The rebuilt model. 

363 """ 

364 if config_data is None: 

365 config_data = self.config_data 

366 model = self.task_fit.get_model( 

367 idx_row=idx_row, 

368 catalog_multi=self.catalog_multi, 

369 catexps=self.catexps, 

370 config_data=config_data, 

371 results=self.fit_results, 

372 set_flux_limits=False, 

373 ) 

374 if init: 

375 self.set_model(idx_row, config_data) 

376 return model 

377 

378 def set_model(self, idx_row: int, config_data: CatalogSourceFitterConfigData = None) -> None: 

379 """Set model parameters to the best-fit values for a given row. 

380 

381 Parameters 

382 ---------- 

383 idx_row 

384 The index of the row in the fit parameter table to initialize from. 

385 config_data 

386 The model configuration data object. 

387 """ 

388 if config_data is None: 

389 config_data = self.config_data 

390 row = self.fit_results[idx_row] 

391 prefix = config_data.config.prefix_column 

392 offsets = {} 

393 offset_cen = config_data.config.centroid_pixel_offset 

394 if offset_cen != 0: 

395 offsets[g2f.CentroidXParameterD] = -offset_cen 

396 offsets[g2f.CentroidYParameterD] = -offset_cen 

397 for key, param in config_data.parameters.items(): 

398 param.value = row[f"{prefix}{key}"] + offsets.get(type(param), 0.0) 

399 

400 

401class PatchModelMatches(pydantic.BaseModel): 

402 """Storage for MultiProFit tables matched to a reference catalog.""" 

403 

404 model_config = pydantic.ConfigDict(arbitrary_types_allowed=True, frozen=True) 

405 

406 matches: astropy.table.Table | None = pydantic.Field(doc="Catalogs of matches") 

407 quantumgraph: QuantumGraph | None = pydantic.Field(doc="Quantum graph for fit task") 

408 rebuilder: DataLoader | ModelRebuilder | None = pydantic.Field(doc="MultiProFit object model rebuilder") 

409 

410 

411class PatchCoaddRebuilder(pydantic.BaseModel): 

412 """A rebuilder for patch-level coadd catalog/exposure fits.""" 

413 

414 model_config = pydantic.ConfigDict(arbitrary_types_allowed=True, frozen=True) 

415 

416 matches: dict[str, PatchModelMatches] = pydantic.Field("Model matches by algorithm name") 

417 name_model_ref: str = pydantic.Field(doc="The name of the reference model in matches") 

418 objects: astropy.table.Table = pydantic.Field(doc="Object table") 

419 objects_multiprofit: astropy.table.Table | None = pydantic.Field(doc="Object table for MultiProFit fits") 

420 reference: astropy.table.Table = pydantic.Field(doc="Reference object table") 

421 

422 skymap: str = pydantic.Field(doc="The skymap name") 

423 tract: int = pydantic.Field(doc="The tract index") 

424 patch: int = pydantic.Field(doc="The patch index") 

425 

426 @classmethod 

427 def from_butler( 

428 cls, 

429 butler: dafButler.Butler, 

430 skymap: str, 

431 tract: int, 

432 patch: int, 

433 collection_merged: str, 

434 matches: dict[str, QuantumGraph | None], 

435 bands: Iterable[str] = None, 

436 name_model_ref: str = None, 

437 format_collection: str = "{run}", 

438 load_multiprofit: bool = True, 

439 dataset_type_ref: str = "truth_summary", 

440 ): 

441 """Construct a PatchCoaddRebuilder from a single Butler collection. 

442 

443 Parameters 

444 ---------- 

445 butler 

446 The butler to load from. 

447 skymap 

448 The skymap for the collection. 

449 tract 

450 The skymap tract id. 

451 patch 

452 The skymap patch id. 

453 collection_merged 

454 The name of the collection with the merged objectTable(s). 

455 matches 

456 A dictionary of model names with corresponding QuantumGraphs. 

457 These may be None but must be provided for MultiProFit model 

458 reconstruction to be possible. 

459 bands 

460 The list of bands to load data for. 

461 name_model_ref 

462 The name of the model to use as a reference. Must be a key in 

463 `matches`. 

464 format_collection 

465 A format string for the output collection(s) defined in the 

466 `matches` QuantumGraphs. 

467 load_multiprofit 

468 Whether to attempt to load an objectTable_tract_multiprofit. 

469 dataset_type_ref 

470 The dataset type of the reference catalog. 

471 

472 Returns 

473 ------- 

474 rebuilder 

475 The fully-configured PatchCoaddRebuilder. 

476 """ 

477 if name_model_ref is None: 

478 for name, quantumgraph in matches.items(): 

479 if quantumgraph is not None: 

480 name_model_ref = name 

481 break 

482 if name_model_ref is None: 

483 raise ValueError("Must supply name_model_ref or at least one matches with a quantumgraph") 

484 dataId = dict(skymap=skymap, tract=tract, patch=patch) 

485 objects = butler.get( 

486 "objectTable_tract", collections=[collection_merged], storageClass="ArrowAstropy", **dataId 

487 ) 

488 objects = objects[objects["patch"] == patch] 

489 if load_multiprofit: 

490 objects_multiprofit = butler.get( 

491 "objectTable_tract_multiprofit", 

492 collections=[collection_merged], 

493 storageClass="ArrowAstropy", 

494 **dataId, 

495 ) 

496 objects_multiprofit = objects_multiprofit[objects_multiprofit["patch"] == patch] 

497 else: 

498 objects_multiprofit = None 

499 reference = butler.get( 

500 dataset_type_ref, collections=[collection_merged], storageClass="ArrowAstropy", **dataId 

501 ) 

502 skymap_tract = butler.get(BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, skymap=skymap)[tract] 

503 unit_coord_ref = get_radec_unit(reference, "ra", "dec", default=geom.degrees) 

504 if "patch" not in reference.columns: 

505 patches = find_patches(skymap_tract, reference["ra"], reference["dec"], unit=unit_coord_ref) 

506 reference["patch"] = patches 

507 elif reference["patch"].dtype != int: 

508 # the ci_imsim truth_summary still has string patches 

509 index_patch = skymap_tract[patch].index 

510 str_patch = f"{index_patch.y},{index_patch.x}" 

511 reference = reference[ 

512 (reference["patch"] == str_patch) & (reference["is_unique_truth_entry"] == True) # noqa: E712 

513 ] 

514 del reference["patch"] 

515 reference["patch"] = patch 

516 reference = reference[reference["patch"] == patch] 

517 points = skymap_tract.wcs.skyToPixel( 

518 [geom.SpherePoint(row["ra"], row["dec"], units=geom.degrees) for row in reference] 

519 ) 

520 reference["x"] = [point.x for point in points] 

521 reference["y"] = [point.y for point in points] 

522 matches_name = {} 

523 for name, quantumgraph in matches.items(): 

524 is_mpf = quantumgraph is not None 

525 matched = butler.get( 

526 f"matched_{dataset_type_ref}_objectTable_tract{'_multiprofit' if is_mpf else ''}", 

527 collections=[ 

528 ( 

529 format_collection.format(run=quantumgraph.metadata["output"], name=name) 

530 if is_mpf 

531 else collection_merged 

532 ) 

533 ], 

534 storageClass="ArrowAstropy", 

535 **dataId, 

536 ) 

537 # unmatched ref objects don't have a patch set 

538 # should probably be fixed in diff_matched 

539 # but need to decide priority on matched - ref first? or target? 

540 unit_coord_ref = get_radec_unit( 

541 matched, 

542 "refcat_ra", 

543 "refcat_dec", 

544 default=geom.degrees, 

545 ) 

546 unmatched = ( 

547 matched["patch"].mask if np.ma.is_masked(matched["patch"]) else ~(matched["patch"] >= 0) 

548 ) & np.isfinite(matched["refcat_ra"]) 

549 patches_unmatched = find_patches( 

550 skymap_tract, 

551 matched["refcat_ra"][unmatched], 

552 matched["refcat_dec"][unmatched], 

553 unit=unit_coord_ref, 

554 ) 

555 matched["patch"][np.where(unmatched)[0]] = patches_unmatched 

556 matched = matched[matched["patch"] == patch] 

557 rebuilder = ( 

558 ModelRebuilder.from_quantumGraph(butler, quantumgraph, dataId=dataId) 

559 if is_mpf 

560 else DataLoader.from_butler( 

561 butler, data_id=dataId, bands=bands, collections=[collection_merged] 

562 ) 

563 ) 

564 matches_name[name] = PatchModelMatches( 

565 matches=matched, quantumgraph=quantumgraph, rebuilder=rebuilder 

566 ) 

567 return cls( 

568 matches=matches_name, 

569 objects=objects, 

570 objects_multiprofit=objects_multiprofit, 

571 reference=reference, 

572 skymap=skymap, 

573 tract=tract, 

574 patch=patch, 

575 name_model_ref=name_model_ref, 

576 )