lsst.pipe.tasks  21.0.0-51-gd3b42663+3149e5afc4
fit_multiband.py
Go to the documentation of this file.
1 # This file is part of pipe_tasks.
2 #
3 # Developed for the LSST Data Management System.
4 # This product includes software developed by the LSST Project
5 # (https://www.lsst.org).
6 # See the COPYRIGHT file at the top-level directory of this distribution
7 # for details of code ownership.
8 #
9 # This program is free software: you can redistribute it and/or modify
10 # it under the terms of the GNU General Public License as published by
11 # the Free Software Foundation, either version 3 of the License, or
12 # (at your option) any later version.
13 #
14 # This program is distributed in the hope that it will be useful,
15 # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 # GNU General Public License for more details.
18 #
19 # You should have received a copy of the GNU General Public License
20 # along with this program. If not, see <https://www.gnu.org/licenses/>.
21 
22 __all__ = ["CatalogExposure", "MultibandFitSubTask", "MultibandFitConfig", "MultibandFitTask"]
23 
24 from abc import ABC, abstractmethod
25 from dataclasses import dataclass, field
26 import lsst.afw.image as afwImage
27 import lsst.afw.table as afwTable
28 import lsst.daf.butler as dafButler
29 import lsst.pex.config as pexConfig
30 import lsst.pipe.base as pipeBase
32 from typing import Dict, Iterable, List, Optional, Set
33 
34 
35 @dataclass(frozen=True)
37  @property
38  def band(self) -> str:
39  return self.dataId['band']
40 
41  @property
42  def calib(self) -> Optional[afwImage.PhotoCalib]:
43  return None if self.exposure is None else self.exposure.getPhotoCalib()
44 
45  catalog: Optional[afwTable.SourceCatalog]
46  exposure: Optional[afwImage.Exposure]
47  dataId: dafButler.DataCoordinate
48  metadata: Dict = field(default_factory=dict)
49 
50  def __post_init__(self):
51  if 'band' not in self.dataId:
52  raise ValueError(f'dataId={self.dataId} must have a band')
53 
54 
55 multibandFitBaseTemplates = {
56  "name_input_coadd": "deep",
57  "name_output_coadd": "deep",
58  "name_output_cat": "fit",
59 }
60 
61 
63  pipeBase.PipelineTaskConnections,
64  dimensions=("tract", "patch", "skymap"),
65  defaultTemplates=multibandFitBaseTemplates,
66 ):
67  cat_ref = cT.Input(
68  doc="Reference multiband source catalog",
69  name="{name_input_coadd}Coadd_ref",
70  storageClass="SourceCatalog",
71  dimensions=("tract", "patch", "skymap"),
72  )
73  cats_meas = cT.Input(
74  doc="Deblended single-band source catalogs",
75  name="{name_input_coadd}Coadd_meas",
76  storageClass="SourceCatalog",
77  multiple=True,
78  dimensions=("tract", "patch", "band", "skymap"),
79  )
80  coadds = cT.Input(
81  doc="Exposures on which to run fits",
82  name="{name_input_coadd}Coadd_calexp",
83  storageClass="ExposureF",
84  multiple=True,
85  dimensions=("tract", "patch", "band", "skymap"),
86  )
87  cat_output = cT.Output(
88  doc="Measurement multi-band catalog",
89  name="{name_output_coadd}Coadd_{name_output_cat}",
90  storageClass="SourceCatalog",
91  dimensions=("tract", "patch", "skymap"),
92  )
93  cat_ref_schema = cT.InitInput(
94  doc="Schema associated with a ref source catalog",
95  storageClass="SourceCatalog",
96  name="{name_input_coadd}Coadd_ref_schema",
97  )
98  cat_output_schema = cT.InitOutput(
99  doc="Output of the schema used in deblending task",
100  name="{name_output_coadd}Coadd_{name_output_cat}_schema",
101  storageClass="SourceCatalog"
102  )
103 
104  def adjustQuantum(self, datasetRefMap):
105  """Validates the `lsst.daf.butler.DatasetRef` bands against the
106  subtask's list of bands to fit and drops unnecessary bands.
107 
108  Parameters
109  ----------
110  datasetRefMap : `NamedKeyDict`
111  Mapping from dataset type to a `set` of
112  `lsst.daf.butler.DatasetRef` objects
113 
114  Returns
115  -------
116  datasetRefMap : `NamedKeyDict`
117  Modified mapping of input with possibly adjusted
118  `lsst.daf.butler.DatasetRef` objects.
119 
120  Raises
121  ------
122  ValueError
123  Raised if any of the per-band datasets have an inconsistent band
124  set, or if the band set to fit is not a subset of the data bands.
125 
126  """
127  datasetRefMap = super().adjustQuantum(datasetRefMap)
128  # Check which bands are going to be fit
129  bands_fit, bands_read_only = self.config.get_band_sets()
130  bands_needed = bands_fit.union(bands_read_only)
131 
132  bands_data = None
133  bands_extra = set()
134 
135  for type_d, ref_d in datasetRefMap.items():
136  # Datasets without bands in their dimensions should be fine
137  if 'band' in type_d.dimensions:
138  bands_set = {dref.dataId['band'] for dref in ref_d}
139  if bands_data is None:
140  bands_data = bands_set
141  if bands_needed != bands_data:
142  if not bands_needed.issubset(bands_data):
143  raise ValueError(
144  f'Datarefs={ref_d} have data with bands in the set={bands_set},'
145  f'which is not a subset of the required bands={bands_needed} defined by '
146  f'{self.config.__class__}.fit_multiband='
147  f'{self.config.fit_multiband._value.__class__}\'s attributes'
148  f' bands_fit={bands_fit} and bands_read_only()={bands_read_only}.'
149  f' Add the required bands={bands_needed.difference(bands_data)}.'
150  )
151  else:
152  bands_extra = bands_data.difference(bands_needed)
153  elif bands_set != bands_data:
154  raise ValueError(
155  f'Datarefs={ref_d} have data with bands in the set={bands_set}'
156  f' which differs from the previous={bands_data}); bandsets must be identical.'
157  )
158  if bands_extra:
159  for dref in ref_d:
160  if dref.dataId['band'] in bands_extra:
161  ref_d.remove(dref)
162  return datasetRefMap
163 
164 
165 class MultibandFitSubConfig(pexConfig.Config):
166  """Config class for the MultibandFitTask to define methods returning
167  values that depend on multiple config settings.
168 
169  """
170  def bands_read_only(self) -> Set:
171  """Return the set of bands that the Task needs to read (e.g. for
172  defining priors) but not necessarily fit.
173 
174  Returns
175  -------
176  The set of such bands.
177  """
178  return set()
179 
180 
181 class MultibandFitSubTask(pipeBase.Task, ABC):
182  """An abstract interface for subtasks of MultibandFitTask to perform
183  multiband fitting of deblended sources.
184 
185  Parameters
186  ----------
187  schema : `lsst.afw.table.Schema`
188  The input schema for the reference source catalog, used to initialize
189  the output schema.
190  **kwargs
191  Additional arguments to be passed to the `lsst.pipe.base.Task`
192  constructor.
193  """
194  ConfigClass = MultibandFitSubConfig
195 
196  def __init__(self, schema: afwTable.Schema, **kwargs):
197  super().__init__(**kwargs)
198 
199  @abstractmethod
200  def run(
201  self, catexps: Iterable[CatalogExposure], cat_ref: afwTable.SourceCatalog
202  ) -> pipeBase.Struct:
203  """Fit sources from a reference catalog using data from multiple
204  exposures in the same patch.
205 
206  Parameters
207  ----------
208  catexps : `typing.List [CatalogExposure]`
209  A list of catalog-exposure pairs in a given band.
210  cat_ref : `lsst.afw.table.SourceCatalog`
211  A reference source catalog to fit.
212 
213  Returns
214  -------
215  retStruct : `lsst.pipe.base.Struct`
216  A struct with a cat_output attribute containing the output
217  measurement catalog.
218 
219  Notes
220  -----
221  Subclasses may have further requirements on the input parameters,
222  including:
223  - Passing only one catexp per band;
224  - Catalogs containing HeavyFootprints with deblended images;
225  - Fitting only a subset of the sources.
226  If any requirements are not met, the subtask should fail as soon as
227  possible.
228  """
229  raise NotImplementedError()
230 
231  @property
232  @abstractmethod
233  def schema(self) -> afwTable.Schema:
234  raise NotImplementedError()
235 
236 
238  pipeBase.PipelineTaskConfig,
239  pipelineConnections=MultibandFitConnections,
240 ):
241  """Configuration class for the MultibandFitTask, containing a
242  configurable subtask that does all fitting.
243  """
244  fit_multiband = pexConfig.ConfigurableField(
245  target=MultibandFitSubTask,
246  doc="Task to fit sources using multiple bands",
247  )
248 
249  def get_band_sets(self):
250  """Get the set of bands required by the fit_multiband subtask.
251 
252  Returns
253  -------
254  bands_fit : `set`
255  The set of bands that the subtask will fit.
256  bands_read_only : `set`
257  The set of bands that the subtask will only read data (measurement catalog and exposure) for.
258  """
259  try:
260  bands_fit = self.fit_multibandfit_multiband.bands_fit
261  except AttributeError:
262  raise RuntimeError(f'{__class__}.fit_multiband must have bands_fit attribute') from None
263  bands_read_only = self.fit_multibandfit_multiband.bands_read_only()
264  return set(bands_fit), set(bands_read_only)
265 
266 
267 class MultibandFitTask(pipeBase.PipelineTask):
268  ConfigClass = MultibandFitConfig
269  _DefaultName = "multibandFit"
270 
271  def __init__(self, initInputs, **kwargs):
272  super().__init__(initInputs=initInputs, **kwargs)
273  self.makeSubtask("fit_multiband", schema=initInputs["cat_ref_schema"].schema)
274  self.cat_output_schemacat_output_schema = afwTable.SourceCatalog(self.fit_multiband.schema)
275 
276  def runQuantum(self, butlerQC, inputRefs, outputRefs):
277  inputs = butlerQC.get(inputRefs)
278  input_refs_objs = [(inputRefs.cats_meas, inputs['cats_meas']), (inputRefs.coadds, inputs['coadds'])]
279  cats, exps = [
280  {dRef.dataId: obj for dRef, obj in zip(refs, objs)}
281  for refs, objs in input_refs_objs
282  ]
283  dataIds = set(cats).union(set(exps))
284  catexps = [
285  CatalogExposure(catalog=cats.get(dataId), exposure=exps.get(dataId), dataId=dataId)
286  for dataId in dataIds
287  ]
288  outputs = self.runrun(catexps=catexps, cat_ref=inputs['cat_ref'])
289  butlerQC.put(outputs, outputRefs)
290  # Validate the output catalog's schema and raise if inconsistent (after output to allow debugging)
291  if outputs.cat_output.schema != self.cat_output_schemacat_output_schema.schema:
292  raise RuntimeError(f'{__class__}.config.fit_multiband.run schema != initOutput schema:'
293  f' {outputs.cat_output.schema} vs {self.cat_output_schema.schema}')
294 
295  def run(self, catexps: List[CatalogExposure], cat_ref: afwTable.SourceCatalog) -> pipeBase.Struct:
296  """Fit sources from a reference catalog using data from multiple
297  exposures in the same region (patch).
298 
299  Parameters
300  ----------
301  catexps : `typing.List [CatalogExposure]`
302  A list of catalog-exposure pairs in a given band.
303  cat_ref : `lsst.afw.table.SourceCatalog`
304  A reference source catalog to fit.
305 
306  Returns
307  -------
308  retStruct : `lsst.pipe.base.Struct`
309  A struct with a cat_output attribute containing the output
310  measurement catalog.
311 
312  Notes
313  -----
314  Subtasks may have further requirements; see `MultibandFitSubTask.run`.
315  """
316  cat_output = self.fit_multiband.run(catexps, cat_ref).output
317  retStruct = pipeBase.Struct(cat_output=cat_output)
318  return retStruct
Optional[afwImage.PhotoCalib] calib(self)
pipeBase.Struct run(self, Iterable[CatalogExposure] catexps, afwTable.SourceCatalog cat_ref)
def __init__(self, afwTable.Schema schema, **kwargs)
pipeBase.Struct run(self, List[CatalogExposure] catexps, afwTable.SourceCatalog cat_ref)
def runQuantum(self, butlerQC, inputRefs, outputRefs)
def __init__(self, initInputs, **kwargs)