Coverage for python/lsst/ap/association/transformDiaSourceCatalog.py: 22%

136 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-10-07 03:46 -0700

1# This file is part of ap_association 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ("TransformDiaSourceCatalogConnections", 

23 "TransformDiaSourceCatalogConfig", 

24 "TransformDiaSourceCatalogTask", 

25 "UnpackApdbFlags") 

26 

27import numpy as np 

28import os 

29import yaml 

30 

31from lsst.daf.base import DateTime 

32import lsst.pex.config as pexConfig 

33import lsst.pipe.base as pipeBase 

34import lsst.pipe.base.connectionTypes as connTypes 

35from lsst.pipe.tasks.postprocess import TransformCatalogBaseTask, TransformCatalogBaseConfig 

36from lsst.pipe.tasks.parquetTable import ParquetTable 

37from lsst.pipe.tasks.functors import Column 

38from lsst.utils.timer import timeMethod 

39 

40 

41class TransformDiaSourceCatalogConnections(pipeBase.PipelineTaskConnections, 

42 dimensions=("instrument", "visit", "detector"), 

43 defaultTemplates={"coaddName": "deep", "fakesType": ""}): 

44 diaSourceSchema = connTypes.InitInput( 

45 doc="Schema for DIASource catalog output by ImageDifference.", 

46 storageClass="SourceCatalog", 

47 name="{fakesType}{coaddName}Diff_diaSrc_schema", 

48 ) 

49 diaSourceCat = connTypes.Input( 

50 doc="Catalog of DiaSources produced during image differencing.", 

51 name="{fakesType}{coaddName}Diff_diaSrc", 

52 storageClass="SourceCatalog", 

53 dimensions=("instrument", "visit", "detector"), 

54 ) 

55 diffIm = connTypes.Input( 

56 doc="Difference image on which the DiaSources were detected.", 

57 name="{fakesType}{coaddName}Diff_differenceExp", 

58 storageClass="ExposureF", 

59 dimensions=("instrument", "visit", "detector"), 

60 ) 

61 diaSourceTable = connTypes.Output( 

62 doc=".", 

63 name="{fakesType}{coaddName}Diff_diaSrcTable", 

64 storageClass="DataFrame", 

65 dimensions=("instrument", "visit", "detector"), 

66 ) 

67 

68 

69class TransformDiaSourceCatalogConfig(TransformCatalogBaseConfig, 

70 pipelineConnections=TransformDiaSourceCatalogConnections): 

71 flagMap = pexConfig.Field( 

72 dtype=str, 

73 doc="Yaml file specifying SciencePipelines flag fields to bit packs.", 

74 default=os.path.join("${AP_ASSOCIATION_DIR}", 

75 "data", 

76 "association-flag-map.yaml"), 

77 ) 

78 flagRenameMap = pexConfig.Field( 

79 dtype=str, 

80 doc="Yaml file specifying specifying rules to rename flag names", 

81 default=os.path.join("${AP_ASSOCIATION_DIR}", 

82 "data", 

83 "flag-rename-rules.yaml"), 

84 ) 

85 doRemoveSkySources = pexConfig.Field( 

86 dtype=bool, 

87 default=False, 

88 doc="Input DiaSource catalog contains SkySources that should be " 

89 "removed before storing the output DiaSource catalog." 

90 ) 

91 doPackFlags = pexConfig.Field( 

92 dtype=bool, 

93 default=True, 

94 doc="Do pack the flags into one integer column named 'flags'." 

95 "If False, instead produce one boolean column per flag." 

96 ) 

97 

98 def setDefaults(self): 

99 super().setDefaults() 

100 self.functorFile = os.path.join("${AP_ASSOCIATION_DIR}", 

101 "data", 

102 "DiaSource.yaml") 

103 

104 

105class TransformDiaSourceCatalogTask(TransformCatalogBaseTask): 

106 """Transform a DiaSource catalog by calibrating and renaming columns to 

107 produce a table ready to insert into the Apdb. 

108 

109 Parameters 

110 ---------- 

111 initInputs : `dict` 

112 Must contain ``diaSourceSchema`` as the schema for the input catalog. 

113 """ 

114 ConfigClass = TransformDiaSourceCatalogConfig 

115 _DefaultName = "transformDiaSourceCatalog" 

116 # Needed to create a valid TransformCatalogBaseTask, but unused 

117 inputDataset = "deepDiff_diaSrc" 

118 outputDataset = "deepDiff_diaSrcTable" 

119 

120 def __init__(self, initInputs, **kwargs): 

121 super().__init__(**kwargs) 

122 self.funcs = self.getFunctors() 

123 self.inputSchema = initInputs['diaSourceSchema'].schema 

124 self._create_bit_pack_mappings() 

125 

126 if not self.config.doPackFlags: 

127 # get the flag rename rules 

128 with open(os.path.expandvars(self.config.flagRenameMap)) as yaml_stream: 

129 self.rename_rules = list(yaml.safe_load_all(yaml_stream)) 

130 

131 def _create_bit_pack_mappings(self): 

132 """Setup all flag bit packings. 

133 """ 

134 self.bit_pack_columns = [] 

135 flag_map_file = os.path.expandvars(self.config.flagMap) 

136 with open(flag_map_file) as yaml_stream: 

137 table_list = list(yaml.safe_load_all(yaml_stream)) 

138 for table in table_list: 

139 if table['tableName'] == 'DiaSource': 

140 self.bit_pack_columns = table['columns'] 

141 break 

142 

143 # Test that all flags requested are present in the input schemas. 

144 # Output schemas are flexible, however if names are not specified in 

145 # the Apdb schema, flag columns will not be persisted. 

146 for outputFlag in self.bit_pack_columns: 

147 bitList = outputFlag['bitList'] 

148 for bit in bitList: 

149 try: 

150 self.inputSchema.find(bit['name']) 

151 except KeyError: 

152 raise KeyError( 

153 "Requested column %s not found in input DiaSource " 

154 "schema. Please check that the requested input " 

155 "column exists." % bit['name']) 

156 

157 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

158 inputs = butlerQC.get(inputRefs) 

159 expId, expBits = butlerQC.quantum.dataId.pack("visit_detector", 

160 returnMaxBits=True) 

161 inputs["ccdVisitId"] = expId 

162 inputs["band"] = butlerQC.quantum.dataId["band"] 

163 

164 outputs = self.run(**inputs) 

165 

166 butlerQC.put(outputs, outputRefs) 

167 

168 @timeMethod 

169 def run(self, 

170 diaSourceCat, 

171 diffIm, 

172 band, 

173 ccdVisitId, 

174 funcs=None): 

175 """Convert input catalog to ParquetTable/Pandas and run functors. 

176 

177 Additionally, add new columns for stripping information from the 

178 exposure and into the DiaSource catalog. 

179 

180 Parameters 

181 ---------- 

182 diaSourceCat : `lsst.afw.table.SourceCatalog` 

183 Catalog of sources measured on the difference image. 

184 diffIm : `lsst.afw.image.Exposure` 

185 Result of subtracting template and science images. 

186 band : `str` 

187 Filter band of the science image. 

188 ccdVisitId : `int` 

189 Identifier for this detector+visit. 

190 funcs : `lsst.pipe.tasks.functors.Functors` 

191 Functors to apply to the catalog's columns. 

192 

193 Returns 

194 ------- 

195 results : `lsst.pipe.base.Struct` 

196 Results struct with components. 

197 

198 - ``diaSourceTable`` : Catalog of DiaSources with calibrated values 

199 and renamed columns. 

200 (`lsst.pipe.tasks.ParquetTable` or `pandas.DataFrame`) 

201 """ 

202 self.log.info( 

203 "Transforming/standardizing the DiaSource table ccdVisitId: %i", 

204 ccdVisitId) 

205 

206 diaSourceDf = diaSourceCat.asAstropy().to_pandas() 

207 if self.config.doRemoveSkySources: 

208 diaSourceDf = diaSourceDf[~diaSourceDf["sky_source"]] 

209 diaSourceCat = diaSourceCat[~diaSourceCat["sky_source"]] 

210 

211 diaSourceDf["snr"] = getSignificance(diaSourceCat) 

212 diaSourceDf["bboxSize"] = self.computeBBoxSizes(diaSourceCat) 

213 diaSourceDf["ccdVisitId"] = ccdVisitId 

214 diaSourceDf["filterName"] = band 

215 diaSourceDf["midPointTai"] = diffIm.getInfo().getVisitInfo().getDate().get(system=DateTime.MJD) 

216 diaSourceDf["diaObjectId"] = 0 

217 diaSourceDf["ssObjectId"] = 0 

218 

219 if self.config.doPackFlags: 

220 # either bitpack the flags 

221 self.bitPackFlags(diaSourceDf) 

222 else: 

223 # or add the individual flag functors 

224 self.addUnpackedFlagFunctors() 

225 # and remove the packed flag functor 

226 if 'flags' in self.funcs.funcDict: 

227 del self.funcs.funcDict['flags'] 

228 

229 df = self.transform(band, 

230 ParquetTable(dataFrame=diaSourceDf), 

231 self.funcs, 

232 dataId=None).df 

233 

234 return pipeBase.Struct( 

235 diaSourceTable=df, 

236 ) 

237 

238 def addUnpackedFlagFunctors(self): 

239 """Add Column functor for each of the flags to the internal functor 

240 dictionary. 

241 """ 

242 for flag in self.bit_pack_columns[0]['bitList']: 

243 flagName = flag['name'] 

244 targetName = self.funcs.renameCol(flagName, self.rename_rules[0]['flag_rename_rules']) 

245 self.funcs.update({targetName: Column(flagName)}) 

246 

247 def computeBBoxSizes(self, inputCatalog): 

248 """Compute the size of a square bbox that fully contains the detection 

249 footprint. 

250 

251 Parameters 

252 ---------- 

253 inputCatalog : `lsst.afw.table.SourceCatalog` 

254 Catalog containing detected footprints. 

255 

256 Returns 

257 ------- 

258 outputBBoxSizes : `np.ndarray`, (N,) 

259 Array of bbox sizes. 

260 """ 

261 # Schema validation requires that this field is int. 

262 outputBBoxSizes = np.empty(len(inputCatalog), dtype=int) 

263 for i, record in enumerate(inputCatalog): 

264 footprintBBox = record.getFootprint().getBBox() 

265 # Compute twice the size of the largest dimension of the footprint 

266 # bounding box. This is the largest footprint we should need to cover 

267 # the complete DiaSource assuming the centroid is within the bounding 

268 # box. 

269 maxSize = 2 * np.max([footprintBBox.getWidth(), 

270 footprintBBox.getHeight()]) 

271 recX = record.getCentroid().x 

272 recY = record.getCentroid().y 

273 bboxSize = int( 

274 np.ceil(2 * np.max(np.fabs([footprintBBox.maxX - recX, 

275 footprintBBox.minX - recX, 

276 footprintBBox.maxY - recY, 

277 footprintBBox.minY - recY])))) 

278 if bboxSize > maxSize: 

279 bboxSize = maxSize 

280 outputBBoxSizes[i] = bboxSize 

281 

282 return outputBBoxSizes 

283 

284 def bitPackFlags(self, df): 

285 """Pack requested flag columns in inputRecord into single columns in 

286 outputRecord. 

287 

288 Parameters 

289 ---------- 

290 df : `pandas.DataFrame` 

291 DataFrame to read bits from and pack them into. 

292 """ 

293 for outputFlag in self.bit_pack_columns: 

294 bitList = outputFlag['bitList'] 

295 value = np.zeros(len(df), dtype=np.uint64) 

296 for bit in bitList: 

297 # Hard type the bit arrays. 

298 value += (df[bit['name']]*2**bit['bit']).to_numpy().astype(np.uint64) 

299 df[outputFlag['columnName']] = value 

300 

301 

302class UnpackApdbFlags: 

303 """Class for unpacking bits from integer flag fields stored in the Apdb. 

304 

305 Attributes 

306 ---------- 

307 flag_map_file : `str` 

308 Absolute or relative path to a yaml file specifiying mappings of flags 

309 to integer bits. 

310 table_name : `str` 

311 Name of the Apdb table the integer bit data are coming from. 

312 """ 

313 

314 def __init__(self, flag_map_file, table_name): 

315 self.bit_pack_columns = [] 

316 flag_map_file = os.path.expandvars(flag_map_file) 

317 with open(flag_map_file) as yaml_stream: 

318 table_list = list(yaml.safe_load_all(yaml_stream)) 

319 for table in table_list: 

320 if table['tableName'] == table_name: 

321 self.bit_pack_columns = table['columns'] 

322 break 

323 

324 self.output_flag_columns = {} 

325 

326 for column in self.bit_pack_columns: 

327 names = [] 

328 for bit in column["bitList"]: 

329 names.append((bit["name"], bool)) 

330 self.output_flag_columns[column["columnName"]] = names 

331 

332 def unpack(self, input_flag_values, flag_name): 

333 """Determine individual boolean flags from an input array of unsigned 

334 ints. 

335 

336 Parameters 

337 ---------- 

338 input_flag_values : array-like of type uint 

339 Array of integer flags to unpack. 

340 flag_name : `str` 

341 Apdb column name of integer flags to unpack. Names of packed int 

342 flags are given by the flag_map_file. 

343 

344 Returns 

345 ------- 

346 output_flags : `numpy.ndarray` 

347 Numpy named tuple of booleans. 

348 """ 

349 bit_names_types = self.output_flag_columns[flag_name] 

350 output_flags = np.zeros(len(input_flag_values), dtype=bit_names_types) 

351 

352 for bit_idx, (bit_name, dtypes) in enumerate(bit_names_types): 

353 masked_bits = np.bitwise_and(input_flag_values, 2**bit_idx) 

354 output_flags[bit_name] = masked_bits 

355 

356 return output_flags 

357 

358 

359def getSignificance(catalog): 

360 """Return the significance value of the first peak in each source 

361 footprint, or NaN for peaks without a significance field. 

362 

363 Parameters 

364 ---------- 

365 catalog : `lsst.afw.table.SourceCatalog` 

366 Catalog to process. 

367 

368 Returns 

369 ------- 

370 significance : `np.ndarray`, (N,) 

371 Signficance of the first peak in each source footprint. 

372 """ 

373 result = np.full(len(catalog), np.nan) 

374 for i, record in enumerate(catalog): 

375 peaks = record.getFootprint().peaks 

376 if "significance" in peaks.schema: 

377 result[i] = peaks[0]["significance"] 

378 return result