lsst.pipe.tasks gd61da0c3fd+5e10b5a532
Loading...
Searching...
No Matches
drpAssociationPipe.py
Go to the documentation of this file.
1# This file is part of pipe_tasks.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21
22"""Pipeline for running DiaSource association in a DRP context.
23"""
24
25__all__ = ["DrpAssociationPipeTask",
26 "DrpAssociationPipeConfig",
27 "DrpAssociationPipeConnections"]
28
29import numpy as np
30import pandas as pd
31
32import lsst.geom as geom
33import lsst.pex.config as pexConfig
34import lsst.pipe.base as pipeBase
35from lsst.skymap import BaseSkyMap
36
37from .coaddBase import makeSkyInfo
38from .simpleAssociation import SimpleAssociationTask
39
40
41class DrpAssociationPipeConnections(pipeBase.PipelineTaskConnections,
42 dimensions=("tract", "patch", "skymap"),
43 defaultTemplates={"coaddName": "deep",
44 "warpTypeSuffix": "",
45 "fakesType": ""}):
46 diaSourceTables = pipeBase.connectionTypes.Input(
47 doc="Set of catalogs of calibrated DiaSources.",
48 name="{fakesType}{coaddName}Diff_diaSrcTable",
49 storageClass="DataFrame",
50 dimensions=("instrument", "visit", "detector"),
51 deferLoad=True,
52 multiple=True
53 )
54 skyMap = pipeBase.connectionTypes.Input(
55 doc="Input definition of geometry/bbox and projection/wcs for coadded "
56 "exposures",
57 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
58 storageClass="SkyMap",
59 dimensions=("skymap", ),
60 )
61 assocDiaSourceTable = pipeBase.connectionTypes.Output(
62 doc="Catalog of DiaSources covering the patch and associated with a "
63 "DiaObject.",
64 name="{fakesType}{coaddName}Diff_assocDiaSrcTable",
65 storageClass="DataFrame",
66 dimensions=("tract", "patch"),
67 )
68 diaObjectTable = pipeBase.connectionTypes.Output(
69 doc="Catalog of DiaObjects created from spatially associating "
70 "DiaSources.",
71 name="{fakesType}{coaddName}Diff_diaObjTable",
72 storageClass="DataFrame",
73 dimensions=("tract", "patch"),
74 )
75
76
77class DrpAssociationPipeConfig(
78 pipeBase.PipelineTaskConfig,
79 pipelineConnections=DrpAssociationPipeConnections):
80 associator = pexConfig.ConfigurableField(
81 target=SimpleAssociationTask,
82 doc="Task used to associate DiaSources with DiaObjects.",
83 )
84 doAddDiaObjectCoords = pexConfig.Field(
85 dtype=bool,
86 default=True,
87 doc="Do pull diaObject's average coordinate as coord_ra and coord_dec"
88 "Duplicates information, but needed for bulk ingest into qserv."
89 )
90 doWriteEmptyTables = pexConfig.Field(
91 dtype=bool,
92 default=False,
93 doc="If True, construct and write out empty diaSource and diaObject "
94 "tables. If False, raise NoWorkFound"
95 )
96
97
98class DrpAssociationPipeTask(pipeBase.PipelineTask):
99 """Driver pipeline for loading DiaSource catalogs in a patch/tract
100 region and associating them.
101 """
102 ConfigClass = DrpAssociationPipeConfig
103 _DefaultName = "drpAssociation"
104
105 def __init__(self, **kwargs):
106 super().__init__(**kwargs)
107 self.makeSubtask('associator')
108
109 def runQuantum(self, butlerQC, inputRefs, outputRefs):
110 inputs = butlerQC.get(inputRefs)
111
112 inputs["tractId"] = butlerQC.quantum.dataId["tract"]
113 inputs["patchId"] = butlerQC.quantum.dataId["patch"]
114 tractPatchId, skymapBits = butlerQC.quantum.dataId.pack(
115 "tract_patch",
116 returnMaxBits=True)
117 inputs["tractPatchId"] = tractPatchId
118 inputs["skymapBits"] = skymapBits
119
120 outputs = self.run(**inputs)
121 butlerQC.put(outputs, outputRefs)
122
123 def run(self,
124 diaSourceTables,
125 skyMap,
126 tractId,
127 patchId,
128 tractPatchId,
129 skymapBits):
130 """Trim DiaSources to the current Patch and run association.
131
132 Takes in the set of DiaSource catalogs that covers the current patch,
133 trims them to the dimensions of the patch, and [TODO: eventually]
134 runs association on the concatenated DiaSource Catalog.
135
136 Parameters
137 ----------
138 diaSourceTables : `list` of `lst.daf.butler.DeferredDatasetHandle`
139 Set of DiaSource catalogs potentially covering this patch/tract.
140 skyMap : `lsst.skymap.BaseSkyMap`
141 SkyMap defining the patch/tract
142 tractId : `int`
143 Id of current tract being processed.
144 patchId : `int`
145 Id of current patch being processed
146
147 Returns
148 -------
149 output : `lsst.pipe.base.Struct`
150 Results struct with attributes:
151
152 ``assocDiaSourceTable``
153 Table of DiaSources with updated value for diaObjectId.
154 (`pandas.DataFrame`)
155 ``diaObjectTable``
156 Table of DiaObjects from matching DiaSources
157 (`pandas.DataFrame`).
158 """
159 self.log.info("Running DPR Association on patch %i, tract %i...",
160 patchId, tractId)
161
162 skyInfo = makeSkyInfo(skyMap, tractId, patchId)
163
164 # Get the patch bounding box.
165 innerPatchBox = geom.Box2D(skyInfo.patchInfo.getInnerBBox())
166
167 diaSourceHistory = []
168 for catRef in diaSourceTables:
169 cat = catRef.get()
170
171 isInTractPatch = self._trimToPatch(cat,
172 innerPatchBox,
173 skyInfo.wcs)
174
175 nDiaSrc = isInTractPatch.sum()
176 self.log.info(
177 "Read DiaSource catalog of length %i from visit %i, "
178 "detector %i. Found %i sources within the patch/tract "
179 "footprint.",
180 len(cat), catRef.dataId["visit"],
181 catRef.dataId["detector"], nDiaSrc)
182
183 if nDiaSrc <= 0:
184 continue
185
186 cutCat = cat[isInTractPatch]
187 diaSourceHistory.append(cutCat)
188
189 if diaSourceHistory:
190 diaSourceHistoryCat = pd.concat(diaSourceHistory)
191 else:
192 # No rows to associate
193 if self.config.doWriteEmptyTables:
194 self.log.info("Constructing empty table")
195 # Construct empty table using last table and dropping all the rows
196 diaSourceHistoryCat = cat.drop(cat.index)
197 else:
198 raise pipeBase.NoWorkFound("Found no overlapping DIASources to associate.")
199
200 self.log.info("Found %i DiaSources overlapping patch %i, tract %i",
201 len(diaSourceHistoryCat), patchId, tractId)
202
203 assocResult = self.associator.run(diaSourceHistoryCat,
204 tractPatchId,
205 skymapBits)
206
207 self.log.info("Associated DiaSources into %i DiaObjects",
208 len(assocResult.diaObjects))
209
210 if self.config.doAddDiaObjectCoords:
211 assocResult.assocDiaSources = self._addDiaObjectCoords(assocResult.diaObjects,
212 assocResult.assocDiaSources)
213
214 return pipeBase.Struct(
215 diaObjectTable=assocResult.diaObjects,
216 assocDiaSourceTable=assocResult.assocDiaSources)
217
218 def _addDiaObjectCoords(self, objects, sources):
219 obj = objects[['ra', 'decl']].rename(columns={"ra": "coord_ra", "decl": "coord_dec"})
220 df = pd.merge(sources.reset_index(), obj, left_on='diaObjectId', right_index=True,
221 how='inner').set_index('diaSourceId')
222 return df
223
224 def _trimToPatch(self, cat, innerPatchBox, wcs):
225 """Create generator testing if a set of DiaSources are in the
226 patch/tract.
227
228 Parameters
229 ----------
230 cat : `pandas.DataFrame`
231 Catalog of DiaSources to test within patch/tract.
232 innerPatchBox : `lsst.geom.Box2D`
233 Bounding box of the patch.
234 wcs : `lsst.geom.SkyWcs`
235 Wcs of the tract.
236
237 Returns
238 ------
239 isInPatch : `numpy.ndarray`, (N,)
240 Booleans representing if the DiaSources are contained within the
241 current patch and tract.
242 """
243 isInPatch = np.array([
244 innerPatchBox.contains(
245 wcs.skyToPixel(
246 geom.SpherePoint(row["ra"], row["decl"], geom.degrees)))
247 for idx, row in cat.iterrows()])
248 return isInPatch