Coverage for python / lsst / pipe / tasks / fit_coadd_multiband.py: 29%

152 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-05 08:30 +0000

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = [ 

23 "CoaddMultibandFitConfig", "CoaddMultibandFitConnections", "CoaddMultibandFitSubConfig", 

24 "CoaddMultibandFitSubTask", "CoaddMultibandFitTask", 

25] 

26 

27from .fit_multiband import CatalogExposure, CatalogExposureConfig 

28 

29import lsst.afw.table as afwTable 

30from lsst.meas.base import SkyMapIdGeneratorConfig 

31from lsst.meas.extensions.scarlet.io import updateCatalogFootprints 

32import lsst.pex.config as pexConfig 

33import lsst.pipe.base as pipeBase 

34import lsst.pipe.base.connectionTypes as cT 

35 

36import astropy.table 

37from abc import ABC, abstractmethod 

38from pydantic import Field 

39from pydantic.dataclasses import dataclass 

40from typing import Iterable 

41 

42CoaddMultibandFitBaseTemplates = { 

43 "name_coadd": "deep", 

44 "name_method": "multiprofit", 

45 "name_table": "objects", 

46} 

47 

48 

49@dataclass(frozen=True, kw_only=True, config=CatalogExposureConfig) 

50class CatalogExposureInputs(CatalogExposure): 

51 table_psf_fits: astropy.table.Table = Field(title="A table of PSF fit parameters for each source") 

52 

53 def get_catalog(self): 

54 return self.catalog 

55 

56 

57class CoaddMultibandFitInputConnections( 

58 pipeBase.PipelineTaskConnections, 

59 dimensions=("tract", "patch", "skymap"), 

60 defaultTemplates=CoaddMultibandFitBaseTemplates, 

61): 

62 cat_ref = cT.Input( 

63 doc="Reference multiband source catalog", 

64 name="{name_coadd}Coadd_ref", 

65 storageClass="SourceCatalog", 

66 dimensions=("tract", "patch", "skymap"), 

67 ) 

68 cats_meas = cT.Input( 

69 doc="Deblended single-band source catalogs", 

70 name="{name_coadd}Coadd_meas", 

71 storageClass="SourceCatalog", 

72 dimensions=("tract", "patch", "band", "skymap"), 

73 multiple=True, 

74 ) 

75 coadds = cT.Input( 

76 doc="Exposures on which to run fits", 

77 name="{name_coadd}Coadd_calexp", 

78 storageClass="ExposureF", 

79 dimensions=("tract", "patch", "band", "skymap"), 

80 multiple=True, 

81 ) 

82 coadds_cell = cT.Input( 

83 doc="Cell-coadd exposures on which to run fits", 

84 name="{name_coadd}CoaddCell", 

85 storageClass="MultipleCellCoadd", 

86 dimensions=("tract", "patch", "band", "skymap"), 

87 multiple=True, 

88 ) 

89 backgrounds = cT.Input( 

90 doc="Background models to subtract from the coadds_cell", 

91 name="{name_coadd}Coadd_calexp_background", 

92 storageClass="Background", 

93 dimensions=("tract", "patch", "band", "skymap"), 

94 multiple=True, 

95 ) 

96 models_psf = cT.Input( 

97 doc="Input PSF model parameter catalog", 

98 # Consider allowing independent psf fit method 

99 name="{name_coadd}Coadd_psfs_{name_method}", 

100 storageClass="ArrowAstropy", 

101 dimensions=("tract", "patch", "band", "skymap"), 

102 multiple=True, 

103 ) 

104 models_scarlet = pipeBase.connectionTypes.Input( 

105 doc="Multiband scarlet models produced by the deblender", 

106 name="{name_coadd}Coadd_scarletModelData", 

107 storageClass="LsstScarletModelData", 

108 dimensions=("tract", "patch", "skymap"), 

109 ) 

110 

111 def adjustQuantum(self, inputs, outputs, label, data_id): 

112 """Validates the `lsst.daf.butler.DatasetRef` bands against the 

113 subtask's list of bands to fit and drops unnecessary bands. 

114 

115 Parameters 

116 ---------- 

117 inputs : `dict` 

118 Dictionary whose keys are an input (regular or prerequisite) 

119 connection name and whose values are a tuple of the connection 

120 instance and a collection of associated `DatasetRef` objects. 

121 The exact type of the nested collections is unspecified; it can be 

122 assumed to be multi-pass iterable and support `len` and ``in``, but 

123 it should not be mutated in place. In contrast, the outer 

124 dictionaries are guaranteed to be temporary copies that are true 

125 `dict` instances, and hence may be modified and even returned; this 

126 is especially useful for delegating to `super` (see notes below). 

127 outputs : `Mapping` 

128 Mapping of output datasets, with the same structure as ``inputs``. 

129 label : `str` 

130 Label for this task in the pipeline (should be used in all 

131 diagnostic messages). 

132 data_id : `lsst.daf.butler.DataCoordinate` 

133 Data ID for this quantum in the pipeline (should be used in all 

134 diagnostic messages). 

135 

136 Returns 

137 ------- 

138 adjusted_inputs : `Mapping` 

139 Mapping of the same form as ``inputs`` with updated containers of 

140 input `DatasetRef` objects. All inputs involving the 'band' 

141 dimension are adjusted to put them in consistent order and remove 

142 unneeded bands. 

143 adjusted_outputs : `Mapping` 

144 Mapping of updated output datasets; always empty for this task. 

145 

146 Raises 

147 ------ 

148 lsst.pipe.base.NoWorkFound 

149 Raised if there are not enough of the right bands to run the task 

150 on this quantum. 

151 """ 

152 # Check which bands are going to be fit 

153 bands_fit, bands_read_only = self.config.get_band_sets() 

154 bands_needed = bands_fit + [band for band in bands_read_only if band not in bands_fit] 

155 bands_needed_set = set(bands_needed) 

156 

157 adjusted_inputs = {} 

158 inputs_to_adjust = {} 

159 bands_found = bands_needed_set 

160 for connection_name, (connection, dataset_refs) in inputs.items(): 

161 # Datasets without bands in their dimensions should be fine 

162 if 'band' in connection.dimensions: 

163 datasets_by_band = {dref.dataId['band']: dref for dref in dataset_refs} 

164 bands_set = set(datasets_by_band.keys()) 

165 if self.config.allow_missing_bands: 

166 if len(bands_found) == 0: 

167 raise pipeBase.NoWorkFound( 

168 f'DatasetRefs={dataset_refs} for {connection_name=} is empty' 

169 ) 

170 bands_found &= bands_set 

171 # All configured bands are treated as necessary 

172 elif not bands_needed_set.issubset(bands_set): 

173 raise pipeBase.NoWorkFound( 

174 f'DatasetRefs={dataset_refs} have data with bands in the' 

175 f' set={set(datasets_by_band.keys())},' 

176 f' which is not a superset of the required bands={bands_needed} defined by' 

177 f' {self.config.__class__}.fit_coadd_multiband=' 

178 f'{self.config.fit_coadd_multiband._value.__class__}\'s attributes' 

179 f' bands_fit={bands_fit} and bands_read_only()={bands_read_only}.' 

180 f' Add the required bands={set(bands_needed).difference(datasets_by_band.keys())}.' 

181 ) 

182 # Adjust all datasets with band dimensions to include just 

183 # the needed bands, in consistent order. 

184 inputs_to_adjust[connection_name] = (connection, datasets_by_band) 

185 

186 if self.config.allow_missing_bands: 

187 bands_needed = [band for band in bands_fit if band in bands_found] + [ 

188 band for band in bands_read_only if band not in bands_found 

189 ] 

190 if len(bands_needed) == 0: 

191 raise pipeBase.NoWorkFound( 

192 f'No common bands remaining for inputs {",".join(inputs_to_adjust.keys())}' 

193 ) 

194 for connection_name, (connection, datasets_by_band) in inputs_to_adjust.items(): 

195 adjusted_inputs[connection_name] = ( 

196 connection, 

197 [datasets_by_band[band] for band in bands_needed] 

198 ) 

199 

200 # Delegate to super for more checks. 

201 inputs.update(adjusted_inputs) 

202 super().adjustQuantum(inputs, outputs, label, data_id) 

203 return adjusted_inputs, {} 

204 

205 def __init__(self, *, config=None): 

206 super().__init__(config=config) 

207 assert isinstance(config, CoaddMultibandFitBaseConfig) 

208 

209 if config.drop_psf_connection: 

210 del self.models_psf 

211 

212 if config.use_cell_coadds: 

213 del self.coadds 

214 else: 

215 del self.coadds_cell 

216 del self.backgrounds 

217 

218 

219class CoaddMultibandFitConnections(CoaddMultibandFitInputConnections): 

220 cat_output = cT.Output( 

221 doc="Output source model fit parameter catalog", 

222 name="{name_coadd}Coadd_{name_table}_{name_method}", 

223 storageClass="ArrowTable", 

224 dimensions=("tract", "patch", "skymap"), 

225 ) 

226 

227 

228class CoaddMultibandFitSubConfig(pexConfig.Config): 

229 """Configuration for implementing fitter subtasks. 

230 """ 

231 

232 bands_fit = pexConfig.ListField[str]( 

233 default=[], 

234 doc="list of bandpass filters to fit", 

235 listCheck=lambda x: (len(x) > 0) and (len(set(x)) == len(x)), 

236 ) 

237 

238 @abstractmethod 

239 def bands_read_only(self) -> set: 

240 """Return the set of bands that the Task needs to read (e.g. for 

241 defining priors) but not necessarily fit. 

242 

243 Returns 

244 ------- 

245 The set of such bands. 

246 """ 

247 return set() 

248 

249 

250class CoaddMultibandFitSubTask(pipeBase.Task, ABC): 

251 """Subtask interface for multiband fitting of deblended sources. 

252 

253 Parameters 

254 ---------- 

255 **kwargs 

256 Additional arguments to be passed to the `lsst.pipe.base.Task` 

257 constructor. 

258 """ 

259 ConfigClass = CoaddMultibandFitSubConfig 

260 

261 def __init__(self, **kwargs): 

262 super().__init__(**kwargs) 

263 

264 @abstractmethod 

265 def run( 

266 self, catexps: Iterable[CatalogExposureInputs], cat_ref: afwTable.SourceCatalog 

267 ) -> pipeBase.Struct: 

268 """Fit models to deblended sources from multi-band inputs. 

269 

270 Parameters 

271 ---------- 

272 catexps : `typing.List [CatalogExposureInputs]` 

273 A list of catalog-exposure pairs with metadata in a given band. 

274 cat_ref : `lsst.afw.table.SourceCatalog` 

275 A reference source catalog to fit. 

276 

277 Returns 

278 ------- 

279 retStruct : `lsst.pipe.base.Struct` 

280 A struct with a cat_output attribute containing the output 

281 measurement catalog. 

282 

283 Notes 

284 ----- 

285 Subclasses may have further requirements on the input parameters, 

286 including: 

287 - Passing only one catexp per band; 

288 - Catalogs containing HeavyFootprints with deblended images; 

289 - Fitting only a subset of the sources. 

290 If any requirements are not met, the subtask should fail as soon as 

291 possible. 

292 """ 

293 

294 

295class CoaddMultibandFitBaseConfig( 

296 pipeBase.PipelineTaskConfig, 

297 pipelineConnections=CoaddMultibandFitInputConnections, 

298): 

299 """Base class for multiband fitting.""" 

300 

301 allow_missing_bands = pexConfig.Field[bool]( 

302 doc="Whether to still fit even if some bands are missing", 

303 default=True, 

304 ) 

305 drop_psf_connection = pexConfig.Field[bool]( 

306 doc="Whether to drop the PSF model connection, e.g. because PSF parameters are in the input catalog", 

307 default=False, 

308 ) 

309 fit_coadd_multiband = pexConfig.ConfigurableField( 

310 target=CoaddMultibandFitSubTask, 

311 doc="Task to fit sources using multiple bands", 

312 ) 

313 use_cell_coadds = pexConfig.Field[bool]( 

314 doc="Use cell coadd images for object fitting?", 

315 default=False, 

316 ) 

317 idGenerator = SkyMapIdGeneratorConfig.make_field() 

318 

319 def get_band_sets(self): 

320 """Get the set of bands required by the fit_coadd_multiband subtask. 

321 

322 Returns 

323 ------- 

324 bands_fit : `set` 

325 The set of bands that the subtask will fit. 

326 bands_read_only : `set` 

327 The set of bands that the subtask will only read data 

328 (measurement catalog and exposure) for. 

329 """ 

330 try: 

331 bands_fit = self.fit_coadd_multiband.bands_fit 

332 except AttributeError: 

333 raise RuntimeError(f'{__class__}.fit_coadd_multiband must have bands_fit attribute') from None 

334 bands_read_only = self.fit_coadd_multiband.bands_read_only() 

335 return tuple(list({band: None for band in bands}.keys()) for bands in (bands_fit, bands_read_only)) 

336 

337 

338class CoaddMultibandFitConfig( 

339 CoaddMultibandFitBaseConfig, 

340 pipelineConnections=CoaddMultibandFitConnections, 

341): 

342 """Configuration for a CoaddMultibandFitTask.""" 

343 

344 

345class CoaddMultibandFitBase: 

346 """Base class for tasks that fit or rebuild multiband models. 

347 

348 This class only implements data reconstruction. 

349 """ 

350 

351 def build_catexps(self, butlerQC, inputRefs, inputs) -> list[CatalogExposureInputs]: 

352 id_tp = self.config.idGenerator.apply(butlerQC.quantum.dataId).catalog_id 

353 # This is a roundabout way of ensuring all inputs get sorted and matched 

354 if self.config.use_cell_coadds: 

355 keys = ["cats_meas", "coadds_cell", "backgrounds"] 

356 else: 

357 keys = ["cats_meas", "coadds"] 

358 has_psf_models = "models_psf" in inputs 

359 if has_psf_models: 

360 keys.append("models_psf") 

361 input_refs_objs = {key: (getattr(inputRefs, key), inputs[key]) for key in keys} 

362 inputs_sorted = { 

363 key: {dRef.dataId: obj for dRef, obj in zip(refs, objs, strict=True)} 

364 for key, (refs, objs) in input_refs_objs.items() 

365 } 

366 cats = inputs_sorted["cats_meas"] 

367 if self.config.use_cell_coadds: 

368 exps = {} 

369 for data_id, background in inputs_sorted["backgrounds"].items(): 

370 mcc = inputs_sorted["coadds_cell"][data_id] 

371 stitched_coadd = mcc.stitch() 

372 exposure = stitched_coadd.asExposure() 

373 exposure.image -= background.getImage() 

374 exps[data_id] = exposure 

375 else: 

376 exps = inputs_sorted["coadds"] 

377 models_psf = inputs_sorted["models_psf"] if has_psf_models else None 

378 dataIds = set(cats).union(set(exps)) 

379 models_scarlet = inputs["models_scarlet"] 

380 catexp_dict = {} 

381 dataId = None 

382 for dataId in dataIds: 

383 catalog = cats[dataId] 

384 exposure = exps[dataId] 

385 updateCatalogFootprints( 

386 modelData=models_scarlet, 

387 catalog=catalog, 

388 band=dataId['band'], 

389 imageForRedistribution=exposure, 

390 removeScarletData=False, 

391 updateFluxColumns=False, 

392 ) 

393 catexp_dict[dataId['band']] = CatalogExposureInputs( 

394 catalog=catalog, 

395 exposure=exposure, 

396 table_psf_fits=models_psf[dataId] if has_psf_models else astropy.table.Table(), 

397 dataId=dataId, 

398 id_tract_patch=id_tp, 

399 ) 

400 # This shouldn't happen unless this is called with no inputs, but check anyway 

401 if dataId is None: 

402 raise RuntimeError(f"Did not build any catexps for {inputRefs=}") 

403 catexps = [] 

404 for band in self.config.get_band_sets()[0]: 

405 if band in catexp_dict: 

406 catexp = catexp_dict[band] 

407 else: 

408 # Make a dummy catexp with a dataId if there's no data 

409 # This should be handled by any subtasks 

410 dataId_band = dataId.to_simple(minimal=True) 

411 dataId_band.dataId["band"] = band 

412 catexp = CatalogExposureInputs( 

413 catalog=afwTable.SourceCatalog(), 

414 exposure=None, 

415 table_psf_fits=astropy.table.Table(), 

416 dataId=dataId.from_simple(dataId_band, universe=dataId.universe), 

417 id_tract_patch=id_tp, 

418 ) 

419 catexps.append(catexp) 

420 return catexps 

421 

422 

423class CoaddMultibandFitTask(CoaddMultibandFitBase, pipeBase.PipelineTask): 

424 """Fit deblended exposures in multiple bands simultaneously. 

425 

426 It is generally assumed but not enforced (except optionally by the 

427 configurable `fit_coadd_multiband` subtask) that there is only one exposure 

428 per band, presumably a coadd. 

429 """ 

430 

431 ConfigClass = CoaddMultibandFitConfig 

432 _DefaultName = "coaddMultibandFit" 

433 

434 def __init__(self, initInputs, **kwargs): 

435 super().__init__(initInputs=initInputs, **kwargs) 

436 self.makeSubtask("fit_coadd_multiband") 

437 

438 def make_kwargs(self, butlerQC, inputRefs, inputs): 

439 """Make any kwargs needed to be passed to run. 

440 

441 This method should be overloaded by subclasses that are configured to 

442 use a specific subtask that needs additional arguments derived from 

443 the inputs but do not otherwise need to overload runQuantum.""" 

444 return {} 

445 

446 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

447 inputs = butlerQC.get(inputRefs) 

448 catexps = self.build_catexps(butlerQC, inputRefs, inputs) 

449 if not self.config.allow_missing_bands and any([catexp is None for catexp in catexps]): 

450 raise RuntimeError( 

451 f"Got a None catexp with {self.config.allow_missing_band=}; NoWorkFound should have been" 

452 f" raised earlier" 

453 ) 

454 kwargs = self.make_kwargs(butlerQC, inputRefs, inputs) 

455 outputs = self.run(catexps=catexps, cat_ref=inputs['cat_ref'], **kwargs) 

456 butlerQC.put(outputs, outputRefs) 

457 

458 def run( 

459 self, 

460 catexps: list[CatalogExposure], 

461 cat_ref: afwTable.SourceCatalog, 

462 **kwargs 

463 ) -> pipeBase.Struct: 

464 """Fit sources from a reference catalog using data from multiple 

465 exposures in the same region (patch). 

466 

467 Parameters 

468 ---------- 

469 catexps : `typing.List [CatalogExposure]` 

470 A list of catalog-exposure pairs in a given band. 

471 cat_ref : `lsst.afw.table.SourceCatalog` 

472 A reference source catalog to fit. 

473 

474 Returns 

475 ------- 

476 retStruct : `lsst.pipe.base.Struct` 

477 A struct with a cat_output attribute containing the output 

478 measurement catalog. 

479 

480 Notes 

481 ----- 

482 Subtasks may have further requirements; see `CoaddMultibandFitSubTask.run`. 

483 """ 

484 cat_output = self.fit_coadd_multiband.run(catalog_multi=cat_ref, catexps=catexps, **kwargs).output 

485 retStruct = pipeBase.Struct(cat_output=cat_output) 

486 return retStruct