Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_tasks. 

2 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21# 

22 

23"""Pipeline for running DiaSource association in a DRP context. 

24""" 

25 

26import numpy as np 

27import pandas as pd 

28 

29import lsst.geom as geom 

30import lsst.pex.config as pexConfig 

31import lsst.pipe.base as pipeBase 

32from lsst.skymap import BaseSkyMap 

33 

34from .coaddBase import makeSkyInfo 

35from .simpleAssociation import SimpleAssociationTask 

36 

37__all__ = ["DrpAssociationPipeTask", 

38 "DrpAssociationPipeConfig", 

39 "DrpAssociationPipeConnections"] 

40 

41 

42class DrpAssociationPipeConnections(pipeBase.PipelineTaskConnections, 

43 dimensions=("tract", "patch", "skymap"), 

44 defaultTemplates={"coaddName": "deep", 

45 "warpTypeSuffix": "", 

46 "fakesType": ""}): 

47 diaSourceTables = pipeBase.connectionTypes.Input( 

48 doc="Set of catalogs of calibrated DiaSources.", 

49 name="{fakesType}{coaddName}Diff_diaSrcTable", 

50 storageClass="DataFrame", 

51 dimensions=("instrument", "visit", "detector"), 

52 deferLoad=True, 

53 multiple=True 

54 ) 

55 skyMap = pipeBase.connectionTypes.Input( 

56 doc="Input definition of geometry/bbox and projection/wcs for coadded " 

57 "exposures", 

58 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

59 storageClass="SkyMap", 

60 dimensions=("skymap", ), 

61 ) 

62 assocDiaSourceTable = pipeBase.connectionTypes.Output( 

63 doc="Catalog of DiaSources covering the patch and associated with a " 

64 "DiaObject.", 

65 name="{fakesType}{coaddName}Diff_assocDiaSrcTable", 

66 storageClass="DataFrame", 

67 dimensions=("tract", "patch"), 

68 ) 

69 diaObjectTable = pipeBase.connectionTypes.Output( 

70 doc="Catalog of DiaObjects created from spatially associating " 

71 "DiaSources.", 

72 name="{fakesType}{coaddName}Diff_diaObjTable", 

73 storageClass="DataFrame", 

74 dimensions=("tract", "patch"), 

75 ) 

76 

77 

78class DrpAssociationPipeConfig( 

79 pipeBase.PipelineTaskConfig, 

80 pipelineConnections=DrpAssociationPipeConnections): 

81 associator = pexConfig.ConfigurableField( 

82 target=SimpleAssociationTask, 

83 doc="Task used to associate DiaSources with DiaObjects.", 

84 ) 

85 doAddDiaObjectCoords = pexConfig.Field( 

86 dtype=bool, 

87 default=True, 

88 doc="Do pull diaObject's average coordinate as coord_ra and coord_dec" 

89 "Duplicates information, but needed for bulk ingest into qserv." 

90 ) 

91 

92 

93class DrpAssociationPipeTask(pipeBase.PipelineTask): 

94 """Driver pipeline for loading DiaSource catalogs in a patch/tract 

95 region and associating them. 

96 """ 

97 ConfigClass = DrpAssociationPipeConfig 

98 _DefaultName = "drpAssociation" 

99 

100 def __init__(self, **kwargs): 

101 super().__init__(**kwargs) 

102 self.makeSubtask('associator') 

103 

104 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

105 inputs = butlerQC.get(inputRefs) 

106 

107 inputs["tractId"] = butlerQC.quantum.dataId["tract"] 

108 inputs["patchId"] = butlerQC.quantum.dataId["patch"] 

109 tractPatchId, skymapBits = butlerQC.quantum.dataId.pack( 

110 "tract_patch", 

111 returnMaxBits=True) 

112 inputs["tractPatchId"] = tractPatchId 

113 inputs["skymapBits"] = skymapBits 

114 

115 outputs = self.run(**inputs) 

116 butlerQC.put(outputs, outputRefs) 

117 

118 def run(self, 

119 diaSourceTables, 

120 skyMap, 

121 tractId, 

122 patchId, 

123 tractPatchId, 

124 skymapBits): 

125 """Trim DiaSources to the current Patch and run association. 

126 

127 Takes in the set of DiaSource catalogs that covers the current patch, 

128 trims them to the dimensions of the patch, and [TODO: eventually] 

129 runs association on the concatenated DiaSource Catalog. 

130 

131 Parameters 

132 ---------- 

133 diaSourceTables : `list` of `lst.daf.butler.DeferredDatasetHandle` 

134 Set of DiaSource catalogs potentially covering this patch/tract. 

135 skyMap : `lsst.skymap.BaseSkyMap` 

136 SkyMap defining the patch/tract 

137 tractId : `int` 

138 Id of current tract being processed. 

139 patchId : `int` 

140 Id of current patch being processed 

141 

142 Returns 

143 ------- 

144 output : `lsst.pipe.base.Struct` 

145 Results struct with attributes: 

146 

147 ``assocDiaSourceTable`` 

148 Table of DiaSources with updated value for diaObjectId. 

149 (`pandas.DataFrame`) 

150 ``diaObjectTable`` 

151 Table of DiaObjects from matching DiaSources 

152 (`pandas.DataFrame`). 

153 """ 

154 self.log.info("Running DPR Association on patch %i, tract %i...", 

155 patchId, tractId) 

156 

157 skyInfo = makeSkyInfo(skyMap, tractId, patchId) 

158 

159 # Get the patch bounding box. 

160 innerPatchBox = geom.Box2D(skyInfo.patchInfo.getInnerBBox()) 

161 

162 diaSourceHistory = [] 

163 for catRef in diaSourceTables: 

164 cat = catRef.get( 

165 datasetType=self.config.connections.diaSourceTables, 

166 immediate=True) 

167 

168 isInTractPatch = self._trimToPatch(cat, 

169 innerPatchBox, 

170 skyInfo.wcs) 

171 

172 nDiaSrc = isInTractPatch.sum() 

173 self.log.info( 

174 "Read DiaSource catalog of length %i from visit %i, " 

175 "detector %i. Found %i sources within the patch/tract " 

176 "footprint.", 

177 len(cat), catRef.dataId["visit"], 

178 catRef.dataId["detector"], nDiaSrc) 

179 

180 if nDiaSrc <= 0: 

181 diaSourceHistory.append(pd.DataFrame(columns=cat.columns)) 

182 continue 

183 

184 cutCat = cat[isInTractPatch] 

185 diaSourceHistory.append(cutCat) 

186 

187 diaSourceHistoryCat = pd.concat(diaSourceHistory) 

188 self.log.info("Found %i DiaSources overlapping patch %i, tract %i", 

189 len(diaSourceHistoryCat), patchId, tractId) 

190 

191 assocResult = self.associator.run(diaSourceHistoryCat, 

192 tractPatchId, 

193 skymapBits) 

194 

195 self.log.info("Associated DiaSources into %i DiaObjects", 

196 len(assocResult.diaObjects)) 

197 

198 if self.config.doAddDiaObjectCoords and not assocResult.diaObjects.empty: 

199 assocResult.assocDiaSources = self._addDiaObjectCoords(assocResult.diaObjects, 

200 assocResult.assocDiaSources) 

201 

202 return pipeBase.Struct( 

203 diaObjectTable=assocResult.diaObjects, 

204 assocDiaSourceTable=assocResult.assocDiaSources) 

205 

206 def _addDiaObjectCoords(self, objects, sources): 

207 obj = objects[['ra', 'decl']].rename(columns={"ra": "coord_ra", "decl": "coord_dec"}) 

208 df = pd.merge(sources, obj, left_on='diaObjectId', right_index=True, how='inner') 

209 return df 

210 

211 def _trimToPatch(self, cat, innerPatchBox, wcs): 

212 """Create generator testing if a set of DiaSources are in the 

213 patch/tract. 

214 

215 Parameters 

216 ---------- 

217 cat : `pandas.DataFrame` 

218 Catalog of DiaSources to test within patch/tract. 

219 innerPatchBox : `lsst.geom.Box2D` 

220 Bounding box of the patch. 

221 wcs : `lsst.geom.SkyWcs` 

222 Wcs of the tract. 

223 

224 Returns 

225 ------ 

226 isInPatch : `numpy.ndarray`, (N,) 

227 Booleans representing if the DiaSources are contained within the 

228 current patch and tract. 

229 """ 

230 isInPatch = np.array([ 

231 innerPatchBox.contains( 

232 wcs.skyToPixel( 

233 geom.SpherePoint(row["ra"], row["decl"], geom.degrees))) 

234 for idx, row in cat.iterrows()]) 

235 return isInPatch