Coverage for python/lsst/pipe/tasks/postprocess.py : 28%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_tasks
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import functools
23import pandas as pd
24import numpy as np
25from collections import defaultdict
27import lsst.geom
28import lsst.pex.config as pexConfig
29import lsst.pipe.base as pipeBase
30import lsst.daf.base as dafBase
31from lsst.pipe.base import connectionTypes
32import lsst.afw.table as afwTable
33from lsst.meas.base import SingleFrameMeasurementTask
34from lsst.pipe.base import CmdLineTask, ArgumentParser, DataIdContainer
35from lsst.coadd.utils.coaddDataIdContainer import CoaddDataIdContainer
36from lsst.daf.butler import DeferredDatasetHandle
38from .parquetTable import ParquetTable
39from .multiBandUtils import makeMergeArgumentParser, MergeSourcesRunner
40from .functors import CompositeFunctor, RAColumn, DecColumn, Column
43def flattenFilters(df, noDupCols=['coord_ra', 'coord_dec'], camelCase=False, inputBands=None):
44 """Flattens a dataframe with multilevel column index
45 """
46 newDf = pd.DataFrame()
47 # band is the level 0 index
48 dfBands = df.columns.unique(level=0).values
49 for band in dfBands:
50 subdf = df[band]
51 columnFormat = '{0}{1}' if camelCase else '{0}_{1}'
52 newColumns = {c: columnFormat.format(band, c)
53 for c in subdf.columns if c not in noDupCols}
54 cols = list(newColumns.keys())
55 newDf = pd.concat([newDf, subdf[cols].rename(columns=newColumns)], axis=1)
57 # Band must be present in the input and output or else column is all NaN:
58 presentBands = dfBands if inputBands is None else list(set(inputBands).intersection(dfBands))
59 # Get the unexploded columns from any present band's partition
60 noDupDf = df[presentBands[0]][noDupCols]
61 newDf = pd.concat([noDupDf, newDf], axis=1)
62 return newDf
65class WriteObjectTableConnections(pipeBase.PipelineTaskConnections,
66 defaultTemplates={"coaddName": "deep"},
67 dimensions=("tract", "patch", "skymap")):
68 inputCatalogMeas = connectionTypes.Input(
69 doc="Catalog of source measurements on the deepCoadd.",
70 dimensions=("tract", "patch", "band", "skymap"),
71 storageClass="SourceCatalog",
72 name="{coaddName}Coadd_meas",
73 multiple=True
74 )
75 inputCatalogForcedSrc = connectionTypes.Input(
76 doc="Catalog of forced measurements (shape and position parameters held fixed) on the deepCoadd.",
77 dimensions=("tract", "patch", "band", "skymap"),
78 storageClass="SourceCatalog",
79 name="{coaddName}Coadd_forced_src",
80 multiple=True
81 )
82 inputCatalogRef = connectionTypes.Input(
83 doc="Catalog marking the primary detection (which band provides a good shape and position)"
84 "for each detection in deepCoadd_mergeDet.",
85 dimensions=("tract", "patch", "skymap"),
86 storageClass="SourceCatalog",
87 name="{coaddName}Coadd_ref"
88 )
89 outputCatalog = connectionTypes.Output(
90 doc="A vertical concatenation of the deepCoadd_{ref|meas|forced_src} catalogs, "
91 "stored as a DataFrame with a multi-level column index per-patch.",
92 dimensions=("tract", "patch", "skymap"),
93 storageClass="DataFrame",
94 name="{coaddName}Coadd_obj"
95 )
98class WriteObjectTableConfig(pipeBase.PipelineTaskConfig,
99 pipelineConnections=WriteObjectTableConnections):
100 engine = pexConfig.Field(
101 dtype=str,
102 default="pyarrow",
103 doc="Parquet engine for writing (pyarrow or fastparquet)"
104 )
105 coaddName = pexConfig.Field(
106 dtype=str,
107 default="deep",
108 doc="Name of coadd"
109 )
112class WriteObjectTableTask(CmdLineTask, pipeBase.PipelineTask):
113 """Write filter-merged source tables to parquet
114 """
115 _DefaultName = "writeObjectTable"
116 ConfigClass = WriteObjectTableConfig
117 RunnerClass = MergeSourcesRunner
119 # Names of table datasets to be merged
120 inputDatasets = ('forced_src', 'meas', 'ref')
122 # Tag of output dataset written by `MergeSourcesTask.write`
123 outputDataset = 'obj'
125 def __init__(self, butler=None, schema=None, **kwargs):
126 # It is a shame that this class can't use the default init for CmdLineTask
127 # But to do so would require its own special task runner, which is many
128 # more lines of specialization, so this is how it is for now
129 super().__init__(**kwargs)
131 def runDataRef(self, patchRefList):
132 """!
133 @brief Merge coadd sources from multiple bands. Calls @ref `run` which must be defined in
134 subclasses that inherit from MergeSourcesTask.
135 @param[in] patchRefList list of data references for each filter
136 """
137 catalogs = dict(self.readCatalog(patchRef) for patchRef in patchRefList)
138 dataId = patchRefList[0].dataId
139 mergedCatalog = self.run(catalogs, tract=dataId['tract'], patch=dataId['patch'])
140 self.write(patchRefList[0], ParquetTable(dataFrame=mergedCatalog))
142 def runQuantum(self, butlerQC, inputRefs, outputRefs):
143 inputs = butlerQC.get(inputRefs)
145 measDict = {ref.dataId['band']: {'meas': cat} for ref, cat in
146 zip(inputRefs.inputCatalogMeas, inputs['inputCatalogMeas'])}
147 forcedSourceDict = {ref.dataId['band']: {'forced_src': cat} for ref, cat in
148 zip(inputRefs.inputCatalogForcedSrc, inputs['inputCatalogForcedSrc'])}
150 catalogs = {}
151 for band in measDict.keys():
152 catalogs[band] = {'meas': measDict[band]['meas'],
153 'forced_src': forcedSourceDict[band]['forced_src'],
154 'ref': inputs['inputCatalogRef']}
155 dataId = butlerQC.quantum.dataId
156 df = self.run(catalogs=catalogs, tract=dataId['tract'], patch=dataId['patch'])
157 outputs = pipeBase.Struct(outputCatalog=df)
158 butlerQC.put(outputs, outputRefs)
160 @classmethod
161 def _makeArgumentParser(cls):
162 """Create a suitable ArgumentParser.
164 We will use the ArgumentParser to get a list of data
165 references for patches; the RunnerClass will sort them into lists
166 of data references for the same patch.
168 References first of self.inputDatasets, rather than
169 self.inputDataset
170 """
171 return makeMergeArgumentParser(cls._DefaultName, cls.inputDatasets[0])
173 def readCatalog(self, patchRef):
174 """Read input catalogs
176 Read all the input datasets given by the 'inputDatasets'
177 attribute.
179 Parameters
180 ----------
181 patchRef : `lsst.daf.persistence.ButlerDataRef`
182 Data reference for patch
184 Returns
185 -------
186 Tuple consisting of band name and a dict of catalogs, keyed by
187 dataset name
188 """
189 band = patchRef.get(self.config.coaddName + "Coadd_filterLabel", immediate=True).bandLabel
190 catalogDict = {}
191 for dataset in self.inputDatasets:
192 catalog = patchRef.get(self.config.coaddName + "Coadd_" + dataset, immediate=True)
193 self.log.info("Read %d sources from %s for band %s: %s" %
194 (len(catalog), dataset, band, patchRef.dataId))
195 catalogDict[dataset] = catalog
196 return band, catalogDict
198 def run(self, catalogs, tract, patch):
199 """Merge multiple catalogs.
201 Parameters
202 ----------
203 catalogs : `dict`
204 Mapping from filter names to dict of catalogs.
205 tract : int
206 tractId to use for the tractId column
207 patch : str
208 patchId to use for the patchId column
210 Returns
211 -------
212 catalog : `pandas.DataFrame`
213 Merged dataframe
214 """
216 dfs = []
217 for filt, tableDict in catalogs.items():
218 for dataset, table in tableDict.items():
219 # Convert afwTable to pandas DataFrame
220 df = table.asAstropy().to_pandas().set_index('id', drop=True)
222 # Sort columns by name, to ensure matching schema among patches
223 df = df.reindex(sorted(df.columns), axis=1)
224 df['tractId'] = tract
225 df['patchId'] = patch
227 # Make columns a 3-level MultiIndex
228 df.columns = pd.MultiIndex.from_tuples([(dataset, filt, c) for c in df.columns],
229 names=('dataset', 'band', 'column'))
230 dfs.append(df)
232 catalog = functools.reduce(lambda d1, d2: d1.join(d2), dfs)
233 return catalog
235 def write(self, patchRef, catalog):
236 """Write the output.
238 Parameters
239 ----------
240 catalog : `ParquetTable`
241 Catalog to write
242 patchRef : `lsst.daf.persistence.ButlerDataRef`
243 Data reference for patch
244 """
245 patchRef.put(catalog, self.config.coaddName + "Coadd_" + self.outputDataset)
246 # since the filter isn't actually part of the data ID for the dataset we're saving,
247 # it's confusing to see it in the log message, even if the butler simply ignores it.
248 mergeDataId = patchRef.dataId.copy()
249 del mergeDataId["filter"]
250 self.log.info("Wrote merged catalog: %s" % (mergeDataId,))
252 def writeMetadata(self, dataRefList):
253 """No metadata to write, and not sure how to write it for a list of dataRefs.
254 """
255 pass
258class WriteSourceTableConnections(pipeBase.PipelineTaskConnections,
259 dimensions=("instrument", "visit", "detector")):
261 catalog = connectionTypes.Input(
262 doc="Input full-depth catalog of sources produced by CalibrateTask",
263 name="src",
264 storageClass="SourceCatalog",
265 dimensions=("instrument", "visit", "detector")
266 )
267 outputCatalog = connectionTypes.Output(
268 doc="Catalog of sources, `src` in Parquet format",
269 name="source",
270 storageClass="DataFrame",
271 dimensions=("instrument", "visit", "detector")
272 )
275class WriteSourceTableConfig(pipeBase.PipelineTaskConfig,
276 pipelineConnections=WriteSourceTableConnections):
277 doApplyExternalPhotoCalib = pexConfig.Field(
278 dtype=bool,
279 default=False,
280 doc=("Add local photoCalib columns from the calexp.photoCalib? Should only set True if "
281 "generating Source Tables from older src tables which do not already have local calib columns")
282 )
283 doApplyExternalSkyWcs = pexConfig.Field(
284 dtype=bool,
285 default=False,
286 doc=("Add local WCS columns from the calexp.wcs? Should only set True if "
287 "generating Source Tables from older src tables which do not already have local calib columns")
288 )
291class WriteSourceTableTask(CmdLineTask, pipeBase.PipelineTask):
292 """Write source table to parquet
293 """
294 _DefaultName = "writeSourceTable"
295 ConfigClass = WriteSourceTableConfig
297 def runDataRef(self, dataRef):
298 src = dataRef.get('src')
299 if self.config.doApplyExternalPhotoCalib or self.config.doApplyExternalSkyWcs:
300 src = self.addCalibColumns(src, dataRef)
302 ccdVisitId = dataRef.get('ccdExposureId')
303 result = self.run(src, ccdVisitId=ccdVisitId)
304 dataRef.put(result.table, 'source')
306 def runQuantum(self, butlerQC, inputRefs, outputRefs):
307 inputs = butlerQC.get(inputRefs)
308 inputs['ccdVisitId'] = butlerQC.quantum.dataId.pack("visit_detector")
309 result = self.run(**inputs).table
310 outputs = pipeBase.Struct(outputCatalog=result.toDataFrame())
311 butlerQC.put(outputs, outputRefs)
313 def run(self, catalog, ccdVisitId=None):
314 """Convert `src` catalog to parquet
316 Parameters
317 ----------
318 catalog: `afwTable.SourceCatalog`
319 catalog to be converted
320 ccdVisitId: `int`
321 ccdVisitId to be added as a column
323 Returns
324 -------
325 result : `lsst.pipe.base.Struct`
326 ``table``
327 `ParquetTable` version of the input catalog
328 """
329 self.log.info("Generating parquet table from src catalog %s", ccdVisitId)
330 df = catalog.asAstropy().to_pandas().set_index('id', drop=True)
331 df['ccdVisitId'] = ccdVisitId
332 return pipeBase.Struct(table=ParquetTable(dataFrame=df))
334 def addCalibColumns(self, catalog, dataRef):
335 """Add columns with local calibration evaluated at each centroid
337 for backwards compatibility with old repos.
338 This exists for the purpose of converting old src catalogs
339 (which don't have the expected local calib columns) to Source Tables.
341 Parameters
342 ----------
343 catalog: `afwTable.SourceCatalog`
344 catalog to which calib columns will be added
345 dataRef: `lsst.daf.persistence.ButlerDataRef
346 for fetching the calibs from disk.
348 Returns
349 -------
350 newCat: `afwTable.SourceCatalog`
351 Source Catalog with requested local calib columns
352 """
353 mapper = afwTable.SchemaMapper(catalog.schema)
354 measureConfig = SingleFrameMeasurementTask.ConfigClass()
355 measureConfig.doReplaceWithNoise = False
357 # Just need the WCS or the PhotoCalib attached to an exposue
358 exposure = dataRef.get('calexp_sub',
359 bbox=lsst.geom.Box2I(lsst.geom.Point2I(0, 0), lsst.geom.Point2I(0, 0)))
361 mapper = afwTable.SchemaMapper(catalog.schema)
362 mapper.addMinimalSchema(catalog.schema, True)
363 schema = mapper.getOutputSchema()
365 exposureIdInfo = dataRef.get("expIdInfo")
366 measureConfig.plugins.names = []
367 if self.config.doApplyExternalSkyWcs:
368 plugin = 'base_LocalWcs'
369 if plugin in schema:
370 raise RuntimeError(f"{plugin} already in src catalog. Set doApplyExternalSkyWcs=False")
371 else:
372 measureConfig.plugins.names.add(plugin)
374 if self.config.doApplyExternalPhotoCalib:
375 plugin = 'base_LocalPhotoCalib'
376 if plugin in schema:
377 raise RuntimeError(f"{plugin} already in src catalog. Set doApplyExternalPhotoCalib=False")
378 else:
379 measureConfig.plugins.names.add(plugin)
381 measurement = SingleFrameMeasurementTask(config=measureConfig, schema=schema)
382 newCat = afwTable.SourceCatalog(schema)
383 newCat.extend(catalog, mapper=mapper)
384 measurement.run(measCat=newCat, exposure=exposure, exposureId=exposureIdInfo.expId)
385 return newCat
387 def writeMetadata(self, dataRef):
388 """No metadata to write.
389 """
390 pass
392 @classmethod
393 def _makeArgumentParser(cls):
394 parser = ArgumentParser(name=cls._DefaultName)
395 parser.add_id_argument("--id", 'src',
396 help="data ID, e.g. --id visit=12345 ccd=0")
397 return parser
400class PostprocessAnalysis(object):
401 """Calculate columns from ParquetTable
403 This object manages and organizes an arbitrary set of computations
404 on a catalog. The catalog is defined by a
405 `lsst.pipe.tasks.parquetTable.ParquetTable` object (or list thereof), such as a
406 `deepCoadd_obj` dataset, and the computations are defined by a collection
407 of `lsst.pipe.tasks.functor.Functor` objects (or, equivalently,
408 a `CompositeFunctor`).
410 After the object is initialized, accessing the `.df` attribute (which
411 holds the `pandas.DataFrame` containing the results of the calculations) triggers
412 computation of said dataframe.
414 One of the conveniences of using this object is the ability to define a desired common
415 filter for all functors. This enables the same functor collection to be passed to
416 several different `PostprocessAnalysis` objects without having to change the original
417 functor collection, since the `filt` keyword argument of this object triggers an
418 overwrite of the `filt` property for all functors in the collection.
420 This object also allows a list of refFlags to be passed, and defines a set of default
421 refFlags that are always included even if not requested.
423 If a list of `ParquetTable` object is passed, rather than a single one, then the
424 calculations will be mapped over all the input catalogs. In principle, it should
425 be straightforward to parallelize this activity, but initial tests have failed
426 (see TODO in code comments).
428 Parameters
429 ----------
430 parq : `lsst.pipe.tasks.ParquetTable` (or list of such)
431 Source catalog(s) for computation
433 functors : `list`, `dict`, or `lsst.pipe.tasks.functors.CompositeFunctor`
434 Computations to do (functors that act on `parq`).
435 If a dict, the output
436 DataFrame will have columns keyed accordingly.
437 If a list, the column keys will come from the
438 `.shortname` attribute of each functor.
440 filt : `str` (optional)
441 Filter in which to calculate. If provided,
442 this will overwrite any existing `.filt` attribute
443 of the provided functors.
445 flags : `list` (optional)
446 List of flags (per-band) to include in output table.
448 refFlags : `list` (optional)
449 List of refFlags (only reference band) to include in output table.
452 """
453 _defaultRefFlags = []
454 _defaultFuncs = (('coord_ra', RAColumn()),
455 ('coord_dec', DecColumn()))
457 def __init__(self, parq, functors, filt=None, flags=None, refFlags=None):
458 self.parq = parq
459 self.functors = functors
461 self.filt = filt
462 self.flags = list(flags) if flags is not None else []
463 self.refFlags = list(self._defaultRefFlags)
464 if refFlags is not None:
465 self.refFlags += list(refFlags)
467 self._df = None
469 @property
470 def defaultFuncs(self):
471 funcs = dict(self._defaultFuncs)
472 return funcs
474 @property
475 def func(self):
476 additionalFuncs = self.defaultFuncs
477 additionalFuncs.update({flag: Column(flag, dataset='ref') for flag in self.refFlags})
478 additionalFuncs.update({flag: Column(flag, dataset='meas') for flag in self.flags})
480 if isinstance(self.functors, CompositeFunctor):
481 func = self.functors
482 else:
483 func = CompositeFunctor(self.functors)
485 func.funcDict.update(additionalFuncs)
486 func.filt = self.filt
488 return func
490 @property
491 def noDupCols(self):
492 return [name for name, func in self.func.funcDict.items() if func.noDup or func.dataset == 'ref']
494 @property
495 def df(self):
496 if self._df is None:
497 self.compute()
498 return self._df
500 def compute(self, dropna=False, pool=None):
501 # map over multiple parquet tables
502 if type(self.parq) in (list, tuple):
503 if pool is None:
504 dflist = [self.func(parq, dropna=dropna) for parq in self.parq]
505 else:
506 # TODO: Figure out why this doesn't work (pyarrow pickling issues?)
507 dflist = pool.map(functools.partial(self.func, dropna=dropna), self.parq)
508 self._df = pd.concat(dflist)
509 else:
510 self._df = self.func(self.parq, dropna=dropna)
512 return self._df
515class TransformCatalogBaseConnections(pipeBase.PipelineTaskConnections,
516 dimensions=()):
517 """Expected Connections for subclasses of TransformCatalogBaseTask.
519 Must be subclassed.
520 """
521 inputCatalog = connectionTypes.Input(
522 name="",
523 storageClass="DataFrame",
524 )
525 outputCatalog = connectionTypes.Output(
526 name="",
527 storageClass="DataFrame",
528 )
531class TransformCatalogBaseConfig(pipeBase.PipelineTaskConfig,
532 pipelineConnections=TransformCatalogBaseConnections):
533 functorFile = pexConfig.Field(
534 dtype=str,
535 doc='Path to YAML file specifying functors to be computed',
536 default=None,
537 optional=True
538 )
541class TransformCatalogBaseTask(CmdLineTask, pipeBase.PipelineTask):
542 """Base class for transforming/standardizing a catalog
544 by applying functors that convert units and apply calibrations.
545 The purpose of this task is to perform a set of computations on
546 an input `ParquetTable` dataset (such as `deepCoadd_obj`) and write the
547 results to a new dataset (which needs to be declared in an `outputDataset`
548 attribute).
550 The calculations to be performed are defined in a YAML file that specifies
551 a set of functors to be computed, provided as
552 a `--functorFile` config parameter. An example of such a YAML file
553 is the following:
555 funcs:
556 psfMag:
557 functor: Mag
558 args:
559 - base_PsfFlux
560 filt: HSC-G
561 dataset: meas
562 cmodel_magDiff:
563 functor: MagDiff
564 args:
565 - modelfit_CModel
566 - base_PsfFlux
567 filt: HSC-G
568 gauss_magDiff:
569 functor: MagDiff
570 args:
571 - base_GaussianFlux
572 - base_PsfFlux
573 filt: HSC-G
574 count:
575 functor: Column
576 args:
577 - base_InputCount_value
578 filt: HSC-G
579 deconvolved_moments:
580 functor: DeconvolvedMoments
581 filt: HSC-G
582 dataset: forced_src
583 refFlags:
584 - calib_psfUsed
585 - merge_measurement_i
586 - merge_measurement_r
587 - merge_measurement_z
588 - merge_measurement_y
589 - merge_measurement_g
590 - base_PixelFlags_flag_inexact_psfCenter
591 - detect_isPrimary
593 The names for each entry under "func" will become the names of columns in the
594 output dataset. All the functors referenced are defined in `lsst.pipe.tasks.functors`.
595 Positional arguments to be passed to each functor are in the `args` list,
596 and any additional entries for each column other than "functor" or "args" (e.g., `'filt'`,
597 `'dataset'`) are treated as keyword arguments to be passed to the functor initialization.
599 The "refFlags" entry is shortcut for a bunch of `Column` functors with the original column and
600 taken from the `'ref'` dataset.
602 The "flags" entry will be expanded out per band.
604 This task uses the `lsst.pipe.tasks.postprocess.PostprocessAnalysis` object
605 to organize and excecute the calculations.
607 """
608 @property
609 def _DefaultName(self):
610 raise NotImplementedError('Subclass must define "_DefaultName" attribute')
612 @property
613 def outputDataset(self):
614 raise NotImplementedError('Subclass must define "outputDataset" attribute')
616 @property
617 def inputDataset(self):
618 raise NotImplementedError('Subclass must define "inputDataset" attribute')
620 @property
621 def ConfigClass(self):
622 raise NotImplementedError('Subclass must define "ConfigClass" attribute')
624 def __init__(self, *args, **kwargs):
625 super().__init__(*args, **kwargs)
626 if self.config.functorFile:
627 self.log.info('Loading tranform functor definitions from %s',
628 self.config.functorFile)
629 self.funcs = CompositeFunctor.from_file(self.config.functorFile)
630 self.funcs.update(dict(PostprocessAnalysis._defaultFuncs))
631 else:
632 self.funcs = None
634 def runQuantum(self, butlerQC, inputRefs, outputRefs):
635 inputs = butlerQC.get(inputRefs)
636 if self.funcs is None:
637 raise ValueError("config.functorFile is None. "
638 "Must be a valid path to yaml in order to run Task as a PipelineTask.")
639 result = self.run(parq=inputs['inputCatalog'], funcs=self.funcs,
640 dataId=outputRefs.outputCatalog.dataId.full)
641 outputs = pipeBase.Struct(outputCatalog=result)
642 butlerQC.put(outputs, outputRefs)
644 def runDataRef(self, dataRef):
645 parq = dataRef.get()
646 if self.funcs is None:
647 raise ValueError("config.functorFile is None. "
648 "Must be a valid path to yaml in order to run as a CommandlineTask.")
649 df = self.run(parq, funcs=self.funcs, dataId=dataRef.dataId)
650 self.write(df, dataRef)
651 return df
653 def run(self, parq, funcs=None, dataId=None, band=None):
654 """Do postprocessing calculations
656 Takes a `ParquetTable` object and dataId,
657 returns a dataframe with results of postprocessing calculations.
659 Parameters
660 ----------
661 parq : `lsst.pipe.tasks.parquetTable.ParquetTable`
662 ParquetTable from which calculations are done.
663 funcs : `lsst.pipe.tasks.functors.Functors`
664 Functors to apply to the table's columns
665 dataId : dict, optional
666 Used to add a `patchId` column to the output dataframe.
667 band : `str`, optional
668 Filter band that is being processed.
670 Returns
671 ------
672 `pandas.DataFrame`
674 """
675 self.log.info("Transforming/standardizing the source table dataId: %s", dataId)
677 df = self.transform(band, parq, funcs, dataId).df
678 self.log.info("Made a table of %d columns and %d rows", len(df.columns), len(df))
679 return df
681 def getFunctors(self):
682 return self.funcs
684 def getAnalysis(self, parq, funcs=None, band=None):
685 if funcs is None:
686 funcs = self.funcs
687 analysis = PostprocessAnalysis(parq, funcs, filt=band)
688 return analysis
690 def transform(self, band, parq, funcs, dataId):
691 analysis = self.getAnalysis(parq, funcs=funcs, band=band)
692 df = analysis.df
693 if dataId is not None:
694 for key, value in dataId.items():
695 df[str(key)] = value
697 return pipeBase.Struct(
698 df=df,
699 analysis=analysis
700 )
702 def write(self, df, parqRef):
703 parqRef.put(ParquetTable(dataFrame=df), self.outputDataset)
705 def writeMetadata(self, dataRef):
706 """No metadata to write.
707 """
708 pass
711class TransformObjectCatalogConnections(pipeBase.PipelineTaskConnections,
712 defaultTemplates={"coaddName": "deep"},
713 dimensions=("tract", "patch", "skymap")):
714 inputCatalog = connectionTypes.Input(
715 doc="The vertical concatenation of the deepCoadd_{ref|meas|forced_src} catalogs, "
716 "stored as a DataFrame with a multi-level column index per-patch.",
717 dimensions=("tract", "patch", "skymap"),
718 storageClass="DataFrame",
719 name="{coaddName}Coadd_obj",
720 deferLoad=True,
721 )
722 outputCatalog = connectionTypes.Output(
723 doc="Per-Patch Object Table of columns transformed from the deepCoadd_obj table per the standard "
724 "data model.",
725 dimensions=("tract", "patch", "skymap"),
726 storageClass="DataFrame",
727 name="objectTable"
728 )
731class TransformObjectCatalogConfig(TransformCatalogBaseConfig,
732 pipelineConnections=TransformObjectCatalogConnections):
733 coaddName = pexConfig.Field(
734 dtype=str,
735 default="deep",
736 doc="Name of coadd"
737 )
738 # TODO: remove in DM-27177
739 filterMap = pexConfig.DictField(
740 keytype=str,
741 itemtype=str,
742 default={},
743 doc=("Dictionary mapping full filter name to short one for column name munging."
744 "These filters determine the output columns no matter what filters the "
745 "input data actually contain."),
746 deprecated=("Coadds are now identified by the band, so this transform is unused."
747 "Will be removed after v22.")
748 )
749 outputBands = pexConfig.ListField(
750 dtype=str,
751 default=None,
752 optional=True,
753 doc=("These bands and only these bands will appear in the output,"
754 " NaN-filled if the input does not include them."
755 " If None, then use all bands found in the input.")
756 )
757 camelCase = pexConfig.Field(
758 dtype=bool,
759 default=True,
760 doc=("Write per-band columns names with camelCase, else underscore "
761 "For example: gPsFlux instead of g_PsFlux.")
762 )
763 multilevelOutput = pexConfig.Field(
764 dtype=bool,
765 default=False,
766 doc=("Whether results dataframe should have a multilevel column index (True) or be flat "
767 "and name-munged (False).")
768 )
771class TransformObjectCatalogTask(TransformCatalogBaseTask):
772 """Produce a flattened Object Table to match the format specified in
773 sdm_schemas.
775 Do the same set of postprocessing calculations on all bands
777 This is identical to `TransformCatalogBaseTask`, except for that it does the
778 specified functor calculations for all filters present in the
779 input `deepCoadd_obj` table. Any specific `"filt"` keywords specified
780 by the YAML file will be superceded.
781 """
782 _DefaultName = "transformObjectCatalog"
783 ConfigClass = TransformObjectCatalogConfig
785 # Used by Gen 2 runDataRef only:
786 inputDataset = 'deepCoadd_obj'
787 outputDataset = 'objectTable'
789 @classmethod
790 def _makeArgumentParser(cls):
791 parser = ArgumentParser(name=cls._DefaultName)
792 parser.add_id_argument("--id", cls.inputDataset,
793 ContainerClass=CoaddDataIdContainer,
794 help="data ID, e.g. --id tract=12345 patch=1,2")
795 return parser
797 def run(self, parq, funcs=None, dataId=None, band=None):
798 # NOTE: band kwarg is ignored here.
799 dfDict = {}
800 analysisDict = {}
801 templateDf = pd.DataFrame()
803 if isinstance(parq, DeferredDatasetHandle):
804 columns = parq.get(component='columns')
805 inputBands = columns.unique(level=1).values
806 else:
807 inputBands = parq.columnLevelNames['band']
809 outputBands = self.config.outputBands if self.config.outputBands else inputBands
811 # Perform transform for data of filters that exist in parq.
812 for inputBand in inputBands:
813 if inputBand not in outputBands:
814 self.log.info("Ignoring %s band data in the input", inputBand)
815 continue
816 self.log.info("Transforming the catalog of band %s", inputBand)
817 result = self.transform(inputBand, parq, funcs, dataId)
818 dfDict[inputBand] = result.df
819 analysisDict[inputBand] = result.analysis
820 if templateDf.empty:
821 templateDf = result.df
823 # Fill NaNs in columns of other wanted bands
824 for filt in outputBands:
825 if filt not in dfDict:
826 self.log.info("Adding empty columns for band %s", filt)
827 dfDict[filt] = pd.DataFrame().reindex_like(templateDf)
829 # This makes a multilevel column index, with band as first level
830 df = pd.concat(dfDict, axis=1, names=['band', 'column'])
832 if not self.config.multilevelOutput:
833 noDupCols = list(set.union(*[set(v.noDupCols) for v in analysisDict.values()]))
834 if dataId is not None:
835 noDupCols += list(dataId.keys())
836 df = flattenFilters(df, noDupCols=noDupCols, camelCase=self.config.camelCase,
837 inputBands=inputBands)
839 self.log.info("Made a table of %d columns and %d rows", len(df.columns), len(df))
840 return df
843class TractObjectDataIdContainer(CoaddDataIdContainer):
845 def makeDataRefList(self, namespace):
846 """Make self.refList from self.idList
848 Generate a list of data references given tract and/or patch.
849 This was adapted from `TractQADataIdContainer`, which was
850 `TractDataIdContainer` modifie to not require "filter".
851 Only existing dataRefs are returned.
852 """
853 def getPatchRefList(tract):
854 return [namespace.butler.dataRef(datasetType=self.datasetType,
855 tract=tract.getId(),
856 patch="%d,%d" % patch.getIndex()) for patch in tract]
858 tractRefs = defaultdict(list) # Data references for each tract
859 for dataId in self.idList:
860 skymap = self.getSkymap(namespace)
862 if "tract" in dataId:
863 tractId = dataId["tract"]
864 if "patch" in dataId:
865 tractRefs[tractId].append(namespace.butler.dataRef(datasetType=self.datasetType,
866 tract=tractId,
867 patch=dataId['patch']))
868 else:
869 tractRefs[tractId] += getPatchRefList(skymap[tractId])
870 else:
871 tractRefs = dict((tract.getId(), tractRefs.get(tract.getId(), []) + getPatchRefList(tract))
872 for tract in skymap)
873 outputRefList = []
874 for tractRefList in tractRefs.values():
875 existingRefs = [ref for ref in tractRefList if ref.datasetExists()]
876 outputRefList.append(existingRefs)
878 self.refList = outputRefList
881class ConsolidateObjectTableConnections(pipeBase.PipelineTaskConnections,
882 dimensions=("tract", "skymap")):
883 inputCatalogs = connectionTypes.Input(
884 doc="Per-Patch objectTables conforming to the standard data model.",
885 name="objectTable",
886 storageClass="DataFrame",
887 dimensions=("tract", "patch", "skymap"),
888 multiple=True,
889 )
890 outputCatalog = connectionTypes.Output(
891 doc="Pre-tract horizontal concatenation of the input objectTables",
892 name="objectTable_tract",
893 storageClass="DataFrame",
894 dimensions=("tract", "skymap"),
895 )
898class ConsolidateObjectTableConfig(pipeBase.PipelineTaskConfig,
899 pipelineConnections=ConsolidateObjectTableConnections):
900 coaddName = pexConfig.Field(
901 dtype=str,
902 default="deep",
903 doc="Name of coadd"
904 )
907class ConsolidateObjectTableTask(CmdLineTask, pipeBase.PipelineTask):
908 """Write patch-merged source tables to a tract-level parquet file
910 Concatenates `objectTable` list into a per-visit `objectTable_tract`
911 """
912 _DefaultName = "consolidateObjectTable"
913 ConfigClass = ConsolidateObjectTableConfig
915 inputDataset = 'objectTable'
916 outputDataset = 'objectTable_tract'
918 def runQuantum(self, butlerQC, inputRefs, outputRefs):
919 inputs = butlerQC.get(inputRefs)
920 self.log.info("Concatenating %s per-patch Object Tables",
921 len(inputs['inputCatalogs']))
922 df = pd.concat(inputs['inputCatalogs'])
923 butlerQC.put(pipeBase.Struct(outputCatalog=df), outputRefs)
925 @classmethod
926 def _makeArgumentParser(cls):
927 parser = ArgumentParser(name=cls._DefaultName)
929 parser.add_id_argument("--id", cls.inputDataset,
930 help="data ID, e.g. --id tract=12345",
931 ContainerClass=TractObjectDataIdContainer)
932 return parser
934 def runDataRef(self, patchRefList):
935 df = pd.concat([patchRef.get().toDataFrame() for patchRef in patchRefList])
936 patchRefList[0].put(ParquetTable(dataFrame=df), self.outputDataset)
938 def writeMetadata(self, dataRef):
939 """No metadata to write.
940 """
941 pass
944class TransformSourceTableConnections(pipeBase.PipelineTaskConnections,
945 dimensions=("instrument", "visit", "detector")):
947 inputCatalog = connectionTypes.Input(
948 doc="Wide input catalog of sources produced by WriteSourceTableTask",
949 name="source",
950 storageClass="DataFrame",
951 dimensions=("instrument", "visit", "detector"),
952 deferLoad=True
953 )
954 outputCatalog = connectionTypes.Output(
955 doc="Narrower, per-detector Source Table transformed and converted per a "
956 "specified set of functors",
957 name="sourceTable",
958 storageClass="DataFrame",
959 dimensions=("instrument", "visit", "detector")
960 )
963class TransformSourceTableConfig(TransformCatalogBaseConfig,
964 pipelineConnections=TransformSourceTableConnections):
965 pass
968class TransformSourceTableTask(TransformCatalogBaseTask):
969 """Transform/standardize a source catalog
970 """
971 _DefaultName = "transformSourceTable"
972 ConfigClass = TransformSourceTableConfig
974 inputDataset = 'source'
975 outputDataset = 'sourceTable'
977 @classmethod
978 def _makeArgumentParser(cls):
979 parser = ArgumentParser(name=cls._DefaultName)
980 parser.add_id_argument("--id", datasetType=cls.inputDataset,
981 level="sensor",
982 help="data ID, e.g. --id visit=12345 ccd=0")
983 return parser
985 def runDataRef(self, dataRef):
986 """Override to specify band label to run()."""
987 parq = dataRef.get()
988 funcs = self.getFunctors()
989 band = dataRef.get("calexp_filterLabel", immediate=True).bandLabel
990 df = self.run(parq, funcs=funcs, dataId=dataRef.dataId, band=band)
991 self.write(df, dataRef)
992 return df
995class ConsolidateVisitSummaryConnections(pipeBase.PipelineTaskConnections,
996 dimensions=("instrument", "visit",),
997 defaultTemplates={}):
998 calexp = connectionTypes.Input(
999 doc="Processed exposures used for metadata",
1000 name="calexp",
1001 storageClass="ExposureF",
1002 dimensions=("instrument", "visit", "detector"),
1003 deferLoad=True,
1004 multiple=True,
1005 )
1006 visitSummary = connectionTypes.Output(
1007 doc=("Per-visit consolidated exposure metadata. These catalogs use "
1008 "detector id for the id and are sorted for fast lookups of a "
1009 "detector."),
1010 name="visitSummary",
1011 storageClass="ExposureCatalog",
1012 dimensions=("instrument", "visit"),
1013 )
1016class ConsolidateVisitSummaryConfig(pipeBase.PipelineTaskConfig,
1017 pipelineConnections=ConsolidateVisitSummaryConnections):
1018 """Config for ConsolidateVisitSummaryTask"""
1019 pass
1022class ConsolidateVisitSummaryTask(pipeBase.PipelineTask, pipeBase.CmdLineTask):
1023 """Task to consolidate per-detector visit metadata.
1025 This task aggregates the following metadata from all the detectors in a
1026 single visit into an exposure catalog:
1027 - The visitInfo.
1028 - The wcs.
1029 - The photoCalib.
1030 - The physical_filter and band (if available).
1031 - The psf size, shape, and effective area at the center of the detector.
1032 - The corners of the bounding box in right ascension/declination.
1034 Other quantities such as Detector, Psf, ApCorrMap, and TransmissionCurve
1035 are not persisted here because of storage concerns, and because of their
1036 limited utility as summary statistics.
1038 Tests for this task are performed in ci_hsc_gen3.
1039 """
1040 _DefaultName = "consolidateVisitSummary"
1041 ConfigClass = ConsolidateVisitSummaryConfig
1043 @classmethod
1044 def _makeArgumentParser(cls):
1045 parser = ArgumentParser(name=cls._DefaultName)
1047 parser.add_id_argument("--id", "calexp",
1048 help="data ID, e.g. --id visit=12345",
1049 ContainerClass=VisitDataIdContainer)
1050 return parser
1052 def writeMetadata(self, dataRef):
1053 """No metadata to persist, so override to remove metadata persistance.
1054 """
1055 pass
1057 def writeConfig(self, butler, clobber=False, doBackup=True):
1058 """No config to persist, so override to remove config persistance.
1059 """
1060 pass
1062 def runDataRef(self, dataRefList):
1063 visit = dataRefList[0].dataId['visit']
1065 self.log.debug("Concatenating metadata from %d per-detector calexps (visit %d)" %
1066 (len(dataRefList), visit))
1068 expCatalog = self._combineExposureMetadata(visit, dataRefList, isGen3=False)
1070 dataRefList[0].put(expCatalog, 'visitSummary', visit=visit)
1072 def runQuantum(self, butlerQC, inputRefs, outputRefs):
1073 dataRefs = butlerQC.get(inputRefs.calexp)
1074 visit = dataRefs[0].dataId.byName()['visit']
1076 self.log.debug("Concatenating metadata from %d per-detector calexps (visit %d)" %
1077 (len(dataRefs), visit))
1079 expCatalog = self._combineExposureMetadata(visit, dataRefs)
1081 butlerQC.put(expCatalog, outputRefs.visitSummary)
1083 def _combineExposureMetadata(self, visit, dataRefs, isGen3=True):
1084 """Make a combined exposure catalog from a list of dataRefs.
1086 Parameters
1087 ----------
1088 visit : `int`
1089 Visit identification number
1090 dataRefs : `list`
1091 List of calexp dataRefs in visit. May be list of
1092 `lsst.daf.persistence.ButlerDataRef` (Gen2) or
1093 `lsst.daf.butler.DeferredDatasetHandle` (Gen3).
1094 isGen3 : `bool`, optional
1095 Specifies if this is a Gen3 list of datarefs.
1097 Returns
1098 -------
1099 visitSummary : `lsst.afw.table.ExposureCatalog`
1100 Exposure catalog with per-detector summary information.
1101 """
1102 schema = afwTable.ExposureTable.makeMinimalSchema()
1103 schema.addField('visit', type='I', doc='Visit number')
1104 schema.addField('physical_filter', type='String', size=32, doc='Physical filter')
1105 schema.addField('band', type='String', size=32, doc='Name of band')
1106 schema.addField('psfSigma', type='F',
1107 doc='PSF model second-moments determinant radius (center of chip) (pixel)')
1108 schema.addField('psfArea', type='F',
1109 doc='PSF model effective area (center of chip) (pixel**2)')
1110 schema.addField('psfIxx', type='F',
1111 doc='PSF model Ixx (center of chip) (pixel**2)')
1112 schema.addField('psfIyy', type='F',
1113 doc='PSF model Iyy (center of chip) (pixel**2)')
1114 schema.addField('psfIxy', type='F',
1115 doc='PSF model Ixy (center of chip) (pixel**2)')
1116 schema.addField('raCorners', type='ArrayD', size=4,
1117 doc='Right Ascension of bounding box corners (degrees)')
1118 schema.addField('decCorners', type='ArrayD', size=4,
1119 doc='Declination of bounding box corners (degrees)')
1121 cat = afwTable.ExposureCatalog(schema)
1122 cat.resize(len(dataRefs))
1124 cat['visit'] = visit
1126 for i, dataRef in enumerate(dataRefs):
1127 if isGen3:
1128 visitInfo = dataRef.get(component='visitInfo')
1129 filterLabel = dataRef.get(component='filterLabel')
1130 psf = dataRef.get(component='psf')
1131 wcs = dataRef.get(component='wcs')
1132 photoCalib = dataRef.get(component='photoCalib')
1133 detector = dataRef.get(component='detector')
1134 bbox = dataRef.get(component='bbox')
1135 validPolygon = dataRef.get(component='validPolygon')
1136 else:
1137 # Note that we need to read the calexp because there is
1138 # no magic access to the psf except through the exposure.
1139 gen2_read_bbox = lsst.geom.BoxI(lsst.geom.PointI(0, 0), lsst.geom.PointI(1, 1))
1140 exp = dataRef.get(datasetType='calexp_sub', bbox=gen2_read_bbox)
1141 visitInfo = exp.getInfo().getVisitInfo()
1142 filterLabel = dataRef.get("calexp_filterLabel")
1143 psf = exp.getPsf()
1144 wcs = exp.getWcs()
1145 photoCalib = exp.getPhotoCalib()
1146 detector = exp.getDetector()
1147 bbox = dataRef.get(datasetType='calexp_bbox')
1148 validPolygon = exp.getInfo().getValidPolygon()
1150 rec = cat[i]
1151 rec.setBBox(bbox)
1152 rec.setVisitInfo(visitInfo)
1153 rec.setWcs(wcs)
1154 rec.setPhotoCalib(photoCalib)
1155 rec.setValidPolygon(validPolygon)
1157 rec['physical_filter'] = filterLabel.physicalLabel if filterLabel.hasPhysicalLabel() else ""
1158 rec['band'] = filterLabel.bandLabel if filterLabel.hasBandLabel() else ""
1159 rec.setId(detector.getId())
1160 shape = psf.computeShape(bbox.getCenter())
1161 rec['psfSigma'] = shape.getDeterminantRadius()
1162 rec['psfIxx'] = shape.getIxx()
1163 rec['psfIyy'] = shape.getIyy()
1164 rec['psfIxy'] = shape.getIxy()
1165 im = psf.computeKernelImage(bbox.getCenter())
1166 # The calculation of effective psf area is taken from
1167 # meas_base/src/PsfFlux.cc#L112. See
1168 # https://github.com/lsst/meas_base/blob/
1169 # 750bffe6620e565bda731add1509507f5c40c8bb/src/PsfFlux.cc#L112
1170 rec['psfArea'] = np.sum(im.array)/np.sum(im.array**2.)
1172 sph_pts = wcs.pixelToSky(lsst.geom.Box2D(bbox).getCorners())
1173 rec['raCorners'][:] = [sph.getRa().asDegrees() for sph in sph_pts]
1174 rec['decCorners'][:] = [sph.getDec().asDegrees() for sph in sph_pts]
1176 metadata = dafBase.PropertyList()
1177 metadata.add("COMMENT", "Catalog id is detector id, sorted.")
1178 # We are looping over existing datarefs, so the following is true
1179 metadata.add("COMMENT", "Only detectors with data have entries.")
1180 cat.setMetadata(metadata)
1182 cat.sort()
1183 return cat
1186class VisitDataIdContainer(DataIdContainer):
1187 """DataIdContainer that groups sensor-level id's by visit
1188 """
1190 def makeDataRefList(self, namespace):
1191 """Make self.refList from self.idList
1193 Generate a list of data references grouped by visit.
1195 Parameters
1196 ----------
1197 namespace : `argparse.Namespace`
1198 Namespace used by `lsst.pipe.base.CmdLineTask` to parse command line arguments
1199 """
1200 # Group by visits
1201 visitRefs = defaultdict(list)
1202 for dataId in self.idList:
1203 if "visit" in dataId:
1204 visitId = dataId["visit"]
1205 # append all subsets to
1206 subset = namespace.butler.subset(self.datasetType, dataId=dataId)
1207 visitRefs[visitId].extend([dataRef for dataRef in subset])
1209 outputRefList = []
1210 for refList in visitRefs.values():
1211 existingRefs = [ref for ref in refList if ref.datasetExists()]
1212 if existingRefs:
1213 outputRefList.append(existingRefs)
1215 self.refList = outputRefList
1218class ConsolidateSourceTableConnections(pipeBase.PipelineTaskConnections,
1219 dimensions=("instrument", "visit")):
1220 inputCatalogs = connectionTypes.Input(
1221 doc="Input per-detector Source Tables",
1222 name="sourceTable",
1223 storageClass="DataFrame",
1224 dimensions=("instrument", "visit", "detector"),
1225 multiple=True
1226 )
1227 outputCatalog = connectionTypes.Output(
1228 doc="Per-visit concatenation of Source Table",
1229 name="sourceTable_visit",
1230 storageClass="DataFrame",
1231 dimensions=("instrument", "visit")
1232 )
1235class ConsolidateSourceTableConfig(pipeBase.PipelineTaskConfig,
1236 pipelineConnections=ConsolidateSourceTableConnections):
1237 pass
1240class ConsolidateSourceTableTask(CmdLineTask, pipeBase.PipelineTask):
1241 """Concatenate `sourceTable` list into a per-visit `sourceTable_visit`
1242 """
1243 _DefaultName = 'consolidateSourceTable'
1244 ConfigClass = ConsolidateSourceTableConfig
1246 inputDataset = 'sourceTable'
1247 outputDataset = 'sourceTable_visit'
1249 def runQuantum(self, butlerQC, inputRefs, outputRefs):
1250 inputs = butlerQC.get(inputRefs)
1251 self.log.info("Concatenating %s per-detector Source Tables",
1252 len(inputs['inputCatalogs']))
1253 df = pd.concat(inputs['inputCatalogs'])
1254 butlerQC.put(pipeBase.Struct(outputCatalog=df), outputRefs)
1256 def runDataRef(self, dataRefList):
1257 self.log.info("Concatenating %s per-detector Source Tables", len(dataRefList))
1258 df = pd.concat([dataRef.get().toDataFrame() for dataRef in dataRefList])
1259 dataRefList[0].put(ParquetTable(dataFrame=df), self.outputDataset)
1261 @classmethod
1262 def _makeArgumentParser(cls):
1263 parser = ArgumentParser(name=cls._DefaultName)
1265 parser.add_id_argument("--id", cls.inputDataset,
1266 help="data ID, e.g. --id visit=12345",
1267 ContainerClass=VisitDataIdContainer)
1268 return parser
1270 def writeMetadata(self, dataRef):
1271 """No metadata to write.
1272 """
1273 pass
1275 def writeConfig(self, butler, clobber=False, doBackup=True):
1276 """No config to write.
1277 """
1278 pass