Coverage for python/lsst/pipe/tasks/postprocess.py : 28%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_tasks
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import functools
23import pandas as pd
24from collections import defaultdict
26import lsst.geom
27import lsst.pex.config as pexConfig
28import lsst.pipe.base as pipeBase
29import lsst.daf.base as dafBase
30from lsst.pipe.base import connectionTypes
31import lsst.afw.table as afwTable
32from lsst.meas.base import SingleFrameMeasurementTask
33from lsst.pipe.base import CmdLineTask, ArgumentParser, DataIdContainer
34from lsst.coadd.utils.coaddDataIdContainer import CoaddDataIdContainer
35from lsst.daf.butler import DeferredDatasetHandle
37from .parquetTable import ParquetTable
38from .multiBandUtils import makeMergeArgumentParser, MergeSourcesRunner
39from .functors import CompositeFunctor, RAColumn, DecColumn, Column
42def flattenFilters(df, noDupCols=['coord_ra', 'coord_dec'], camelCase=False, inputBands=None):
43 """Flattens a dataframe with multilevel column index
44 """
45 newDf = pd.DataFrame()
46 # band is the level 0 index
47 dfBands = df.columns.unique(level=0).values
48 for band in dfBands:
49 subdf = df[band]
50 columnFormat = '{0}{1}' if camelCase else '{0}_{1}'
51 newColumns = {c: columnFormat.format(band, c)
52 for c in subdf.columns if c not in noDupCols}
53 cols = list(newColumns.keys())
54 newDf = pd.concat([newDf, subdf[cols].rename(columns=newColumns)], axis=1)
56 # Band must be present in the input and output or else column is all NaN:
57 presentBands = dfBands if inputBands is None else list(set(inputBands).intersection(dfBands))
58 # Get the unexploded columns from any present band's partition
59 noDupDf = df[presentBands[0]][noDupCols]
60 newDf = pd.concat([noDupDf, newDf], axis=1)
61 return newDf
64class WriteObjectTableConnections(pipeBase.PipelineTaskConnections,
65 defaultTemplates={"coaddName": "deep"},
66 dimensions=("tract", "patch", "skymap")):
67 inputCatalogMeas = connectionTypes.Input(
68 doc="Catalog of source measurements on the deepCoadd.",
69 dimensions=("tract", "patch", "band", "skymap"),
70 storageClass="SourceCatalog",
71 name="{coaddName}Coadd_meas",
72 multiple=True
73 )
74 inputCatalogForcedSrc = connectionTypes.Input(
75 doc="Catalog of forced measurements (shape and position parameters held fixed) on the deepCoadd.",
76 dimensions=("tract", "patch", "band", "skymap"),
77 storageClass="SourceCatalog",
78 name="{coaddName}Coadd_forced_src",
79 multiple=True
80 )
81 inputCatalogRef = connectionTypes.Input(
82 doc="Catalog marking the primary detection (which band provides a good shape and position)"
83 "for each detection in deepCoadd_mergeDet.",
84 dimensions=("tract", "patch", "skymap"),
85 storageClass="SourceCatalog",
86 name="{coaddName}Coadd_ref"
87 )
88 outputCatalog = connectionTypes.Output(
89 doc="A vertical concatenation of the deepCoadd_{ref|meas|forced_src} catalogs, "
90 "stored as a DataFrame with a multi-level column index per-patch.",
91 dimensions=("tract", "patch", "skymap"),
92 storageClass="DataFrame",
93 name="{coaddName}Coadd_obj"
94 )
97class WriteObjectTableConfig(pipeBase.PipelineTaskConfig,
98 pipelineConnections=WriteObjectTableConnections):
99 engine = pexConfig.Field(
100 dtype=str,
101 default="pyarrow",
102 doc="Parquet engine for writing (pyarrow or fastparquet)"
103 )
104 coaddName = pexConfig.Field(
105 dtype=str,
106 default="deep",
107 doc="Name of coadd"
108 )
111class WriteObjectTableTask(CmdLineTask, pipeBase.PipelineTask):
112 """Write filter-merged source tables to parquet
113 """
114 _DefaultName = "writeObjectTable"
115 ConfigClass = WriteObjectTableConfig
116 RunnerClass = MergeSourcesRunner
118 # Names of table datasets to be merged
119 inputDatasets = ('forced_src', 'meas', 'ref')
121 # Tag of output dataset written by `MergeSourcesTask.write`
122 outputDataset = 'obj'
124 def __init__(self, butler=None, schema=None, **kwargs):
125 # It is a shame that this class can't use the default init for CmdLineTask
126 # But to do so would require its own special task runner, which is many
127 # more lines of specialization, so this is how it is for now
128 super().__init__(**kwargs)
130 def runDataRef(self, patchRefList):
131 """!
132 @brief Merge coadd sources from multiple bands. Calls @ref `run` which must be defined in
133 subclasses that inherit from MergeSourcesTask.
134 @param[in] patchRefList list of data references for each filter
135 """
136 catalogs = dict(self.readCatalog(patchRef) for patchRef in patchRefList)
137 dataId = patchRefList[0].dataId
138 mergedCatalog = self.run(catalogs, tract=dataId['tract'], patch=dataId['patch'])
139 self.write(patchRefList[0], ParquetTable(dataFrame=mergedCatalog))
141 def runQuantum(self, butlerQC, inputRefs, outputRefs):
142 inputs = butlerQC.get(inputRefs)
144 measDict = {ref.dataId['band']: {'meas': cat} for ref, cat in
145 zip(inputRefs.inputCatalogMeas, inputs['inputCatalogMeas'])}
146 forcedSourceDict = {ref.dataId['band']: {'forced_src': cat} for ref, cat in
147 zip(inputRefs.inputCatalogForcedSrc, inputs['inputCatalogForcedSrc'])}
149 catalogs = {}
150 for band in measDict.keys():
151 catalogs[band] = {'meas': measDict[band]['meas'],
152 'forced_src': forcedSourceDict[band]['forced_src'],
153 'ref': inputs['inputCatalogRef']}
154 dataId = butlerQC.quantum.dataId
155 df = self.run(catalogs=catalogs, tract=dataId['tract'], patch=dataId['patch'])
156 outputs = pipeBase.Struct(outputCatalog=df)
157 butlerQC.put(outputs, outputRefs)
159 @classmethod
160 def _makeArgumentParser(cls):
161 """Create a suitable ArgumentParser.
163 We will use the ArgumentParser to get a list of data
164 references for patches; the RunnerClass will sort them into lists
165 of data references for the same patch.
167 References first of self.inputDatasets, rather than
168 self.inputDataset
169 """
170 return makeMergeArgumentParser(cls._DefaultName, cls.inputDatasets[0])
172 def readCatalog(self, patchRef):
173 """Read input catalogs
175 Read all the input datasets given by the 'inputDatasets'
176 attribute.
178 Parameters
179 ----------
180 patchRef : `lsst.daf.persistence.ButlerDataRef`
181 Data reference for patch
183 Returns
184 -------
185 Tuple consisting of band name and a dict of catalogs, keyed by
186 dataset name
187 """
188 band = patchRef.get(self.config.coaddName + "Coadd_filterLabel", immediate=True).bandLabel
189 catalogDict = {}
190 for dataset in self.inputDatasets:
191 catalog = patchRef.get(self.config.coaddName + "Coadd_" + dataset, immediate=True)
192 self.log.info("Read %d sources from %s for band %s: %s" %
193 (len(catalog), dataset, band, patchRef.dataId))
194 catalogDict[dataset] = catalog
195 return band, catalogDict
197 def run(self, catalogs, tract, patch):
198 """Merge multiple catalogs.
200 Parameters
201 ----------
202 catalogs : `dict`
203 Mapping from filter names to dict of catalogs.
204 tract : int
205 tractId to use for the tractId column
206 patch : str
207 patchId to use for the patchId column
209 Returns
210 -------
211 catalog : `pandas.DataFrame`
212 Merged dataframe
213 """
215 dfs = []
216 for filt, tableDict in catalogs.items():
217 for dataset, table in tableDict.items():
218 # Convert afwTable to pandas DataFrame
219 df = table.asAstropy().to_pandas().set_index('id', drop=True)
221 # Sort columns by name, to ensure matching schema among patches
222 df = df.reindex(sorted(df.columns), axis=1)
223 df['tractId'] = tract
224 df['patchId'] = patch
226 # Make columns a 3-level MultiIndex
227 df.columns = pd.MultiIndex.from_tuples([(dataset, filt, c) for c in df.columns],
228 names=('dataset', 'band', 'column'))
229 dfs.append(df)
231 catalog = functools.reduce(lambda d1, d2: d1.join(d2), dfs)
232 return catalog
234 def write(self, patchRef, catalog):
235 """Write the output.
237 Parameters
238 ----------
239 catalog : `ParquetTable`
240 Catalog to write
241 patchRef : `lsst.daf.persistence.ButlerDataRef`
242 Data reference for patch
243 """
244 patchRef.put(catalog, self.config.coaddName + "Coadd_" + self.outputDataset)
245 # since the filter isn't actually part of the data ID for the dataset we're saving,
246 # it's confusing to see it in the log message, even if the butler simply ignores it.
247 mergeDataId = patchRef.dataId.copy()
248 del mergeDataId["filter"]
249 self.log.info("Wrote merged catalog: %s" % (mergeDataId,))
251 def writeMetadata(self, dataRefList):
252 """No metadata to write, and not sure how to write it for a list of dataRefs.
253 """
254 pass
257class WriteSourceTableConnections(pipeBase.PipelineTaskConnections,
258 defaultTemplates={"catalogType": ""},
259 dimensions=("instrument", "visit", "detector")):
261 catalog = connectionTypes.Input(
262 doc="Input full-depth catalog of sources produced by CalibrateTask",
263 name="{catalogType}src",
264 storageClass="SourceCatalog",
265 dimensions=("instrument", "visit", "detector")
266 )
267 outputCatalog = connectionTypes.Output(
268 doc="Catalog of sources, `src` in Parquet format",
269 name="{catalogType}source",
270 storageClass="DataFrame",
271 dimensions=("instrument", "visit", "detector")
272 )
275class WriteSourceTableConfig(pipeBase.PipelineTaskConfig,
276 pipelineConnections=WriteSourceTableConnections):
277 doApplyExternalPhotoCalib = pexConfig.Field(
278 dtype=bool,
279 default=False,
280 doc=("Add local photoCalib columns from the calexp.photoCalib? Should only set True if "
281 "generating Source Tables from older src tables which do not already have local calib columns")
282 )
283 doApplyExternalSkyWcs = pexConfig.Field(
284 dtype=bool,
285 default=False,
286 doc=("Add local WCS columns from the calexp.wcs? Should only set True if "
287 "generating Source Tables from older src tables which do not already have local calib columns")
288 )
291class WriteSourceTableTask(CmdLineTask, pipeBase.PipelineTask):
292 """Write source table to parquet
293 """
294 _DefaultName = "writeSourceTable"
295 ConfigClass = WriteSourceTableConfig
297 def runDataRef(self, dataRef):
298 src = dataRef.get('src')
299 if self.config.doApplyExternalPhotoCalib or self.config.doApplyExternalSkyWcs:
300 src = self.addCalibColumns(src, dataRef)
302 ccdVisitId = dataRef.get('ccdExposureId')
303 result = self.run(src, ccdVisitId=ccdVisitId)
304 dataRef.put(result.table, 'source')
306 def runQuantum(self, butlerQC, inputRefs, outputRefs):
307 inputs = butlerQC.get(inputRefs)
308 inputs['ccdVisitId'] = butlerQC.quantum.dataId.pack("visit_detector")
309 result = self.run(**inputs).table
310 outputs = pipeBase.Struct(outputCatalog=result.toDataFrame())
311 butlerQC.put(outputs, outputRefs)
313 def run(self, catalog, ccdVisitId=None):
314 """Convert `src` catalog to parquet
316 Parameters
317 ----------
318 catalog: `afwTable.SourceCatalog`
319 catalog to be converted
320 ccdVisitId: `int`
321 ccdVisitId to be added as a column
323 Returns
324 -------
325 result : `lsst.pipe.base.Struct`
326 ``table``
327 `ParquetTable` version of the input catalog
328 """
329 self.log.info("Generating parquet table from src catalog %s", ccdVisitId)
330 df = catalog.asAstropy().to_pandas().set_index('id', drop=True)
331 df['ccdVisitId'] = ccdVisitId
332 return pipeBase.Struct(table=ParquetTable(dataFrame=df))
334 def addCalibColumns(self, catalog, dataRef):
335 """Add columns with local calibration evaluated at each centroid
337 for backwards compatibility with old repos.
338 This exists for the purpose of converting old src catalogs
339 (which don't have the expected local calib columns) to Source Tables.
341 Parameters
342 ----------
343 catalog: `afwTable.SourceCatalog`
344 catalog to which calib columns will be added
345 dataRef: `lsst.daf.persistence.ButlerDataRef
346 for fetching the calibs from disk.
348 Returns
349 -------
350 newCat: `afwTable.SourceCatalog`
351 Source Catalog with requested local calib columns
352 """
353 mapper = afwTable.SchemaMapper(catalog.schema)
354 measureConfig = SingleFrameMeasurementTask.ConfigClass()
355 measureConfig.doReplaceWithNoise = False
357 # Just need the WCS or the PhotoCalib attached to an exposue
358 exposure = dataRef.get('calexp_sub',
359 bbox=lsst.geom.Box2I(lsst.geom.Point2I(0, 0), lsst.geom.Point2I(0, 0)))
361 mapper = afwTable.SchemaMapper(catalog.schema)
362 mapper.addMinimalSchema(catalog.schema, True)
363 schema = mapper.getOutputSchema()
365 exposureIdInfo = dataRef.get("expIdInfo")
366 measureConfig.plugins.names = []
367 if self.config.doApplyExternalSkyWcs:
368 plugin = 'base_LocalWcs'
369 if plugin in schema:
370 raise RuntimeError(f"{plugin} already in src catalog. Set doApplyExternalSkyWcs=False")
371 else:
372 measureConfig.plugins.names.add(plugin)
374 if self.config.doApplyExternalPhotoCalib:
375 plugin = 'base_LocalPhotoCalib'
376 if plugin in schema:
377 raise RuntimeError(f"{plugin} already in src catalog. Set doApplyExternalPhotoCalib=False")
378 else:
379 measureConfig.plugins.names.add(plugin)
381 measurement = SingleFrameMeasurementTask(config=measureConfig, schema=schema)
382 newCat = afwTable.SourceCatalog(schema)
383 newCat.extend(catalog, mapper=mapper)
384 measurement.run(measCat=newCat, exposure=exposure, exposureId=exposureIdInfo.expId)
385 return newCat
387 def writeMetadata(self, dataRef):
388 """No metadata to write.
389 """
390 pass
392 @classmethod
393 def _makeArgumentParser(cls):
394 parser = ArgumentParser(name=cls._DefaultName)
395 parser.add_id_argument("--id", 'src',
396 help="data ID, e.g. --id visit=12345 ccd=0")
397 return parser
400class PostprocessAnalysis(object):
401 """Calculate columns from ParquetTable
403 This object manages and organizes an arbitrary set of computations
404 on a catalog. The catalog is defined by a
405 `lsst.pipe.tasks.parquetTable.ParquetTable` object (or list thereof), such as a
406 `deepCoadd_obj` dataset, and the computations are defined by a collection
407 of `lsst.pipe.tasks.functor.Functor` objects (or, equivalently,
408 a `CompositeFunctor`).
410 After the object is initialized, accessing the `.df` attribute (which
411 holds the `pandas.DataFrame` containing the results of the calculations) triggers
412 computation of said dataframe.
414 One of the conveniences of using this object is the ability to define a desired common
415 filter for all functors. This enables the same functor collection to be passed to
416 several different `PostprocessAnalysis` objects without having to change the original
417 functor collection, since the `filt` keyword argument of this object triggers an
418 overwrite of the `filt` property for all functors in the collection.
420 This object also allows a list of refFlags to be passed, and defines a set of default
421 refFlags that are always included even if not requested.
423 If a list of `ParquetTable` object is passed, rather than a single one, then the
424 calculations will be mapped over all the input catalogs. In principle, it should
425 be straightforward to parallelize this activity, but initial tests have failed
426 (see TODO in code comments).
428 Parameters
429 ----------
430 parq : `lsst.pipe.tasks.ParquetTable` (or list of such)
431 Source catalog(s) for computation
433 functors : `list`, `dict`, or `lsst.pipe.tasks.functors.CompositeFunctor`
434 Computations to do (functors that act on `parq`).
435 If a dict, the output
436 DataFrame will have columns keyed accordingly.
437 If a list, the column keys will come from the
438 `.shortname` attribute of each functor.
440 filt : `str` (optional)
441 Filter in which to calculate. If provided,
442 this will overwrite any existing `.filt` attribute
443 of the provided functors.
445 flags : `list` (optional)
446 List of flags (per-band) to include in output table.
448 refFlags : `list` (optional)
449 List of refFlags (only reference band) to include in output table.
452 """
453 _defaultRefFlags = []
454 _defaultFuncs = (('coord_ra', RAColumn()),
455 ('coord_dec', DecColumn()))
457 def __init__(self, parq, functors, filt=None, flags=None, refFlags=None):
458 self.parq = parq
459 self.functors = functors
461 self.filt = filt
462 self.flags = list(flags) if flags is not None else []
463 self.refFlags = list(self._defaultRefFlags)
464 if refFlags is not None:
465 self.refFlags += list(refFlags)
467 self._df = None
469 @property
470 def defaultFuncs(self):
471 funcs = dict(self._defaultFuncs)
472 return funcs
474 @property
475 def func(self):
476 additionalFuncs = self.defaultFuncs
477 additionalFuncs.update({flag: Column(flag, dataset='ref') for flag in self.refFlags})
478 additionalFuncs.update({flag: Column(flag, dataset='meas') for flag in self.flags})
480 if isinstance(self.functors, CompositeFunctor):
481 func = self.functors
482 else:
483 func = CompositeFunctor(self.functors)
485 func.funcDict.update(additionalFuncs)
486 func.filt = self.filt
488 return func
490 @property
491 def noDupCols(self):
492 return [name for name, func in self.func.funcDict.items() if func.noDup or func.dataset == 'ref']
494 @property
495 def df(self):
496 if self._df is None:
497 self.compute()
498 return self._df
500 def compute(self, dropna=False, pool=None):
501 # map over multiple parquet tables
502 if type(self.parq) in (list, tuple):
503 if pool is None:
504 dflist = [self.func(parq, dropna=dropna) for parq in self.parq]
505 else:
506 # TODO: Figure out why this doesn't work (pyarrow pickling issues?)
507 dflist = pool.map(functools.partial(self.func, dropna=dropna), self.parq)
508 self._df = pd.concat(dflist)
509 else:
510 self._df = self.func(self.parq, dropna=dropna)
512 return self._df
515class TransformCatalogBaseConnections(pipeBase.PipelineTaskConnections,
516 dimensions=()):
517 """Expected Connections for subclasses of TransformCatalogBaseTask.
519 Must be subclassed.
520 """
521 inputCatalog = connectionTypes.Input(
522 name="",
523 storageClass="DataFrame",
524 )
525 outputCatalog = connectionTypes.Output(
526 name="",
527 storageClass="DataFrame",
528 )
531class TransformCatalogBaseConfig(pipeBase.PipelineTaskConfig,
532 pipelineConnections=TransformCatalogBaseConnections):
533 functorFile = pexConfig.Field(
534 dtype=str,
535 doc='Path to YAML file specifying functors to be computed',
536 default=None,
537 optional=True
538 )
541class TransformCatalogBaseTask(CmdLineTask, pipeBase.PipelineTask):
542 """Base class for transforming/standardizing a catalog
544 by applying functors that convert units and apply calibrations.
545 The purpose of this task is to perform a set of computations on
546 an input `ParquetTable` dataset (such as `deepCoadd_obj`) and write the
547 results to a new dataset (which needs to be declared in an `outputDataset`
548 attribute).
550 The calculations to be performed are defined in a YAML file that specifies
551 a set of functors to be computed, provided as
552 a `--functorFile` config parameter. An example of such a YAML file
553 is the following:
555 funcs:
556 psfMag:
557 functor: Mag
558 args:
559 - base_PsfFlux
560 filt: HSC-G
561 dataset: meas
562 cmodel_magDiff:
563 functor: MagDiff
564 args:
565 - modelfit_CModel
566 - base_PsfFlux
567 filt: HSC-G
568 gauss_magDiff:
569 functor: MagDiff
570 args:
571 - base_GaussianFlux
572 - base_PsfFlux
573 filt: HSC-G
574 count:
575 functor: Column
576 args:
577 - base_InputCount_value
578 filt: HSC-G
579 deconvolved_moments:
580 functor: DeconvolvedMoments
581 filt: HSC-G
582 dataset: forced_src
583 refFlags:
584 - calib_psfUsed
585 - merge_measurement_i
586 - merge_measurement_r
587 - merge_measurement_z
588 - merge_measurement_y
589 - merge_measurement_g
590 - base_PixelFlags_flag_inexact_psfCenter
591 - detect_isPrimary
593 The names for each entry under "func" will become the names of columns in the
594 output dataset. All the functors referenced are defined in `lsst.pipe.tasks.functors`.
595 Positional arguments to be passed to each functor are in the `args` list,
596 and any additional entries for each column other than "functor" or "args" (e.g., `'filt'`,
597 `'dataset'`) are treated as keyword arguments to be passed to the functor initialization.
599 The "refFlags" entry is shortcut for a bunch of `Column` functors with the original column and
600 taken from the `'ref'` dataset.
602 The "flags" entry will be expanded out per band.
604 This task uses the `lsst.pipe.tasks.postprocess.PostprocessAnalysis` object
605 to organize and excecute the calculations.
607 """
608 @property
609 def _DefaultName(self):
610 raise NotImplementedError('Subclass must define "_DefaultName" attribute')
612 @property
613 def outputDataset(self):
614 raise NotImplementedError('Subclass must define "outputDataset" attribute')
616 @property
617 def inputDataset(self):
618 raise NotImplementedError('Subclass must define "inputDataset" attribute')
620 @property
621 def ConfigClass(self):
622 raise NotImplementedError('Subclass must define "ConfigClass" attribute')
624 def __init__(self, *args, **kwargs):
625 super().__init__(*args, **kwargs)
626 if self.config.functorFile:
627 self.log.info('Loading tranform functor definitions from %s',
628 self.config.functorFile)
629 self.funcs = CompositeFunctor.from_file(self.config.functorFile)
630 self.funcs.update(dict(PostprocessAnalysis._defaultFuncs))
631 else:
632 self.funcs = None
634 def runQuantum(self, butlerQC, inputRefs, outputRefs):
635 inputs = butlerQC.get(inputRefs)
636 if self.funcs is None:
637 raise ValueError("config.functorFile is None. "
638 "Must be a valid path to yaml in order to run Task as a PipelineTask.")
639 result = self.run(parq=inputs['inputCatalog'], funcs=self.funcs,
640 dataId=outputRefs.outputCatalog.dataId.full)
641 outputs = pipeBase.Struct(outputCatalog=result)
642 butlerQC.put(outputs, outputRefs)
644 def runDataRef(self, dataRef):
645 parq = dataRef.get()
646 if self.funcs is None:
647 raise ValueError("config.functorFile is None. "
648 "Must be a valid path to yaml in order to run as a CommandlineTask.")
649 df = self.run(parq, funcs=self.funcs, dataId=dataRef.dataId)
650 self.write(df, dataRef)
651 return df
653 def run(self, parq, funcs=None, dataId=None, band=None):
654 """Do postprocessing calculations
656 Takes a `ParquetTable` object and dataId,
657 returns a dataframe with results of postprocessing calculations.
659 Parameters
660 ----------
661 parq : `lsst.pipe.tasks.parquetTable.ParquetTable`
662 ParquetTable from which calculations are done.
663 funcs : `lsst.pipe.tasks.functors.Functors`
664 Functors to apply to the table's columns
665 dataId : dict, optional
666 Used to add a `patchId` column to the output dataframe.
667 band : `str`, optional
668 Filter band that is being processed.
670 Returns
671 ------
672 `pandas.DataFrame`
674 """
675 self.log.info("Transforming/standardizing the source table dataId: %s", dataId)
677 df = self.transform(band, parq, funcs, dataId).df
678 self.log.info("Made a table of %d columns and %d rows", len(df.columns), len(df))
679 return df
681 def getFunctors(self):
682 return self.funcs
684 def getAnalysis(self, parq, funcs=None, band=None):
685 if funcs is None:
686 funcs = self.funcs
687 analysis = PostprocessAnalysis(parq, funcs, filt=band)
688 return analysis
690 def transform(self, band, parq, funcs, dataId):
691 analysis = self.getAnalysis(parq, funcs=funcs, band=band)
692 df = analysis.df
693 if dataId is not None:
694 for key, value in dataId.items():
695 df[str(key)] = value
697 return pipeBase.Struct(
698 df=df,
699 analysis=analysis
700 )
702 def write(self, df, parqRef):
703 parqRef.put(ParquetTable(dataFrame=df), self.outputDataset)
705 def writeMetadata(self, dataRef):
706 """No metadata to write.
707 """
708 pass
711class TransformObjectCatalogConnections(pipeBase.PipelineTaskConnections,
712 defaultTemplates={"coaddName": "deep"},
713 dimensions=("tract", "patch", "skymap")):
714 inputCatalog = connectionTypes.Input(
715 doc="The vertical concatenation of the deepCoadd_{ref|meas|forced_src} catalogs, "
716 "stored as a DataFrame with a multi-level column index per-patch.",
717 dimensions=("tract", "patch", "skymap"),
718 storageClass="DataFrame",
719 name="{coaddName}Coadd_obj",
720 deferLoad=True,
721 )
722 outputCatalog = connectionTypes.Output(
723 doc="Per-Patch Object Table of columns transformed from the deepCoadd_obj table per the standard "
724 "data model.",
725 dimensions=("tract", "patch", "skymap"),
726 storageClass="DataFrame",
727 name="objectTable"
728 )
731class TransformObjectCatalogConfig(TransformCatalogBaseConfig,
732 pipelineConnections=TransformObjectCatalogConnections):
733 coaddName = pexConfig.Field(
734 dtype=str,
735 default="deep",
736 doc="Name of coadd"
737 )
738 # TODO: remove in DM-27177
739 filterMap = pexConfig.DictField(
740 keytype=str,
741 itemtype=str,
742 default={},
743 doc=("Dictionary mapping full filter name to short one for column name munging."
744 "These filters determine the output columns no matter what filters the "
745 "input data actually contain."),
746 deprecated=("Coadds are now identified by the band, so this transform is unused."
747 "Will be removed after v22.")
748 )
749 outputBands = pexConfig.ListField(
750 dtype=str,
751 default=None,
752 optional=True,
753 doc=("These bands and only these bands will appear in the output,"
754 " NaN-filled if the input does not include them."
755 " If None, then use all bands found in the input.")
756 )
757 camelCase = pexConfig.Field(
758 dtype=bool,
759 default=True,
760 doc=("Write per-band columns names with camelCase, else underscore "
761 "For example: gPsFlux instead of g_PsFlux.")
762 )
763 multilevelOutput = pexConfig.Field(
764 dtype=bool,
765 default=False,
766 doc=("Whether results dataframe should have a multilevel column index (True) or be flat "
767 "and name-munged (False).")
768 )
771class TransformObjectCatalogTask(TransformCatalogBaseTask):
772 """Produce a flattened Object Table to match the format specified in
773 sdm_schemas.
775 Do the same set of postprocessing calculations on all bands
777 This is identical to `TransformCatalogBaseTask`, except for that it does the
778 specified functor calculations for all filters present in the
779 input `deepCoadd_obj` table. Any specific `"filt"` keywords specified
780 by the YAML file will be superceded.
781 """
782 _DefaultName = "transformObjectCatalog"
783 ConfigClass = TransformObjectCatalogConfig
785 # Used by Gen 2 runDataRef only:
786 inputDataset = 'deepCoadd_obj'
787 outputDataset = 'objectTable'
789 @classmethod
790 def _makeArgumentParser(cls):
791 parser = ArgumentParser(name=cls._DefaultName)
792 parser.add_id_argument("--id", cls.inputDataset,
793 ContainerClass=CoaddDataIdContainer,
794 help="data ID, e.g. --id tract=12345 patch=1,2")
795 return parser
797 def run(self, parq, funcs=None, dataId=None, band=None):
798 # NOTE: band kwarg is ignored here.
799 dfDict = {}
800 analysisDict = {}
801 templateDf = pd.DataFrame()
803 if isinstance(parq, DeferredDatasetHandle):
804 columns = parq.get(component='columns')
805 inputBands = columns.unique(level=1).values
806 else:
807 inputBands = parq.columnLevelNames['band']
809 outputBands = self.config.outputBands if self.config.outputBands else inputBands
811 # Perform transform for data of filters that exist in parq.
812 for inputBand in inputBands:
813 if inputBand not in outputBands:
814 self.log.info("Ignoring %s band data in the input", inputBand)
815 continue
816 self.log.info("Transforming the catalog of band %s", inputBand)
817 result = self.transform(inputBand, parq, funcs, dataId)
818 dfDict[inputBand] = result.df
819 analysisDict[inputBand] = result.analysis
820 if templateDf.empty:
821 templateDf = result.df
823 # Fill NaNs in columns of other wanted bands
824 for filt in outputBands:
825 if filt not in dfDict:
826 self.log.info("Adding empty columns for band %s", filt)
827 dfDict[filt] = pd.DataFrame().reindex_like(templateDf)
829 # This makes a multilevel column index, with band as first level
830 df = pd.concat(dfDict, axis=1, names=['band', 'column'])
832 if not self.config.multilevelOutput:
833 noDupCols = list(set.union(*[set(v.noDupCols) for v in analysisDict.values()]))
834 if dataId is not None:
835 noDupCols += list(dataId.keys())
836 df = flattenFilters(df, noDupCols=noDupCols, camelCase=self.config.camelCase,
837 inputBands=inputBands)
839 self.log.info("Made a table of %d columns and %d rows", len(df.columns), len(df))
840 return df
843class TractObjectDataIdContainer(CoaddDataIdContainer):
845 def makeDataRefList(self, namespace):
846 """Make self.refList from self.idList
848 Generate a list of data references given tract and/or patch.
849 This was adapted from `TractQADataIdContainer`, which was
850 `TractDataIdContainer` modifie to not require "filter".
851 Only existing dataRefs are returned.
852 """
853 def getPatchRefList(tract):
854 return [namespace.butler.dataRef(datasetType=self.datasetType,
855 tract=tract.getId(),
856 patch="%d,%d" % patch.getIndex()) for patch in tract]
858 tractRefs = defaultdict(list) # Data references for each tract
859 for dataId in self.idList:
860 skymap = self.getSkymap(namespace)
862 if "tract" in dataId:
863 tractId = dataId["tract"]
864 if "patch" in dataId:
865 tractRefs[tractId].append(namespace.butler.dataRef(datasetType=self.datasetType,
866 tract=tractId,
867 patch=dataId['patch']))
868 else:
869 tractRefs[tractId] += getPatchRefList(skymap[tractId])
870 else:
871 tractRefs = dict((tract.getId(), tractRefs.get(tract.getId(), []) + getPatchRefList(tract))
872 for tract in skymap)
873 outputRefList = []
874 for tractRefList in tractRefs.values():
875 existingRefs = [ref for ref in tractRefList if ref.datasetExists()]
876 outputRefList.append(existingRefs)
878 self.refList = outputRefList
881class ConsolidateObjectTableConnections(pipeBase.PipelineTaskConnections,
882 dimensions=("tract", "skymap")):
883 inputCatalogs = connectionTypes.Input(
884 doc="Per-Patch objectTables conforming to the standard data model.",
885 name="objectTable",
886 storageClass="DataFrame",
887 dimensions=("tract", "patch", "skymap"),
888 multiple=True,
889 )
890 outputCatalog = connectionTypes.Output(
891 doc="Pre-tract horizontal concatenation of the input objectTables",
892 name="objectTable_tract",
893 storageClass="DataFrame",
894 dimensions=("tract", "skymap"),
895 )
898class ConsolidateObjectTableConfig(pipeBase.PipelineTaskConfig,
899 pipelineConnections=ConsolidateObjectTableConnections):
900 coaddName = pexConfig.Field(
901 dtype=str,
902 default="deep",
903 doc="Name of coadd"
904 )
907class ConsolidateObjectTableTask(CmdLineTask, pipeBase.PipelineTask):
908 """Write patch-merged source tables to a tract-level parquet file
910 Concatenates `objectTable` list into a per-visit `objectTable_tract`
911 """
912 _DefaultName = "consolidateObjectTable"
913 ConfigClass = ConsolidateObjectTableConfig
915 inputDataset = 'objectTable'
916 outputDataset = 'objectTable_tract'
918 def runQuantum(self, butlerQC, inputRefs, outputRefs):
919 inputs = butlerQC.get(inputRefs)
920 self.log.info("Concatenating %s per-patch Object Tables",
921 len(inputs['inputCatalogs']))
922 df = pd.concat(inputs['inputCatalogs'])
923 butlerQC.put(pipeBase.Struct(outputCatalog=df), outputRefs)
925 @classmethod
926 def _makeArgumentParser(cls):
927 parser = ArgumentParser(name=cls._DefaultName)
929 parser.add_id_argument("--id", cls.inputDataset,
930 help="data ID, e.g. --id tract=12345",
931 ContainerClass=TractObjectDataIdContainer)
932 return parser
934 def runDataRef(self, patchRefList):
935 df = pd.concat([patchRef.get().toDataFrame() for patchRef in patchRefList])
936 patchRefList[0].put(ParquetTable(dataFrame=df), self.outputDataset)
938 def writeMetadata(self, dataRef):
939 """No metadata to write.
940 """
941 pass
944class TransformSourceTableConnections(pipeBase.PipelineTaskConnections,
945 defaultTemplates={"catalogType": ""},
946 dimensions=("instrument", "visit", "detector")):
948 inputCatalog = connectionTypes.Input(
949 doc="Wide input catalog of sources produced by WriteSourceTableTask",
950 name="{catalogType}source",
951 storageClass="DataFrame",
952 dimensions=("instrument", "visit", "detector"),
953 deferLoad=True
954 )
955 outputCatalog = connectionTypes.Output(
956 doc="Narrower, per-detector Source Table transformed and converted per a "
957 "specified set of functors",
958 name="{catalogType}sourceTable",
959 storageClass="DataFrame",
960 dimensions=("instrument", "visit", "detector")
961 )
964class TransformSourceTableConfig(TransformCatalogBaseConfig,
965 pipelineConnections=TransformSourceTableConnections):
966 pass
969class TransformSourceTableTask(TransformCatalogBaseTask):
970 """Transform/standardize a source catalog
971 """
972 _DefaultName = "transformSourceTable"
973 ConfigClass = TransformSourceTableConfig
975 inputDataset = 'source'
976 outputDataset = 'sourceTable'
978 @classmethod
979 def _makeArgumentParser(cls):
980 parser = ArgumentParser(name=cls._DefaultName)
981 parser.add_id_argument("--id", datasetType=cls.inputDataset,
982 level="sensor",
983 help="data ID, e.g. --id visit=12345 ccd=0")
984 return parser
986 def runDataRef(self, dataRef):
987 """Override to specify band label to run()."""
988 parq = dataRef.get()
989 funcs = self.getFunctors()
990 band = dataRef.get("calexp_filterLabel", immediate=True).bandLabel
991 df = self.run(parq, funcs=funcs, dataId=dataRef.dataId, band=band)
992 self.write(df, dataRef)
993 return df
996class ConsolidateVisitSummaryConnections(pipeBase.PipelineTaskConnections,
997 dimensions=("instrument", "visit",),
998 defaultTemplates={"calexpType": ""}):
999 calexp = connectionTypes.Input(
1000 doc="Processed exposures used for metadata",
1001 name="{calexpType}calexp",
1002 storageClass="ExposureF",
1003 dimensions=("instrument", "visit", "detector"),
1004 deferLoad=True,
1005 multiple=True,
1006 )
1007 visitSummary = connectionTypes.Output(
1008 doc=("Per-visit consolidated exposure metadata. These catalogs use "
1009 "detector id for the id and are sorted for fast lookups of a "
1010 "detector."),
1011 name="{calexpType}visitSummary",
1012 storageClass="ExposureCatalog",
1013 dimensions=("instrument", "visit"),
1014 )
1017class ConsolidateVisitSummaryConfig(pipeBase.PipelineTaskConfig,
1018 pipelineConnections=ConsolidateVisitSummaryConnections):
1019 """Config for ConsolidateVisitSummaryTask"""
1020 pass
1023class ConsolidateVisitSummaryTask(pipeBase.PipelineTask, pipeBase.CmdLineTask):
1024 """Task to consolidate per-detector visit metadata.
1026 This task aggregates the following metadata from all the detectors in a
1027 single visit into an exposure catalog:
1028 - The visitInfo.
1029 - The wcs.
1030 - The photoCalib.
1031 - The physical_filter and band (if available).
1032 - The psf size, shape, and effective area at the center of the detector.
1033 - The corners of the bounding box in right ascension/declination.
1035 Other quantities such as Detector, Psf, ApCorrMap, and TransmissionCurve
1036 are not persisted here because of storage concerns, and because of their
1037 limited utility as summary statistics.
1039 Tests for this task are performed in ci_hsc_gen3.
1040 """
1041 _DefaultName = "consolidateVisitSummary"
1042 ConfigClass = ConsolidateVisitSummaryConfig
1044 @classmethod
1045 def _makeArgumentParser(cls):
1046 parser = ArgumentParser(name=cls._DefaultName)
1048 parser.add_id_argument("--id", "calexp",
1049 help="data ID, e.g. --id visit=12345",
1050 ContainerClass=VisitDataIdContainer)
1051 return parser
1053 def writeMetadata(self, dataRef):
1054 """No metadata to persist, so override to remove metadata persistance.
1055 """
1056 pass
1058 def writeConfig(self, butler, clobber=False, doBackup=True):
1059 """No config to persist, so override to remove config persistance.
1060 """
1061 pass
1063 def runDataRef(self, dataRefList):
1064 visit = dataRefList[0].dataId['visit']
1066 self.log.debug("Concatenating metadata from %d per-detector calexps (visit %d)" %
1067 (len(dataRefList), visit))
1069 expCatalog = self._combineExposureMetadata(visit, dataRefList, isGen3=False)
1071 dataRefList[0].put(expCatalog, 'visitSummary', visit=visit)
1073 def runQuantum(self, butlerQC, inputRefs, outputRefs):
1074 dataRefs = butlerQC.get(inputRefs.calexp)
1075 visit = dataRefs[0].dataId.byName()['visit']
1077 self.log.debug("Concatenating metadata from %d per-detector calexps (visit %d)" %
1078 (len(dataRefs), visit))
1080 expCatalog = self._combineExposureMetadata(visit, dataRefs)
1082 butlerQC.put(expCatalog, outputRefs.visitSummary)
1084 def _combineExposureMetadata(self, visit, dataRefs, isGen3=True):
1085 """Make a combined exposure catalog from a list of dataRefs.
1086 These dataRefs must point to exposures with wcs, summaryStats,
1087 and other visit metadata.
1089 Parameters
1090 ----------
1091 visit : `int`
1092 Visit identification number.
1093 dataRefs : `list`
1094 List of dataRefs in visit. May be list of
1095 `lsst.daf.persistence.ButlerDataRef` (Gen2) or
1096 `lsst.daf.butler.DeferredDatasetHandle` (Gen3).
1097 isGen3 : `bool`, optional
1098 Specifies if this is a Gen3 list of datarefs.
1100 Returns
1101 -------
1102 visitSummary : `lsst.afw.table.ExposureCatalog`
1103 Exposure catalog with per-detector summary information.
1104 """
1105 schema = self._makeVisitSummarySchema()
1106 cat = afwTable.ExposureCatalog(schema)
1107 cat.resize(len(dataRefs))
1109 cat['visit'] = visit
1111 for i, dataRef in enumerate(dataRefs):
1112 if isGen3:
1113 visitInfo = dataRef.get(component='visitInfo')
1114 filterLabel = dataRef.get(component='filterLabel')
1115 summaryStats = dataRef.get(component='summaryStats')
1116 detector = dataRef.get(component='detector')
1117 wcs = dataRef.get(component='wcs')
1118 photoCalib = dataRef.get(component='photoCalib')
1119 detector = dataRef.get(component='detector')
1120 bbox = dataRef.get(component='bbox')
1121 validPolygon = dataRef.get(component='validPolygon')
1122 else:
1123 # Note that we need to read the calexp because there is
1124 # no magic access to the psf except through the exposure.
1125 gen2_read_bbox = lsst.geom.BoxI(lsst.geom.PointI(0, 0), lsst.geom.PointI(1, 1))
1126 exp = dataRef.get(datasetType='calexp_sub', bbox=gen2_read_bbox)
1127 visitInfo = exp.getInfo().getVisitInfo()
1128 filterLabel = dataRef.get("calexp_filterLabel")
1129 summaryStats = exp.getInfo().getSummaryStats()
1130 wcs = exp.getWcs()
1131 photoCalib = exp.getPhotoCalib()
1132 detector = exp.getDetector()
1133 bbox = dataRef.get(datasetType='calexp_bbox')
1134 validPolygon = exp.getInfo().getValidPolygon()
1136 rec = cat[i]
1137 rec.setBBox(bbox)
1138 rec.setVisitInfo(visitInfo)
1139 rec.setWcs(wcs)
1140 rec.setPhotoCalib(photoCalib)
1141 rec.setValidPolygon(validPolygon)
1143 rec['physical_filter'] = filterLabel.physicalLabel if filterLabel.hasPhysicalLabel() else ""
1144 rec['band'] = filterLabel.bandLabel if filterLabel.hasBandLabel() else ""
1145 rec.setId(detector.getId())
1146 rec['psfSigma'] = summaryStats.psfSigma
1147 rec['psfIxx'] = summaryStats.psfIxx
1148 rec['psfIyy'] = summaryStats.psfIyy
1149 rec['psfIxy'] = summaryStats.psfIxy
1150 rec['psfArea'] = summaryStats.psfArea
1151 rec['raCorners'][:] = summaryStats.raCorners
1152 rec['decCorners'][:] = summaryStats.decCorners
1153 rec['ra'] = summaryStats.ra
1154 rec['decl'] = summaryStats.decl
1155 rec['zenithDistance'] = summaryStats.zenithDistance
1156 rec['zeroPoint'] = summaryStats.zeroPoint
1157 rec['skyBg'] = summaryStats.skyBg
1158 rec['skyNoise'] = summaryStats.skyNoise
1159 rec['meanVar'] = summaryStats.meanVar
1161 metadata = dafBase.PropertyList()
1162 metadata.add("COMMENT", "Catalog id is detector id, sorted.")
1163 # We are looping over existing datarefs, so the following is true
1164 metadata.add("COMMENT", "Only detectors with data have entries.")
1165 cat.setMetadata(metadata)
1167 cat.sort()
1168 return cat
1170 def _makeVisitSummarySchema(self):
1171 """Make the schema for the visitSummary catalog."""
1172 schema = afwTable.ExposureTable.makeMinimalSchema()
1173 schema.addField('visit', type='I', doc='Visit number')
1174 schema.addField('physical_filter', type='String', size=32, doc='Physical filter')
1175 schema.addField('band', type='String', size=32, doc='Name of band')
1176 schema.addField('psfSigma', type='F',
1177 doc='PSF model second-moments determinant radius (center of chip) (pixel)')
1178 schema.addField('psfArea', type='F',
1179 doc='PSF model effective area (center of chip) (pixel**2)')
1180 schema.addField('psfIxx', type='F',
1181 doc='PSF model Ixx (center of chip) (pixel**2)')
1182 schema.addField('psfIyy', type='F',
1183 doc='PSF model Iyy (center of chip) (pixel**2)')
1184 schema.addField('psfIxy', type='F',
1185 doc='PSF model Ixy (center of chip) (pixel**2)')
1186 schema.addField('raCorners', type='ArrayD', size=4,
1187 doc='Right Ascension of bounding box corners (degrees)')
1188 schema.addField('decCorners', type='ArrayD', size=4,
1189 doc='Declination of bounding box corners (degrees)')
1190 schema.addField('ra', type='D',
1191 doc='Right Ascension of bounding box center (degrees)')
1192 schema.addField('decl', type='D',
1193 doc='Declination of bounding box center (degrees)')
1194 schema.addField('zenithDistance', type='F',
1195 doc='Zenith distance of bounding box center (degrees)')
1196 schema.addField('zeroPoint', type='F',
1197 doc='Mean zeropoint in detector (mag)')
1198 schema.addField('skyBg', type='F',
1199 doc='Average sky background (ADU)')
1200 schema.addField('skyNoise', type='F',
1201 doc='Average sky noise (ADU)')
1202 schema.addField('meanVar', type='F',
1203 doc='Mean variance of the weight plane (ADU**2)')
1205 return schema
1208class VisitDataIdContainer(DataIdContainer):
1209 """DataIdContainer that groups sensor-level id's by visit
1210 """
1212 def makeDataRefList(self, namespace):
1213 """Make self.refList from self.idList
1215 Generate a list of data references grouped by visit.
1217 Parameters
1218 ----------
1219 namespace : `argparse.Namespace`
1220 Namespace used by `lsst.pipe.base.CmdLineTask` to parse command line arguments
1221 """
1222 # Group by visits
1223 visitRefs = defaultdict(list)
1224 for dataId in self.idList:
1225 if "visit" in dataId:
1226 visitId = dataId["visit"]
1227 # append all subsets to
1228 subset = namespace.butler.subset(self.datasetType, dataId=dataId)
1229 visitRefs[visitId].extend([dataRef for dataRef in subset])
1231 outputRefList = []
1232 for refList in visitRefs.values():
1233 existingRefs = [ref for ref in refList if ref.datasetExists()]
1234 if existingRefs:
1235 outputRefList.append(existingRefs)
1237 self.refList = outputRefList
1240class ConsolidateSourceTableConnections(pipeBase.PipelineTaskConnections,
1241 defaultTemplates={"catalogType": ""},
1242 dimensions=("instrument", "visit")):
1243 inputCatalogs = connectionTypes.Input(
1244 doc="Input per-detector Source Tables",
1245 name="{catalogType}sourceTable",
1246 storageClass="DataFrame",
1247 dimensions=("instrument", "visit", "detector"),
1248 multiple=True
1249 )
1250 outputCatalog = connectionTypes.Output(
1251 doc="Per-visit concatenation of Source Table",
1252 name="{catalogType}sourceTable_visit",
1253 storageClass="DataFrame",
1254 dimensions=("instrument", "visit")
1255 )
1258class ConsolidateSourceTableConfig(pipeBase.PipelineTaskConfig,
1259 pipelineConnections=ConsolidateSourceTableConnections):
1260 pass
1263class ConsolidateSourceTableTask(CmdLineTask, pipeBase.PipelineTask):
1264 """Concatenate `sourceTable` list into a per-visit `sourceTable_visit`
1265 """
1266 _DefaultName = 'consolidateSourceTable'
1267 ConfigClass = ConsolidateSourceTableConfig
1269 inputDataset = 'sourceTable'
1270 outputDataset = 'sourceTable_visit'
1272 def runQuantum(self, butlerQC, inputRefs, outputRefs):
1273 inputs = butlerQC.get(inputRefs)
1274 self.log.info("Concatenating %s per-detector Source Tables",
1275 len(inputs['inputCatalogs']))
1276 df = pd.concat(inputs['inputCatalogs'])
1277 butlerQC.put(pipeBase.Struct(outputCatalog=df), outputRefs)
1279 def runDataRef(self, dataRefList):
1280 self.log.info("Concatenating %s per-detector Source Tables", len(dataRefList))
1281 df = pd.concat([dataRef.get().toDataFrame() for dataRef in dataRefList])
1282 dataRefList[0].put(ParquetTable(dataFrame=df), self.outputDataset)
1284 @classmethod
1285 def _makeArgumentParser(cls):
1286 parser = ArgumentParser(name=cls._DefaultName)
1288 parser.add_id_argument("--id", cls.inputDataset,
1289 help="data ID, e.g. --id visit=12345",
1290 ContainerClass=VisitDataIdContainer)
1291 return parser
1293 def writeMetadata(self, dataRef):
1294 """No metadata to write.
1295 """
1296 pass
1298 def writeConfig(self, butler, clobber=False, doBackup=True):
1299 """No config to write.
1300 """
1301 pass