lsst.pipe.tasks  19.0.0-28-g53bcf5f6+7
postprocess.py
Go to the documentation of this file.
1 # This file is part of pipe_tasks
2 #
3 # Developed for the LSST Data Management System.
4 # This product includes software developed by the LSST Project
5 # (https://www.lsst.org).
6 # See the COPYRIGHT file at the top-level directory of this distribution
7 # for details of code ownership.
8 #
9 # This program is free software: you can redistribute it and/or modify
10 # it under the terms of the GNU General Public License as published by
11 # the Free Software Foundation, either version 3 of the License, or
12 # (at your option) any later version.
13 #
14 # This program is distributed in the hope that it will be useful,
15 # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 # GNU General Public License for more details.
18 #
19 # You should have received a copy of the GNU General Public License
20 # along with this program. If not, see <https://www.gnu.org/licenses/>.
21 
22 import functools
23 import pandas as pd
24 from collections import defaultdict
25 
26 import lsst.pex.config as pexConfig
27 import lsst.pipe.base as pipeBase
28 from lsst.pipe.base import CmdLineTask, ArgumentParser
29 from lsst.coadd.utils.coaddDataIdContainer import CoaddDataIdContainer
30 
31 from .parquetTable import ParquetTable
32 from .multiBandUtils import makeMergeArgumentParser, MergeSourcesRunner
33 from .functors import CompositeFunctor, RAColumn, DecColumn, Column
34 
35 
36 def flattenFilters(df, filterDict, noDupCols=['coord_ra', 'coord_dec'], camelCase=False):
37  """Flattens a dataframe with multilevel column index
38  """
39  newDf = pd.DataFrame()
40  for filt, filtShort in filterDict.items():
41  subdf = df[filt]
42  columnFormat = '{0}{1}' if camelCase else '{0}_{1}'
43  newColumns = {c: columnFormat.format(filtShort, c)
44  for c in subdf.columns if c not in noDupCols}
45  cols = list(newColumns.keys())
46  newDf = pd.concat([newDf, subdf[cols].rename(columns=newColumns)], axis=1)
47 
48  newDf = pd.concat([subdf[noDupCols], newDf], axis=1)
49  return newDf
50 
51 
52 class WriteObjectTableConfig(pexConfig.Config):
53  priorityList = pexConfig.ListField(
54  dtype=str,
55  default=[],
56  doc="Priority-ordered list of bands for the merge."
57  )
58  engine = pexConfig.Field(
59  dtype=str,
60  default="pyarrow",
61  doc="Parquet engine for writing (pyarrow or fastparquet)"
62  )
63  coaddName = pexConfig.Field(
64  dtype=str,
65  default="deep",
66  doc="Name of coadd"
67  )
68 
69  def validate(self):
70  pexConfig.Config.validate(self)
71  if len(self.priorityList) == 0:
72  raise RuntimeError("No priority list provided")
73 
74 
75 class WriteObjectTableTask(CmdLineTask):
76  """Write filter-merged source tables to parquet
77  """
78  _DefaultName = "writeObjectTable"
79  ConfigClass = WriteObjectTableConfig
80  RunnerClass = MergeSourcesRunner
81 
82  # Names of table datasets to be merged
83  inputDatasets = ('forced_src', 'meas', 'ref')
84 
85  # Tag of output dataset written by `MergeSourcesTask.write`
86  outputDataset = 'obj'
87 
88  def __init__(self, butler=None, schema=None, **kwargs):
89  # It is a shame that this class can't use the default init for CmdLineTask
90  # But to do so would require its own special task runner, which is many
91  # more lines of specialization, so this is how it is for now
92  CmdLineTask.__init__(self, **kwargs)
93 
94  def runDataRef(self, patchRefList):
95  """!
96  @brief Merge coadd sources from multiple bands. Calls @ref `run` which must be defined in
97  subclasses that inherit from MergeSourcesTask.
98  @param[in] patchRefList list of data references for each filter
99  """
100  catalogs = dict(self.readCatalog(patchRef) for patchRef in patchRefList)
101  dataId = patchRefList[0].dataId
102  mergedCatalog = self.run(catalogs, tract=dataId['tract'], patch=dataId['patch'])
103  self.write(patchRefList[0], mergedCatalog)
104 
105  @classmethod
106  def _makeArgumentParser(cls):
107  """Create a suitable ArgumentParser.
108 
109  We will use the ArgumentParser to get a list of data
110  references for patches; the RunnerClass will sort them into lists
111  of data references for the same patch.
112 
113  References first of self.inputDatasets, rather than
114  self.inputDataset
115  """
117 
118  def readCatalog(self, patchRef):
119  """Read input catalogs
120 
121  Read all the input datasets given by the 'inputDatasets'
122  attribute.
123 
124  Parameters
125  ----------
126  patchRef : `lsst.daf.persistence.ButlerDataRef`
127  Data reference for patch
128 
129  Returns
130  -------
131  Tuple consisting of filter name and a dict of catalogs, keyed by
132  dataset name
133  """
134  filterName = patchRef.dataId["filter"]
135  catalogDict = {}
136  for dataset in self.inputDatasets:
137  catalog = patchRef.get(self.config.coaddName + "Coadd_" + dataset, immediate=True)
138  self.log.info("Read %d sources from %s for filter %s: %s" %
139  (len(catalog), dataset, filterName, patchRef.dataId))
140  catalogDict[dataset] = catalog
141  return filterName, catalogDict
142 
143  def run(self, catalogs, tract, patch):
144  """Merge multiple catalogs.
145 
146  Parameters
147  ----------
148  catalogs : `dict`
149  Mapping from filter names to dict of catalogs.
150  tract : int
151  tractId to use for the tractId column
152  patch : str
153  patchId to use for the patchId column
154 
155  Returns
156  -------
157  catalog : `lsst.pipe.tasks.parquetTable.ParquetTable`
158  Merged dataframe, with each column prefixed by
159  `filter_tag(filt)`, wrapped in the parquet writer shim class.
160  """
161 
162  dfs = []
163  for filt, tableDict in catalogs.items():
164  for dataset, table in tableDict.items():
165  # Convert afwTable to pandas DataFrame
166  df = table.asAstropy().to_pandas().set_index('id', drop=True)
167 
168  # Sort columns by name, to ensure matching schema among patches
169  df = df.reindex(sorted(df.columns), axis=1)
170  df['tractId'] = tract
171  df['patchId'] = patch
172 
173  # Make columns a 3-level MultiIndex
174  df.columns = pd.MultiIndex.from_tuples([(dataset, filt, c) for c in df.columns],
175  names=('dataset', 'filter', 'column'))
176  dfs.append(df)
177 
178  catalog = functools.reduce(lambda d1, d2: d1.join(d2), dfs)
179  return ParquetTable(dataFrame=catalog)
180 
181  def write(self, patchRef, catalog):
182  """Write the output.
183 
184  Parameters
185  ----------
186  catalog : `ParquetTable`
187  Catalog to write
188  patchRef : `lsst.daf.persistence.ButlerDataRef`
189  Data reference for patch
190  """
191  patchRef.put(catalog, self.config.coaddName + "Coadd_" + self.outputDataset)
192  # since the filter isn't actually part of the data ID for the dataset we're saving,
193  # it's confusing to see it in the log message, even if the butler simply ignores it.
194  mergeDataId = patchRef.dataId.copy()
195  del mergeDataId["filter"]
196  self.log.info("Wrote merged catalog: %s" % (mergeDataId,))
197 
198  def writeMetadata(self, dataRefList):
199  """No metadata to write, and not sure how to write it for a list of dataRefs.
200  """
201  pass
202 
203 
204 class PostprocessAnalysis(object):
205  """Calculate columns from ParquetTable
206 
207  This object manages and organizes an arbitrary set of computations
208  on a catalog. The catalog is defined by a
209  `lsst.pipe.tasks.parquetTable.ParquetTable` object (or list thereof), such as a
210  `deepCoadd_obj` dataset, and the computations are defined by a collection
211  of `lsst.pipe.tasks.functor.Functor` objects (or, equivalently,
212  a `CompositeFunctor`).
213 
214  After the object is initialized, accessing the `.df` attribute (which
215  holds the `pandas.DataFrame` containing the results of the calculations) triggers
216  computation of said dataframe.
217 
218  One of the conveniences of using this object is the ability to define a desired common
219  filter for all functors. This enables the same functor collection to be passed to
220  several different `PostprocessAnalysis` objects without having to change the original
221  functor collection, since the `filt` keyword argument of this object triggers an
222  overwrite of the `filt` property for all functors in the collection.
223 
224  This object also allows a list of refFlags to be passed, and defines a set of default
225  refFlags that are always included even if not requested.
226 
227  If a list of `ParquetTable` object is passed, rather than a single one, then the
228  calculations will be mapped over all the input catalogs. In principle, it should
229  be straightforward to parallelize this activity, but initial tests have failed
230  (see TODO in code comments).
231 
232  Parameters
233  ----------
234  parq : `lsst.pipe.tasks.ParquetTable` (or list of such)
235  Source catalog(s) for computation
236 
237  functors : `list`, `dict`, or `lsst.pipe.tasks.functors.CompositeFunctor`
238  Computations to do (functors that act on `parq`).
239  If a dict, the output
240  DataFrame will have columns keyed accordingly.
241  If a list, the column keys will come from the
242  `.shortname` attribute of each functor.
243 
244  filt : `str` (optional)
245  Filter in which to calculate. If provided,
246  this will overwrite any existing `.filt` attribute
247  of the provided functors.
248 
249  flags : `list` (optional)
250  List of flags (per-band) to include in output table.
251 
252  refFlags : `list` (optional)
253  List of refFlags (only reference band) to include in output table.
254 
255 
256  """
257  _defaultRefFlags = ('calib_psf_used', 'detect_isPrimary')
258  _defaultFuncs = (('coord_ra', RAColumn()),
259  ('coord_dec', DecColumn()))
260 
261  def __init__(self, parq, functors, filt=None, flags=None, refFlags=None):
262  self.parq = parq
263  self.functors = functors
264 
265  self.filt = filt
266  self.flags = list(flags) if flags is not None else []
267  self.refFlags = list(self._defaultRefFlags)
268  if refFlags is not None:
269  self.refFlags += list(refFlags)
270 
271  self._df = None
272 
273  @property
274  def defaultFuncs(self):
275  funcs = dict(self._defaultFuncs)
276  return funcs
277 
278  @property
279  def func(self):
280  additionalFuncs = self.defaultFuncs
281  additionalFuncs.update({flag: Column(flag, dataset='ref') for flag in self.refFlags})
282  additionalFuncs.update({flag: Column(flag, dataset='meas') for flag in self.flags})
283 
284  if isinstance(self.functors, CompositeFunctor):
285  func = self.functors
286  else:
287  func = CompositeFunctor(self.functors)
288 
289  func.funcDict.update(additionalFuncs)
290  func.filt = self.filt
291 
292  return func
293 
294  @property
295  def noDupCols(self):
296  return [name for name, func in self.func.funcDict.items() if func.noDup or func.dataset == 'ref']
297 
298  @property
299  def df(self):
300  if self._df is None:
301  self.compute()
302  return self._df
303 
304  def compute(self, dropna=False, pool=None):
305  # map over multiple parquet tables
306  if type(self.parq) in (list, tuple):
307  if pool is None:
308  dflist = [self.func(parq, dropna=dropna) for parq in self.parq]
309  else:
310  # TODO: Figure out why this doesn't work (pyarrow pickling issues?)
311  dflist = pool.map(functools.partial(self.func, dropna=dropna), self.parq)
312  self._df = pd.concat(dflist)
313  else:
314  self._df = self.func(self.parq, dropna=dropna)
315 
316  return self._df
317 
318 
319 class TransformCatalogBaseConfig(pexConfig.Config):
320  coaddName = pexConfig.Field(
321  dtype=str,
322  default="deep",
323  doc="Name of coadd"
324  )
325  functorFile = pexConfig.Field(
326  dtype=str,
327  doc='Path to YAML file specifying functors to be computed',
328  default=None
329  )
330 
331 
332 class TransformCatalogBaseTask(CmdLineTask):
333  """Base class for transforming/standardizing a catalog
334 
335  by applying functors that convert units and apply calibrations.
336  The purpose of this task is to perform a set of computations on
337  an input `ParquetTable` dataset (such as `deepCoadd_obj`) and write the
338  results to a new dataset (which needs to be declared in an `outputDataset`
339  attribute).
340 
341  The calculations to be performed are defined in a YAML file that specifies
342  a set of functors to be computed, provided as
343  a `--functorFile` config parameter. An example of such a YAML file
344  is the following:
345 
346  funcs:
347  psfMag:
348  functor: Mag
349  args:
350  - base_PsfFlux
351  filt: HSC-G
352  dataset: meas
353  cmodel_magDiff:
354  functor: MagDiff
355  args:
356  - modelfit_CModel
357  - base_PsfFlux
358  filt: HSC-G
359  gauss_magDiff:
360  functor: MagDiff
361  args:
362  - base_GaussianFlux
363  - base_PsfFlux
364  filt: HSC-G
365  count:
366  functor: Column
367  args:
368  - base_InputCount_value
369  filt: HSC-G
370  deconvolved_moments:
371  functor: DeconvolvedMoments
372  filt: HSC-G
373  dataset: forced_src
374  refFlags:
375  - calib_psfUsed
376  - merge_measurement_i
377  - merge_measurement_r
378  - merge_measurement_z
379  - merge_measurement_y
380  - merge_measurement_g
381  - base_PixelFlags_flag_inexact_psfCenter
382  - detect_isPrimary
383 
384  The names for each entry under "func" will become the names of columns in the
385  output dataset. All the functors referenced are defined in `lsst.pipe.tasks.functors`.
386  Positional arguments to be passed to each functor are in the `args` list,
387  and any additional entries for each column other than "functor" or "args" (e.g., `'filt'`,
388  `'dataset'`) are treated as keyword arguments to be passed to the functor initialization.
389 
390  The "refFlags" entry is shortcut for a bunch of `Column` functors with the original column and
391  taken from the `'ref'` dataset.
392 
393  The "flags" entry will be expanded out per band.
394 
395  Note, if `'filter'` is provided as part of the `dataId` when running this task (even though
396  `deepCoadd_obj` does not use `'filter'`), then this will override the `filt` kwargs
397  provided in the YAML file, and the calculations will be done in that filter.
398 
399  This task uses the `lsst.pipe.tasks.postprocess.PostprocessAnalysis` object
400  to organize and excecute the calculations.
401 
402  """
403  @property
404  def _DefaultName(self):
405  raise NotImplementedError('Subclass must define "_DefaultName" attribute')
406 
407  @property
408  def outputDataset(self):
409  raise NotImplementedError('Subclass must define "outputDataset" attribute')
410 
411  @property
412  def inputDataset(self):
413  raise NotImplementedError('Subclass must define "inputDataset" attribute')
414 
415  @property
416  def ConfigClass(self):
417  raise NotImplementedError('Subclass must define "ConfigClass" attribute')
418 
419  def runDataRef(self, patchRef):
420  parq = patchRef.get()
421  dataId = patchRef.dataId
422  funcs = self.getFunctors()
423  self.log.info("Transforming/standardizing the catalog of %s", dataId)
424  df = self.run(parq, funcs=funcs, dataId=dataId)
425  self.write(df, patchRef)
426  return df
427 
428  def run(self, parq, funcs=None, dataId=None):
429  """Do postprocessing calculations
430 
431  Takes a `ParquetTable` object and dataId,
432  returns a dataframe with results of postprocessing calculations.
433 
434  Parameters
435  ----------
436  parq : `lsst.pipe.tasks.parquetTable.ParquetTable`
437  ParquetTable from which calculations are done.
438  funcs : `lsst.pipe.tasks.functors.Functors`
439  Functors to apply to the table's columns
440  dataId : dict, optional
441  Used to add a `patchId` column to the output dataframe.
442 
443  Returns
444  ------
445  `pandas.DataFrame`
446 
447  """
448  filt = dataId.get('filter', None)
449  return self.transform(filt, parq, funcs, dataId).df
450 
451  def getFunctors(self):
452  funcs = CompositeFunctor.from_file(self.config.functorFile)
453  funcs.update(dict(PostprocessAnalysis._defaultFuncs))
454  return funcs
455 
456  def getAnalysis(self, parq, funcs=None, filt=None):
457  # Avoids disk access if funcs is passed
458  if funcs is None:
459  funcs = self.getFunctors()
460  analysis = PostprocessAnalysis(parq, funcs, filt=filt)
461  return analysis
462 
463  def transform(self, filt, parq, funcs, dataId):
464  analysis = self.getAnalysis(parq, funcs=funcs, filt=filt)
465  df = analysis.df
466  if dataId is not None:
467  for key, value in dataId.items():
468  df[key] = value
469 
470  return pipeBase.Struct(
471  df=df,
472  analysis=analysis
473  )
474 
475  def write(self, df, parqRef):
476  parqRef.put(ParquetTable(dataFrame=df), self.outputDataset)
477 
478  def writeMetadata(self, dataRef):
479  """No metadata to write.
480  """
481  pass
482 
483 
484 class TransformObjectCatalogConfig(TransformCatalogBaseConfig):
485  filterMap = pexConfig.DictField(
486  keytype=str,
487  itemtype=str,
488  default={},
489  doc=("Dictionary mapping full filter name to short one for column name munging."
490  "These filters determine the output columns no matter what filters the "
491  "input data actually contain.")
492  )
493  camelCase = pexConfig.Field(
494  dtype=bool,
495  default=True,
496  doc=("Write per-filter columns names with camelCase, else underscore "
497  "For example: gPsfFlux instead of g_PsfFlux.")
498  )
499  multilevelOutput = pexConfig.Field(
500  dtype=bool,
501  default=False,
502  doc=("Whether results dataframe should have a multilevel column index (True) or be flat "
503  "and name-munged (False).")
504  )
505 
506 
508  """Compute Flatted Object Table as defined in the DPDD
509 
510  Do the same set of postprocessing calculations on all bands
511 
512  This is identical to `TransformCatalogBaseTask`, except for that it does the
513  specified functor calculations for all filters present in the
514  input `deepCoadd_obj` table. Any specific `"filt"` keywords specified
515  by the YAML file will be superceded.
516  """
517  _DefaultName = "transformObjectCatalog"
518  ConfigClass = TransformObjectCatalogConfig
519 
520  inputDataset = 'deepCoadd_obj'
521  outputDataset = 'objectTable'
522 
523  @classmethod
524  def _makeArgumentParser(cls):
525  parser = ArgumentParser(name=cls._DefaultName)
526  parser.add_id_argument("--id", cls.inputDataset,
527  ContainerClass=CoaddDataIdContainer,
528  help="data ID, e.g. --id tract=12345 patch=1,2")
529  return parser
530 
531  def run(self, parq, funcs=None, dataId=None):
532  dfDict = {}
533  analysisDict = {}
534  templateDf = pd.DataFrame()
535  # Perform transform for data of filters that exist in parq and are
536  # specified in config.filterMap
537  for filt in parq.columnLevelNames['filter']:
538  if filt not in self.config.filterMap:
539  self.log.info("Ignoring %s data in the input", filt)
540  continue
541  self.log.info("Transforming the catalog of filter %s", filt)
542  result = self.transform(filt, parq, funcs, dataId)
543  dfDict[filt] = result.df
544  analysisDict[filt] = result.analysis
545  if templateDf.empty:
546  templateDf = result.df
547 
548  # Fill NaNs in columns of other wanted filters
549  for filt in self.config.filterMap:
550  if filt not in dfDict:
551  self.log.info("Adding empty columns for filter %s", filt)
552  dfDict[filt] = pd.DataFrame().reindex_like(templateDf)
553 
554  # This makes a multilevel column index, with filter as first level
555  df = pd.concat(dfDict, axis=1, names=['filter', 'column'])
556 
557  if not self.config.multilevelOutput:
558  noDupCols = list(set.union(*[set(v.noDupCols) for v in analysisDict.values()]))
559  if dataId is not None:
560  noDupCols += list(dataId.keys())
561  df = flattenFilters(df, self.config.filterMap, noDupCols=noDupCols,
562  camelCase=self.config.camelCase)
563 
564  self.log.info("Made a table of %d columns and %d rows", len(df.columns), len(df))
565  return df
566 
567 
569 
570  def makeDataRefList(self, namespace):
571  """Make self.refList from self.idList
572 
573  Generate a list of data references given tract and/or patch.
574  This was adapted from `TractQADataIdContainer`, which was
575  `TractDataIdContainer` modifie to not require "filter".
576  Only existing dataRefs are returned.
577  """
578  def getPatchRefList(tract):
579  return [namespace.butler.dataRef(datasetType=self.datasetType,
580  tract=tract.getId(),
581  patch="%d,%d" % patch.getIndex()) for patch in tract]
582 
583  tractRefs = defaultdict(list) # Data references for each tract
584  for dataId in self.idList:
585  skymap = self.getSkymap(namespace)
586 
587  if "tract" in dataId:
588  tractId = dataId["tract"]
589  if "patch" in dataId:
590  tractRefs[tractId].append(namespace.butler.dataRef(datasetType=self.datasetType,
591  tract=tractId,
592  patch=dataId['patch']))
593  else:
594  tractRefs[tractId] += getPatchRefList(skymap[tractId])
595  else:
596  tractRefs = dict((tract.getId(), tractRefs.get(tract.getId(), []) + getPatchRefList(tract))
597  for tract in skymap)
598  outputRefList = []
599  for tractRefList in tractRefs.values():
600  existingRefs = [ref for ref in tractRefList if ref.datasetExists()]
601  outputRefList.append(existingRefs)
602 
603  self.refList = outputRefList
604 
605 
606 class ConsolidateObjectTableConfig(pexConfig.Config):
607  coaddName = pexConfig.Field(
608  dtype=str,
609  default="deep",
610  doc="Name of coadd"
611  )
612 
613 
614 class ConsolidateObjectTableTask(CmdLineTask):
615  """Write patch-merged source tables to a tract-level parquet file
616  """
617  _DefaultName = "consolidateObjectTable"
618  ConfigClass = ConsolidateObjectTableConfig
619 
620  inputDataset = 'objectTable'
621  outputDataset = 'objectTable_tract'
622 
623  @classmethod
624  def _makeArgumentParser(cls):
625  parser = ArgumentParser(name=cls._DefaultName)
626 
627  parser.add_id_argument("--id", cls.inputDataset,
628  help="data ID, e.g. --id tract=12345",
629  ContainerClass=TractObjectDataIdContainer)
630  return parser
631 
632  def runDataRef(self, patchRefList):
633  df = pd.concat([patchRef.get().toDataFrame() for patchRef in patchRefList])
634  patchRefList[0].put(ParquetTable(dataFrame=df), self.outputDataset)
635 
636  def writeMetadata(self, dataRef):
637  """No metadata to write.
638  """
639  pass
def flattenFilters(df, filterDict, noDupCols=['coord_ra', coord_dec, camelCase=False)
Definition: postprocess.py:36
def makeMergeArgumentParser(name, dataset)
Create a suitable ArgumentParser.
def run(self, parq, funcs=None, dataId=None)
Definition: postprocess.py:428
def __init__(self, butler=None, schema=None, kwargs)
Definition: postprocess.py:88
def compute(self, dropna=False, pool=None)
Definition: postprocess.py:304
def run(self, catalogs, tract, patch)
Definition: postprocess.py:143
def runDataRef(self, patchRefList)
Merge coadd sources from multiple bands.
Definition: postprocess.py:94
def run(self, parq, funcs=None, dataId=None)
Definition: postprocess.py:531
def __init__(self, parq, functors, filt=None, flags=None, refFlags=None)
Definition: postprocess.py:261
def transform(self, filt, parq, funcs, dataId)
Definition: postprocess.py:463
def getAnalysis(self, parq, funcs=None, filt=None)
Definition: postprocess.py:456