Coverage for python/lsst/pipe/tasks/drpAssociationPipe.py: 34%

67 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-01-28 02:53 -0800

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22"""Pipeline for running DiaSource association in a DRP context. 

23""" 

24 

25__all__ = ["DrpAssociationPipeTask", 

26 "DrpAssociationPipeConfig", 

27 "DrpAssociationPipeConnections"] 

28 

29import numpy as np 

30import pandas as pd 

31 

32import lsst.geom as geom 

33import lsst.pex.config as pexConfig 

34import lsst.pipe.base as pipeBase 

35from lsst.skymap import BaseSkyMap 

36 

37from .coaddBase import makeSkyInfo 

38from .simpleAssociation import SimpleAssociationTask 

39 

40 

41class DrpAssociationPipeConnections(pipeBase.PipelineTaskConnections, 

42 dimensions=("tract", "patch", "skymap"), 

43 defaultTemplates={"coaddName": "deep", 

44 "warpTypeSuffix": "", 

45 "fakesType": ""}): 

46 diaSourceTables = pipeBase.connectionTypes.Input( 

47 doc="Set of catalogs of calibrated DiaSources.", 

48 name="{fakesType}{coaddName}Diff_diaSrcTable", 

49 storageClass="DataFrame", 

50 dimensions=("instrument", "visit", "detector"), 

51 deferLoad=True, 

52 multiple=True 

53 ) 

54 skyMap = pipeBase.connectionTypes.Input( 

55 doc="Input definition of geometry/bbox and projection/wcs for coadded " 

56 "exposures", 

57 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

58 storageClass="SkyMap", 

59 dimensions=("skymap", ), 

60 ) 

61 assocDiaSourceTable = pipeBase.connectionTypes.Output( 

62 doc="Catalog of DiaSources covering the patch and associated with a " 

63 "DiaObject.", 

64 name="{fakesType}{coaddName}Diff_assocDiaSrcTable", 

65 storageClass="DataFrame", 

66 dimensions=("tract", "patch"), 

67 ) 

68 diaObjectTable = pipeBase.connectionTypes.Output( 

69 doc="Catalog of DiaObjects created from spatially associating " 

70 "DiaSources.", 

71 name="{fakesType}{coaddName}Diff_diaObjTable", 

72 storageClass="DataFrame", 

73 dimensions=("tract", "patch"), 

74 ) 

75 

76 

77class DrpAssociationPipeConfig( 

78 pipeBase.PipelineTaskConfig, 

79 pipelineConnections=DrpAssociationPipeConnections): 

80 associator = pexConfig.ConfigurableField( 

81 target=SimpleAssociationTask, 

82 doc="Task used to associate DiaSources with DiaObjects.", 

83 ) 

84 doAddDiaObjectCoords = pexConfig.Field( 

85 dtype=bool, 

86 default=True, 

87 doc="Do pull diaObject's average coordinate as coord_ra and coord_dec" 

88 "Duplicates information, but needed for bulk ingest into qserv." 

89 ) 

90 doWriteEmptyTables = pexConfig.Field( 

91 dtype=bool, 

92 default=False, 

93 doc="If True, construct and write out empty diaSource and diaObject " 

94 "tables. If False, raise NoWorkFound" 

95 ) 

96 

97 

98class DrpAssociationPipeTask(pipeBase.PipelineTask): 

99 """Driver pipeline for loading DiaSource catalogs in a patch/tract 

100 region and associating them. 

101 """ 

102 ConfigClass = DrpAssociationPipeConfig 

103 _DefaultName = "drpAssociation" 

104 

105 def __init__(self, **kwargs): 

106 super().__init__(**kwargs) 

107 self.makeSubtask('associator') 

108 

109 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

110 inputs = butlerQC.get(inputRefs) 

111 

112 inputs["tractId"] = butlerQC.quantum.dataId["tract"] 

113 inputs["patchId"] = butlerQC.quantum.dataId["patch"] 

114 tractPatchId, skymapBits = butlerQC.quantum.dataId.pack( 

115 "tract_patch", 

116 returnMaxBits=True) 

117 inputs["tractPatchId"] = tractPatchId 

118 inputs["skymapBits"] = skymapBits 

119 

120 outputs = self.run(**inputs) 

121 butlerQC.put(outputs, outputRefs) 

122 

123 def run(self, 

124 diaSourceTables, 

125 skyMap, 

126 tractId, 

127 patchId, 

128 tractPatchId, 

129 skymapBits): 

130 """Trim DiaSources to the current Patch and run association. 

131 

132 Takes in the set of DiaSource catalogs that covers the current patch, 

133 trims them to the dimensions of the patch, and [TODO: eventually] 

134 runs association on the concatenated DiaSource Catalog. 

135 

136 Parameters 

137 ---------- 

138 diaSourceTables : `list` of `lst.daf.butler.DeferredDatasetHandle` 

139 Set of DiaSource catalogs potentially covering this patch/tract. 

140 skyMap : `lsst.skymap.BaseSkyMap` 

141 SkyMap defining the patch/tract 

142 tractId : `int` 

143 Id of current tract being processed. 

144 patchId : `int` 

145 Id of current patch being processed 

146 

147 Returns 

148 ------- 

149 output : `lsst.pipe.base.Struct` 

150 Results struct with attributes: 

151 

152 ``assocDiaSourceTable`` 

153 Table of DiaSources with updated value for diaObjectId. 

154 (`pandas.DataFrame`) 

155 ``diaObjectTable`` 

156 Table of DiaObjects from matching DiaSources 

157 (`pandas.DataFrame`). 

158 """ 

159 self.log.info("Running DPR Association on patch %i, tract %i...", 

160 patchId, tractId) 

161 

162 skyInfo = makeSkyInfo(skyMap, tractId, patchId) 

163 

164 # Get the patch bounding box. 

165 innerPatchBox = geom.Box2D(skyInfo.patchInfo.getInnerBBox()) 

166 

167 diaSourceHistory = [] 

168 for catRef in diaSourceTables: 

169 cat = catRef.get( 

170 datasetType=self.config.connections.diaSourceTables, 

171 immediate=True) 

172 

173 isInTractPatch = self._trimToPatch(cat, 

174 innerPatchBox, 

175 skyInfo.wcs) 

176 

177 nDiaSrc = isInTractPatch.sum() 

178 self.log.info( 

179 "Read DiaSource catalog of length %i from visit %i, " 

180 "detector %i. Found %i sources within the patch/tract " 

181 "footprint.", 

182 len(cat), catRef.dataId["visit"], 

183 catRef.dataId["detector"], nDiaSrc) 

184 

185 if nDiaSrc <= 0: 

186 continue 

187 

188 cutCat = cat[isInTractPatch] 

189 diaSourceHistory.append(cutCat) 

190 

191 if diaSourceHistory: 

192 diaSourceHistoryCat = pd.concat(diaSourceHistory) 

193 else: 

194 # No rows to associate 

195 if self.config.doWriteEmptyTables: 

196 self.log.info("Constructing empty table") 

197 # Construct empty table using last table and dropping all the rows 

198 diaSourceHistoryCat = cat.drop(cat.index) 

199 else: 

200 raise pipeBase.NoWorkFound("Found no overlapping DIASources to associate.") 

201 

202 self.log.info("Found %i DiaSources overlapping patch %i, tract %i", 

203 len(diaSourceHistoryCat), patchId, tractId) 

204 

205 assocResult = self.associator.run(diaSourceHistoryCat, 

206 tractPatchId, 

207 skymapBits) 

208 

209 self.log.info("Associated DiaSources into %i DiaObjects", 

210 len(assocResult.diaObjects)) 

211 

212 if self.config.doAddDiaObjectCoords: 

213 assocResult.assocDiaSources = self._addDiaObjectCoords(assocResult.diaObjects, 

214 assocResult.assocDiaSources) 

215 

216 return pipeBase.Struct( 

217 diaObjectTable=assocResult.diaObjects, 

218 assocDiaSourceTable=assocResult.assocDiaSources) 

219 

220 def _addDiaObjectCoords(self, objects, sources): 

221 obj = objects[['ra', 'decl']].rename(columns={"ra": "coord_ra", "decl": "coord_dec"}) 

222 df = pd.merge(sources.reset_index(), obj, left_on='diaObjectId', right_index=True, 

223 how='inner').set_index('diaSourceId') 

224 return df 

225 

226 def _trimToPatch(self, cat, innerPatchBox, wcs): 

227 """Create generator testing if a set of DiaSources are in the 

228 patch/tract. 

229 

230 Parameters 

231 ---------- 

232 cat : `pandas.DataFrame` 

233 Catalog of DiaSources to test within patch/tract. 

234 innerPatchBox : `lsst.geom.Box2D` 

235 Bounding box of the patch. 

236 wcs : `lsst.geom.SkyWcs` 

237 Wcs of the tract. 

238 

239 Returns 

240 ------ 

241 isInPatch : `numpy.ndarray`, (N,) 

242 Booleans representing if the DiaSources are contained within the 

243 current patch and tract. 

244 """ 

245 isInPatch = np.array([ 

246 innerPatchBox.contains( 

247 wcs.skyToPixel( 

248 geom.SpherePoint(row["ra"], row["decl"], geom.degrees))) 

249 for idx, row in cat.iterrows()]) 

250 return isInPatch