Coverage for python/lsst/ap/association/transformDiaSourceCatalog.py: 20%

164 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-10 10:38 +0000

1# This file is part of ap_association 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ("TransformDiaSourceCatalogConnections", 

23 "TransformDiaSourceCatalogConfig", 

24 "TransformDiaSourceCatalogTask", 

25 "UnpackApdbFlags") 

26 

27import numpy as np 

28import os 

29import yaml 

30import pandas as pd 

31 

32from lsst.daf.base import DateTime 

33import lsst.pex.config as pexConfig 

34import lsst.pipe.base as pipeBase 

35import lsst.pipe.base.connectionTypes as connTypes 

36from lsst.pipe.tasks.postprocess import TransformCatalogBaseTask, TransformCatalogBaseConfig 

37from lsst.pipe.tasks.functors import Column 

38from lsst.utils.timer import timeMethod 

39 

40 

41class TransformDiaSourceCatalogConnections(pipeBase.PipelineTaskConnections, 

42 dimensions=("instrument", "visit", "detector"), 

43 defaultTemplates={"coaddName": "deep", "fakesType": ""}): 

44 diaSourceSchema = connTypes.InitInput( 

45 doc="Schema for DIASource catalog output by ImageDifference.", 

46 storageClass="SourceCatalog", 

47 name="{fakesType}{coaddName}Diff_diaSrc_schema", 

48 ) 

49 diaSourceCat = connTypes.Input( 

50 doc="Catalog of DiaSources produced during image differencing.", 

51 name="{fakesType}{coaddName}Diff_candidateDiaSrc", 

52 storageClass="SourceCatalog", 

53 dimensions=("instrument", "visit", "detector"), 

54 ) 

55 diffIm = connTypes.Input( 

56 doc="Difference image on which the DiaSources were detected.", 

57 name="{fakesType}{coaddName}Diff_differenceExp", 

58 storageClass="ExposureF", 

59 dimensions=("instrument", "visit", "detector"), 

60 ) 

61 reliability = connTypes.Input( 

62 doc="Reliability (e.g. real/bogus) classificiation of diaSourceCat sources (optional).", 

63 name="{fakesType}{coaddName}RealBogusSources", 

64 storageClass="Catalog", 

65 dimensions=("instrument", "visit", "detector"), 

66 ) 

67 diaSourceTable = connTypes.Output( 

68 doc=".", 

69 name="{fakesType}{coaddName}Diff_diaSrcTable", 

70 storageClass="DataFrame", 

71 dimensions=("instrument", "visit", "detector"), 

72 ) 

73 

74 def __init__(self, *, config=None): 

75 super().__init__(config=config) 

76 if not self.config.doIncludeReliability: 

77 self.inputs.remove("reliability") 

78 

79 

80class TransformDiaSourceCatalogConfig(TransformCatalogBaseConfig, 

81 pipelineConnections=TransformDiaSourceCatalogConnections): 

82 flagMap = pexConfig.Field( 

83 dtype=str, 

84 doc="Yaml file specifying SciencePipelines flag fields to bit packs.", 

85 default=os.path.join("${AP_ASSOCIATION_DIR}", 

86 "data", 

87 "association-flag-map.yaml"), 

88 ) 

89 flagRenameMap = pexConfig.Field( 

90 dtype=str, 

91 doc="Yaml file specifying specifying rules to rename flag names", 

92 default=os.path.join("${AP_ASSOCIATION_DIR}", 

93 "data", 

94 "flag-rename-rules.yaml"), 

95 ) 

96 doRemoveSkySources = pexConfig.Field( 

97 dtype=bool, 

98 default=False, 

99 doc="Input DiaSource catalog contains SkySources that should be " 

100 "removed before storing the output DiaSource catalog." 

101 ) 

102 # TODO: remove on DM-41532 

103 doPackFlags = pexConfig.Field( 

104 dtype=bool, 

105 default=False, 

106 doc="Do pack the flags into one integer column named 'flags'." 

107 "If False, instead produce one boolean column per flag.", 

108 deprecated="This field is no longer used. Will be removed after v28." 

109 ) 

110 doIncludeReliability = pexConfig.Field( 

111 dtype=bool, 

112 default=False, 

113 doc="Include the reliability (e.g. real/bogus) classifications in the output." 

114 ) 

115 

116 def setDefaults(self): 

117 super().setDefaults() 

118 self.functorFile = os.path.join("${AP_ASSOCIATION_DIR}", 

119 "data", 

120 "DiaSource.yaml") 

121 

122 

123class TransformDiaSourceCatalogTask(TransformCatalogBaseTask): 

124 """Transform a DiaSource catalog by calibrating and renaming columns to 

125 produce a table ready to insert into the Apdb. 

126 

127 Parameters 

128 ---------- 

129 initInputs : `dict` 

130 Must contain ``diaSourceSchema`` as the schema for the input catalog. 

131 """ 

132 ConfigClass = TransformDiaSourceCatalogConfig 

133 _DefaultName = "transformDiaSourceCatalog" 

134 # Needed to create a valid TransformCatalogBaseTask, but unused 

135 inputDataset = "deepDiff_diaSrc" 

136 outputDataset = "deepDiff_diaSrcTable" 

137 

138 def __init__(self, initInputs, **kwargs): 

139 super().__init__(**kwargs) 

140 self.funcs = self.getFunctors() 

141 self.inputSchema = initInputs['diaSourceSchema'].schema 

142 self._create_bit_pack_mappings() 

143 

144 if not self.config.doPackFlags: 

145 # get the flag rename rules 

146 with open(os.path.expandvars(self.config.flagRenameMap)) as yaml_stream: 

147 self.rename_rules = list(yaml.safe_load_all(yaml_stream)) 

148 

149 def _create_bit_pack_mappings(self): 

150 """Setup all flag bit packings. 

151 """ 

152 self.bit_pack_columns = [] 

153 flag_map_file = os.path.expandvars(self.config.flagMap) 

154 with open(flag_map_file) as yaml_stream: 

155 table_list = list(yaml.safe_load_all(yaml_stream)) 

156 for table in table_list: 

157 if table['tableName'] == 'DiaSource': 

158 self.bit_pack_columns = table['columns'] 

159 break 

160 

161 # Test that all flags requested are present in the input schemas. 

162 # Output schemas are flexible, however if names are not specified in 

163 # the Apdb schema, flag columns will not be persisted. 

164 for outputFlag in self.bit_pack_columns: 

165 bitList = outputFlag['bitList'] 

166 for bit in bitList: 

167 try: 

168 self.inputSchema.find(bit['name']) 

169 except KeyError: 

170 raise KeyError( 

171 "Requested column %s not found in input DiaSource " 

172 "schema. Please check that the requested input " 

173 "column exists." % bit['name']) 

174 

175 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

176 inputs = butlerQC.get(inputRefs) 

177 inputs["band"] = butlerQC.quantum.dataId["band"] 

178 

179 outputs = self.run(**inputs) 

180 

181 butlerQC.put(outputs, outputRefs) 

182 

183 @timeMethod 

184 def run(self, 

185 diaSourceCat, 

186 diffIm, 

187 band, 

188 reliability=None): 

189 """Convert input catalog to ParquetTable/Pandas and run functors. 

190 

191 Additionally, add new columns for stripping information from the 

192 exposure and into the DiaSource catalog. 

193 

194 Parameters 

195 ---------- 

196 diaSourceCat : `lsst.afw.table.SourceCatalog` 

197 Catalog of sources measured on the difference image. 

198 diffIm : `lsst.afw.image.Exposure` 

199 Result of subtracting template and science images. 

200 band : `str` 

201 Filter band of the science image. 

202 reliability : `lsst.afw.table.SourceCatalog` 

203 Reliability (e.g. real/bogus) scores, row-matched to 

204 ``diaSourceCat``. 

205 

206 Returns 

207 ------- 

208 results : `lsst.pipe.base.Struct` 

209 Results struct with components. 

210 

211 - ``diaSourceTable`` : Catalog of DiaSources with calibrated values 

212 and renamed columns. 

213 (`lsst.pipe.tasks.ParquetTable` or `pandas.DataFrame`) 

214 """ 

215 self.log.info( 

216 "Transforming/standardizing the DiaSource table for visit,detector: %i, %i", 

217 diffIm.visitInfo.id, diffIm.detector.getId()) 

218 

219 diaSourceDf = diaSourceCat.asAstropy().to_pandas() 

220 if self.config.doRemoveSkySources: 

221 diaSourceDf = diaSourceDf[~diaSourceDf["sky_source"]] 

222 diaSourceCat = diaSourceCat[~diaSourceCat["sky_source"]] 

223 

224 diaSourceDf["time_processed"] = DateTime.now().toPython() 

225 diaSourceDf["snr"] = getSignificance(diaSourceCat) 

226 diaSourceDf["bboxSize"] = self.computeBBoxSizes(diaSourceCat) 

227 diaSourceDf["visit"] = diffIm.visitInfo.id 

228 # int16 instead of uint8 because databases don't like unsigned bytes. 

229 diaSourceDf["detector"] = np.int16(diffIm.detector.getId()) 

230 diaSourceDf["band"] = band 

231 diaSourceDf["midpointMjdTai"] = diffIm.visitInfo.date.get(system=DateTime.MJD) 

232 diaSourceDf["diaObjectId"] = 0 

233 diaSourceDf["ssObjectId"] = 0 

234 

235 if self.config.doIncludeReliability: 

236 reliabilityDf = reliability.asAstropy().to_pandas() 

237 # This uses the pandas index to match scores with diaSources 

238 # but it will silently fill with NaNs if they don't match. 

239 diaSourceDf = pd.merge(diaSourceDf, reliabilityDf, 

240 how="left", on="id", validate="1:1") 

241 diaSourceDf = diaSourceDf.rename(columns={"score": "reliability"}) 

242 if np.sum(diaSourceDf["reliability"].isna()) == len(diaSourceDf): 

243 self.log.warning("Reliability identifiers did not match diaSourceIds") 

244 else: 

245 diaSourceDf["reliability"] = np.float32(np.nan) 

246 

247 if self.config.doPackFlags: 

248 # either bitpack the flags 

249 self.bitPackFlags(diaSourceDf) 

250 else: 

251 # or add the individual flag functors 

252 self.addUnpackedFlagFunctors() 

253 # and remove the packed flag functor 

254 if 'flags' in self.funcs.funcDict: 

255 del self.funcs.funcDict['flags'] 

256 

257 df = self.transform(band, 

258 diaSourceDf, 

259 self.funcs, 

260 dataId=None).df 

261 

262 return pipeBase.Struct( 

263 diaSourceTable=df, 

264 ) 

265 

266 def addUnpackedFlagFunctors(self): 

267 """Add Column functor for each of the flags to the internal functor 

268 dictionary. 

269 """ 

270 for flag in self.bit_pack_columns[0]['bitList']: 

271 flagName = flag['name'] 

272 targetName = self.funcs.renameCol(flagName, self.rename_rules[0]['flag_rename_rules']) 

273 self.funcs.update({targetName: Column(flagName)}) 

274 

275 def computeBBoxSizes(self, inputCatalog): 

276 """Compute the size of a square bbox that fully contains the detection 

277 footprint. 

278 

279 Parameters 

280 ---------- 

281 inputCatalog : `lsst.afw.table.SourceCatalog` 

282 Catalog containing detected footprints. 

283 

284 Returns 

285 ------- 

286 outputBBoxSizes : `np.ndarray`, (N,) 

287 Array of bbox sizes. 

288 """ 

289 # Schema validation requires that this field is int. 

290 outputBBoxSizes = np.empty(len(inputCatalog), dtype=int) 

291 for i, record in enumerate(inputCatalog): 

292 footprintBBox = record.getFootprint().getBBox() 

293 # Compute twice the size of the largest dimension of the footprint 

294 # bounding box. This is the largest footprint we should need to cover 

295 # the complete DiaSource assuming the centroid is within the bounding 

296 # box. 

297 maxSize = 2 * np.max([footprintBBox.getWidth(), 

298 footprintBBox.getHeight()]) 

299 recX = record.getCentroid().x 

300 recY = record.getCentroid().y 

301 bboxSize = int( 

302 np.ceil(2 * np.max(np.fabs([footprintBBox.maxX - recX, 

303 footprintBBox.minX - recX, 

304 footprintBBox.maxY - recY, 

305 footprintBBox.minY - recY])))) 

306 if bboxSize > maxSize: 

307 bboxSize = maxSize 

308 outputBBoxSizes[i] = bboxSize 

309 

310 return outputBBoxSizes 

311 

312 def bitPackFlags(self, df): 

313 """Pack requested flag columns in inputRecord into single columns in 

314 outputRecord. 

315 

316 Parameters 

317 ---------- 

318 df : `pandas.DataFrame` 

319 DataFrame to read bits from and pack them into. 

320 """ 

321 for outputFlag in self.bit_pack_columns: 

322 bitList = outputFlag['bitList'] 

323 value = np.zeros(len(df), dtype=np.uint64) 

324 for bit in bitList: 

325 # Hard type the bit arrays. 

326 value += (df[bit['name']]*2**bit['bit']).to_numpy().astype(np.uint64) 

327 df[outputFlag['columnName']] = value 

328 

329 

330class UnpackApdbFlags: 

331 """Class for unpacking bits from integer flag fields stored in the Apdb. 

332 

333 Attributes 

334 ---------- 

335 flag_map_file : `str` 

336 Absolute or relative path to a yaml file specifiying mappings of flags 

337 to integer bits. 

338 table_name : `str` 

339 Name of the Apdb table the integer bit data are coming from. 

340 """ 

341 

342 def __init__(self, flag_map_file, table_name): 

343 self.bit_pack_columns = [] 

344 flag_map_file = os.path.expandvars(flag_map_file) 

345 with open(flag_map_file) as yaml_stream: 

346 table_list = list(yaml.safe_load_all(yaml_stream)) 

347 for table in table_list: 

348 if table['tableName'] == table_name: 

349 self.bit_pack_columns = table['columns'] 

350 break 

351 

352 self.output_flag_columns = {} 

353 

354 for column in self.bit_pack_columns: 

355 names = {} 

356 for bit in column["bitList"]: 

357 names[bit["name"]] = bit["bit"] 

358 self.output_flag_columns[column["columnName"]] = names 

359 

360 def unpack(self, input_flag_values, flag_name): 

361 """Determine individual boolean flags from an input array of unsigned 

362 ints. 

363 

364 Parameters 

365 ---------- 

366 input_flag_values : array-like of type uint 

367 Array of integer packed bit flags to unpack. 

368 flag_name : `str` 

369 Apdb column name from the loaded file, e.g. "flags". 

370 

371 Returns 

372 ------- 

373 output_flags : `numpy.ndarray` 

374 Numpy structured array of booleans, one column per flag in the 

375 loaded file. 

376 """ 

377 output_flags = np.zeros(len(input_flag_values), 

378 dtype=[(name, bool) for name in self.output_flag_columns[flag_name]]) 

379 

380 for name in self.output_flag_columns[flag_name]: 

381 masked_bits = np.bitwise_and(input_flag_values, 

382 2**self.output_flag_columns[flag_name][name]) 

383 output_flags[name] = masked_bits 

384 

385 return output_flags 

386 

387 def flagExists(self, flagName, columnName='flags'): 

388 """Check if named flag is in the bitpacked flag set. 

389 

390 Parameters: 

391 ---------- 

392 flagName : `str` 

393 Flag name to search for. 

394 columnName : `str`, optional 

395 Name of bitpacked flag column to search in. 

396 

397 Returns 

398 ------- 

399 flagExists : `bool` 

400 `True` if `flagName` is present in `columnName`. 

401 

402 Raises 

403 ------ 

404 ValueError 

405 Raised if `columnName` is not defined. 

406 """ 

407 if columnName not in self.output_flag_columns: 

408 raise ValueError(f'column {columnName} not in flag map: {self.output_flag_columns}') 

409 

410 return flagName in [c for c in self.output_flag_columns[columnName]] 

411 

412 def makeFlagBitMask(self, flagNames, columnName='flags'): 

413 """Return a bitmask corresponding to the supplied flag names. 

414 

415 Parameters: 

416 ---------- 

417 flagNames : `list` [`str`] 

418 Flag names to include in the bitmask. 

419 columnName : `str`, optional 

420 Name of bitpacked flag column. 

421 

422 Returns 

423 ------- 

424 bitmask : `np.unit64` 

425 Bitmask corresponding to the supplied flag names given the loaded configuration. 

426 

427 Raises 

428 ------ 

429 ValueError 

430 Raised if a flag in `flagName` is not included in `columnName`. 

431 """ 

432 bitmask = np.uint64(0) 

433 

434 for flag in flagNames: 

435 if not self.flagExists(flag, columnName=columnName): 

436 raise ValueError(f"flag '{flag}' not included in '{columnName}' flag column") 

437 

438 for outputFlag in self.bit_pack_columns: 

439 if outputFlag['columnName'] == columnName: 

440 bitList = outputFlag['bitList'] 

441 for bit in bitList: 

442 if bit['name'] in flagNames: 

443 bitmask += np.uint64(2**bit['bit']) 

444 

445 return bitmask 

446 

447 

448def getSignificance(catalog): 

449 """Return the significance value of the first peak in each source 

450 footprint, or NaN for peaks without a significance field. 

451 

452 Parameters 

453 ---------- 

454 catalog : `lsst.afw.table.SourceCatalog` 

455 Catalog to process. 

456 

457 Returns 

458 ------- 

459 significance : `np.ndarray`, (N,) 

460 Signficance of the first peak in each source footprint. 

461 """ 

462 result = np.full(len(catalog), np.nan) 

463 for i, record in enumerate(catalog): 

464 peaks = record.getFootprint().peaks 

465 if "significance" in peaks.schema: 

466 result[i] = peaks[0]["significance"] 

467 return result