Coverage for python/lsst/ap/association/transformDiaSourceCatalog.py: 24%

143 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2022-07-27 02:40 -0700

1# This file is part of ap_association 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ("TransformDiaSourceCatalogConnections", 

23 "TransformDiaSourceCatalogConfig", 

24 "TransformDiaSourceCatalogTask", 

25 "UnpackApdbFlags") 

26 

27import numpy as np 

28import os 

29import yaml 

30 

31from lsst.daf.base import DateTime 

32import lsst.pex.config as pexConfig 

33import lsst.pipe.base as pipeBase 

34import lsst.pipe.base.connectionTypes as connTypes 

35from lsst.pipe.tasks.postprocess import TransformCatalogBaseTask, TransformCatalogBaseConfig 

36from lsst.pipe.tasks.parquetTable import ParquetTable 

37from lsst.pipe.tasks.functors import Column 

38from lsst.utils.timer import timeMethod 

39 

40 

41class TransformDiaSourceCatalogConnections(pipeBase.PipelineTaskConnections, 

42 dimensions=("instrument", "visit", "detector"), 

43 defaultTemplates={"coaddName": "deep", "fakesType": ""}): 

44 diaSourceSchema = connTypes.InitInput( 

45 doc="Schema for DIASource catalog output by ImageDifference.", 

46 storageClass="SourceCatalog", 

47 name="{fakesType}{coaddName}Diff_diaSrc_schema", 

48 ) 

49 diaSourceCat = connTypes.Input( 

50 doc="Catalog of DiaSources produced during image differencing.", 

51 name="{fakesType}{coaddName}Diff_diaSrc", 

52 storageClass="SourceCatalog", 

53 dimensions=("instrument", "visit", "detector"), 

54 ) 

55 diffIm = connTypes.Input( 

56 doc="Difference image on which the DiaSources were detected.", 

57 name="{fakesType}{coaddName}Diff_differenceExp", 

58 storageClass="ExposureF", 

59 dimensions=("instrument", "visit", "detector"), 

60 ) 

61 diaSourceTable = connTypes.Output( 

62 doc=".", 

63 name="{fakesType}{coaddName}Diff_diaSrcTable", 

64 storageClass="DataFrame", 

65 dimensions=("instrument", "visit", "detector"), 

66 ) 

67 

68 

69class TransformDiaSourceCatalogConfig(TransformCatalogBaseConfig, 

70 pipelineConnections=TransformDiaSourceCatalogConnections): 

71 flagMap = pexConfig.Field( 

72 dtype=str, 

73 doc="Yaml file specifying SciencePipelines flag fields to bit packs.", 

74 default=os.path.join("${AP_ASSOCIATION_DIR}", 

75 "data", 

76 "association-flag-map.yaml"), 

77 ) 

78 flagRenameMap = pexConfig.Field( 

79 dtype=str, 

80 doc="Yaml file specifying specifying rules to rename flag names", 

81 default=os.path.join("${AP_ASSOCIATION_DIR}", 

82 "data", 

83 "flag-rename-rules.yaml"), 

84 ) 

85 doRemoveSkySources = pexConfig.Field( 

86 dtype=bool, 

87 default=False, 

88 doc="Input DiaSource catalog contains SkySources that should be " 

89 "removed before storing the output DiaSource catalog." 

90 ) 

91 doPackFlags = pexConfig.Field( 

92 dtype=bool, 

93 default=True, 

94 doc="Do pack the flags into one integer column named 'flags'." 

95 "If False, instead produce one boolean column per flag." 

96 ) 

97 

98 def setDefaults(self): 

99 super().setDefaults() 

100 self.functorFile = os.path.join("${AP_ASSOCIATION_DIR}", 

101 "data", 

102 "DiaSource.yaml") 

103 

104 

105class TransformDiaSourceCatalogTask(TransformCatalogBaseTask): 

106 """Transform a DiaSource catalog by calibrating and renaming columns to 

107 produce a table ready to insert into the Apdb. 

108 

109 Parameters 

110 ---------- 

111 initInputs : `dict` 

112 Must contain ``diaSourceSchema`` as the schema for the input catalog. 

113 """ 

114 ConfigClass = TransformDiaSourceCatalogConfig 

115 _DefaultName = "transformDiaSourceCatalog" 

116 RunnerClass = pipeBase.ButlerInitializedTaskRunner 

117 # Needed to create a valid TransformCatalogBaseTask, but unused 

118 inputDataset = "deepDiff_diaSrc" 

119 outputDataset = "deepDiff_diaSrcTable" 

120 

121 def __init__(self, initInputs, **kwargs): 

122 super().__init__(**kwargs) 

123 self.funcs = self.getFunctors() 

124 self.inputSchema = initInputs['diaSourceSchema'].schema 

125 self._create_bit_pack_mappings() 

126 

127 if not self.config.doPackFlags: 

128 # get the flag rename rules 

129 with open(os.path.expandvars(self.config.flagRenameMap)) as yaml_stream: 

130 self.rename_rules = list(yaml.safe_load_all(yaml_stream)) 

131 

132 def _create_bit_pack_mappings(self): 

133 """Setup all flag bit packings. 

134 """ 

135 self.bit_pack_columns = [] 

136 flag_map_file = os.path.expandvars(self.config.flagMap) 

137 with open(flag_map_file) as yaml_stream: 

138 table_list = list(yaml.safe_load_all(yaml_stream)) 

139 for table in table_list: 

140 if table['tableName'] == 'DiaSource': 

141 self.bit_pack_columns = table['columns'] 

142 break 

143 

144 # Test that all flags requested are present in the input schemas. 

145 # Output schemas are flexible, however if names are not specified in 

146 # the Apdb schema, flag columns will not be persisted. 

147 for outputFlag in self.bit_pack_columns: 

148 bitList = outputFlag['bitList'] 

149 for bit in bitList: 

150 try: 

151 self.inputSchema.find(bit['name']) 

152 except KeyError: 

153 raise KeyError( 

154 "Requested column %s not found in input DiaSource " 

155 "schema. Please check that the requested input " 

156 "column exists." % bit['name']) 

157 

158 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

159 inputs = butlerQC.get(inputRefs) 

160 expId, expBits = butlerQC.quantum.dataId.pack("visit_detector", 

161 returnMaxBits=True) 

162 inputs["ccdVisitId"] = expId 

163 inputs["band"] = butlerQC.quantum.dataId["band"] 

164 

165 outputs = self.run(**inputs) 

166 

167 butlerQC.put(outputs, outputRefs) 

168 

169 @timeMethod 

170 def run(self, 

171 diaSourceCat, 

172 diffIm, 

173 band, 

174 ccdVisitId, 

175 funcs=None): 

176 """Convert input catalog to ParquetTable/Pandas and run functors. 

177 

178 Additionally, add new columns for stripping information from the 

179 exposure and into the DiaSource catalog. 

180 

181 Parameters 

182 ---------- 

183 diaSourceCat : `lsst.afw.table.SourceCatalog` 

184 Catalog of sources measured on the difference image. 

185 diffIm : `lsst.afw.image.Exposure` 

186 Result of subtracting template and science images. 

187 band : `str` 

188 Filter band of the science image. 

189 ccdVisitId : `int` 

190 Identifier for this detector+visit. 

191 funcs : `lsst.pipe.tasks.functors.Functors` 

192 Functors to apply to the catalog's columns. 

193 

194 Returns 

195 ------- 

196 results : `lsst.pipe.base.Struct` 

197 Results struct with components. 

198 

199 - ``diaSourceTable`` : Catalog of DiaSources with calibrated values 

200 and renamed columns. 

201 (`lsst.pipe.tasks.ParquetTable` or `pandas.DataFrame`) 

202 """ 

203 self.log.info( 

204 "Transforming/standardizing the DiaSource table ccdVisitId: %i", 

205 ccdVisitId) 

206 

207 diaSourceDf = diaSourceCat.asAstropy().to_pandas() 

208 

209 def getSignificance(): 

210 """Return the significance value of the first peak in each source 

211 footprint.""" 

212 size = len(diaSourceDf) 

213 result = np.full(size, np.nan) 

214 for i in range(size): 

215 record = diaSourceCat[i] 

216 if self.config.doRemoveSkySources and record["sky_source"]: 

217 continue 

218 peaks = record.getFootprint().peaks 

219 if "significance" in peaks.schema: 

220 result[i] = peaks[0]["significance"] 

221 return result 

222 

223 diaSourceDf["snr"] = getSignificance() 

224 

225 if self.config.doRemoveSkySources: 

226 diaSourceDf = diaSourceDf[~diaSourceDf["sky_source"]] 

227 

228 diaSourceDf["bboxSize"] = self.computeBBoxSizes(diaSourceCat) 

229 diaSourceDf["ccdVisitId"] = ccdVisitId 

230 diaSourceDf["filterName"] = band 

231 diaSourceDf["midPointTai"] = diffIm.getInfo().getVisitInfo().getDate().get(system=DateTime.MJD) 

232 diaSourceDf["diaObjectId"] = 0 

233 diaSourceDf["ssObjectId"] = 0 

234 

235 if self.config.doPackFlags: 

236 # either bitpack the flags 

237 self.bitPackFlags(diaSourceDf) 

238 else: 

239 # or add the individual flag functors 

240 self.addUnpackedFlagFunctors() 

241 # and remove the packed flag functor 

242 if 'flags' in self.funcs.funcDict: 

243 del self.funcs.funcDict['flags'] 

244 

245 df = self.transform(band, 

246 ParquetTable(dataFrame=diaSourceDf), 

247 self.funcs, 

248 dataId=None).df 

249 

250 return pipeBase.Struct( 

251 diaSourceTable=df, 

252 ) 

253 

254 def addUnpackedFlagFunctors(self): 

255 """Add Column functor for each of the flags to the internal functor 

256 dictionary. 

257 """ 

258 for flag in self.bit_pack_columns[0]['bitList']: 

259 flagName = flag['name'] 

260 targetName = self.funcs.renameCol(flagName, self.rename_rules[0]['flag_rename_rules']) 

261 self.funcs.update({targetName: Column(flagName)}) 

262 

263 def computeBBoxSizes(self, inputCatalog): 

264 """Compute the size of a square bbox that fully contains the detection 

265 footprint. 

266 

267 Parameters 

268 ---------- 

269 inputCatalog : `lsst.afw.table.SourceCatalog` 

270 Catalog containing detected footprints. 

271 

272 Returns 

273 ------- 

274 outputBBoxSizes : `list` of `float` 

275 Array of bbox sizes. 

276 """ 

277 outputBBoxSizes = [] 

278 for record in inputCatalog: 

279 if self.config.doRemoveSkySources: 

280 if record["sky_source"]: 

281 continue 

282 footprintBBox = record.getFootprint().getBBox() 

283 # Compute twice the size of the largest dimension of the footprint 

284 # bounding box. This is the largest footprint we should need to cover 

285 # the complete DiaSource assuming the centroid is withing the bounding 

286 # box. 

287 maxSize = 2 * np.max([footprintBBox.getWidth(), 

288 footprintBBox.getHeight()]) 

289 recX = record.getCentroid().x 

290 recY = record.getCentroid().y 

291 bboxSize = int( 

292 np.ceil(2 * np.max(np.fabs([footprintBBox.maxX - recX, 

293 footprintBBox.minX - recX, 

294 footprintBBox.maxY - recY, 

295 footprintBBox.minY - recY])))) 

296 if bboxSize > maxSize: 

297 bboxSize = maxSize 

298 outputBBoxSizes.append(bboxSize) 

299 

300 return outputBBoxSizes 

301 

302 def bitPackFlags(self, df): 

303 """Pack requested flag columns in inputRecord into single columns in 

304 outputRecord. 

305 

306 Parameters 

307 ---------- 

308 df : `pandas.DataFrame` 

309 DataFrame to read bits from and pack them into. 

310 """ 

311 for outputFlag in self.bit_pack_columns: 

312 bitList = outputFlag['bitList'] 

313 value = np.zeros(len(df), dtype=np.uint64) 

314 for bit in bitList: 

315 # Hard type the bit arrays. 

316 value += (df[bit['name']]*2**bit['bit']).to_numpy().astype(np.uint64) 

317 df[outputFlag['columnName']] = value 

318 

319 

320class UnpackApdbFlags: 

321 """Class for unpacking bits from integer flag fields stored in the Apdb. 

322 

323 Attributes 

324 ---------- 

325 flag_map_file : `str` 

326 Absolute or relative path to a yaml file specifiying mappings of flags 

327 to integer bits. 

328 table_name : `str` 

329 Name of the Apdb table the integer bit data are coming from. 

330 """ 

331 

332 def __init__(self, flag_map_file, table_name): 

333 self.bit_pack_columns = [] 

334 flag_map_file = os.path.expandvars(flag_map_file) 

335 with open(flag_map_file) as yaml_stream: 

336 table_list = list(yaml.safe_load_all(yaml_stream)) 

337 for table in table_list: 

338 if table['tableName'] == table_name: 

339 self.bit_pack_columns = table['columns'] 

340 break 

341 

342 self.output_flag_columns = {} 

343 

344 for column in self.bit_pack_columns: 

345 names = [] 

346 for bit in column["bitList"]: 

347 names.append((bit["name"], bool)) 

348 self.output_flag_columns[column["columnName"]] = names 

349 

350 def unpack(self, input_flag_values, flag_name): 

351 """Determine individual boolean flags from an input array of unsigned 

352 ints. 

353 

354 Parameters 

355 ---------- 

356 input_flag_values : array-like of type uint 

357 Array of integer flags to unpack. 

358 flag_name : `str` 

359 Apdb column name of integer flags to unpack. Names of packed int 

360 flags are given by the flag_map_file. 

361 

362 Returns 

363 ------- 

364 output_flags : `numpy.ndarray` 

365 Numpy named tuple of booleans. 

366 """ 

367 bit_names_types = self.output_flag_columns[flag_name] 

368 output_flags = np.zeros(len(input_flag_values), dtype=bit_names_types) 

369 

370 for bit_idx, (bit_name, dtypes) in enumerate(bit_names_types): 

371 masked_bits = np.bitwise_and(input_flag_values, 2**bit_idx) 

372 output_flags[bit_name] = masked_bits 

373 

374 return output_flags