Coverage for python / lsst / ap / association / transformDiaSourceCatalog.py: 22%

163 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-06 08:53 +0000

1# This file is part of ap_association 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ("TransformDiaSourceCatalogConnections", 

23 "TransformDiaSourceCatalogConfig", 

24 "TransformDiaSourceCatalogTask", 

25 "UnpackApdbFlags") 

26 

27import os 

28import yaml 

29 

30import numpy as np 

31 

32from lsst.resources import ResourcePath 

33from lsst.daf.base import DateTime 

34import lsst.pex.config as pexConfig 

35import lsst.pipe.base as pipeBase 

36import lsst.pipe.base.connectionTypes as connTypes 

37from lsst.pipe.tasks.postprocess import TransformCatalogBaseTask, TransformCatalogBaseConfig 

38from lsst.pipe.tasks.functors import Column 

39from lsst.utils.timer import timeMethod 

40 

41from lsst.pipe.tasks.schemaUtils import convertDataFrameToSdmSchema, readSdmSchemaFile 

42 

43 

44class TransformDiaSourceCatalogConnections(pipeBase.PipelineTaskConnections, 

45 dimensions=("instrument", "visit", "detector"), 

46 defaultTemplates={"coaddName": "deep", "fakesType": ""}): 

47 diaSourceSchema = connTypes.InitInput( 

48 doc="Schema for DIASource catalog output by ImageDifference.", 

49 storageClass="SourceCatalog", 

50 name="{fakesType}{coaddName}Diff_diaSrc_schema", 

51 ) 

52 diaSourceCat = connTypes.Input( 

53 doc="Catalog of DiaSources produced during image differencing.", 

54 name="{fakesType}{coaddName}Diff_candidateDiaSrc", 

55 storageClass="SourceCatalog", 

56 dimensions=("instrument", "visit", "detector"), 

57 ) 

58 diffIm = connTypes.Input( 

59 doc="Difference image on which the DiaSources were detected.", 

60 name="{fakesType}{coaddName}Diff_differenceExp", 

61 storageClass="ExposureF", 

62 dimensions=("instrument", "visit", "detector"), 

63 ) 

64 diaSourceTable = connTypes.Output( 

65 doc=".", 

66 name="{fakesType}{coaddName}Diff_diaSrcTable", 

67 storageClass="ArrowAstropy", 

68 dimensions=("instrument", "visit", "detector"), 

69 ) 

70 

71 

72class TransformDiaSourceCatalogConfig(TransformCatalogBaseConfig, 

73 pipelineConnections=TransformDiaSourceCatalogConnections): 

74 flagMap = pexConfig.Field( 

75 dtype=str, 

76 doc="Yaml file specifying SciencePipelines flag fields to bit packs.", 

77 default=os.path.join("${AP_ASSOCIATION_DIR}", 

78 "data", 

79 "association-flag-map.yaml"), 

80 ) 

81 flagRenameMap = pexConfig.Field( 

82 dtype=str, 

83 doc="Yaml file specifying specifying rules to rename flag names", 

84 default=os.path.join("${AP_ASSOCIATION_DIR}", 

85 "data", 

86 "flag-rename-rules.yaml"), 

87 ) 

88 doRemoveSkySources = pexConfig.Field( 

89 dtype=bool, 

90 default=False, 

91 doc="Input DiaSource catalog contains SkySources that should be " 

92 "removed before storing the output DiaSource catalog.", 

93 ) 

94 # TODO: remove on DM-41532 

95 doPackFlags = pexConfig.Field( 

96 dtype=bool, 

97 default=False, 

98 doc="Do pack the flags into one integer column named 'flags'." 

99 "If False, instead produce one boolean column per flag.", 

100 deprecated="This field is no longer used. Will be removed after v28." 

101 ) 

102 doUseApdbSchema = pexConfig.Field( 

103 dtype=bool, 

104 default=False, 

105 doc="Use the APDB schema to coerce the data types of the output columns.", 

106 deprecated="This field has been renamed to doUseSchema, and will be " 

107 "removed after v30." 

108 ) 

109 doUseSchema = pexConfig.Field( 

110 dtype=bool, 

111 default=False, 

112 doc="Use an existing schema to coerce the data types of the output columns." 

113 ) 

114 schemaDir = pexConfig.Field( 

115 dtype=str, 

116 doc="Path to the directory containing schema definitions.", 

117 default=os.path.join("${SDM_SCHEMAS_DIR}", 

118 "yml"), 

119 ) 

120 schemaFile = pexConfig.Field( 

121 dtype=str, 

122 doc="Yaml file specifying the schema of the output catalog.", 

123 default="apdb.yaml", 

124 ) 

125 schemaName = pexConfig.Field( 

126 dtype=str, 

127 doc="Name of the table in the schema file to read.", 

128 default="ApdbSchema", 

129 deprecated="This config is no longer used, and will be removed after v30" 

130 ) 

131 

132 def setDefaults(self): 

133 super().setDefaults() 

134 self.functorFile = os.path.join("${AP_ASSOCIATION_DIR}", 

135 "data", 

136 "DiaSource.yaml") 

137 

138 

139class TransformDiaSourceCatalogTask(TransformCatalogBaseTask): 

140 """Transform a DiaSource catalog by calibrating and renaming columns to 

141 produce a table ready to insert into the Apdb. 

142 

143 Parameters 

144 ---------- 

145 initInputs : `dict` 

146 Must contain ``diaSourceSchema`` as the schema for the input catalog. 

147 """ 

148 ConfigClass = TransformDiaSourceCatalogConfig 

149 _DefaultName = "transformDiaSourceCatalog" 

150 # Needed to create a valid TransformCatalogBaseTask, but unused 

151 inputDataset = "deepDiff_diaSrc" 

152 outputDataset = "deepDiff_diaSrcTable" 

153 

154 def __init__(self, initInputs, **kwargs): 

155 super().__init__(**kwargs) 

156 self.funcs = self.getFunctors() 

157 self.inputSchema = initInputs['diaSourceSchema'].schema 

158 self._create_bit_pack_mappings() 

159 if self.config.doUseSchema: 

160 schemaFile = os.path.join(self.config.schemaDir, self.config.schemaFile) 

161 self.schema = readSdmSchemaFile(schemaFile) 

162 else: 

163 self.schema = None 

164 

165 if not self.config.doPackFlags: 

166 # get the flag rename rules 

167 with open(os.path.expandvars(self.config.flagRenameMap)) as yaml_stream: 

168 self.rename_rules = list(yaml.safe_load_all(yaml_stream)) 

169 

170 def _create_bit_pack_mappings(self): 

171 """Setup all flag bit packings. 

172 """ 

173 self.bit_pack_columns = [] 

174 flag_map_file = os.path.expandvars(self.config.flagMap) 

175 with open(flag_map_file) as yaml_stream: 

176 table_list = list(yaml.safe_load_all(yaml_stream)) 

177 for table in table_list: 

178 if table['tableName'] == 'DiaSource': 

179 self.bit_pack_columns = table['columns'] 

180 break 

181 

182 # Test that all flags requested are present in the input schemas. 

183 # Output schemas are flexible, however if names are not specified in 

184 # the Apdb schema, flag columns will not be persisted. 

185 for outputFlag in self.bit_pack_columns: 

186 bitList = outputFlag['bitList'] 

187 for bit in bitList: 

188 try: 

189 self.inputSchema.find(bit['name']) 

190 except KeyError: 

191 raise KeyError( 

192 "Requested column %s not found in input DiaSource " 

193 "schema. Please check that the requested input " 

194 "column exists." % bit['name']) 

195 

196 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

197 inputs = butlerQC.get(inputRefs) 

198 inputs["band"] = butlerQC.quantum.dataId["band"] 

199 

200 outputs = self.run(**inputs) 

201 

202 butlerQC.put(outputs, outputRefs) 

203 

204 @timeMethod 

205 def run(self, 

206 diaSourceCat, 

207 diffIm, 

208 band, 

209 reliability=None): 

210 """Convert input catalog to ParquetTable/Pandas and run functors. 

211 

212 Additionally, add new columns for stripping information from the 

213 exposure and into the DiaSource catalog. 

214 

215 Parameters 

216 ---------- 

217 diaSourceCat : `lsst.afw.table.SourceCatalog` 

218 Catalog of sources measured on the difference image. 

219 diffIm : `lsst.afw.image.Exposure` 

220 Result of subtracting template and science images. 

221 band : `str` 

222 Filter band of the science image. 

223 reliability : `lsst.afw.table.SourceCatalog` 

224 Reliability (e.g. real/bogus) scores, row-matched to 

225 ``diaSourceCat``. 

226 

227 Returns 

228 ------- 

229 results : `lsst.pipe.base.Struct` 

230 Results struct with components. 

231 

232 - ``diaSourceTable`` : Catalog of DiaSources with calibrated values 

233 and renamed columns. 

234 (`lsst.pipe.tasks.ParquetTable` or `pandas.DataFrame`) 

235 """ 

236 self.log.info( 

237 "Transforming/standardizing the DiaSource table for visit,detector: %i, %i", 

238 diffIm.visitInfo.id, diffIm.detector.getId()) 

239 

240 diaSourceDf = diaSourceCat.asAstropy().to_pandas() 

241 if self.config.doRemoveSkySources: 

242 diaSourceDf = diaSourceDf[~diaSourceDf["sky_source"]] 

243 diaSourceCat = diaSourceCat[~diaSourceCat["sky_source"]] 

244 

245 # Need UTC time but without a timezone because pandas requires a 

246 # naive datetime. 

247 diaSourceDf["timeProcessedMjdTai"] = DateTime.now().get(system=DateTime.MJD, scale=DateTime.TAI) 

248 diaSourceDf["snr"] = getSignificance(diaSourceCat) 

249 diaSourceDf["bboxSize"] = self.computeBBoxSizes(diaSourceCat) 

250 diaSourceDf["visit"] = diffIm.visitInfo.id 

251 # int16 instead of uint8 because databases don't like unsigned bytes. 

252 diaSourceDf["detector"] = np.int16(diffIm.detector.getId()) 

253 diaSourceDf["band"] = band 

254 diaSourceDf["midpointMjdTai"] = diffIm.visitInfo.date.get(system=DateTime.MJD) 

255 diaSourceDf["diaObjectId"] = 0 

256 diaSourceDf["ssObjectId"] = 0 

257 

258 # TODO: this has been formally deprecated and should be removed too 

259 if self.config.doPackFlags: 

260 # either bitpack the flags 

261 self.bitPackFlags(diaSourceDf) 

262 else: 

263 # or add the individual flag functors 

264 self.addUnpackedFlagFunctors() 

265 # and remove the packed flag functor 

266 if 'flags' in self.funcs.funcDict: 

267 del self.funcs.funcDict['flags'] 

268 

269 df = self.transform(band, 

270 diaSourceDf, 

271 self.funcs, 

272 dataId=None).df 

273 if self.config.doUseSchema: 

274 df = convertDataFrameToSdmSchema(self.schema, df, tableName="DiaSource") 

275 

276 return pipeBase.Struct( 

277 diaSourceTable=df, 

278 ) 

279 

280 def addUnpackedFlagFunctors(self): 

281 """Add Column functor for each of the flags to the internal functor 

282 dictionary. 

283 """ 

284 for flag in self.bit_pack_columns[0]['bitList']: 

285 flagName = flag['name'] 

286 targetName = self.funcs.renameCol(flagName, self.rename_rules[0]['flag_rename_rules']) 

287 self.funcs.update({targetName: Column(flagName)}) 

288 

289 def computeBBoxSizes(self, inputCatalog): 

290 """Compute the size of a square bbox that fully contains the detection 

291 footprint. 

292 

293 Parameters 

294 ---------- 

295 inputCatalog : `lsst.afw.table.SourceCatalog` 

296 Catalog containing detected footprints. 

297 

298 Returns 

299 ------- 

300 outputBBoxSizes : `np.ndarray`, (N,) 

301 Array of bbox sizes. 

302 """ 

303 # Schema validation requires that this field is int. 

304 outputBBoxSizes = np.empty(len(inputCatalog), dtype=int) 

305 for i, record in enumerate(inputCatalog): 

306 footprintBBox = record.getFootprint().getBBox() 

307 # Compute twice the size of the largest dimension of the footprint 

308 # bounding box. This is the largest footprint we should need to cover 

309 # the complete DiaSource assuming the centroid is within the bounding 

310 # box. 

311 maxSize = 2 * np.max([footprintBBox.getWidth(), 

312 footprintBBox.getHeight()]) 

313 recX = record.getCentroid().x 

314 recY = record.getCentroid().y 

315 bboxSize = int( 

316 np.ceil(2 * np.max(np.fabs([footprintBBox.maxX - recX, 

317 footprintBBox.minX - recX, 

318 footprintBBox.maxY - recY, 

319 footprintBBox.minY - recY])))) 

320 if bboxSize > maxSize: 

321 bboxSize = maxSize 

322 outputBBoxSizes[i] = bboxSize 

323 

324 return outputBBoxSizes 

325 

326 def bitPackFlags(self, df): 

327 """Pack requested flag columns in inputRecord into single columns in 

328 outputRecord. 

329 

330 Parameters 

331 ---------- 

332 df : `pandas.DataFrame` 

333 DataFrame to read bits from and pack them into. 

334 """ 

335 for outputFlag in self.bit_pack_columns: 

336 bitList = outputFlag['bitList'] 

337 value = np.zeros(len(df), dtype=np.uint64) 

338 for bit in bitList: 

339 # Hard type the bit arrays. 

340 value += (df[bit['name']]*2**bit['bit']).to_numpy().astype(np.uint64) 

341 df[outputFlag['columnName']] = value 

342 

343 

344class UnpackApdbFlags: 

345 """Class for unpacking bits from integer flag fields stored in the Apdb. 

346 

347 Attributes 

348 ---------- 

349 flag_map_file : `lsst.resources.ResourcePathExpression` 

350 Absolute or relative URI to a yaml file specifiying mappings of flags 

351 to integer bits. 

352 table_name : `str` 

353 Name of the Apdb table the integer bit data are coming from. 

354 """ 

355 

356 def __init__(self, flag_map_file, table_name): 

357 self.bit_pack_columns = [] 

358 flag_map_file = os.path.expandvars(flag_map_file) 

359 with ResourcePath(flag_map_file, forceDirectory=False).open("r") as yaml_stream: 

360 table_list = list(yaml.safe_load_all(yaml_stream)) 

361 for table in table_list: 

362 if table['tableName'] == table_name: 

363 self.bit_pack_columns = table['columns'] 

364 break 

365 

366 self.output_flag_columns = {} 

367 

368 for column in self.bit_pack_columns: 

369 names = {} 

370 for bit in column["bitList"]: 

371 names[bit["name"]] = bit["bit"] 

372 self.output_flag_columns[column["columnName"]] = names 

373 

374 def unpack(self, input_flag_values, flag_name): 

375 """Determine individual boolean flags from an input array of unsigned 

376 ints. 

377 

378 Parameters 

379 ---------- 

380 input_flag_values : array-like of type uint 

381 Array of integer packed bit flags to unpack. 

382 flag_name : `str` 

383 Apdb column name from the loaded file, e.g. "flags". 

384 

385 Returns 

386 ------- 

387 output_flags : `numpy.ndarray` 

388 Numpy structured array of booleans, one column per flag in the 

389 loaded file. 

390 """ 

391 output_flags = np.zeros(len(input_flag_values), 

392 dtype=[(name, bool) for name in self.output_flag_columns[flag_name]]) 

393 

394 for name in self.output_flag_columns[flag_name]: 

395 masked_bits = np.bitwise_and(input_flag_values, 

396 2**self.output_flag_columns[flag_name][name]) 

397 output_flags[name] = masked_bits 

398 

399 return output_flags 

400 

401 def flagExists(self, flagName, columnName='flags'): 

402 """Check if named flag is in the bitpacked flag set. 

403 

404 Parameters: 

405 ---------- 

406 flagName : `str` 

407 Flag name to search for. 

408 columnName : `str`, optional 

409 Name of bitpacked flag column to search in. 

410 

411 Returns 

412 ------- 

413 flagExists : `bool` 

414 `True` if `flagName` is present in `columnName`. 

415 

416 Raises 

417 ------ 

418 ValueError 

419 Raised if `columnName` is not defined. 

420 """ 

421 if columnName not in self.output_flag_columns: 

422 raise ValueError(f'column {columnName} not in flag map: {self.output_flag_columns}') 

423 

424 return flagName in [c for c in self.output_flag_columns[columnName]] 

425 

426 def makeFlagBitMask(self, flagNames, columnName='flags'): 

427 """Return a bitmask corresponding to the supplied flag names. 

428 

429 Parameters: 

430 ---------- 

431 flagNames : `list` [`str`] 

432 Flag names to include in the bitmask. 

433 columnName : `str`, optional 

434 Name of bitpacked flag column. 

435 

436 Returns 

437 ------- 

438 bitmask : `np.unit64` 

439 Bitmask corresponding to the supplied flag names given the loaded configuration. 

440 

441 Raises 

442 ------ 

443 ValueError 

444 Raised if a flag in `flagName` is not included in `columnName`. 

445 """ 

446 bitmask = np.uint64(0) 

447 

448 for flag in flagNames: 

449 if not self.flagExists(flag, columnName=columnName): 

450 raise ValueError(f"flag '{flag}' not included in '{columnName}' flag column") 

451 

452 for outputFlag in self.bit_pack_columns: 

453 if outputFlag['columnName'] == columnName: 

454 bitList = outputFlag['bitList'] 

455 for bit in bitList: 

456 if bit['name'] in flagNames: 

457 bitmask += np.uint64(2**bit['bit']) 

458 

459 return bitmask 

460 

461 

462def getSignificance(catalog): 

463 """Return the significance value of the first peak in each source 

464 footprint, or NaN for peaks without a significance field. 

465 

466 Parameters 

467 ---------- 

468 catalog : `lsst.afw.table.SourceCatalog` 

469 Catalog to process. 

470 

471 Returns 

472 ------- 

473 significance : `np.ndarray`, (N,) 

474 Signficance of the first peak in each source footprint. 

475 """ 

476 result = np.full(len(catalog), np.nan) 

477 for i, record in enumerate(catalog): 

478 peaks = record.getFootprint().peaks 

479 if "significance" in peaks.schema: 

480 result[i] = peaks[0]["significance"] 

481 return result