Coverage for python/lsst/pipe/tasks/fit_coadd_multiband.py: 45%

84 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-02-14 10:01 +0000

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = [ 

23 "CoaddMultibandFitConfig", "CoaddMultibandFitSubConfig", "CoaddMultibandFitSubTask", 

24 "CoaddMultibandFitTask", 

25] 

26 

27from .fit_multiband import CatalogExposure, CatalogExposureConfig 

28 

29import lsst.afw.table as afwTable 

30from lsst.meas.base import SkyMapIdGeneratorConfig 

31from lsst.meas.extensions.scarlet.io import updateCatalogFootprints 

32import lsst.pex.config as pexConfig 

33import lsst.pipe.base as pipeBase 

34import lsst.pipe.base.connectionTypes as cT 

35 

36import astropy 

37from abc import ABC, abstractmethod 

38from pydantic import Field 

39from pydantic.dataclasses import dataclass 

40from typing import Iterable 

41 

42CoaddMultibandFitBaseTemplates = { 

43 "name_coadd": "deep", 

44 "name_method": "multiprofit", 

45} 

46 

47 

48@dataclass(frozen=True, kw_only=True, config=CatalogExposureConfig) 

49class CatalogExposureInputs(CatalogExposure): 

50 table_psf_fits: astropy.table.Table = Field(title="A table of PSF fit parameters for each source") 

51 

52 def get_catalog(self): 

53 return self.catalog 

54 

55 

56class CoaddMultibandFitConnections( 

57 pipeBase.PipelineTaskConnections, 

58 dimensions=("tract", "patch", "skymap"), 

59 defaultTemplates=CoaddMultibandFitBaseTemplates, 

60): 

61 cat_ref = cT.Input( 

62 doc="Reference multiband source catalog", 

63 name="{name_coadd}Coadd_ref", 

64 storageClass="SourceCatalog", 

65 dimensions=("tract", "patch", "skymap"), 

66 ) 

67 cats_meas = cT.Input( 

68 doc="Deblended single-band source catalogs", 

69 name="{name_coadd}Coadd_meas", 

70 storageClass="SourceCatalog", 

71 dimensions=("tract", "patch", "band", "skymap"), 

72 multiple=True, 

73 ) 

74 coadds = cT.Input( 

75 doc="Exposures on which to run fits", 

76 name="{name_coadd}Coadd_calexp", 

77 storageClass="ExposureF", 

78 dimensions=("tract", "patch", "band", "skymap"), 

79 multiple=True, 

80 ) 

81 models_psf = cT.Input( 

82 doc="Input PSF model parameter catalog", 

83 # Consider allowing independent psf fit method 

84 name="{name_coadd}Coadd_psfs_{name_method}", 

85 storageClass="ArrowAstropy", 

86 dimensions=("tract", "patch", "band", "skymap"), 

87 multiple=True, 

88 ) 

89 models_scarlet = pipeBase.connectionTypes.Input( 

90 doc="Multiband scarlet models produced by the deblender", 

91 name="{name_coadd}Coadd_scarletModelData", 

92 storageClass="ScarletModelData", 

93 dimensions=("tract", "patch", "skymap"), 

94 ) 

95 cat_output = cT.Output( 

96 doc="Output source model fit parameter catalog", 

97 name="{name_coadd}Coadd_objects_{name_method}", 

98 storageClass="ArrowTable", 

99 dimensions=("tract", "patch", "skymap"), 

100 ) 

101 

102 def adjustQuantum(self, inputs, outputs, label, data_id): 

103 """Validates the `lsst.daf.butler.DatasetRef` bands against the 

104 subtask's list of bands to fit and drops unnecessary bands. 

105 

106 Parameters 

107 ---------- 

108 inputs : `dict` 

109 Dictionary whose keys are an input (regular or prerequisite) 

110 connection name and whose values are a tuple of the connection 

111 instance and a collection of associated `DatasetRef` objects. 

112 The exact type of the nested collections is unspecified; it can be 

113 assumed to be multi-pass iterable and support `len` and ``in``, but 

114 it should not be mutated in place. In contrast, the outer 

115 dictionaries are guaranteed to be temporary copies that are true 

116 `dict` instances, and hence may be modified and even returned; this 

117 is especially useful for delegating to `super` (see notes below). 

118 outputs : `Mapping` 

119 Mapping of output datasets, with the same structure as ``inputs``. 

120 label : `str` 

121 Label for this task in the pipeline (should be used in all 

122 diagnostic messages). 

123 data_id : `lsst.daf.butler.DataCoordinate` 

124 Data ID for this quantum in the pipeline (should be used in all 

125 diagnostic messages). 

126 

127 Returns 

128 ------- 

129 adjusted_inputs : `Mapping` 

130 Mapping of the same form as ``inputs`` with updated containers of 

131 input `DatasetRef` objects. All inputs involving the 'band' 

132 dimension are adjusted to put them in consistent order and remove 

133 unneeded bands. 

134 adjusted_outputs : `Mapping` 

135 Mapping of updated output datasets; always empty for this task. 

136 

137 Raises 

138 ------ 

139 lsst.pipe.base.NoWorkFound 

140 Raised if there are not enough of the right bands to run the task 

141 on this quantum. 

142 """ 

143 # Check which bands are going to be fit 

144 bands_fit, bands_read_only = self.config.get_band_sets() 

145 bands_needed = bands_fit + [band for band in bands_read_only if band not in bands_fit] 

146 

147 adjusted_inputs = {} 

148 for connection_name, (connection, dataset_refs) in inputs.items(): 

149 # Datasets without bands in their dimensions should be fine 

150 if 'band' in connection.dimensions: 

151 datasets_by_band = {dref.dataId['band']: dref for dref in dataset_refs} 

152 if not set(bands_needed).issubset(datasets_by_band.keys()): 

153 raise pipeBase.NoWorkFound( 

154 f'DatasetRefs={dataset_refs} have data with bands in the' 

155 f' set={set(datasets_by_band.keys())},' 

156 f' which is not a superset of the required bands={bands_needed} defined by' 

157 f' {self.config.__class__}.fit_coadd_multiband=' 

158 f'{self.config.fit_coadd_multiband._value.__class__}\'s attributes' 

159 f' bands_fit={bands_fit} and bands_read_only()={bands_read_only}.' 

160 f' Add the required bands={set(bands_needed).difference(datasets_by_band.keys())}.' 

161 ) 

162 # Adjust all datasets with band dimensions to include just 

163 # the needed bands, in consistent order. 

164 adjusted_inputs[connection_name] = ( 

165 connection, 

166 [datasets_by_band[band] for band in bands_needed] 

167 ) 

168 

169 # Delegate to super for more checks. 

170 inputs.update(adjusted_inputs) 

171 super().adjustQuantum(inputs, outputs, label, data_id) 

172 return adjusted_inputs, {} 

173 

174 

175class CoaddMultibandFitSubConfig(pexConfig.Config): 

176 """Configuration for implementing fitter subtasks. 

177 """ 

178 @abstractmethod 

179 def bands_read_only(self) -> set: 

180 """Return the set of bands that the Task needs to read (e.g. for 

181 defining priors) but not necessarily fit. 

182 

183 Returns 

184 ------- 

185 The set of such bands. 

186 """ 

187 

188 

189class CoaddMultibandFitSubTask(pipeBase.Task, ABC): 

190 """Subtask interface for multiband fitting of deblended sources. 

191 

192 Parameters 

193 ---------- 

194 **kwargs 

195 Additional arguments to be passed to the `lsst.pipe.base.Task` 

196 constructor. 

197 """ 

198 ConfigClass = CoaddMultibandFitSubConfig 

199 

200 def __init__(self, **kwargs): 

201 super().__init__(**kwargs) 

202 

203 @abstractmethod 

204 def run( 

205 self, catexps: Iterable[CatalogExposureInputs], cat_ref: afwTable.SourceCatalog 

206 ) -> pipeBase.Struct: 

207 """Fit models to deblended sources from multi-band inputs. 

208 

209 Parameters 

210 ---------- 

211 catexps : `typing.List [CatalogExposureInputs]` 

212 A list of catalog-exposure pairs with metadata in a given band. 

213 cat_ref : `lsst.afw.table.SourceCatalog` 

214 A reference source catalog to fit. 

215 

216 Returns 

217 ------- 

218 retStruct : `lsst.pipe.base.Struct` 

219 A struct with a cat_output attribute containing the output 

220 measurement catalog. 

221 

222 Notes 

223 ----- 

224 Subclasses may have further requirements on the input parameters, 

225 including: 

226 - Passing only one catexp per band; 

227 - Catalogs containing HeavyFootprints with deblended images; 

228 - Fitting only a subset of the sources. 

229 If any requirements are not met, the subtask should fail as soon as 

230 possible. 

231 """ 

232 

233 

234class CoaddMultibandFitConfig( 

235 pipeBase.PipelineTaskConfig, 

236 pipelineConnections=CoaddMultibandFitConnections, 

237): 

238 """Configure a CoaddMultibandFitTask, including a configurable fitting subtask. 

239 """ 

240 fit_coadd_multiband = pexConfig.ConfigurableField( 

241 target=CoaddMultibandFitSubTask, 

242 doc="Task to fit sources using multiple bands", 

243 ) 

244 idGenerator = SkyMapIdGeneratorConfig.make_field() 

245 

246 def get_band_sets(self): 

247 """Get the set of bands required by the fit_coadd_multiband subtask. 

248 

249 Returns 

250 ------- 

251 bands_fit : `set` 

252 The set of bands that the subtask will fit. 

253 bands_read_only : `set` 

254 The set of bands that the subtask will only read data 

255 (measurement catalog and exposure) for. 

256 """ 

257 try: 

258 bands_fit = self.fit_coadd_multiband.bands_fit 

259 except AttributeError: 

260 raise RuntimeError(f'{__class__}.fit_coadd_multiband must have bands_fit attribute') from None 

261 bands_read_only = self.fit_coadd_multiband.bands_read_only() 

262 return tuple(list({band: None for band in bands}.keys()) for bands in (bands_fit, bands_read_only)) 

263 

264 

265class CoaddMultibandFitTask(pipeBase.PipelineTask): 

266 """Fit deblended exposures in multiple bands simultaneously. 

267 

268 It is generally assumed but not enforced (except optionally by the 

269 configurable `fit_coadd_multiband` subtask) that there is only one exposure 

270 per band, presumably a coadd. 

271 """ 

272 ConfigClass = CoaddMultibandFitConfig 

273 _DefaultName = "CoaddMultibandFit" 

274 

275 def __init__(self, initInputs, **kwargs): 

276 super().__init__(initInputs=initInputs, **kwargs) 

277 self.makeSubtask("fit_coadd_multiband") 

278 

279 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

280 inputs = butlerQC.get(inputRefs) 

281 id_tp = self.config.idGenerator.apply(butlerQC.quantum.dataId).catalog_id 

282 # This is a roundabout way of ensuring all inputs get sorted and matched 

283 input_refs_objs = [(getattr(inputRefs, key), inputs[key]) 

284 for key in ("cats_meas", "coadds", "models_psf")] 

285 cats, exps, models_psf = [ 

286 {dRef.dataId: obj for dRef, obj in zip(refs, objs)} 

287 for refs, objs in input_refs_objs 

288 ] 

289 dataIds = set(cats).union(set(exps)) 

290 models_scarlet = inputs["models_scarlet"] 

291 catexps = {} 

292 for dataId in dataIds: 

293 catalog = cats[dataId] 

294 exposure = exps[dataId] 

295 updateCatalogFootprints( 

296 modelData=models_scarlet, 

297 catalog=catalog, 

298 band=dataId['band'], 

299 imageForRedistribution=exposure, 

300 removeScarletData=True, 

301 updateFluxColumns=False, 

302 ) 

303 catexps[dataId['band']] = CatalogExposureInputs( 

304 catalog=catalog, exposure=exposure, table_psf_fits=models_psf[dataId], 

305 dataId=dataId, id_tract_patch=id_tp, 

306 ) 

307 catexps = [catexps[band] for band in self.config.get_band_sets()[0]] 

308 outputs = self.run(catexps=catexps, cat_ref=inputs['cat_ref']) 

309 butlerQC.put(outputs, outputRefs) 

310 

311 def run(self, catexps: list[CatalogExposure], cat_ref: afwTable.SourceCatalog) -> pipeBase.Struct: 

312 """Fit sources from a reference catalog using data from multiple 

313 exposures in the same region (patch). 

314 

315 Parameters 

316 ---------- 

317 catexps : `typing.List [CatalogExposure]` 

318 A list of catalog-exposure pairs in a given band. 

319 cat_ref : `lsst.afw.table.SourceCatalog` 

320 A reference source catalog to fit. 

321 

322 Returns 

323 ------- 

324 retStruct : `lsst.pipe.base.Struct` 

325 A struct with a cat_output attribute containing the output 

326 measurement catalog. 

327 

328 Notes 

329 ----- 

330 Subtasks may have further requirements; see `CoaddMultibandFitSubTask.run`. 

331 """ 

332 cat_output = self.fit_coadd_multiband.run(catalog_multi=cat_ref, catexps=catexps).output 

333 retStruct = pipeBase.Struct(cat_output=cat_output) 

334 return retStruct