Coverage for python/lsst/ap/association/transformDiaSourceCatalog.py: 19%

151 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-04-25 04:35 -0700

1# This file is part of ap_association 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ("TransformDiaSourceCatalogConnections", 

23 "TransformDiaSourceCatalogConfig", 

24 "TransformDiaSourceCatalogTask", 

25 "UnpackApdbFlags") 

26 

27import numpy as np 

28import os 

29import yaml 

30 

31from lsst.daf.base import DateTime 

32import lsst.pex.config as pexConfig 

33import lsst.pipe.base as pipeBase 

34import lsst.pipe.base.connectionTypes as connTypes 

35from lsst.pipe.tasks.postprocess import TransformCatalogBaseTask, TransformCatalogBaseConfig 

36from lsst.pipe.tasks.functors import Column 

37from lsst.utils.timer import timeMethod 

38 

39 

40class TransformDiaSourceCatalogConnections(pipeBase.PipelineTaskConnections, 

41 dimensions=("instrument", "visit", "detector"), 

42 defaultTemplates={"coaddName": "deep", "fakesType": ""}): 

43 diaSourceSchema = connTypes.InitInput( 

44 doc="Schema for DIASource catalog output by ImageDifference.", 

45 storageClass="SourceCatalog", 

46 name="{fakesType}{coaddName}Diff_diaSrc_schema", 

47 ) 

48 diaSourceCat = connTypes.Input( 

49 doc="Catalog of DiaSources produced during image differencing.", 

50 name="{fakesType}{coaddName}Diff_diaSrc", 

51 storageClass="SourceCatalog", 

52 dimensions=("instrument", "visit", "detector"), 

53 ) 

54 diffIm = connTypes.Input( 

55 doc="Difference image on which the DiaSources were detected.", 

56 name="{fakesType}{coaddName}Diff_differenceExp", 

57 storageClass="ExposureF", 

58 dimensions=("instrument", "visit", "detector"), 

59 ) 

60 diaSourceTable = connTypes.Output( 

61 doc=".", 

62 name="{fakesType}{coaddName}Diff_diaSrcTable", 

63 storageClass="DataFrame", 

64 dimensions=("instrument", "visit", "detector"), 

65 ) 

66 

67 

68class TransformDiaSourceCatalogConfig(TransformCatalogBaseConfig, 

69 pipelineConnections=TransformDiaSourceCatalogConnections): 

70 flagMap = pexConfig.Field( 

71 dtype=str, 

72 doc="Yaml file specifying SciencePipelines flag fields to bit packs.", 

73 default=os.path.join("${AP_ASSOCIATION_DIR}", 

74 "data", 

75 "association-flag-map.yaml"), 

76 ) 

77 flagRenameMap = pexConfig.Field( 

78 dtype=str, 

79 doc="Yaml file specifying specifying rules to rename flag names", 

80 default=os.path.join("${AP_ASSOCIATION_DIR}", 

81 "data", 

82 "flag-rename-rules.yaml"), 

83 ) 

84 doRemoveSkySources = pexConfig.Field( 

85 dtype=bool, 

86 default=False, 

87 doc="Input DiaSource catalog contains SkySources that should be " 

88 "removed before storing the output DiaSource catalog." 

89 ) 

90 doPackFlags = pexConfig.Field( 

91 dtype=bool, 

92 default=True, 

93 doc="Do pack the flags into one integer column named 'flags'." 

94 "If False, instead produce one boolean column per flag." 

95 ) 

96 

97 def setDefaults(self): 

98 super().setDefaults() 

99 self.functorFile = os.path.join("${AP_ASSOCIATION_DIR}", 

100 "data", 

101 "DiaSource.yaml") 

102 

103 

104class TransformDiaSourceCatalogTask(TransformCatalogBaseTask): 

105 """Transform a DiaSource catalog by calibrating and renaming columns to 

106 produce a table ready to insert into the Apdb. 

107 

108 Parameters 

109 ---------- 

110 initInputs : `dict` 

111 Must contain ``diaSourceSchema`` as the schema for the input catalog. 

112 """ 

113 ConfigClass = TransformDiaSourceCatalogConfig 

114 _DefaultName = "transformDiaSourceCatalog" 

115 # Needed to create a valid TransformCatalogBaseTask, but unused 

116 inputDataset = "deepDiff_diaSrc" 

117 outputDataset = "deepDiff_diaSrcTable" 

118 

119 def __init__(self, initInputs, **kwargs): 

120 super().__init__(**kwargs) 

121 self.funcs = self.getFunctors() 

122 self.inputSchema = initInputs['diaSourceSchema'].schema 

123 self._create_bit_pack_mappings() 

124 

125 if not self.config.doPackFlags: 

126 # get the flag rename rules 

127 with open(os.path.expandvars(self.config.flagRenameMap)) as yaml_stream: 

128 self.rename_rules = list(yaml.safe_load_all(yaml_stream)) 

129 

130 def _create_bit_pack_mappings(self): 

131 """Setup all flag bit packings. 

132 """ 

133 self.bit_pack_columns = [] 

134 flag_map_file = os.path.expandvars(self.config.flagMap) 

135 with open(flag_map_file) as yaml_stream: 

136 table_list = list(yaml.safe_load_all(yaml_stream)) 

137 for table in table_list: 

138 if table['tableName'] == 'DiaSource': 

139 self.bit_pack_columns = table['columns'] 

140 break 

141 

142 # Test that all flags requested are present in the input schemas. 

143 # Output schemas are flexible, however if names are not specified in 

144 # the Apdb schema, flag columns will not be persisted. 

145 for outputFlag in self.bit_pack_columns: 

146 bitList = outputFlag['bitList'] 

147 for bit in bitList: 

148 try: 

149 self.inputSchema.find(bit['name']) 

150 except KeyError: 

151 raise KeyError( 

152 "Requested column %s not found in input DiaSource " 

153 "schema. Please check that the requested input " 

154 "column exists." % bit['name']) 

155 

156 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

157 inputs = butlerQC.get(inputRefs) 

158 expId, expBits = butlerQC.quantum.dataId.pack("visit_detector", 

159 returnMaxBits=True) 

160 inputs["ccdVisitId"] = expId 

161 inputs["band"] = butlerQC.quantum.dataId["band"] 

162 

163 outputs = self.run(**inputs) 

164 

165 butlerQC.put(outputs, outputRefs) 

166 

167 @timeMethod 

168 def run(self, 

169 diaSourceCat, 

170 diffIm, 

171 band, 

172 ccdVisitId, 

173 funcs=None): 

174 """Convert input catalog to ParquetTable/Pandas and run functors. 

175 

176 Additionally, add new columns for stripping information from the 

177 exposure and into the DiaSource catalog. 

178 

179 Parameters 

180 ---------- 

181 diaSourceCat : `lsst.afw.table.SourceCatalog` 

182 Catalog of sources measured on the difference image. 

183 diffIm : `lsst.afw.image.Exposure` 

184 Result of subtracting template and science images. 

185 band : `str` 

186 Filter band of the science image. 

187 ccdVisitId : `int` 

188 Identifier for this detector+visit. 

189 funcs : `lsst.pipe.tasks.functors.Functors` 

190 Functors to apply to the catalog's columns. 

191 

192 Returns 

193 ------- 

194 results : `lsst.pipe.base.Struct` 

195 Results struct with components. 

196 

197 - ``diaSourceTable`` : Catalog of DiaSources with calibrated values 

198 and renamed columns. 

199 (`lsst.pipe.tasks.ParquetTable` or `pandas.DataFrame`) 

200 """ 

201 self.log.info( 

202 "Transforming/standardizing the DiaSource table ccdVisitId: %i", 

203 ccdVisitId) 

204 

205 diaSourceDf = diaSourceCat.asAstropy().to_pandas() 

206 if self.config.doRemoveSkySources: 

207 diaSourceDf = diaSourceDf[~diaSourceDf["sky_source"]] 

208 diaSourceCat = diaSourceCat[~diaSourceCat["sky_source"]] 

209 

210 diaSourceDf["snr"] = getSignificance(diaSourceCat) 

211 diaSourceDf["bboxSize"] = self.computeBBoxSizes(diaSourceCat) 

212 diaSourceDf["ccdVisitId"] = ccdVisitId 

213 diaSourceDf["filterName"] = band 

214 diaSourceDf["midPointTai"] = diffIm.getInfo().getVisitInfo().getDate().get(system=DateTime.MJD) 

215 diaSourceDf["diaObjectId"] = 0 

216 diaSourceDf["ssObjectId"] = 0 

217 

218 if self.config.doPackFlags: 

219 # either bitpack the flags 

220 self.bitPackFlags(diaSourceDf) 

221 else: 

222 # or add the individual flag functors 

223 self.addUnpackedFlagFunctors() 

224 # and remove the packed flag functor 

225 if 'flags' in self.funcs.funcDict: 

226 del self.funcs.funcDict['flags'] 

227 

228 df = self.transform(band, 

229 diaSourceDf, 

230 self.funcs, 

231 dataId=None).df 

232 

233 return pipeBase.Struct( 

234 diaSourceTable=df, 

235 ) 

236 

237 def addUnpackedFlagFunctors(self): 

238 """Add Column functor for each of the flags to the internal functor 

239 dictionary. 

240 """ 

241 for flag in self.bit_pack_columns[0]['bitList']: 

242 flagName = flag['name'] 

243 targetName = self.funcs.renameCol(flagName, self.rename_rules[0]['flag_rename_rules']) 

244 self.funcs.update({targetName: Column(flagName)}) 

245 

246 def computeBBoxSizes(self, inputCatalog): 

247 """Compute the size of a square bbox that fully contains the detection 

248 footprint. 

249 

250 Parameters 

251 ---------- 

252 inputCatalog : `lsst.afw.table.SourceCatalog` 

253 Catalog containing detected footprints. 

254 

255 Returns 

256 ------- 

257 outputBBoxSizes : `np.ndarray`, (N,) 

258 Array of bbox sizes. 

259 """ 

260 # Schema validation requires that this field is int. 

261 outputBBoxSizes = np.empty(len(inputCatalog), dtype=int) 

262 for i, record in enumerate(inputCatalog): 

263 footprintBBox = record.getFootprint().getBBox() 

264 # Compute twice the size of the largest dimension of the footprint 

265 # bounding box. This is the largest footprint we should need to cover 

266 # the complete DiaSource assuming the centroid is within the bounding 

267 # box. 

268 maxSize = 2 * np.max([footprintBBox.getWidth(), 

269 footprintBBox.getHeight()]) 

270 recX = record.getCentroid().x 

271 recY = record.getCentroid().y 

272 bboxSize = int( 

273 np.ceil(2 * np.max(np.fabs([footprintBBox.maxX - recX, 

274 footprintBBox.minX - recX, 

275 footprintBBox.maxY - recY, 

276 footprintBBox.minY - recY])))) 

277 if bboxSize > maxSize: 

278 bboxSize = maxSize 

279 outputBBoxSizes[i] = bboxSize 

280 

281 return outputBBoxSizes 

282 

283 def bitPackFlags(self, df): 

284 """Pack requested flag columns in inputRecord into single columns in 

285 outputRecord. 

286 

287 Parameters 

288 ---------- 

289 df : `pandas.DataFrame` 

290 DataFrame to read bits from and pack them into. 

291 """ 

292 for outputFlag in self.bit_pack_columns: 

293 bitList = outputFlag['bitList'] 

294 value = np.zeros(len(df), dtype=np.uint64) 

295 for bit in bitList: 

296 # Hard type the bit arrays. 

297 value += (df[bit['name']]*2**bit['bit']).to_numpy().astype(np.uint64) 

298 df[outputFlag['columnName']] = value 

299 

300 

301class UnpackApdbFlags: 

302 """Class for unpacking bits from integer flag fields stored in the Apdb. 

303 

304 Attributes 

305 ---------- 

306 flag_map_file : `str` 

307 Absolute or relative path to a yaml file specifiying mappings of flags 

308 to integer bits. 

309 table_name : `str` 

310 Name of the Apdb table the integer bit data are coming from. 

311 """ 

312 

313 def __init__(self, flag_map_file, table_name): 

314 self.bit_pack_columns = [] 

315 flag_map_file = os.path.expandvars(flag_map_file) 

316 with open(flag_map_file) as yaml_stream: 

317 table_list = list(yaml.safe_load_all(yaml_stream)) 

318 for table in table_list: 

319 if table['tableName'] == table_name: 

320 self.bit_pack_columns = table['columns'] 

321 break 

322 

323 self.output_flag_columns = {} 

324 

325 for column in self.bit_pack_columns: 

326 names = [] 

327 for bit in column["bitList"]: 

328 names.append((bit["name"], bool)) 

329 self.output_flag_columns[column["columnName"]] = names 

330 

331 def unpack(self, input_flag_values, flag_name): 

332 """Determine individual boolean flags from an input array of unsigned 

333 ints. 

334 

335 Parameters 

336 ---------- 

337 input_flag_values : array-like of type uint 

338 Array of integer flags to unpack. 

339 flag_name : `str` 

340 Apdb column name of integer flags to unpack. Names of packed int 

341 flags are given by the flag_map_file. 

342 

343 Returns 

344 ------- 

345 output_flags : `numpy.ndarray` 

346 Numpy named tuple of booleans. 

347 """ 

348 bit_names_types = self.output_flag_columns[flag_name] 

349 output_flags = np.zeros(len(input_flag_values), dtype=bit_names_types) 

350 

351 for bit_idx, (bit_name, dtypes) in enumerate(bit_names_types): 

352 masked_bits = np.bitwise_and(input_flag_values, 2**bit_idx) 

353 output_flags[bit_name] = masked_bits 

354 

355 return output_flags 

356 

357 def flagExists(self, flagName, columnName='flags'): 

358 """Check if named flag is in the bitpacked flag set. 

359 

360 Parameters: 

361 ---------- 

362 flagName : `str` 

363 Flag name to search for. 

364 columnName : `str`, optional 

365 Name of bitpacked flag column to search in. 

366 

367 Returns 

368 ------- 

369 flagExists : `bool` 

370 `True` if `flagName` is present in `columnName`. 

371 

372 Raises 

373 ------ 

374 ValueError 

375 Raised if `columnName` is not defined. 

376 """ 

377 if columnName not in self.output_flag_columns: 

378 raise ValueError(f'column {columnName} not in flag map: {self.output_flag_columns}') 

379 

380 return flagName in [c[0] for c in self.output_flag_columns[columnName]] 

381 

382 def makeFlagBitMask(self, flagNames, columnName='flags'): 

383 """Return a bitmask corresponding to the supplied flag names. 

384 

385 Parameters: 

386 ---------- 

387 flagNames : `list` [`str`] 

388 Flag names to include in the bitmask. 

389 columnName : `str`, optional 

390 Name of bitpacked flag column. 

391 

392 Returns 

393 ------- 

394 bitmask : `np.unit64` 

395 Bitmask corresponding to the supplied flag names given the loaded configuration. 

396 

397 Raises 

398 ------ 

399 ValueError 

400 Raised if a flag in `flagName` is not included in `columnName`. 

401 """ 

402 bitmask = np.uint64(0) 

403 

404 for flag in flagNames: 

405 if not self.flagExists(flag, columnName=columnName): 

406 raise ValueError(f"flag '{flag}' not included in '{columnName}' flag column") 

407 

408 for outputFlag in self.bit_pack_columns: 

409 if outputFlag['columnName'] == columnName: 

410 bitList = outputFlag['bitList'] 

411 for bit in bitList: 

412 if bit['name'] in flagNames: 

413 bitmask += np.uint64(2**bit['bit']) 

414 

415 return bitmask 

416 

417 

418def getSignificance(catalog): 

419 """Return the significance value of the first peak in each source 

420 footprint, or NaN for peaks without a significance field. 

421 

422 Parameters 

423 ---------- 

424 catalog : `lsst.afw.table.SourceCatalog` 

425 Catalog to process. 

426 

427 Returns 

428 ------- 

429 significance : `np.ndarray`, (N,) 

430 Signficance of the first peak in each source footprint. 

431 """ 

432 result = np.full(len(catalog), np.nan) 

433 for i, record in enumerate(catalog): 

434 peaks = record.getFootprint().peaks 

435 if "significance" in peaks.schema: 

436 result[i] = peaks[0]["significance"] 

437 return result