Coverage for python/lsst/pipe/tasks/postprocess.py : 25%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_tasks
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import functools
23import pandas as pd
24from collections import defaultdict
26import lsst.pex.config as pexConfig
27import lsst.pipe.base as pipeBase
28from lsst.pipe.base import CmdLineTask, ArgumentParser
29from lsst.coadd.utils.coaddDataIdContainer import CoaddDataIdContainer
31from .parquetTable import ParquetTable
32from .multiBandUtils import makeMergeArgumentParser, MergeSourcesRunner
33from .functors import CompositeFunctor, RAColumn, DecColumn, Column
36def flattenFilters(df, filterDict, noDupCols=['coord_ra', 'coord_dec'], camelCase=False):
37 """Flattens a dataframe with multilevel column index
38 """
39 newDf = pd.DataFrame()
40 for filt, filtShort in filterDict.items():
41 subdf = df[filt]
42 columnFormat = '{0}{1}' if camelCase else '{0}_{1}'
43 newColumns = {c: columnFormat.format(filtShort, c)
44 for c in subdf.columns if c not in noDupCols}
45 cols = list(newColumns.keys())
46 newDf = pd.concat([newDf, subdf[cols].rename(columns=newColumns)], axis=1)
48 newDf = pd.concat([subdf[noDupCols], newDf], axis=1)
49 return newDf
52class WriteObjectTableConfig(pexConfig.Config):
53 priorityList = pexConfig.ListField(
54 dtype=str,
55 default=[],
56 doc="Priority-ordered list of bands for the merge."
57 )
58 engine = pexConfig.Field(
59 dtype=str,
60 default="pyarrow",
61 doc="Parquet engine for writing (pyarrow or fastparquet)"
62 )
63 coaddName = pexConfig.Field(
64 dtype=str,
65 default="deep",
66 doc="Name of coadd"
67 )
69 def validate(self):
70 pexConfig.Config.validate(self)
71 if len(self.priorityList) == 0:
72 raise RuntimeError("No priority list provided")
75class WriteObjectTableTask(CmdLineTask):
76 """Write filter-merged source tables to parquet
77 """
78 _DefaultName = "writeObjectTable"
79 ConfigClass = WriteObjectTableConfig
80 RunnerClass = MergeSourcesRunner
82 # Names of table datasets to be merged
83 inputDatasets = ('forced_src', 'meas', 'ref')
85 # Tag of output dataset written by `MergeSourcesTask.write`
86 outputDataset = 'obj'
88 def __init__(self, butler=None, schema=None, **kwargs):
89 # It is a shame that this class can't use the default init for CmdLineTask
90 # But to do so would require its own special task runner, which is many
91 # more lines of specialization, so this is how it is for now
92 CmdLineTask.__init__(self, **kwargs)
94 def runDataRef(self, patchRefList):
95 """!
96 @brief Merge coadd sources from multiple bands. Calls @ref `run` which must be defined in
97 subclasses that inherit from MergeSourcesTask.
98 @param[in] patchRefList list of data references for each filter
99 """
100 catalogs = dict(self.readCatalog(patchRef) for patchRef in patchRefList)
101 dataId = patchRefList[0].dataId
102 mergedCatalog = self.run(catalogs, tract=dataId['tract'], patch=dataId['patch'])
103 self.write(patchRefList[0], mergedCatalog)
105 @classmethod
106 def _makeArgumentParser(cls):
107 """Create a suitable ArgumentParser.
109 We will use the ArgumentParser to get a list of data
110 references for patches; the RunnerClass will sort them into lists
111 of data references for the same patch.
113 References first of self.inputDatasets, rather than
114 self.inputDataset
115 """
116 return makeMergeArgumentParser(cls._DefaultName, cls.inputDatasets[0])
118 def readCatalog(self, patchRef):
119 """Read input catalogs
121 Read all the input datasets given by the 'inputDatasets'
122 attribute.
124 Parameters
125 ----------
126 patchRef : `lsst.daf.persistence.ButlerDataRef`
127 Data reference for patch
129 Returns
130 -------
131 Tuple consisting of filter name and a dict of catalogs, keyed by
132 dataset name
133 """
134 filterName = patchRef.dataId["filter"]
135 catalogDict = {}
136 for dataset in self.inputDatasets:
137 catalog = patchRef.get(self.config.coaddName + "Coadd_" + dataset, immediate=True)
138 self.log.info("Read %d sources from %s for filter %s: %s" %
139 (len(catalog), dataset, filterName, patchRef.dataId))
140 catalogDict[dataset] = catalog
141 return filterName, catalogDict
143 def run(self, catalogs, tract, patch):
144 """Merge multiple catalogs.
146 Parameters
147 ----------
148 catalogs : `dict`
149 Mapping from filter names to dict of catalogs.
150 tract : int
151 tractId to use for the tractId column
152 patch : str
153 patchId to use for the patchId column
155 Returns
156 -------
157 catalog : `lsst.pipe.tasks.parquetTable.ParquetTable`
158 Merged dataframe, with each column prefixed by
159 `filter_tag(filt)`, wrapped in the parquet writer shim class.
160 """
162 dfs = []
163 for filt, tableDict in catalogs.items():
164 for dataset, table in tableDict.items():
165 # Convert afwTable to pandas DataFrame
166 df = table.asAstropy().to_pandas().set_index('id', drop=True)
168 # Sort columns by name, to ensure matching schema among patches
169 df = df.reindex(sorted(df.columns), axis=1)
170 df['tractId'] = tract
171 df['patchId'] = patch
173 # Make columns a 3-level MultiIndex
174 df.columns = pd.MultiIndex.from_tuples([(dataset, filt, c) for c in df.columns],
175 names=('dataset', 'filter', 'column'))
176 dfs.append(df)
178 catalog = functools.reduce(lambda d1, d2: d1.join(d2), dfs)
179 return ParquetTable(dataFrame=catalog)
181 def write(self, patchRef, catalog):
182 """Write the output.
184 Parameters
185 ----------
186 catalog : `ParquetTable`
187 Catalog to write
188 patchRef : `lsst.daf.persistence.ButlerDataRef`
189 Data reference for patch
190 """
191 patchRef.put(catalog, self.config.coaddName + "Coadd_" + self.outputDataset)
192 # since the filter isn't actually part of the data ID for the dataset we're saving,
193 # it's confusing to see it in the log message, even if the butler simply ignores it.
194 mergeDataId = patchRef.dataId.copy()
195 del mergeDataId["filter"]
196 self.log.info("Wrote merged catalog: %s" % (mergeDataId,))
198 def writeMetadata(self, dataRefList):
199 """No metadata to write, and not sure how to write it for a list of dataRefs.
200 """
201 pass
204class PostprocessAnalysis(object):
205 """Calculate columns from ParquetTable
207 This object manages and organizes an arbitrary set of computations
208 on a catalog. The catalog is defined by a
209 `lsst.pipe.tasks.parquetTable.ParquetTable` object (or list thereof), such as a
210 `deepCoadd_obj` dataset, and the computations are defined by a collection
211 of `lsst.pipe.tasks.functor.Functor` objects (or, equivalently,
212 a `CompositeFunctor`).
214 After the object is initialized, accessing the `.df` attribute (which
215 holds the `pandas.DataFrame` containing the results of the calculations) triggers
216 computation of said dataframe.
218 One of the conveniences of using this object is the ability to define a desired common
219 filter for all functors. This enables the same functor collection to be passed to
220 several different `PostprocessAnalysis` objects without having to change the original
221 functor collection, since the `filt` keyword argument of this object triggers an
222 overwrite of the `filt` property for all functors in the collection.
224 This object also allows a list of refFlags to be passed, and defines a set of default
225 refFlags that are always included even if not requested.
227 If a list of `ParquetTable` object is passed, rather than a single one, then the
228 calculations will be mapped over all the input catalogs. In principle, it should
229 be straightforward to parallelize this activity, but initial tests have failed
230 (see TODO in code comments).
232 Parameters
233 ----------
234 parq : `lsst.pipe.tasks.ParquetTable` (or list of such)
235 Source catalog(s) for computation
237 functors : `list`, `dict`, or `lsst.pipe.tasks.functors.CompositeFunctor`
238 Computations to do (functors that act on `parq`).
239 If a dict, the output
240 DataFrame will have columns keyed accordingly.
241 If a list, the column keys will come from the
242 `.shortname` attribute of each functor.
244 filt : `str` (optional)
245 Filter in which to calculate. If provided,
246 this will overwrite any existing `.filt` attribute
247 of the provided functors.
249 flags : `list` (optional)
250 List of flags (per-band) to include in output table.
252 refFlags : `list` (optional)
253 List of refFlags (only reference band) to include in output table.
256 """
257 _defaultRefFlags = ('calib_psf_used', 'detect_isPrimary')
258 _defaultFuncs = (('coord_ra', RAColumn()),
259 ('coord_dec', DecColumn()))
261 def __init__(self, parq, functors, filt=None, flags=None, refFlags=None):
262 self.parq = parq
263 self.functors = functors
265 self.filt = filt
266 self.flags = list(flags) if flags is not None else []
267 self.refFlags = list(self._defaultRefFlags)
268 if refFlags is not None:
269 self.refFlags += list(refFlags)
271 self._df = None
273 @property
274 def defaultFuncs(self):
275 funcs = dict(self._defaultFuncs)
276 return funcs
278 @property
279 def func(self):
280 additionalFuncs = self.defaultFuncs
281 additionalFuncs.update({flag: Column(flag, dataset='ref') for flag in self.refFlags})
282 additionalFuncs.update({flag: Column(flag, dataset='meas') for flag in self.flags})
284 if isinstance(self.functors, CompositeFunctor):
285 func = self.functors
286 else:
287 func = CompositeFunctor(self.functors)
289 func.funcDict.update(additionalFuncs)
290 func.filt = self.filt
292 return func
294 @property
295 def noDupCols(self):
296 return [name for name, func in self.func.funcDict.items() if func.noDup or func.dataset == 'ref']
298 @property
299 def df(self):
300 if self._df is None:
301 self.compute()
302 return self._df
304 def compute(self, dropna=False, pool=None):
305 # map over multiple parquet tables
306 if type(self.parq) in (list, tuple):
307 if pool is None:
308 dflist = [self.func(parq, dropna=dropna) for parq in self.parq]
309 else:
310 # TODO: Figure out why this doesn't work (pyarrow pickling issues?)
311 dflist = pool.map(functools.partial(self.func, dropna=dropna), self.parq)
312 self._df = pd.concat(dflist)
313 else:
314 self._df = self.func(self.parq, dropna=dropna)
316 return self._df
319class TransformCatalogBaseConfig(pexConfig.Config):
320 coaddName = pexConfig.Field(
321 dtype=str,
322 default="deep",
323 doc="Name of coadd"
324 )
325 functorFile = pexConfig.Field(
326 dtype=str,
327 doc='Path to YAML file specifying functors to be computed',
328 default=None
329 )
332class TransformCatalogBaseTask(CmdLineTask):
333 """Base class for transforming/standardizing a catalog
335 by applying functors that convert units and apply calibrations.
336 The purpose of this task is to perform a set of computations on
337 an input `ParquetTable` dataset (such as `deepCoadd_obj`) and write the
338 results to a new dataset (which needs to be declared in an `outputDataset`
339 attribute).
341 The calculations to be performed are defined in a YAML file that specifies
342 a set of functors to be computed, provided as
343 a `--functorFile` config parameter. An example of such a YAML file
344 is the following:
346 funcs:
347 psfMag:
348 functor: Mag
349 args:
350 - base_PsfFlux
351 filt: HSC-G
352 dataset: meas
353 cmodel_magDiff:
354 functor: MagDiff
355 args:
356 - modelfit_CModel
357 - base_PsfFlux
358 filt: HSC-G
359 gauss_magDiff:
360 functor: MagDiff
361 args:
362 - base_GaussianFlux
363 - base_PsfFlux
364 filt: HSC-G
365 count:
366 functor: Column
367 args:
368 - base_InputCount_value
369 filt: HSC-G
370 deconvolved_moments:
371 functor: DeconvolvedMoments
372 filt: HSC-G
373 dataset: forced_src
374 refFlags:
375 - calib_psfUsed
376 - merge_measurement_i
377 - merge_measurement_r
378 - merge_measurement_z
379 - merge_measurement_y
380 - merge_measurement_g
381 - base_PixelFlags_flag_inexact_psfCenter
382 - detect_isPrimary
384 The names for each entry under "func" will become the names of columns in the
385 output dataset. All the functors referenced are defined in `lsst.pipe.tasks.functors`.
386 Positional arguments to be passed to each functor are in the `args` list,
387 and any additional entries for each column other than "functor" or "args" (e.g., `'filt'`,
388 `'dataset'`) are treated as keyword arguments to be passed to the functor initialization.
390 The "refFlags" entry is shortcut for a bunch of `Column` functors with the original column and
391 taken from the `'ref'` dataset.
393 The "flags" entry will be expanded out per band.
395 Note, if `'filter'` is provided as part of the `dataId` when running this task (even though
396 `deepCoadd_obj` does not use `'filter'`), then this will override the `filt` kwargs
397 provided in the YAML file, and the calculations will be done in that filter.
399 This task uses the `lsst.pipe.tasks.postprocess.PostprocessAnalysis` object
400 to organize and excecute the calculations.
402 """
403 @property
404 def _DefaultName(self):
405 raise NotImplementedError('Subclass must define "_DefaultName" attribute')
407 @property
408 def outputDataset(self):
409 raise NotImplementedError('Subclass must define "outputDataset" attribute')
411 @property
412 def inputDataset(self):
413 raise NotImplementedError('Subclass must define "inputDataset" attribute')
415 @property
416 def ConfigClass(self):
417 raise NotImplementedError('Subclass must define "ConfigClass" attribute')
419 def runDataRef(self, patchRef):
420 parq = patchRef.get()
421 dataId = patchRef.dataId
422 funcs = self.getFunctors()
423 self.log.info("Transforming/standardizing the catalog of %s", dataId)
424 df = self.run(parq, funcs=funcs, dataId=dataId)
425 self.write(df, patchRef)
426 return df
428 def run(self, parq, funcs=None, dataId=None):
429 """Do postprocessing calculations
431 Takes a `ParquetTable` object and dataId,
432 returns a dataframe with results of postprocessing calculations.
434 Parameters
435 ----------
436 parq : `lsst.pipe.tasks.parquetTable.ParquetTable`
437 ParquetTable from which calculations are done.
438 funcs : `lsst.pipe.tasks.functors.Functors`
439 Functors to apply to the table's columns
440 dataId : dict, optional
441 Used to add a `patchId` column to the output dataframe.
443 Returns
444 ------
445 `pandas.DataFrame`
447 """
448 filt = dataId.get('filter', None)
449 return self.transform(filt, parq, funcs, dataId).df
451 def getFunctors(self):
452 funcs = CompositeFunctor.from_file(self.config.functorFile)
453 funcs.update(dict(PostprocessAnalysis._defaultFuncs))
454 return funcs
456 def getAnalysis(self, parq, funcs=None, filt=None):
457 # Avoids disk access if funcs is passed
458 if funcs is None:
459 funcs = self.getFunctors()
460 analysis = PostprocessAnalysis(parq, funcs, filt=filt)
461 return analysis
463 def transform(self, filt, parq, funcs, dataId):
464 analysis = self.getAnalysis(parq, funcs=funcs, filt=filt)
465 df = analysis.df
466 if dataId is not None:
467 for key, value in dataId.items():
468 df[key] = value
470 return pipeBase.Struct(
471 df=df,
472 analysis=analysis
473 )
475 def write(self, df, parqRef):
476 parqRef.put(ParquetTable(dataFrame=df), self.outputDataset)
478 def writeMetadata(self, dataRef):
479 """No metadata to write.
480 """
481 pass
484class TransformObjectCatalogConfig(TransformCatalogBaseConfig):
485 filterMap = pexConfig.DictField(
486 keytype=str,
487 itemtype=str,
488 default={},
489 doc=("Dictionary mapping full filter name to short one for column name munging."
490 "These filters determine the output columns no matter what filters the "
491 "input data actually contain.")
492 )
493 camelCase = pexConfig.Field(
494 dtype=bool,
495 default=True,
496 doc=("Write per-filter columns names with camelCase, else underscore "
497 "For example: gPsfFlux instead of g_PsfFlux.")
498 )
499 multilevelOutput = pexConfig.Field(
500 dtype=bool,
501 default=False,
502 doc=("Whether results dataframe should have a multilevel column index (True) or be flat "
503 "and name-munged (False).")
504 )
507class TransformObjectCatalogTask(TransformCatalogBaseTask):
508 """Compute Flatted Object Table as defined in the DPDD
510 Do the same set of postprocessing calculations on all bands
512 This is identical to `TransformCatalogBaseTask`, except for that it does the
513 specified functor calculations for all filters present in the
514 input `deepCoadd_obj` table. Any specific `"filt"` keywords specified
515 by the YAML file will be superceded.
516 """
517 _DefaultName = "transformObjectCatalog"
518 ConfigClass = TransformObjectCatalogConfig
520 inputDataset = 'deepCoadd_obj'
521 outputDataset = 'objectTable'
523 @classmethod
524 def _makeArgumentParser(cls):
525 parser = ArgumentParser(name=cls._DefaultName)
526 parser.add_id_argument("--id", cls.inputDataset,
527 ContainerClass=CoaddDataIdContainer,
528 help="data ID, e.g. --id tract=12345 patch=1,2")
529 return parser
531 def run(self, parq, funcs=None, dataId=None):
532 dfDict = {}
533 analysisDict = {}
534 templateDf = pd.DataFrame()
535 # Perform transform for data of filters that exist in parq and are
536 # specified in config.filterMap
537 for filt in parq.columnLevelNames['filter']:
538 if filt not in self.config.filterMap:
539 self.log.info("Ignoring %s data in the input", filt)
540 continue
541 self.log.info("Transforming the catalog of filter %s", filt)
542 result = self.transform(filt, parq, funcs, dataId)
543 dfDict[filt] = result.df
544 analysisDict[filt] = result.analysis
545 if templateDf.empty:
546 templateDf = result.df
548 # Fill NaNs in columns of other wanted filters
549 for filt in self.config.filterMap:
550 if filt not in dfDict:
551 self.log.info("Adding empty columns for filter %s", filt)
552 dfDict[filt] = pd.DataFrame().reindex_like(templateDf)
554 # This makes a multilevel column index, with filter as first level
555 df = pd.concat(dfDict, axis=1, names=['filter', 'column'])
557 if not self.config.multilevelOutput:
558 noDupCols = list(set.union(*[set(v.noDupCols) for v in analysisDict.values()]))
559 if dataId is not None:
560 noDupCols += list(dataId.keys())
561 df = flattenFilters(df, self.config.filterMap, noDupCols=noDupCols,
562 camelCase=self.config.camelCase)
564 self.log.info("Made a table of %d columns and %d rows", len(df.columns), len(df))
565 return df
568class TractObjectDataIdContainer(CoaddDataIdContainer):
570 def makeDataRefList(self, namespace):
571 """Make self.refList from self.idList
573 Generate a list of data references given tract and/or patch.
574 This was adapted from `TractQADataIdContainer`, which was
575 `TractDataIdContainer` modifie to not require "filter".
576 Only existing dataRefs are returned.
577 """
578 def getPatchRefList(tract):
579 return [namespace.butler.dataRef(datasetType=self.datasetType,
580 tract=tract.getId(),
581 patch="%d,%d" % patch.getIndex()) for patch in tract]
583 tractRefs = defaultdict(list) # Data references for each tract
584 for dataId in self.idList:
585 skymap = self.getSkymap(namespace)
587 if "tract" in dataId:
588 tractId = dataId["tract"]
589 if "patch" in dataId:
590 tractRefs[tractId].append(namespace.butler.dataRef(datasetType=self.datasetType,
591 tract=tractId,
592 patch=dataId['patch']))
593 else:
594 tractRefs[tractId] += getPatchRefList(skymap[tractId])
595 else:
596 tractRefs = dict((tract.getId(), tractRefs.get(tract.getId(), []) + getPatchRefList(tract))
597 for tract in skymap)
598 outputRefList = []
599 for tractRefList in tractRefs.values():
600 existingRefs = [ref for ref in tractRefList if ref.datasetExists()]
601 outputRefList.append(existingRefs)
603 self.refList = outputRefList
606class ConsolidateObjectTableConfig(pexConfig.Config):
607 coaddName = pexConfig.Field(
608 dtype=str,
609 default="deep",
610 doc="Name of coadd"
611 )
614class ConsolidateObjectTableTask(CmdLineTask):
615 """Write patch-merged source tables to a tract-level parquet file
616 """
617 _DefaultName = "consolidateObjectTable"
618 ConfigClass = ConsolidateObjectTableConfig
620 inputDataset = 'objectTable'
621 outputDataset = 'objectTable_tract'
623 @classmethod
624 def _makeArgumentParser(cls):
625 parser = ArgumentParser(name=cls._DefaultName)
627 parser.add_id_argument("--id", cls.inputDataset,
628 help="data ID, e.g. --id tract=12345",
629 ContainerClass=TractObjectDataIdContainer)
630 return parser
632 def runDataRef(self, patchRefList):
633 df = pd.concat([patchRef.get().toDataFrame() for patchRef in patchRefList])
634 patchRefList[0].put(ParquetTable(dataFrame=df), self.outputDataset)
636 def writeMetadata(self, dataRef):
637 """No metadata to write.
638 """
639 pass