Coverage for python / lsst / pipe / tasks / drpAssociationPipe.py: 21%
173 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 09:11 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 09:11 +0000
1# This file is part of pipe_tasks.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22"""Pipeline for running DiaSource association in a DRP context.
23"""
25__all__ = ["DrpAssociationPipeTask",
26 "DrpAssociationPipeConfig",
27 "DrpAssociationPipeConnections"]
29import os
31import astropy.table as tb
32import numpy as np
33import pandas as pd
36from lsst.pipe.tasks.ssoAssociation import SolarSystemAssociationTask
37import lsst.geom as geom
38import lsst.pex.config as pexConfig
39import lsst.pipe.base as pipeBase
40from lsst.meas.base import SkyMapIdGeneratorConfig
41from lsst.skymap import BaseSkyMap
43from .coaddBase import makeSkyInfo
44from .schemaUtils import convertDataFrameToSdmSchema, readSdmSchemaFile
45from .simpleAssociation import SimpleAssociationTask
48class DrpAssociationPipeConnections(pipeBase.PipelineTaskConnections,
49 dimensions=("tract", "patch", "skymap"),
50 defaultTemplates={"coaddName": "deep",
51 "warpTypeSuffix": "",
52 "fakesType": ""}):
53 diaSourceTables = pipeBase.connectionTypes.Input(
54 doc="Set of catalogs of calibrated DiaSources.",
55 name="{fakesType}{coaddName}Diff_diaSrcTable",
56 storageClass="ArrowAstropy",
57 dimensions=("instrument", "visit", "detector"),
58 deferLoad=True,
59 multiple=True
60 )
61 ssObjectTableRefs = pipeBase.connectionTypes.Input(
62 doc="Reference to catalogs of SolarSolarSystem objects expected to be "
63 "observable in each (visit, detector).",
64 name="preloaded_ss_object_visit",
65 storageClass="ArrowAstropy",
66 dimensions=("instrument", "visit"),
67 minimum=0,
68 deferLoad=True,
69 multiple=True
70 )
71 skyMap = pipeBase.connectionTypes.Input(
72 doc="Input definition of geometry/bbox and projection/wcs for coadded "
73 "exposures",
74 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
75 storageClass="SkyMap",
76 dimensions=("skymap", ),
77 )
78 finalVisitSummaryRefs = pipeBase.connectionTypes.Input(
79 doc="Reference to finalVisitSummary of each exposure, containing "
80 "visitInfo, bbox, and wcs.",
81 name="finalVisitSummary",
82 storageClass="ExposureCatalog",
83 dimensions=("instrument", "visit"),
84 deferLoad=True,
85 multiple=True
86 )
87 assocDiaSourceTable = pipeBase.connectionTypes.Output(
88 doc="Catalog of DiaSources covering the patch and associated with a "
89 "DiaObject.",
90 name="{fakesType}{coaddName}Diff_assocDiaSrcTable",
91 storageClass="DataFrame",
92 dimensions=("tract", "patch"),
93 )
94 associatedSsSources = pipeBase.connectionTypes.Output(
95 doc="Optional output storing ssSource data computed during association.",
96 name="{fakesType}{coaddName}Diff_assocSsSrcTable",
97 storageClass="ArrowAstropy",
98 dimensions=("tract", "patch"),
99 )
100 unassociatedSsObjects = pipeBase.connectionTypes.Output(
101 doc="Expected locations of ssObjects with no associated source.",
102 name="{fakesType}{coaddName}Diff_unassocSsObjTable",
103 storageClass="ArrowAstropy",
104 dimensions=("tract", "patch"),
105 )
106 diaObjectTable = pipeBase.connectionTypes.Output(
107 doc="Catalog of DiaObjects created from spatially associating "
108 "DiaSources.",
109 name="{fakesType}{coaddName}Diff_diaObjTable",
110 storageClass="DataFrame",
111 dimensions=("tract", "patch"),
112 )
114 def __init__(self, *, config=None):
115 super().__init__(config=config)
117 if not config.doSolarSystemAssociation:
118 del self.ssObjectTableRefs
119 del self.associatedSsSources
120 del self.unassociatedSsObjects
121 del self.finalVisitSummaryRefs
124class DrpAssociationPipeConfig(
125 pipeBase.PipelineTaskConfig,
126 pipelineConnections=DrpAssociationPipeConnections):
127 associator = pexConfig.ConfigurableField(
128 target=SimpleAssociationTask,
129 doc="Task used to associate DiaSources with DiaObjects.",
130 )
131 solarSystemAssociator = pexConfig.ConfigurableField(
132 target=SolarSystemAssociationTask,
133 doc="Task used to associate DiaSources with SolarSystemObjects.",
134 )
135 doAddDiaObjectCoords = pexConfig.Field(
136 dtype=bool,
137 default=True,
138 doc="Do pull diaObject's average coordinate as coord_ra and coord_dec"
139 "Duplicates information, but needed for bulk ingest into qserv."
140 )
141 doWriteEmptyTables = pexConfig.Field(
142 dtype=bool,
143 default=False,
144 doc="If True, construct and write out empty diaSource and diaObject "
145 "tables. If False, raise NoWorkFound"
146 )
147 doSolarSystemAssociation = pexConfig.Field(
148 dtype=bool,
149 default=True,
150 doc="Process SolarSystem objects through the pipeline.",
151 )
152 doUseSchema = pexConfig.Field(
153 dtype=bool,
154 default=False,
155 doc="Use an existing schema to coerce the data types of the output columns."
156 )
157 schemaDir = pexConfig.Field(
158 dtype=str,
159 doc="Path to the directory containing schema definitions.",
160 default=os.path.join("${SDM_SCHEMAS_DIR}",
161 "yml"),
162 )
163 schemaFile = pexConfig.Field(
164 dtype=str,
165 doc="Yaml file specifying the schema of the output catalog.",
166 default="lsstcam.yaml",
167 )
168 idGenerator = SkyMapIdGeneratorConfig.make_field()
171class DrpAssociationPipeTask(pipeBase.PipelineTask):
172 """Driver pipeline for loading DiaSource catalogs in a patch/tract
173 region and associating them.
174 """
175 ConfigClass = DrpAssociationPipeConfig
176 _DefaultName = "drpAssociation"
178 def __init__(self, **kwargs):
179 super().__init__(**kwargs)
180 self.makeSubtask('associator')
182 if self.config.doUseSchema:
183 schemaFile = os.path.join(self.config.schemaDir, self.config.schemaFile)
184 self.schema = readSdmSchemaFile(schemaFile)
185 else:
186 self.schema = None
187 if self.config.doSolarSystemAssociation:
188 self.makeSubtask("solarSystemAssociator")
190 def runQuantum(self, butlerQC, inputRefs, outputRefs):
191 inputs = butlerQC.get(inputRefs)
193 inputs["tractId"] = butlerQC.quantum.dataId["tract"]
194 inputs["patchId"] = butlerQC.quantum.dataId["patch"]
195 inputs["idGenerator"] = self.config.idGenerator.apply(butlerQC.quantum.dataId)
196 if not self.config.doSolarSystemAssociation:
197 inputs["ssObjectTableRefs"] = []
198 inputs["finalVisitSummaryRefs"] = []
199 outputs = self.run(**inputs)
200 butlerQC.put(outputs, outputRefs)
202 def run(self,
203 diaSourceTables,
204 ssObjectTableRefs,
205 skyMap,
206 finalVisitSummaryRefs,
207 tractId,
208 patchId,
209 idGenerator=None):
210 """Trim DiaSources to the current Patch and run association.
212 Takes in the set of DiaSource catalogs that covers the current patch,
213 trims them to the dimensions of the patch, and [TODO: eventually]
214 runs association on the concatenated DiaSource Catalog.
216 Parameters
217 ----------
218 diaSourceTables : `list` of `lsst.daf.butler.DeferredDatasetHandle`
219 Set of DiaSource catalogs potentially covering this patch/tract.
220 ssObjectTableRefs : `list` of `lsst.daf.butler.DeferredDatasetHandle`
221 Set of known SSO ephemerides potentially covering this patch/tract.
222 skyMap : `lsst.skymap.BaseSkyMap`
223 SkyMap defining the patch/tract
224 finalVisitSummaryRefs : `list` of `lsst.daf.butler.DeferredDatasetHandle`
225 Reference to finalVisitSummary of each exposure potentially
226 covering this patch/tract, which contain visitInfo, bbox, and wcs
227 tractId : `int`
228 Id of current tract being processed.
229 patchId : `int`
230 Id of current patch being processed.
231 idGenerator : `lsst.meas.base.IdGenerator`, optional
232 Object that generates Object IDs and random number generator seeds.
234 Returns
235 -------
236 output : `lsst.pipe.base.Struct`
237 Results struct with attributes:
239 ``assocDiaSourceTable``
240 Table of DiaSources with updated value for diaObjectId.
241 (`pandas.DataFrame`)
242 ``diaObjectTable``
243 Table of DiaObjects from matching DiaSources
244 (`pandas.DataFrame`).
245 """
246 self.log.info("Running DPR Association on patch %i, tract %i...",
247 patchId, tractId)
249 skyInfo = makeSkyInfo(skyMap, tractId, patchId)
251 # Get the patch bounding box.
252 innerPatchBox = geom.Box2D(skyInfo.patchInfo.getInnerBBox())
253 outerPatchBox = geom.Box2D(skyInfo.patchInfo.getOuterBBox())
254 innerTractSkyRegion = skyInfo.tractInfo.getInnerSkyRegion()
256 # Keep track of our diaCats, ssObject cats, and finalVisitSummaries by their (visit, detector) IDs
257 diaIdDict = prepareCatalogDict(diaSourceTables, useVisitDetector=True)
258 ssObjectIdDict = prepareCatalogDict(ssObjectTableRefs, useVisitDetector=False)
259 finalVisitSummaryIdDict = prepareCatalogDict(finalVisitSummaryRefs, useVisitDetector=False)
261 # diaSourceHistory: non-ss diaSources to be made into diaObjects.
262 # ssDiaSourceHistory: sso-associated diaSources which skip diaObject creation,
263 # but are included in diaSource.
264 diaSourceHistory, ssDiaSourceHistory, ssSourceHistory, unassociatedSsObjectHistory = [], [], [], []
265 nSsSrc, nSsObj = 0, 0
266 visits = set([v for v, _ in diaIdDict.keys()])
267 for visit in visits:
268 # visit summaries and Solar System catalogs are per-visit, so only
269 # load them once for all detectors with that visit
270 visitSummary = finalVisitSummaryIdDict[visit].get() if visit in finalVisitSummaryIdDict else None
271 ssCat = ssObjectIdDict[visit].get() if visit in ssObjectIdDict else None
272 detectors = [det for (v, det) in diaIdDict.keys() if v == visit]
273 for detector in detectors:
274 diaCat = diaIdDict[(visit, detector)].get()
275 nDiaSrcIn = len(diaCat)
276 if (ssCat is not None) and (visitSummary is not None):
277 ssoAssocResult = self.runSolarSystemAssociation(diaCat,
278 ssCat.copy(),
279 visitSummary=visitSummary,
280 patchBbox=innerPatchBox,
281 patchWcs=skyInfo.wcs,
282 innerTractSkyRegion=innerTractSkyRegion,
283 detector=detector,
284 visit=visit,
285 )
287 nSsSrc = len(ssoAssocResult.associatedSsSources)
288 nSsObj = len(ssoAssocResult.unassociatedSsObjects)
289 # If diaSources were associated with Solar System objects,
290 # remove them from the catalog so they won't create new
291 # diaObjects or be associated with other diaObjects.
292 diaCat = ssoAssocResult.unassociatedDiaSources
293 else:
294 nSsSrc, nSsObj = 0, 0
296 # Only trim diaSources to the outer bbox of the patch, so that
297 # diaSources near the patch boundary can be associated.
298 # DiaObjects will be trimmed to the inner patch bbox, and any
299 # diaSources associated with dropped diaObjects will also be dropped
300 diaInPatch = self._trimToPatch(diaCat.to_pandas(), outerPatchBox, skyInfo.wcs)
302 nDiaSrc = diaInPatch.sum()
304 self.log.info(
305 "Read DiaSource catalog of length %i from visit %i, "
306 "detector %i. Found %i sources within the patch/tract "
307 "footprint, including %i associated with SSOs.",
308 nDiaSrcIn, visit, detector, nDiaSrc + nSsSrc, nSsSrc)
310 if nDiaSrc > 0:
311 diaSourceHistory.append(diaCat[diaInPatch])
312 if nSsSrc > 0:
313 ssSourceHistory.append(ssoAssocResult.associatedSsSources)
314 ssDiaSourceHistory.append(ssoAssocResult.associatedSsDiaSources)
316 if nSsObj > 0:
317 unassociatedSsObjectHistory.append(ssoAssocResult.unassociatedSsObjects)
319 # After looping over all of the detector-level catalogs that overlap the
320 # patch, combine them into patch-level catalogs
321 diaSourceHistoryCat = self._stackCatalogs(diaSourceHistory)
322 ssDiaSourceHistoryCat = self._stackCatalogs(ssDiaSourceHistory)
323 if self.config.doSolarSystemAssociation:
324 ssSourceHistoryCat = self._stackCatalogs(ssSourceHistory, remove_columns=['ra', 'dec'])
325 nSsSrcTotal = len(ssSourceHistoryCat) if ssSourceHistoryCat else 0
326 unassociatedSsObjectHistoryCat = self._stackCatalogs(unassociatedSsObjectHistory)
327 nSsObjTotal = len(unassociatedSsObjectHistoryCat) if unassociatedSsObjectHistoryCat else 0
328 self.log.info("Found %i ssSources and %i missing ssObjects in patch %i, tract %i",
329 nSsSrcTotal, nSsObjTotal, patchId, tractId)
330 else:
331 ssSourceHistoryCat = None
332 unassociatedSsObjectHistoryCat = None
334 if (not diaSourceHistory) and (not ssSourceHistory):
335 if not self.config.doWriteEmptyTables:
336 raise pipeBase.NoWorkFound("Found no overlapping DIASources to associate.")
338 self.log.info("Found %i DiaSources overlapping patch %i, tract %i",
339 len(diaSourceHistoryCat), patchId, tractId)
341 diaSourceTable = diaSourceHistoryCat.to_pandas()
342 diaSourceTable.set_index("diaSourceId", drop=False)
343 # Now run diaObject association on the stacked remaining diaSources
344 assocResult = self.associator.run(diaSourceTable, idGenerator=idGenerator)
346 # Drop any diaObjects that were created outside the inner region of the
347 # patch. These will be associated in the overlapping patch instead.
348 objectsInTractPatch = self._trimToPatch(assocResult.diaObjects,
349 innerPatchBox,
350 skyInfo.wcs,
351 innerTractSkyRegion=innerTractSkyRegion)
352 diaObjects = assocResult.diaObjects[objectsInTractPatch]
353 # Instead of dropping diaSources based on their patch, assign them to a
354 # patch based on whether their diaObject was inside. This means that
355 # some diaSources in the patch catalog will actually have coordinates
356 # just outside the patch.
357 assocDiaSources = self.dropDiaSourceByDiaObjectId(assocResult.diaObjects[~objectsInTractPatch].index,
358 assocResult.assocDiaSources)
360 self.log.info("Associated DiaSources into %i DiaObjects", len(diaObjects))
362 if self.config.doAddDiaObjectCoords:
363 assocDiaSources = self._addDiaObjectCoords(diaObjects, assocDiaSources)
364 if ssDiaSourceHistoryCat:
365 ssDiaSourceHistoryCat = ssDiaSourceHistoryCat.to_pandas().set_index("diaSourceId", drop=True)
366 assocDiaSources = pd.concat([assocDiaSources, ssDiaSourceHistoryCat])
367 if self.config.doUseSchema:
368 diaObjects = convertDataFrameToSdmSchema(self.schema, diaObjects, tableName="DiaObject")
369 assocDiaSources = convertDataFrameToSdmSchema(self.schema, assocDiaSources, tableName="DiaSource")
370 return pipeBase.Struct(
371 diaObjectTable=diaObjects,
372 assocDiaSourceTable=assocDiaSources,
373 associatedSsSources=ssSourceHistoryCat,
374 unassociatedSsObjects=unassociatedSsObjectHistoryCat,
375 )
377 def _stackCatalogs(self, catalogs, remove_columns=None, empty=None):
378 """Stack a list of catalogs.
380 Parameters
381 ----------
382 catalogs : `list` of `astropy.table.Table`
383 Input catalogs with the same columns to be combined.
384 remove_columns : `list` of `str` or None, optional
385 List of column names to drop from the tables before stacking.
387 Returns
388 -------
389 `astropy.table.Table`
390 The combined catalog.
391 """
392 if catalogs:
393 sourceHistory = tb.vstack(catalogs)
394 if remove_columns is not None:
395 sourceHistory.remove_columns(remove_columns)
396 return sourceHistory
397 else:
398 return empty
400 def runSolarSystemAssociation(self, diaCat, ssCat,
401 visitSummary,
402 patchBbox,
403 patchWcs,
404 innerTractSkyRegion,
405 detector,
406 visit,
407 ):
408 """Run Solar System object association and filter the results.
410 Parameters
411 ----------
412 diaCat : `pandas.DataFrame`
413 Catalog of detected diaSources on the image difference.
414 ssCat : `astropy.table.Table`
415 Catalog of predicted coordinates of known Solar System objects.
416 visitSummary : `lsst.afw.table.ExposureCatalog`
417 Table of calibration and metadata for all detectors in a visit.
418 patchBbox : `lsst.geom.Box2D`
419 Bounding box of the patch.
420 patchWcs : `lsst.geom.SkyWcs`
421 Wcs of the tract containing the patch.
422 innerTractSkyRegion : `lsst.sphgeom.Box`
423 Region defining the inner non-overlapping part of a tract.
424 detector : `int`
425 Detector number of the science exposure.
426 visit : `int`
427 Visit number of the science exposure.
429 Returns
430 -------
431 ssoAssocResult : `lsst.pipe.base.Struct`
432 Results struct with attributes:
434 ``associatedSsSources``
435 Table of DiaSources associated with Solar System objects.
436 (`astropy.table.Table`)
437 ``associatedSsDiaSources``
438 Table of Solar System objects associated with DiaSources.
439 (`astropy.table.Table`).
440 ``unassociatedSsObjects``
441 Table of Solar System objects in the patch not associated with
442 any DiaSource (`astropy.table.Table`).
443 ``unassociatedDiaSources``
444 Table of DiaSources not associated with any Solar System object
445 (`astropy.table.Table`).
446 """
447 # Get the exposure metadata from the detector's row in the visitSummary table.
448 ssoAssocResult = self.solarSystemAssociator.run(
449 diaCat,
450 ssCat,
451 visitInfo=visitSummary.find(detector).visitInfo,
452 bbox=visitSummary.find(detector).getBBox(),
453 wcs=visitSummary.find(detector).wcs,
454 )
456 ssInTractPatch = self._trimToPatch(ssoAssocResult.associatedSsSources.to_pandas(),
457 patchBbox,
458 patchWcs,
459 innerTractSkyRegion=innerTractSkyRegion)
460 associatedSsSources = ssoAssocResult.associatedSsSources[ssInTractPatch].copy()
461 assocDiaSrcIds = set(associatedSsSources['diaSourceId'])
462 diaSrcMask = [diaId in assocDiaSrcIds for diaId in ssoAssocResult.ssoAssocDiaSources['diaSourceId']]
463 associatedSsDiaSources = ssoAssocResult.ssoAssocDiaSources[np.array(diaSrcMask)]
465 ssObjInTractPatch = self._trimToPatch(ssoAssocResult.unassociatedSsObjects.to_pandas(),
466 patchBbox,
467 patchWcs,
468 innerTractSkyRegion=innerTractSkyRegion)
469 unassociatedSsObjects = ssoAssocResult.unassociatedSsObjects[ssObjInTractPatch].copy()
470 # Update the table of Solar System objects that were not found with the
471 # visit and detector where they were predicted.
472 if len(unassociatedSsObjects) > 0:
473 unassociatedSsObjects['visit'] = visit
474 unassociatedSsObjects['detector'] = detector
476 return pipeBase.Struct(
477 associatedSsSources=associatedSsSources,
478 associatedSsDiaSources=associatedSsDiaSources,
479 unassociatedSsObjects=unassociatedSsObjects,
480 unassociatedDiaSources=ssoAssocResult.unAssocDiaSources
481 )
483 def _addDiaObjectCoords(self, objects, sources):
484 obj = objects[['ra', 'dec']].rename(columns={"ra": "coord_ra", "dec": "coord_dec"})
485 df = pd.merge(sources.reset_index(), obj, left_on='diaObjectId', right_index=True,
486 how='inner').set_index('diaSourceId')
487 return df
489 def _trimToPatch(self, cat, patchBox, wcs, innerTractSkyRegion=None):
490 """Create generator testing if a set of DiaSources are in the
491 patch/tract.
493 Parameters
494 ----------
495 cat : `pandas.DataFrame`
496 Catalog of DiaSources to test within patch/tract.
497 patchBox : `lsst.geom.Box2D`
498 Bounding box of the patch.
499 wcs : `lsst.geom.SkyWcs`
500 Wcs of the tract.
501 innerTractSkyRegion : `lsst.sphgeom.Box`, optional
502 Region defining the inner non-overlapping part of a tract.
504 Returns
505 -------
506 isInPatch : `numpy.ndarray`, (N,)
507 Booleans representing if the DiaSources are contained within the
508 current patch and tract.
509 """
510 isInPatch = np.zeros(len(cat), dtype=bool)
512 for idx, row in cat.reset_index().iterrows():
513 spPoint = geom.SpherePoint(row["ra"], row["dec"], geom.degrees)
514 pxCoord = wcs.skyToPixel(spPoint)
515 ra_rad = np.deg2rad(row["ra"])
516 dec_rad = np.deg2rad(row["dec"])
517 isInPatch[idx] = patchBox.contains(pxCoord)
519 if innerTractSkyRegion is not None:
520 isInPatch[idx] &= innerTractSkyRegion.contains(ra_rad, dec_rad)
522 return isInPatch
524 def dropDiaSourceByDiaObjectId(self, droppedDiaObjectIds, diaSources):
525 """Drop diaSources with diaObject IDs in the supplied list.
527 Parameters
528 ----------
529 droppedDiaObjectIds : `pandas.DataFrame`
530 DiaObjectIds to match and drop from the list of diaSources.
531 diaSources : `pandas.DataFrame`
532 Catalog of diaSources to check and filter.
534 Returns
535 -------
536 filteredDiaSources : `pandas.DataFrame`
537 The input diaSources with any rows matching the listed diaObjectIds
538 removed.
539 """
540 toDrop = diaSources['diaObjectId'].isin(droppedDiaObjectIds)
542 # Keep only rows that do NOT match
543 return diaSources.loc[~toDrop].copy()
546def prepareCatalogDict(dataRefList, useVisitDetector=True):
547 """Prepare lookup tables of the data references.
549 Parameters
550 ----------
551 dataRefList : `list` of `lsst.daf.butler.DeferredDatasetHandle`
552 The data references to make a lookup table for.
553 useVisitDetector : `bool`, optional
554 Use both visit and detector in the dict key? If False, use only visit.
556 Returns
557 -------
558 `dict` of `lsst.daf.butler.DeferredDatasetHandle`
559 Lookup table of the data references by visit (and optionally detector)
560 """
561 dataDict = {}
563 if useVisitDetector:
564 for dataRef in dataRefList:
565 dataDict[(dataRef.dataId["visit"], dataRef.dataId["detector"])] = dataRef
566 else:
567 for dataRef in dataRefList:
568 dataDict[dataRef.dataId["visit"]] = dataRef
569 return dataDict