Coverage for python/lsst/ap/association/transformDiaSourceCatalog.py: 20%

153 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-08 08:52 +0000

1# This file is part of ap_association 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ("TransformDiaSourceCatalogConnections", 

23 "TransformDiaSourceCatalogConfig", 

24 "TransformDiaSourceCatalogTask", 

25 "UnpackApdbFlags") 

26 

27import numpy as np 

28import os 

29import yaml 

30 

31from lsst.daf.base import DateTime 

32import lsst.pex.config as pexConfig 

33import lsst.pipe.base as pipeBase 

34import lsst.pipe.base.connectionTypes as connTypes 

35from lsst.meas.base import DetectorVisitIdGeneratorConfig 

36from lsst.pipe.tasks.postprocess import TransformCatalogBaseTask, TransformCatalogBaseConfig 

37from lsst.pipe.tasks.functors import Column 

38from lsst.utils.timer import timeMethod 

39 

40 

41class TransformDiaSourceCatalogConnections(pipeBase.PipelineTaskConnections, 

42 dimensions=("instrument", "visit", "detector"), 

43 defaultTemplates={"coaddName": "deep", "fakesType": ""}): 

44 diaSourceSchema = connTypes.InitInput( 

45 doc="Schema for DIASource catalog output by ImageDifference.", 

46 storageClass="SourceCatalog", 

47 name="{fakesType}{coaddName}Diff_diaSrc_schema", 

48 ) 

49 diaSourceCat = connTypes.Input( 

50 doc="Catalog of DiaSources produced during image differencing.", 

51 name="{fakesType}{coaddName}Diff_diaSrc", 

52 storageClass="SourceCatalog", 

53 dimensions=("instrument", "visit", "detector"), 

54 ) 

55 diffIm = connTypes.Input( 

56 doc="Difference image on which the DiaSources were detected.", 

57 name="{fakesType}{coaddName}Diff_differenceExp", 

58 storageClass="ExposureF", 

59 dimensions=("instrument", "visit", "detector"), 

60 ) 

61 diaSourceTable = connTypes.Output( 

62 doc=".", 

63 name="{fakesType}{coaddName}Diff_diaSrcTable", 

64 storageClass="DataFrame", 

65 dimensions=("instrument", "visit", "detector"), 

66 ) 

67 

68 

69class TransformDiaSourceCatalogConfig(TransformCatalogBaseConfig, 

70 pipelineConnections=TransformDiaSourceCatalogConnections): 

71 flagMap = pexConfig.Field( 

72 dtype=str, 

73 doc="Yaml file specifying SciencePipelines flag fields to bit packs.", 

74 default=os.path.join("${AP_ASSOCIATION_DIR}", 

75 "data", 

76 "association-flag-map.yaml"), 

77 ) 

78 flagRenameMap = pexConfig.Field( 

79 dtype=str, 

80 doc="Yaml file specifying specifying rules to rename flag names", 

81 default=os.path.join("${AP_ASSOCIATION_DIR}", 

82 "data", 

83 "flag-rename-rules.yaml"), 

84 ) 

85 doRemoveSkySources = pexConfig.Field( 

86 dtype=bool, 

87 default=False, 

88 doc="Input DiaSource catalog contains SkySources that should be " 

89 "removed before storing the output DiaSource catalog." 

90 ) 

91 doPackFlags = pexConfig.Field( 

92 dtype=bool, 

93 default=True, 

94 doc="Do pack the flags into one integer column named 'flags'." 

95 "If False, instead produce one boolean column per flag." 

96 ) 

97 idGenerator = DetectorVisitIdGeneratorConfig.make_field() 

98 

99 def setDefaults(self): 

100 super().setDefaults() 

101 self.functorFile = os.path.join("${AP_ASSOCIATION_DIR}", 

102 "data", 

103 "DiaSource.yaml") 

104 

105 

106class TransformDiaSourceCatalogTask(TransformCatalogBaseTask): 

107 """Transform a DiaSource catalog by calibrating and renaming columns to 

108 produce a table ready to insert into the Apdb. 

109 

110 Parameters 

111 ---------- 

112 initInputs : `dict` 

113 Must contain ``diaSourceSchema`` as the schema for the input catalog. 

114 """ 

115 ConfigClass = TransformDiaSourceCatalogConfig 

116 _DefaultName = "transformDiaSourceCatalog" 

117 # Needed to create a valid TransformCatalogBaseTask, but unused 

118 inputDataset = "deepDiff_diaSrc" 

119 outputDataset = "deepDiff_diaSrcTable" 

120 

121 def __init__(self, initInputs, **kwargs): 

122 super().__init__(**kwargs) 

123 self.funcs = self.getFunctors() 

124 self.inputSchema = initInputs['diaSourceSchema'].schema 

125 self._create_bit_pack_mappings() 

126 

127 if not self.config.doPackFlags: 

128 # get the flag rename rules 

129 with open(os.path.expandvars(self.config.flagRenameMap)) as yaml_stream: 

130 self.rename_rules = list(yaml.safe_load_all(yaml_stream)) 

131 

132 def _create_bit_pack_mappings(self): 

133 """Setup all flag bit packings. 

134 """ 

135 self.bit_pack_columns = [] 

136 flag_map_file = os.path.expandvars(self.config.flagMap) 

137 with open(flag_map_file) as yaml_stream: 

138 table_list = list(yaml.safe_load_all(yaml_stream)) 

139 for table in table_list: 

140 if table['tableName'] == 'DiaSource': 

141 self.bit_pack_columns = table['columns'] 

142 break 

143 

144 # Test that all flags requested are present in the input schemas. 

145 # Output schemas are flexible, however if names are not specified in 

146 # the Apdb schema, flag columns will not be persisted. 

147 for outputFlag in self.bit_pack_columns: 

148 bitList = outputFlag['bitList'] 

149 for bit in bitList: 

150 try: 

151 self.inputSchema.find(bit['name']) 

152 except KeyError: 

153 raise KeyError( 

154 "Requested column %s not found in input DiaSource " 

155 "schema. Please check that the requested input " 

156 "column exists." % bit['name']) 

157 

158 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

159 inputs = butlerQC.get(inputRefs) 

160 idGenerator = self.config.idGenerator.apply(butlerQC.quantum.dataId) 

161 inputs["ccdVisitId"] = idGenerator.catalog_id 

162 inputs["band"] = butlerQC.quantum.dataId["band"] 

163 

164 outputs = self.run(**inputs) 

165 

166 butlerQC.put(outputs, outputRefs) 

167 

168 @timeMethod 

169 def run(self, 

170 diaSourceCat, 

171 diffIm, 

172 band, 

173 ccdVisitId): 

174 """Convert input catalog to ParquetTable/Pandas and run functors. 

175 

176 Additionally, add new columns for stripping information from the 

177 exposure and into the DiaSource catalog. 

178 

179 Parameters 

180 ---------- 

181 diaSourceCat : `lsst.afw.table.SourceCatalog` 

182 Catalog of sources measured on the difference image. 

183 diffIm : `lsst.afw.image.Exposure` 

184 Result of subtracting template and science images. 

185 band : `str` 

186 Filter band of the science image. 

187 ccdVisitId : `int` 

188 Identifier for this detector+visit. 

189 

190 Returns 

191 ------- 

192 results : `lsst.pipe.base.Struct` 

193 Results struct with components. 

194 

195 - ``diaSourceTable`` : Catalog of DiaSources with calibrated values 

196 and renamed columns. 

197 (`lsst.pipe.tasks.ParquetTable` or `pandas.DataFrame`) 

198 """ 

199 self.log.info( 

200 "Transforming/standardizing the DiaSource table ccdVisitId: %i", 

201 ccdVisitId) 

202 

203 diaSourceDf = diaSourceCat.asAstropy().to_pandas() 

204 if self.config.doRemoveSkySources: 

205 diaSourceDf = diaSourceDf[~diaSourceDf["sky_source"]] 

206 diaSourceCat = diaSourceCat[~diaSourceCat["sky_source"]] 

207 

208 diaSourceDf["snr"] = getSignificance(diaSourceCat) 

209 diaSourceDf["bboxSize"] = self.computeBBoxSizes(diaSourceCat) 

210 diaSourceDf["ccdVisitId"] = ccdVisitId 

211 diaSourceDf["filterName"] = band 

212 diaSourceDf["midPointTai"] = diffIm.getInfo().getVisitInfo().getDate().get(system=DateTime.MJD) 

213 diaSourceDf["diaObjectId"] = 0 

214 diaSourceDf["ssObjectId"] = 0 

215 

216 if self.config.doPackFlags: 

217 # either bitpack the flags 

218 self.bitPackFlags(diaSourceDf) 

219 else: 

220 # or add the individual flag functors 

221 self.addUnpackedFlagFunctors() 

222 # and remove the packed flag functor 

223 if 'flags' in self.funcs.funcDict: 

224 del self.funcs.funcDict['flags'] 

225 

226 df = self.transform(band, 

227 diaSourceDf, 

228 self.funcs, 

229 dataId=None).df 

230 

231 return pipeBase.Struct( 

232 diaSourceTable=df, 

233 ) 

234 

235 def addUnpackedFlagFunctors(self): 

236 """Add Column functor for each of the flags to the internal functor 

237 dictionary. 

238 """ 

239 for flag in self.bit_pack_columns[0]['bitList']: 

240 flagName = flag['name'] 

241 targetName = self.funcs.renameCol(flagName, self.rename_rules[0]['flag_rename_rules']) 

242 self.funcs.update({targetName: Column(flagName)}) 

243 

244 def computeBBoxSizes(self, inputCatalog): 

245 """Compute the size of a square bbox that fully contains the detection 

246 footprint. 

247 

248 Parameters 

249 ---------- 

250 inputCatalog : `lsst.afw.table.SourceCatalog` 

251 Catalog containing detected footprints. 

252 

253 Returns 

254 ------- 

255 outputBBoxSizes : `np.ndarray`, (N,) 

256 Array of bbox sizes. 

257 """ 

258 # Schema validation requires that this field is int. 

259 outputBBoxSizes = np.empty(len(inputCatalog), dtype=int) 

260 for i, record in enumerate(inputCatalog): 

261 footprintBBox = record.getFootprint().getBBox() 

262 # Compute twice the size of the largest dimension of the footprint 

263 # bounding box. This is the largest footprint we should need to cover 

264 # the complete DiaSource assuming the centroid is within the bounding 

265 # box. 

266 maxSize = 2 * np.max([footprintBBox.getWidth(), 

267 footprintBBox.getHeight()]) 

268 recX = record.getCentroid().x 

269 recY = record.getCentroid().y 

270 bboxSize = int( 

271 np.ceil(2 * np.max(np.fabs([footprintBBox.maxX - recX, 

272 footprintBBox.minX - recX, 

273 footprintBBox.maxY - recY, 

274 footprintBBox.minY - recY])))) 

275 if bboxSize > maxSize: 

276 bboxSize = maxSize 

277 outputBBoxSizes[i] = bboxSize 

278 

279 return outputBBoxSizes 

280 

281 def bitPackFlags(self, df): 

282 """Pack requested flag columns in inputRecord into single columns in 

283 outputRecord. 

284 

285 Parameters 

286 ---------- 

287 df : `pandas.DataFrame` 

288 DataFrame to read bits from and pack them into. 

289 """ 

290 for outputFlag in self.bit_pack_columns: 

291 bitList = outputFlag['bitList'] 

292 value = np.zeros(len(df), dtype=np.uint64) 

293 for bit in bitList: 

294 # Hard type the bit arrays. 

295 value += (df[bit['name']]*2**bit['bit']).to_numpy().astype(np.uint64) 

296 df[outputFlag['columnName']] = value 

297 

298 

299class UnpackApdbFlags: 

300 """Class for unpacking bits from integer flag fields stored in the Apdb. 

301 

302 Attributes 

303 ---------- 

304 flag_map_file : `str` 

305 Absolute or relative path to a yaml file specifiying mappings of flags 

306 to integer bits. 

307 table_name : `str` 

308 Name of the Apdb table the integer bit data are coming from. 

309 """ 

310 

311 def __init__(self, flag_map_file, table_name): 

312 self.bit_pack_columns = [] 

313 flag_map_file = os.path.expandvars(flag_map_file) 

314 with open(flag_map_file) as yaml_stream: 

315 table_list = list(yaml.safe_load_all(yaml_stream)) 

316 for table in table_list: 

317 if table['tableName'] == table_name: 

318 self.bit_pack_columns = table['columns'] 

319 break 

320 

321 self.output_flag_columns = {} 

322 

323 for column in self.bit_pack_columns: 

324 names = [] 

325 for bit in column["bitList"]: 

326 names.append((bit["name"], bool)) 

327 self.output_flag_columns[column["columnName"]] = names 

328 

329 def unpack(self, input_flag_values, flag_name): 

330 """Determine individual boolean flags from an input array of unsigned 

331 ints. 

332 

333 Parameters 

334 ---------- 

335 input_flag_values : array-like of type uint 

336 Array of integer flags to unpack. 

337 flag_name : `str` 

338 Apdb column name of integer flags to unpack. Names of packed int 

339 flags are given by the flag_map_file. 

340 

341 Returns 

342 ------- 

343 output_flags : `numpy.ndarray` 

344 Numpy named tuple of booleans. 

345 """ 

346 bit_names_types = self.output_flag_columns[flag_name] 

347 output_flags = np.zeros(len(input_flag_values), dtype=bit_names_types) 

348 

349 for bit_idx, (bit_name, dtypes) in enumerate(bit_names_types): 

350 masked_bits = np.bitwise_and(input_flag_values, 2**bit_idx) 

351 output_flags[bit_name] = masked_bits 

352 

353 return output_flags 

354 

355 def flagExists(self, flagName, columnName='flags'): 

356 """Check if named flag is in the bitpacked flag set. 

357 

358 Parameters: 

359 ---------- 

360 flagName : `str` 

361 Flag name to search for. 

362 columnName : `str`, optional 

363 Name of bitpacked flag column to search in. 

364 

365 Returns 

366 ------- 

367 flagExists : `bool` 

368 `True` if `flagName` is present in `columnName`. 

369 

370 Raises 

371 ------ 

372 ValueError 

373 Raised if `columnName` is not defined. 

374 """ 

375 if columnName not in self.output_flag_columns: 

376 raise ValueError(f'column {columnName} not in flag map: {self.output_flag_columns}') 

377 

378 return flagName in [c[0] for c in self.output_flag_columns[columnName]] 

379 

380 def makeFlagBitMask(self, flagNames, columnName='flags'): 

381 """Return a bitmask corresponding to the supplied flag names. 

382 

383 Parameters: 

384 ---------- 

385 flagNames : `list` [`str`] 

386 Flag names to include in the bitmask. 

387 columnName : `str`, optional 

388 Name of bitpacked flag column. 

389 

390 Returns 

391 ------- 

392 bitmask : `np.unit64` 

393 Bitmask corresponding to the supplied flag names given the loaded configuration. 

394 

395 Raises 

396 ------ 

397 ValueError 

398 Raised if a flag in `flagName` is not included in `columnName`. 

399 """ 

400 bitmask = np.uint64(0) 

401 

402 for flag in flagNames: 

403 if not self.flagExists(flag, columnName=columnName): 

404 raise ValueError(f"flag '{flag}' not included in '{columnName}' flag column") 

405 

406 for outputFlag in self.bit_pack_columns: 

407 if outputFlag['columnName'] == columnName: 

408 bitList = outputFlag['bitList'] 

409 for bit in bitList: 

410 if bit['name'] in flagNames: 

411 bitmask += np.uint64(2**bit['bit']) 

412 

413 return bitmask 

414 

415 

416def getSignificance(catalog): 

417 """Return the significance value of the first peak in each source 

418 footprint, or NaN for peaks without a significance field. 

419 

420 Parameters 

421 ---------- 

422 catalog : `lsst.afw.table.SourceCatalog` 

423 Catalog to process. 

424 

425 Returns 

426 ------- 

427 significance : `np.ndarray`, (N,) 

428 Signficance of the first peak in each source footprint. 

429 """ 

430 result = np.full(len(catalog), np.nan) 

431 for i, record in enumerate(catalog): 

432 peaks = record.getFootprint().peaks 

433 if "significance" in peaks.schema: 

434 result[i] = peaks[0]["significance"] 

435 return result