Coverage for python/lsst/pipe/tasks/fit_multiband.py: 47%

93 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-11-17 02:40 -0800

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = [ 

23 "CatalogExposure", "MultibandFitConfig", "MultibandFitSubConfig", "MultibandFitSubTask", 

24 "MultibandFitTask", 

25] 

26 

27from abc import ABC, abstractmethod 

28from dataclasses import dataclass, field 

29import lsst.afw.image as afwImage 

30import lsst.afw.table as afwTable 

31import lsst.daf.butler as dafButler 

32from lsst.obs.base import ExposureIdInfo 

33import lsst.pex.config as pexConfig 

34import lsst.pipe.base as pipeBase 

35import lsst.pipe.base.connectionTypes as cT 

36from typing import Dict, Iterable, List, Optional, Set 

37 

38 

39@dataclass(frozen=True) 

40class CatalogExposure: 

41 """A class to store a catalog, exposure, and metadata for a given dataId. 

42 

43 This class is intended to store an exposure and an associated measurement 

44 catalog. There are no checks to ensure this, so repurpose responsibly. 

45 """ 

46 @property 

47 def band(self) -> str: 

48 return self.dataId['band'] 

49 

50 @property 

51 def calib(self) -> Optional[afwImage.PhotoCalib]: 

52 return None if self.exposure is None else self.exposure.getPhotoCalib() 

53 

54 catalog: Optional[afwTable.SourceCatalog] 

55 exposure: Optional[afwImage.Exposure] 

56 dataId: dafButler.DataCoordinate 

57 id_tract_patch: Optional[int] = 0 

58 metadata: Dict = field(default_factory=dict) 

59 

60 def __post_init__(self): 

61 if 'band' not in self.dataId: 

62 raise ValueError(f'dataId={self.dataId} must have a band') 

63 

64 

65multibandFitBaseTemplates = { 

66 "name_input_coadd": "deep", 

67 "name_output_coadd": "deep", 

68 "name_output_cat": "fit", 

69} 

70 

71 

72class MultibandFitConnections( 

73 pipeBase.PipelineTaskConnections, 

74 dimensions=("tract", "patch", "skymap"), 

75 defaultTemplates=multibandFitBaseTemplates, 

76): 

77 cat_ref = cT.Input( 

78 doc="Reference multiband source catalog", 

79 name="{name_input_coadd}Coadd_ref", 

80 storageClass="SourceCatalog", 

81 dimensions=("tract", "patch", "skymap"), 

82 ) 

83 cats_meas = cT.Input( 

84 doc="Deblended single-band source catalogs", 

85 name="{name_input_coadd}Coadd_meas", 

86 storageClass="SourceCatalog", 

87 multiple=True, 

88 dimensions=("tract", "patch", "band", "skymap"), 

89 ) 

90 coadds = cT.Input( 

91 doc="Exposures on which to run fits", 

92 name="{name_input_coadd}Coadd_calexp", 

93 storageClass="ExposureF", 

94 multiple=True, 

95 dimensions=("tract", "patch", "band", "skymap"), 

96 ) 

97 cat_output = cT.Output( 

98 doc="Measurement multi-band catalog", 

99 name="{name_output_coadd}Coadd_{name_output_cat}", 

100 storageClass="SourceCatalog", 

101 dimensions=("tract", "patch", "skymap"), 

102 ) 

103 cat_ref_schema = cT.InitInput( 

104 doc="Schema associated with a ref source catalog", 

105 storageClass="SourceCatalog", 

106 name="{name_input_coadd}Coadd_ref_schema", 

107 ) 

108 cat_output_schema = cT.InitOutput( 

109 doc="Output of the schema used in deblending task", 

110 name="{name_output_coadd}Coadd_{name_output_cat}_schema", 

111 storageClass="SourceCatalog" 

112 ) 

113 

114 def adjustQuantum(self, inputs, outputs, label, data_id): 

115 """Validates the `lsst.daf.butler.DatasetRef` bands against the 

116 subtask's list of bands to fit and drops unnecessary bands. 

117 

118 Parameters 

119 ---------- 

120 inputs : `dict` 

121 Dictionary whose keys are an input (regular or prerequisite) 

122 connection name and whose values are a tuple of the connection 

123 instance and a collection of associated `DatasetRef` objects. 

124 The exact type of the nested collections is unspecified; it can be 

125 assumed to be multi-pass iterable and support `len` and ``in``, but 

126 it should not be mutated in place. In contrast, the outer 

127 dictionaries are guaranteed to be temporary copies that are true 

128 `dict` instances, and hence may be modified and even returned; this 

129 is especially useful for delegating to `super` (see notes below). 

130 outputs : `Mapping` 

131 Mapping of output datasets, with the same structure as ``inputs``. 

132 label : `str` 

133 Label for this task in the pipeline (should be used in all 

134 diagnostic messages). 

135 data_id : `lsst.daf.butler.DataCoordinate` 

136 Data ID for this quantum in the pipeline (should be used in all 

137 diagnostic messages). 

138 

139 Returns 

140 ------- 

141 adjusted_inputs : `Mapping` 

142 Mapping of the same form as ``inputs`` with updated containers of 

143 input `DatasetRef` objects. All inputs involving the 'band' 

144 dimension are adjusted to put them in consistent order and remove 

145 unneeded bands. 

146 adjusted_outputs : `Mapping` 

147 Mapping of updated output datasets; always empty for this task. 

148 

149 Raises 

150 ------ 

151 lsst.pipe.base.NoWorkFound 

152 Raised if there are not enough of the right bands to run the task 

153 on this quantum. 

154 """ 

155 # Check which bands are going to be fit 

156 bands_fit, bands_read_only = self.config.get_band_sets() 

157 bands_needed = bands_fit.union(bands_read_only) 

158 

159 adjusted_inputs = {} 

160 for connection_name, (connection, dataset_refs) in inputs.items(): 

161 # Datasets without bands in their dimensions should be fine 

162 if 'band' in connection.dimensions: 

163 datasets_by_band = {dref.dataId['band']: dref for dref in dataset_refs} 

164 if not bands_needed.issubset(datasets_by_band.keys()): 

165 raise pipeBase.NoWorkFound( 

166 f'DatasetRefs={dataset_refs} have data with bands in the' 

167 f' set={set(datasets_by_band.keys())},' 

168 f' which is not a superset of the required bands={bands_needed} defined by' 

169 f' {self.config.__class__}.fit_multiband=' 

170 f'{self.config.fit_multiband._value.__class__}\'s attributes' 

171 f' bands_fit={bands_fit} and bands_read_only()={bands_read_only}.' 

172 f' Add the required bands={bands_needed.difference(datasets_by_band.keys())}.' 

173 ) 

174 # Adjust all datasets with band dimensions to include just 

175 # the needed bands, in consistent order. 

176 adjusted_inputs[connection_name] = ( 

177 connection, 

178 [datasets_by_band[band] for band in bands_needed] 

179 ) 

180 

181 # Delegate to super for more checks. 

182 inputs.update(adjusted_inputs) 

183 super().adjustQuantum(inputs, outputs, label, data_id) 

184 return adjusted_inputs, {} 

185 

186 

187class MultibandFitSubConfig(pexConfig.Config): 

188 """Config class for the MultibandFitTask to define methods returning 

189 values that depend on multiple config settings. 

190 

191 """ 

192 def bands_read_only(self) -> Set: 

193 """Return the set of bands that the Task needs to read (e.g. for 

194 defining priors) but not necessarily fit. 

195 

196 Returns 

197 ------- 

198 The set of such bands. 

199 """ 

200 return set() 

201 

202 

203class MultibandFitSubTask(pipeBase.Task, ABC): 

204 """An abstract interface for subtasks of MultibandFitTask to perform 

205 multiband fitting of deblended sources. 

206 

207 Parameters 

208 ---------- 

209 schema : `lsst.afw.table.Schema` 

210 The input schema for the reference source catalog, used to initialize 

211 the output schema. 

212 **kwargs 

213 Additional arguments to be passed to the `lsst.pipe.base.Task` 

214 constructor. 

215 """ 

216 ConfigClass = MultibandFitSubConfig 

217 

218 def __init__(self, schema: afwTable.Schema, **kwargs): 

219 super().__init__(**kwargs) 

220 

221 @abstractmethod 

222 def run( 

223 self, catexps: Iterable[CatalogExposure], cat_ref: afwTable.SourceCatalog 

224 ) -> pipeBase.Struct: 

225 """Fit sources from a reference catalog using data from multiple 

226 exposures in the same patch. 

227 

228 Parameters 

229 ---------- 

230 catexps : `typing.List [CatalogExposure]` 

231 A list of catalog-exposure pairs in a given band. 

232 cat_ref : `lsst.afw.table.SourceCatalog` 

233 A reference source catalog to fit. 

234 

235 Returns 

236 ------- 

237 retStruct : `lsst.pipe.base.Struct` 

238 A struct with a cat_output attribute containing the output 

239 measurement catalog. 

240 

241 Notes 

242 ----- 

243 Subclasses may have further requirements on the input parameters, 

244 including: 

245 - Passing only one catexp per band; 

246 - Catalogs containing HeavyFootprints with deblended images; 

247 - Fitting only a subset of the sources. 

248 If any requirements are not met, the subtask should fail as soon as 

249 possible. 

250 """ 

251 raise NotImplementedError() 

252 

253 @property 

254 @abstractmethod 

255 def schema(self) -> afwTable.Schema: 

256 raise NotImplementedError() 

257 

258 

259class MultibandFitConfig( 

260 pipeBase.PipelineTaskConfig, 

261 pipelineConnections=MultibandFitConnections, 

262): 

263 """Configure a MultibandFitTask, including a configurable fitting subtask. 

264 """ 

265 fit_multiband = pexConfig.ConfigurableField( 

266 target=MultibandFitSubTask, 

267 doc="Task to fit sources using multiple bands", 

268 ) 

269 

270 def get_band_sets(self): 

271 """Get the set of bands required by the fit_multiband subtask. 

272 

273 Returns 

274 ------- 

275 bands_fit : `set` 

276 The set of bands that the subtask will fit. 

277 bands_read_only : `set` 

278 The set of bands that the subtask will only read data 

279 (measurement catalog and exposure) for. 

280 """ 

281 try: 

282 bands_fit = self.fit_multiband.bands_fit 

283 except AttributeError: 

284 raise RuntimeError(f'{__class__}.fit_multiband must have bands_fit attribute') from None 

285 bands_read_only = self.fit_multiband.bands_read_only() 

286 return set(bands_fit), set(bands_read_only) 

287 

288 

289class MultibandFitTask(pipeBase.PipelineTask): 

290 """Fit deblended exposures in multiple bands simultaneously. 

291 

292 It is generally assumed but not enforced (except optionally by the 

293 configurable `fit_multiband` subtask) that there is only one exposure 

294 per band, presumably a coadd. 

295 """ 

296 ConfigClass = MultibandFitConfig 

297 _DefaultName = "multibandFit" 

298 

299 def __init__(self, initInputs, **kwargs): 

300 super().__init__(initInputs=initInputs, **kwargs) 

301 self.makeSubtask("fit_multiband", schema=initInputs["cat_ref_schema"].schema) 

302 self.cat_output_schema = afwTable.SourceCatalog(self.fit_multiband.schema) 

303 

304 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

305 inputs = butlerQC.get(inputRefs) 

306 id_tp = ExposureIdInfo.fromDataId(butlerQC.quantum.dataId, "tract_patch").expId 

307 input_refs_objs = [(inputRefs.cats_meas, inputs['cats_meas']), (inputRefs.coadds, inputs['coadds'])] 

308 cats, exps = [ 

309 {dRef.dataId: obj for dRef, obj in zip(refs, objs)} 

310 for refs, objs in input_refs_objs 

311 ] 

312 dataIds = set(cats).union(set(exps)) 

313 catexps = [ 

314 CatalogExposure( 

315 catalog=cats.get(dataId), exposure=exps.get(dataId), dataId=dataId, id_tract_patch=id_tp, 

316 ) 

317 for dataId in dataIds 

318 ] 

319 outputs = self.run(catexps=catexps, cat_ref=inputs['cat_ref']) 

320 butlerQC.put(outputs, outputRefs) 

321 # Validate the output catalog's schema and raise if inconsistent (after output to allow debugging) 

322 if outputs.cat_output.schema != self.cat_output_schema.schema: 

323 raise RuntimeError(f'{__class__}.config.fit_multiband.run schema != initOutput schema:' 

324 f' {outputs.cat_output.schema} vs {self.cat_output_schema.schema}') 

325 

326 def run(self, catexps: List[CatalogExposure], cat_ref: afwTable.SourceCatalog) -> pipeBase.Struct: 

327 """Fit sources from a reference catalog using data from multiple 

328 exposures in the same region (patch). 

329 

330 Parameters 

331 ---------- 

332 catexps : `typing.List [CatalogExposure]` 

333 A list of catalog-exposure pairs in a given band. 

334 cat_ref : `lsst.afw.table.SourceCatalog` 

335 A reference source catalog to fit. 

336 

337 Returns 

338 ------- 

339 retStruct : `lsst.pipe.base.Struct` 

340 A struct with a cat_output attribute containing the output 

341 measurement catalog. 

342 

343 Notes 

344 ----- 

345 Subtasks may have further requirements; see `MultibandFitSubTask.run`. 

346 """ 

347 cat_output = self.fit_multiband.run(catexps, cat_ref).output 

348 retStruct = pipeBase.Struct(cat_output=cat_output) 

349 return retStruct