lsst.pipe.tasks  19.0.0-23-g5d8da22d+1
postprocess.py
Go to the documentation of this file.
1 # This file is part of pipe_tasks
2 #
3 # Developed for the LSST Data Management System.
4 # This product includes software developed by the LSST Project
5 # (https://www.lsst.org).
6 # See the COPYRIGHT file at the top-level directory of this distribution
7 # for details of code ownership.
8 #
9 # This program is free software: you can redistribute it and/or modify
10 # it under the terms of the GNU General Public License as published by
11 # the Free Software Foundation, either version 3 of the License, or
12 # (at your option) any later version.
13 #
14 # This program is distributed in the hope that it will be useful,
15 # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 # GNU General Public License for more details.
18 #
19 # You should have received a copy of the GNU General Public License
20 # along with this program. If not, see <https://www.gnu.org/licenses/>.
21 
22 import functools
23 import pandas as pd
24 from collections import defaultdict
25 
26 import lsst.pex.config as pexConfig
27 import lsst.pipe.base as pipeBase
28 from lsst.pipe.base import CmdLineTask, ArgumentParser
29 from lsst.coadd.utils.coaddDataIdContainer import CoaddDataIdContainer
30 
31 from .parquetTable import ParquetTable
32 from .multiBandUtils import makeMergeArgumentParser, MergeSourcesRunner
33 from .functors import CompositeFunctor, RAColumn, DecColumn, Column
34 
35 
36 def flattenFilters(df, filterDict, noDupCols=['coord_ra', 'coord_dec'], camelCase=False):
37  """Flattens a dataframe with multilevel column index
38  """
39  newDf = pd.DataFrame()
40  for filt, filtShort in filterDict.items():
41  subdf = df[filt]
42  columnFormat = '{0}{1}' if camelCase else '{0}_{1}'
43  newColumns = {c: columnFormat.format(filtShort, c)
44  for c in subdf.columns if c not in noDupCols}
45  cols = list(newColumns.keys())
46  newDf = pd.concat([newDf, subdf[cols].rename(columns=newColumns)], axis=1)
47 
48  newDf = pd.concat([subdf[noDupCols], newDf], axis=1)
49  return newDf
50 
51 
52 class WriteObjectTableConfig(pexConfig.Config):
53  priorityList = pexConfig.ListField(
54  dtype=str,
55  default=[],
56  doc="Priority-ordered list of bands for the merge."
57  )
58  engine = pexConfig.Field(
59  dtype=str,
60  default="pyarrow",
61  doc="Parquet engine for writing (pyarrow or fastparquet)"
62  )
63  coaddName = pexConfig.Field(
64  dtype=str,
65  default="deep",
66  doc="Name of coadd"
67  )
68 
69  def validate(self):
70  pexConfig.Config.validate(self)
71  if len(self.priorityList) == 0:
72  raise RuntimeError("No priority list provided")
73 
74 
75 class WriteObjectTableTask(CmdLineTask):
76  """Write filter-merged source tables to parquet
77  """
78  _DefaultName = "writeObjectTable"
79  ConfigClass = WriteObjectTableConfig
80  RunnerClass = MergeSourcesRunner
81 
82  # Names of table datasets to be merged
83  inputDatasets = ('forced_src', 'meas', 'ref')
84 
85  # Tag of output dataset written by `MergeSourcesTask.write`
86  outputDataset = 'obj'
87 
88  def __init__(self, butler=None, schema=None, **kwargs):
89  # It is a shame that this class can't use the default init for CmdLineTask
90  # But to do so would require its own special task runner, which is many
91  # more lines of specialization, so this is how it is for now
92  CmdLineTask.__init__(self, **kwargs)
93 
94  def runDataRef(self, patchRefList):
95  """!
96  @brief Merge coadd sources from multiple bands. Calls @ref `run` which must be defined in
97  subclasses that inherit from MergeSourcesTask.
98  @param[in] patchRefList list of data references for each filter
99  """
100  catalogs = dict(self.readCatalog(patchRef) for patchRef in patchRefList)
101  dataId = patchRefList[0].dataId
102  mergedCatalog = self.run(catalogs, tract=dataId['tract'], patch=dataId['patch'])
103  self.write(patchRefList[0], mergedCatalog)
104 
105  @classmethod
106  def _makeArgumentParser(cls):
107  """Create a suitable ArgumentParser.
108 
109  We will use the ArgumentParser to get a list of data
110  references for patches; the RunnerClass will sort them into lists
111  of data references for the same patch.
112 
113  References first of self.inputDatasets, rather than
114  self.inputDataset
115  """
117 
118  def readCatalog(self, patchRef):
119  """Read input catalogs
120 
121  Read all the input datasets given by the 'inputDatasets'
122  attribute.
123 
124  Parameters
125  ----------
126  patchRef : `lsst.daf.persistence.ButlerDataRef`
127  Data reference for patch
128 
129  Returns
130  -------
131  Tuple consisting of filter name and a dict of catalogs, keyed by
132  dataset name
133  """
134  filterName = patchRef.dataId["filter"]
135  catalogDict = {}
136  for dataset in self.inputDatasets:
137  catalog = patchRef.get(self.config.coaddName + "Coadd_" + dataset, immediate=True)
138  self.log.info("Read %d sources from %s for filter %s: %s" %
139  (len(catalog), dataset, filterName, patchRef.dataId))
140  catalogDict[dataset] = catalog
141  return filterName, catalogDict
142 
143  def run(self, catalogs, tract, patch):
144  """Merge multiple catalogs.
145 
146  Parameters
147  ----------
148  catalogs : `dict`
149  Mapping from filter names to dict of catalogs.
150  tract : int
151  tractId to use for the tractId column
152  patch : str
153  patchId to use for the patchId column
154 
155  Returns
156  -------
157  catalog : `lsst.pipe.tasks.parquetTable.ParquetTable`
158  Merged dataframe, with each column prefixed by
159  `filter_tag(filt)`, wrapped in the parquet writer shim class.
160  """
161 
162  dfs = []
163  for filt, tableDict in catalogs.items():
164  for dataset, table in tableDict.items():
165  # Convert afwTable to pandas DataFrame
166  df = table.asAstropy().to_pandas().set_index('id', drop=True)
167 
168  # Sort columns by name, to ensure matching schema among patches
169  df = df.reindex(sorted(df.columns), axis=1)
170  df['tractId'] = tract
171  df['patchId'] = patch
172 
173  # Make columns a 3-level MultiIndex
174  df.columns = pd.MultiIndex.from_tuples([(dataset, filt, c) for c in df.columns],
175  names=('dataset', 'filter', 'column'))
176  dfs.append(df)
177 
178  catalog = functools.reduce(lambda d1, d2: d1.join(d2), dfs)
179  return ParquetTable(dataFrame=catalog)
180 
181  def write(self, patchRef, catalog):
182  """Write the output.
183 
184  Parameters
185  ----------
186  catalog : `ParquetTable`
187  Catalog to write
188  patchRef : `lsst.daf.persistence.ButlerDataRef`
189  Data reference for patch
190  """
191  patchRef.put(catalog, self.config.coaddName + "Coadd_" + self.outputDataset)
192  # since the filter isn't actually part of the data ID for the dataset we're saving,
193  # it's confusing to see it in the log message, even if the butler simply ignores it.
194  mergeDataId = patchRef.dataId.copy()
195  del mergeDataId["filter"]
196  self.log.info("Wrote merged catalog: %s" % (mergeDataId,))
197 
198  def writeMetadata(self, dataRefList):
199  """No metadata to write, and not sure how to write it for a list of dataRefs.
200  """
201  pass
202 
203 
204 class PostprocessAnalysis(object):
205  """Calculate columns from ParquetTable
206 
207  This object manages and organizes an arbitrary set of computations
208  on a catalog. The catalog is defined by a
209  `lsst.pipe.tasks.parquetTable.ParquetTable` object (or list thereof), such as a
210  `deepCoadd_obj` dataset, and the computations are defined by a collection
211  of `lsst.pipe.tasks.functor.Functor` objects (or, equivalently,
212  a `CompositeFunctor`).
213 
214  After the object is initialized, accessing the `.df` attribute (which
215  holds the `pandas.DataFrame` containing the results of the calculations) triggers
216  computation of said dataframe.
217 
218  One of the conveniences of using this object is the ability to define a desired common
219  filter for all functors. This enables the same functor collection to be passed to
220  several different `PostprocessAnalysis` objects without having to change the original
221  functor collection, since the `filt` keyword argument of this object triggers an
222  overwrite of the `filt` property for all functors in the collection.
223 
224  This object also allows a list of flags to be passed, and defines a set of default
225  flags that are always included even if not requested.
226 
227  If a list of `ParquetTable` object is passed, rather than a single one, then the
228  calculations will be mapped over all the input catalogs. In principle, it should
229  be straightforward to parallelize this activity, but initial tests have failed
230  (see TODO in code comments).
231 
232  Parameters
233  ----------
234  parq : `lsst.pipe.tasks.ParquetTable` (or list of such)
235  Source catalog(s) for computation
236 
237  functors : `list`, `dict`, or `lsst.pipe.tasks.functors.CompositeFunctor`
238  Computations to do (functors that act on `parq`).
239  If a dict, the output
240  DataFrame will have columns keyed accordingly.
241  If a list, the column keys will come from the
242  `.shortname` attribute of each functor.
243 
244  filt : `str` (optional)
245  Filter in which to calculate. If provided,
246  this will overwrite any existing `.filt` attribute
247  of the provided functors.
248 
249  flags : `list` (optional)
250  List of flags to include in output table.
251  """
252  _defaultFlags = ('calib_psf_used', 'detect_isPrimary')
253  _defaultFuncs = (('coord_ra', RAColumn()),
254  ('coord_dec', DecColumn()))
255 
256  def __init__(self, parq, functors, filt=None, flags=None):
257  self.parq = parq
258  self.functors = functors
259 
260  self.filt = filt
261  self.flags = list(self._defaultFlags)
262  if flags is not None:
263  self.flags += list(flags)
264 
265  self._df = None
266 
267  @property
268  def defaultFuncs(self):
269  funcs = dict(self._defaultFuncs)
270  return funcs
271 
272  @property
273  def func(self):
274  additionalFuncs = self.defaultFuncs
275  additionalFuncs.update({flag: Column(flag) for flag in self.flags})
276 
277  if isinstance(self.functors, CompositeFunctor):
278  func = self.functors
279  else:
280  func = CompositeFunctor(self.functors)
281 
282  func.funcDict.update(additionalFuncs)
283  func.filt = self.filt
284 
285  return func
286 
287  @property
288  def noDupCols(self):
289  return [name for name, func in self.func.funcDict.items() if func.noDup or func.dataset == 'ref']
290 
291  @property
292  def df(self):
293  if self._df is None:
294  self.compute()
295  return self._df
296 
297  def compute(self, dropna=False, pool=None):
298  # map over multiple parquet tables
299  if type(self.parq) in (list, tuple):
300  if pool is None:
301  dflist = [self.func(parq, dropna=dropna) for parq in self.parq]
302  else:
303  # TODO: Figure out why this doesn't work (pyarrow pickling issues?)
304  dflist = pool.map(functools.partial(self.func, dropna=dropna), self.parq)
305  self._df = pd.concat(dflist)
306  else:
307  self._df = self.func(self.parq, dropna=dropna)
308 
309  return self._df
310 
311 
312 class TransformCatalogBaseConfig(pexConfig.Config):
313  coaddName = pexConfig.Field(
314  dtype=str,
315  default="deep",
316  doc="Name of coadd"
317  )
318  functorFile = pexConfig.Field(
319  dtype=str,
320  doc='Path to YAML file specifying functors to be computed',
321  default=None
322  )
323 
324 
325 class TransformCatalogBaseTask(CmdLineTask):
326  """Base class for transforming/standardizing a catalog
327 
328  by applying functors that convert units and apply calibrations.
329  The purpose of this task is to perform a set of computations on
330  an input `ParquetTable` dataset (such as `deepCoadd_obj`) and write the
331  results to a new dataset (which needs to be declared in an `outputDataset`
332  attribute).
333 
334  The calculations to be performed are defined in a YAML file that specifies
335  a set of functors to be computed, provided as
336  a `--functorFile` config parameter. An example of such a YAML file
337  is the following:
338 
339  funcs:
340  psfMag:
341  functor: Mag
342  args:
343  - base_PsfFlux
344  filt: HSC-G
345  dataset: meas
346  cmodel_magDiff:
347  functor: MagDiff
348  args:
349  - modelfit_CModel
350  - base_PsfFlux
351  filt: HSC-G
352  gauss_magDiff:
353  functor: MagDiff
354  args:
355  - base_GaussianFlux
356  - base_PsfFlux
357  filt: HSC-G
358  count:
359  functor: Column
360  args:
361  - base_InputCount_value
362  filt: HSC-G
363  deconvolved_moments:
364  functor: DeconvolvedMoments
365  filt: HSC-G
366  dataset: forced_src
367  flags:
368  - calib_psfUsed
369  - merge_measurement_i
370  - merge_measurement_r
371  - merge_measurement_z
372  - merge_measurement_y
373  - merge_measurement_g
374  - base_PixelFlags_flag_inexact_psfCenter
375  - detect_isPrimary
376 
377  The names for each entry under "func" will become the names of columns in the
378  output dataset. All the functors referenced are defined in `lsst.pipe.tasks.functors`.
379  Positional arguments to be passed to each functor are in the `args` list,
380  and any additional entries for each column other than "functor" or "args" (e.g., `'filt'`,
381  `'dataset'`) are treated as keyword arguments to be passed to the functor initialization.
382 
383  The "flags" entry is shortcut for a bunch of `Column` functors with the original column and
384  taken from the `'ref'` dataset.
385 
386  Note, if `'filter'` is provided as part of the `dataId` when running this task (even though
387  `deepCoadd_obj` does not use `'filter'`), then this will override the `filt` kwargs
388  provided in the YAML file, and the calculations will be done in that filter.
389 
390  This task uses the `lsst.pipe.tasks.postprocess.PostprocessAnalysis` object
391  to organize and excecute the calculations.
392 
393  """
394  @property
395  def _DefaultName(self):
396  raise NotImplementedError('Subclass must define "_DefaultName" attribute')
397 
398  @property
399  def outputDataset(self):
400  raise NotImplementedError('Subclass must define "outputDataset" attribute')
401 
402  @property
403  def inputDataset(self):
404  raise NotImplementedError('Subclass must define "inputDataset" attribute')
405 
406  @property
407  def ConfigClass(self):
408  raise NotImplementedError('Subclass must define "ConfigClass" attribute')
409 
410  def runDataRef(self, patchRef):
411  parq = patchRef.get()
412  dataId = patchRef.dataId
413  funcs = self.getFunctors()
414  self.log.info("Transforming/standardizing the catalog of %s", dataId)
415  df = self.run(parq, funcs=funcs, dataId=dataId)
416  self.write(df, patchRef)
417  return df
418 
419  def run(self, parq, funcs=None, dataId=None):
420  """Do postprocessing calculations
421 
422  Takes a `ParquetTable` object and dataId,
423  returns a dataframe with results of postprocessing calculations.
424 
425  Parameters
426  ----------
427  parq : `lsst.pipe.tasks.parquetTable.ParquetTable`
428  ParquetTable from which calculations are done.
429  funcs : `lsst.pipe.tasks.functors.Functors`
430  Functors to apply to the table's columns
431  dataId : dict, optional
432  Used to add a `patchId` column to the output dataframe.
433 
434  Returns
435  ------
436  `pandas.DataFrame`
437 
438  """
439  filt = dataId.get('filter', None)
440  return self.transform(filt, parq, funcs, dataId).df
441 
442  def getFunctors(self):
443  funcs = CompositeFunctor.from_file(self.config.functorFile)
444  funcs.update(dict(PostprocessAnalysis._defaultFuncs))
445  return funcs
446 
447  def getAnalysis(self, parq, funcs=None, filt=None):
448  # Avoids disk access if funcs is passed
449  if funcs is None:
450  funcs = self.getFunctors()
451  analysis = PostprocessAnalysis(parq, funcs, filt=filt)
452  return analysis
453 
454  def transform(self, filt, parq, funcs, dataId):
455  analysis = self.getAnalysis(parq, funcs=funcs, filt=filt)
456  df = analysis.df
457  if dataId is not None:
458  for key, value in dataId.items():
459  df[key] = value
460 
461  return pipeBase.Struct(
462  df=df,
463  analysis=analysis
464  )
465 
466  def write(self, df, parqRef):
467  parqRef.put(ParquetTable(dataFrame=df), self.outputDataset)
468 
469  def writeMetadata(self, dataRef):
470  """No metadata to write.
471  """
472  pass
473 
474 
475 class TransformObjectCatalogConfig(TransformCatalogBaseConfig):
476  filterMap = pexConfig.DictField(
477  keytype=str,
478  itemtype=str,
479  default={},
480  doc=("Dictionary mapping full filter name to short one for column name munging."
481  "These filters determine the output columns no matter what filters the "
482  "input data actually contain.")
483  )
484  camelCase = pexConfig.Field(
485  dtype=bool,
486  default=True,
487  doc=("Write per-filter columns names with camelCase, else underscore "
488  "For example: gPsfFlux instead of g_PsfFlux.")
489  )
490  multilevelOutput = pexConfig.Field(
491  dtype=bool,
492  default=False,
493  doc=("Whether results dataframe should have a multilevel column index (True) or be flat "
494  "and name-munged (False).")
495  )
496 
497 
499  """Compute Flatted Object Table as defined in the DPDD
500 
501  Do the same set of postprocessing calculations on all bands
502 
503  This is identical to `TransformCatalogBaseTask`, except for that it does the
504  specified functor calculations for all filters present in the
505  input `deepCoadd_obj` table. Any specific `"filt"` keywords specified
506  by the YAML file will be superceded.
507  """
508  _DefaultName = "transformObjectCatalog"
509  ConfigClass = TransformObjectCatalogConfig
510 
511  inputDataset = 'deepCoadd_obj'
512  outputDataset = 'objectTable'
513 
514  @classmethod
515  def _makeArgumentParser(cls):
516  parser = ArgumentParser(name=cls._DefaultName)
517  parser.add_id_argument("--id", cls.inputDataset,
518  ContainerClass=CoaddDataIdContainer,
519  help="data ID, e.g. --id tract=12345 patch=1,2")
520  return parser
521 
522  def run(self, parq, funcs=None, dataId=None):
523  dfDict = {}
524  analysisDict = {}
525  templateDf = pd.DataFrame()
526  # Perform transform for data of filters that exist in parq and are
527  # specified in config.filterMap
528  for filt in parq.columnLevelNames['filter']:
529  if filt not in self.config.filterMap:
530  self.log.info("Ignoring %s data in the input", filt)
531  continue
532  self.log.info("Transforming the catalog of filter %s", filt)
533  result = self.transform(filt, parq, funcs, dataId)
534  dfDict[filt] = result.df
535  analysisDict[filt] = result.analysis
536  if templateDf.empty:
537  templateDf = result.df
538 
539  # Fill NaNs in columns of other wanted filters
540  for filt in self.config.filterMap:
541  if filt not in dfDict:
542  self.log.info("Adding empty columns for filter %s", filt)
543  dfDict[filt] = pd.DataFrame().reindex_like(templateDf)
544 
545  # This makes a multilevel column index, with filter as first level
546  df = pd.concat(dfDict, axis=1, names=['filter', 'column'])
547 
548  if not self.config.multilevelOutput:
549  noDupCols = list(set.union(*[set(v.noDupCols) for v in analysisDict.values()]))
550  if dataId is not None:
551  noDupCols += list(dataId.keys())
552  df = flattenFilters(df, self.config.filterMap, noDupCols=noDupCols,
553  camelCase=self.config.camelCase)
554 
555  self.log.info("Made a table of %d columns and %d rows", len(df.columns), len(df))
556  return df
557 
558 
560 
561  def makeDataRefList(self, namespace):
562  """Make self.refList from self.idList
563 
564  Generate a list of data references given tract and/or patch.
565  This was adapted from `TractQADataIdContainer`, which was
566  `TractDataIdContainer` modifie to not require "filter".
567  Only existing dataRefs are returned.
568  """
569  def getPatchRefList(tract):
570  return [namespace.butler.dataRef(datasetType=self.datasetType,
571  tract=tract.getId(),
572  patch="%d,%d" % patch.getIndex()) for patch in tract]
573 
574  tractRefs = defaultdict(list) # Data references for each tract
575  for dataId in self.idList:
576  skymap = self.getSkymap(namespace)
577 
578  if "tract" in dataId:
579  tractId = dataId["tract"]
580  if "patch" in dataId:
581  tractRefs[tractId].append(namespace.butler.dataRef(datasetType=self.datasetType,
582  tract=tractId,
583  patch=dataId['patch']))
584  else:
585  tractRefs[tractId] += getPatchRefList(skymap[tractId])
586  else:
587  tractRefs = dict((tract.getId(), tractRefs.get(tract.getId(), []) + getPatchRefList(tract))
588  for tract in skymap)
589  outputRefList = []
590  for tractRefList in tractRefs.values():
591  existingRefs = [ref for ref in tractRefList if ref.datasetExists()]
592  outputRefList.append(existingRefs)
593 
594  self.refList = outputRefList
595 
596 
597 class ConsolidateObjectTableConfig(pexConfig.Config):
598  coaddName = pexConfig.Field(
599  dtype=str,
600  default="deep",
601  doc="Name of coadd"
602  )
603 
604 
605 class ConsolidateObjectTableTask(CmdLineTask):
606  """Write patch-merged source tables to a tract-level parquet file
607  """
608  _DefaultName = "consolidateObjectTable"
609  ConfigClass = ConsolidateObjectTableConfig
610 
611  inputDataset = 'objectTable'
612  outputDataset = 'objectTable_tract'
613 
614  @classmethod
615  def _makeArgumentParser(cls):
616  parser = ArgumentParser(name=cls._DefaultName)
617 
618  parser.add_id_argument("--id", cls.inputDataset,
619  help="data ID, e.g. --id tract=12345",
620  ContainerClass=TractObjectDataIdContainer)
621  return parser
622 
623  def runDataRef(self, patchRefList):
624  df = pd.concat([patchRef.get().toDataFrame() for patchRef in patchRefList])
625  patchRefList[0].put(ParquetTable(dataFrame=df), self.outputDataset)
626 
627  def writeMetadata(self, dataRef):
628  """No metadata to write.
629  """
630  pass
def flattenFilters(df, filterDict, noDupCols=['coord_ra', coord_dec, camelCase=False)
Definition: postprocess.py:36
def makeMergeArgumentParser(name, dataset)
Create a suitable ArgumentParser.
def run(self, parq, funcs=None, dataId=None)
Definition: postprocess.py:419
def __init__(self, parq, functors, filt=None, flags=None)
Definition: postprocess.py:256
def __init__(self, butler=None, schema=None, kwargs)
Definition: postprocess.py:88
def compute(self, dropna=False, pool=None)
Definition: postprocess.py:297
def run(self, catalogs, tract, patch)
Definition: postprocess.py:143
def runDataRef(self, patchRefList)
Merge coadd sources from multiple bands.
Definition: postprocess.py:94
def run(self, parq, funcs=None, dataId=None)
Definition: postprocess.py:522
def transform(self, filt, parq, funcs, dataId)
Definition: postprocess.py:454
def getAnalysis(self, parq, funcs=None, filt=None)
Definition: postprocess.py:447