Coverage for python/lsst/pipe/tasks/fit_coadd_multiband.py: 44%

81 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-02-24 01:40 -0800

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = [ 

23 "CoaddMultibandFitConfig", "CoaddMultibandFitSubConfig", "CoaddMultibandFitSubTask", 

24 "CoaddMultibandFitTask", 

25] 

26 

27from .fit_multiband import CatalogExposure, CatalogExposureConfig 

28 

29import lsst.afw.table as afwTable 

30from lsst.obs.base import ExposureIdInfo 

31import lsst.pex.config as pexConfig 

32import lsst.pipe.base as pipeBase 

33import lsst.pipe.base.connectionTypes as cT 

34 

35import astropy 

36from abc import ABC, abstractmethod 

37from pydantic import Field 

38from pydantic.dataclasses import dataclass 

39from typing import Iterable 

40 

41CoaddMultibandFitBaseTemplates = { 

42 "name_coadd": "deep", 

43 "name_method": "multiprofit", 

44} 

45 

46 

47@dataclass(frozen=True, kw_only=True, config=CatalogExposureConfig) 

48class CatalogExposureInputs(CatalogExposure): 

49 table_psf_fits: astropy.table.Table = Field(title="A table of PSF fit parameters for each source") 

50 

51 def get_catalog(self): 

52 return self.catalog 

53 

54 

55class CoaddMultibandFitConnections( 

56 pipeBase.PipelineTaskConnections, 

57 dimensions=("tract", "patch", "skymap"), 

58 defaultTemplates=CoaddMultibandFitBaseTemplates, 

59): 

60 cat_ref = cT.Input( 

61 doc="Reference multiband source catalog", 

62 name="{name_coadd}Coadd_ref", 

63 storageClass="SourceCatalog", 

64 dimensions=("tract", "patch", "skymap"), 

65 ) 

66 cats_meas = cT.Input( 

67 doc="Deblended single-band source catalogs", 

68 name="{name_coadd}Coadd_meas", 

69 storageClass="SourceCatalog", 

70 dimensions=("tract", "patch", "band", "skymap"), 

71 multiple=True, 

72 ) 

73 coadds = cT.Input( 

74 doc="Exposures on which to run fits", 

75 name="{name_coadd}Coadd_calexp", 

76 storageClass="ExposureF", 

77 dimensions=("tract", "patch", "band", "skymap"), 

78 multiple=True, 

79 ) 

80 models_psf = cT.Input( 

81 doc="Input PSF model parameter catalog", 

82 # Consider allowing independent psf fit method 

83 name="{name_coadd}Coadd_psfs_{name_method}", 

84 storageClass="ArrowAstropy", 

85 dimensions=("tract", "patch", "band", "skymap"), 

86 multiple=True, 

87 ) 

88 models_scarlet = pipeBase.connectionTypes.Input( 

89 doc="Multiband scarlet models produced by the deblender", 

90 name="{name_coadd}Coadd_scarletModelData", 

91 storageClass="ScarletModelData", 

92 dimensions=("tract", "patch", "skymap"), 

93 ) 

94 cat_output = cT.Output( 

95 doc="Output source model fit parameter catalog", 

96 name="{name_coadd}Coadd_objects_{name_method}", 

97 storageClass="ArrowTable", 

98 dimensions=("tract", "patch", "skymap"), 

99 ) 

100 

101 def adjustQuantum(self, inputs, outputs, label, data_id): 

102 """Validates the `lsst.daf.butler.DatasetRef` bands against the 

103 subtask's list of bands to fit and drops unnecessary bands. 

104 

105 Parameters 

106 ---------- 

107 inputs : `dict` 

108 Dictionary whose keys are an input (regular or prerequisite) 

109 connection name and whose values are a tuple of the connection 

110 instance and a collection of associated `DatasetRef` objects. 

111 The exact type of the nested collections is unspecified; it can be 

112 assumed to be multi-pass iterable and support `len` and ``in``, but 

113 it should not be mutated in place. In contrast, the outer 

114 dictionaries are guaranteed to be temporary copies that are true 

115 `dict` instances, and hence may be modified and even returned; this 

116 is especially useful for delegating to `super` (see notes below). 

117 outputs : `Mapping` 

118 Mapping of output datasets, with the same structure as ``inputs``. 

119 label : `str` 

120 Label for this task in the pipeline (should be used in all 

121 diagnostic messages). 

122 data_id : `lsst.daf.butler.DataCoordinate` 

123 Data ID for this quantum in the pipeline (should be used in all 

124 diagnostic messages). 

125 

126 Returns 

127 ------- 

128 adjusted_inputs : `Mapping` 

129 Mapping of the same form as ``inputs`` with updated containers of 

130 input `DatasetRef` objects. All inputs involving the 'band' 

131 dimension are adjusted to put them in consistent order and remove 

132 unneeded bands. 

133 adjusted_outputs : `Mapping` 

134 Mapping of updated output datasets; always empty for this task. 

135 

136 Raises 

137 ------ 

138 lsst.pipe.base.NoWorkFound 

139 Raised if there are not enough of the right bands to run the task 

140 on this quantum. 

141 """ 

142 # Check which bands are going to be fit 

143 bands_fit, bands_read_only = self.config.get_band_sets() 

144 bands_needed = bands_fit.union(bands_read_only) 

145 

146 adjusted_inputs = {} 

147 for connection_name, (connection, dataset_refs) in inputs.items(): 

148 # Datasets without bands in their dimensions should be fine 

149 if 'band' in connection.dimensions: 

150 datasets_by_band = {dref.dataId['band']: dref for dref in dataset_refs} 

151 if not bands_needed.issubset(datasets_by_band.keys()): 

152 raise pipeBase.NoWorkFound( 

153 f'DatasetRefs={dataset_refs} have data with bands in the' 

154 f' set={set(datasets_by_band.keys())},' 

155 f' which is not a superset of the required bands={bands_needed} defined by' 

156 f' {self.config.__class__}.fit_coadd_multiband=' 

157 f'{self.config.fit_coadd_multiband._value.__class__}\'s attributes' 

158 f' bands_fit={bands_fit} and bands_read_only()={bands_read_only}.' 

159 f' Add the required bands={bands_needed.difference(datasets_by_band.keys())}.' 

160 ) 

161 # Adjust all datasets with band dimensions to include just 

162 # the needed bands, in consistent order. 

163 adjusted_inputs[connection_name] = ( 

164 connection, 

165 [datasets_by_band[band] for band in bands_needed] 

166 ) 

167 

168 # Delegate to super for more checks. 

169 inputs.update(adjusted_inputs) 

170 super().adjustQuantum(inputs, outputs, label, data_id) 

171 return adjusted_inputs, {} 

172 

173 

174class CoaddMultibandFitSubConfig(pexConfig.Config): 

175 """Configuration for implementing fitter subtasks. 

176 """ 

177 @abstractmethod 

178 def bands_read_only(self) -> set: 

179 """Return the set of bands that the Task needs to read (e.g. for 

180 defining priors) but not necessarily fit. 

181 

182 Returns 

183 ------- 

184 The set of such bands. 

185 """ 

186 

187 

188class CoaddMultibandFitSubTask(pipeBase.Task, ABC): 

189 """Subtask interface for multiband fitting of deblended sources. 

190 

191 Parameters 

192 ---------- 

193 **kwargs 

194 Additional arguments to be passed to the `lsst.pipe.base.Task` 

195 constructor. 

196 """ 

197 ConfigClass = CoaddMultibandFitSubConfig 

198 

199 def __init__(self, **kwargs): 

200 super().__init__(**kwargs) 

201 

202 @abstractmethod 

203 def run( 

204 self, catexps: Iterable[CatalogExposureInputs], cat_ref: afwTable.SourceCatalog 

205 ) -> pipeBase.Struct: 

206 """Fit models to deblended sources from multi-band inputs. 

207 

208 Parameters 

209 ---------- 

210 catexps : `typing.List [CatalogExposureInputs]` 

211 A list of catalog-exposure pairs with metadata in a given band. 

212 cat_ref : `lsst.afw.table.SourceCatalog` 

213 A reference source catalog to fit. 

214 

215 Returns 

216 ------- 

217 retStruct : `lsst.pipe.base.Struct` 

218 A struct with a cat_output attribute containing the output 

219 measurement catalog. 

220 

221 Notes 

222 ----- 

223 Subclasses may have further requirements on the input parameters, 

224 including: 

225 - Passing only one catexp per band; 

226 - Catalogs containing HeavyFootprints with deblended images; 

227 - Fitting only a subset of the sources. 

228 If any requirements are not met, the subtask should fail as soon as 

229 possible. 

230 """ 

231 

232 

233class CoaddMultibandFitConfig( 

234 pipeBase.PipelineTaskConfig, 

235 pipelineConnections=CoaddMultibandFitConnections, 

236): 

237 """Configure a CoaddMultibandFitTask, including a configurable fitting subtask. 

238 """ 

239 fit_coadd_multiband = pexConfig.ConfigurableField( 

240 target=CoaddMultibandFitSubTask, 

241 doc="Task to fit sources using multiple bands", 

242 ) 

243 

244 def get_band_sets(self): 

245 """Get the set of bands required by the fit_coadd_multiband subtask. 

246 

247 Returns 

248 ------- 

249 bands_fit : `set` 

250 The set of bands that the subtask will fit. 

251 bands_read_only : `set` 

252 The set of bands that the subtask will only read data 

253 (measurement catalog and exposure) for. 

254 """ 

255 try: 

256 bands_fit = self.fit_coadd_multiband.bands_fit 

257 except AttributeError: 

258 raise RuntimeError(f'{__class__}.fit_coadd_multiband must have bands_fit attribute') from None 

259 bands_read_only = self.fit_coadd_multiband.bands_read_only() 

260 return set(bands_fit), set(bands_read_only) 

261 

262 

263class CoaddMultibandFitTask(pipeBase.PipelineTask): 

264 """Fit deblended exposures in multiple bands simultaneously. 

265 

266 It is generally assumed but not enforced (except optionally by the 

267 configurable `fit_coadd_multiband` subtask) that there is only one exposure 

268 per band, presumably a coadd. 

269 """ 

270 ConfigClass = CoaddMultibandFitConfig 

271 _DefaultName = "CoaddMultibandFit" 

272 

273 def __init__(self, initInputs, **kwargs): 

274 super().__init__(initInputs=initInputs, **kwargs) 

275 self.makeSubtask("fit_coadd_multiband") 

276 

277 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

278 inputs = butlerQC.get(inputRefs) 

279 id_tp = ExposureIdInfo.fromDataId(butlerQC.quantum.dataId, "tract_patch").expId 

280 # This is a roundabout way of ensuring all inputs get sorted and matched 

281 input_refs_objs = [(getattr(inputRefs, key), inputs[key]) 

282 for key in ("cats_meas", "coadds", "models_psf")] 

283 cats, exps, models_psf = [ 

284 {dRef.dataId: obj for dRef, obj in zip(refs, objs)} 

285 for refs, objs in input_refs_objs 

286 ] 

287 dataIds = set(cats).union(set(exps)) 

288 models_scarlet = inputs["models_scarlet"] 

289 catexps = [None]*len(dataIds) 

290 for idx, dataId in enumerate(dataIds): 

291 catalog = cats[dataId] 

292 exposure = exps[dataId] 

293 models_scarlet.updateCatalogFootprints( 

294 catalog=catalog, 

295 band=dataId['band'], 

296 psfModel=exposure.getPsf(), 

297 redistributeImage=exposure.image, 

298 removeScarletData=True, 

299 updateFluxColumns=False, 

300 ) 

301 catexps[idx] = CatalogExposureInputs( 

302 catalog=catalog, exposure=exposure, table_psf_fits=models_psf[dataId], 

303 dataId=dataId, id_tract_patch=id_tp, 

304 ) 

305 outputs = self.run(catexps=catexps, cat_ref=inputs['cat_ref']) 

306 butlerQC.put(outputs, outputRefs) 

307 

308 def run(self, catexps: list[CatalogExposure], cat_ref: afwTable.SourceCatalog) -> pipeBase.Struct: 

309 """Fit sources from a reference catalog using data from multiple 

310 exposures in the same region (patch). 

311 

312 Parameters 

313 ---------- 

314 catexps : `typing.List [CatalogExposure]` 

315 A list of catalog-exposure pairs in a given band. 

316 cat_ref : `lsst.afw.table.SourceCatalog` 

317 A reference source catalog to fit. 

318 

319 Returns 

320 ------- 

321 retStruct : `lsst.pipe.base.Struct` 

322 A struct with a cat_output attribute containing the output 

323 measurement catalog. 

324 

325 Notes 

326 ----- 

327 Subtasks may have further requirements; see `CoaddMultibandFitSubTask.run`. 

328 """ 

329 cat_output = self.fit_coadd_multiband.run(catalog_multi=cat_ref, catexps=catexps).output 

330 retStruct = pipeBase.Struct(cat_output=cat_output) 

331 return retStruct