Coverage for python/lsst/pipe/tasks/extended_psf.py: 21%

204 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-11 10:41 +0000

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22"""Read preprocessed bright stars and stack to build an extended PSF model.""" 

23 

24__all__ = [ 

25 "FocalPlaneRegionExtendedPsf", 

26 "ExtendedPsf", 

27 "StackBrightStarsConfig", 

28 "StackBrightStarsTask", 

29 "MeasureExtendedPsfConfig", 

30 "MeasureExtendedPsfTask", 

31] 

32 

33from dataclasses import dataclass 

34from typing import List 

35 

36from lsst.afw.fits import Fits, readMetadata 

37from lsst.afw.image import ImageF, MaskedImageF, MaskX 

38from lsst.afw.math import StatisticsControl, statisticsStack, stringToStatisticsProperty 

39from lsst.daf.base import PropertyList 

40from lsst.geom import Extent2I 

41from lsst.pex.config import ChoiceField, Config, ConfigurableField, DictField, Field, ListField 

42from lsst.pipe.base import PipelineTaskConfig, PipelineTaskConnections, Struct, Task 

43from lsst.pipe.base.connectionTypes import Input, Output 

44from lsst.pipe.tasks.assembleCoadd import AssembleCoaddTask 

45 

46 

47@dataclass 

48class FocalPlaneRegionExtendedPsf: 

49 """Single extended PSF over a focal plane region. 

50 

51 The focal plane region is defined through a list of detectors. 

52 

53 Parameters 

54 ---------- 

55 extended_psf_image : `lsst.afw.image.MaskedImageF` 

56 Image of the extended PSF model. 

57 detector_list : `list` [`int`] 

58 List of detector IDs that define the focal plane region over which this 

59 extended PSF model has been built (and can be used). 

60 """ 

61 

62 extended_psf_image: MaskedImageF 

63 detector_list: List[int] 

64 

65 

66class ExtendedPsf: 

67 """Extended PSF model. 

68 

69 Each instance may contain a default extended PSF, a set of extended PSFs 

70 that correspond to different focal plane regions, or both. At this time, 

71 focal plane regions are always defined as a subset of detectors. 

72 

73 Parameters 

74 ---------- 

75 default_extended_psf : `lsst.afw.image.MaskedImageF` 

76 Extended PSF model to be used as default (or only) extended PSF model. 

77 """ 

78 

79 def __init__(self, default_extended_psf=None): 

80 self.default_extended_psf = default_extended_psf 

81 self.focal_plane_regions = {} 

82 self.detectors_focal_plane_regions = {} 

83 

84 def add_regional_extended_psf(self, extended_psf_image, region_name, detector_list): 

85 """Add a new focal plane region, along wit hits extended PSF, to the 

86 ExtendedPsf instance. 

87 

88 Parameters 

89 ---------- 

90 extended_psf_image : `lsst.afw.image.MaskedImageF` 

91 Extended PSF model for the region. 

92 region_name : `str` 

93 Name of the focal plane region. Will be converted to all-uppercase. 

94 detector_list : `list` [`int`] 

95 List of IDs for the detectors that define the focal plane region. 

96 """ 

97 region_name = region_name.upper() 

98 if region_name in self.focal_plane_regions: 

99 raise ValueError(f"Region name {region_name} is already used by this ExtendedPsf instance.") 

100 self.focal_plane_regions[region_name] = FocalPlaneRegionExtendedPsf( 

101 extended_psf_image=extended_psf_image, detector_list=detector_list 

102 ) 

103 for det in detector_list: 

104 self.detectors_focal_plane_regions[det] = region_name 

105 

106 def __call__(self, detector=None): 

107 """Return the appropriate extended PSF. 

108 

109 If the instance contains no extended PSF defined over focal plane 

110 regions, the default extended PSF will be returned regardless of 

111 whether a detector ID was passed as argument. 

112 

113 Parameters 

114 ---------- 

115 detector : `int`, optional 

116 Detector ID. If focal plane region PSFs are defined, is used to 

117 determine which model to return. 

118 

119 Returns 

120 ------- 

121 extendedPsfImage : `lsst.afw.image.MaskedImageF` 

122 The extended PSF model. If this instance contains extended PSFs 

123 defined over focal plane regions, the extended PSF model for the 

124 region that contains ``detector`` is returned. If not, the default 

125 extended PSF is returned. 

126 """ 

127 if detector is None: 

128 if self.default_extended_psf is None: 

129 raise ValueError("No default extended PSF available; please provide detector number.") 

130 return self.default_extended_psf 

131 elif not self.focal_plane_regions: 

132 return self.default_extended_psf 

133 return self.get_regional_extended_psf(detector=detector) 

134 

135 def __len__(self): 

136 """Returns the number of extended PSF models present in the instance. 

137 

138 Note that if the instance contains both a default model and a set of 

139 focal plane region models, the length of the instance will be the 

140 number of regional models, plus one (the default). This is true even 

141 in the case where the default model is one of the focal plane 

142 region-specific models. 

143 """ 

144 n_regions = len(self.focal_plane_regions) 

145 if self.default_extended_psf is not None: 

146 n_regions += 1 

147 return n_regions 

148 

149 def get_regional_extended_psf(self, region_name=None, detector=None): 

150 """Returns the extended PSF for a focal plane region. 

151 

152 The region can be identified either by name, or through a detector ID. 

153 

154 Parameters 

155 ---------- 

156 region_name : `str` or `None`, optional 

157 Name of the region for which the extended PSF should be retrieved. 

158 Ignored if ``detector`` is provided. Must be provided if 

159 ``detector`` is None. 

160 detector : `int` or `None`, optional 

161 If provided, returns the extended PSF for the focal plane region 

162 that includes this detector. 

163 

164 Raises 

165 ------ 

166 ValueError 

167 Raised if neither ``detector`` nor ``regionName`` is provided. 

168 """ 

169 if detector is None: 

170 if region_name is None: 

171 raise ValueError("One of either a regionName or a detector number must be provided.") 

172 return self.focal_plane_regions[region_name].extended_psf_image 

173 return self.focal_plane_regions[self.detectors_focal_plane_regions[detector]].extended_psf_image 

174 

175 def write_fits(self, filename): 

176 """Write this object to a file. 

177 

178 Parameters 

179 ---------- 

180 filename : `str` 

181 Name of file to write. 

182 """ 

183 # Create primary HDU with global metadata. 

184 metadata = PropertyList() 

185 metadata["HAS_DEFAULT"] = self.default_extended_psf is not None 

186 if self.focal_plane_regions: 

187 metadata["HAS_REGIONS"] = True 

188 metadata["REGION_NAMES"] = list(self.focal_plane_regions.keys()) 

189 for region, e_psf_region in self.focal_plane_regions.items(): 

190 metadata[region] = e_psf_region.detector_list 

191 else: 

192 metadata["HAS_REGIONS"] = False 

193 fits_primary = Fits(filename, "w") 

194 fits_primary.createEmpty() 

195 fits_primary.writeMetadata(metadata) 

196 fits_primary.closeFile() 

197 # Write default extended PSF. 

198 if self.default_extended_psf is not None: 

199 default_hdu_metadata = PropertyList() 

200 default_hdu_metadata.update({"REGION": "DEFAULT", "EXTNAME": "IMAGE"}) 

201 self.default_extended_psf.image.writeFits(filename, metadata=default_hdu_metadata, mode="a") 

202 default_hdu_metadata.update({"REGION": "DEFAULT", "EXTNAME": "MASK"}) 

203 self.default_extended_psf.mask.writeFits(filename, metadata=default_hdu_metadata, mode="a") 

204 # Write extended PSF for each focal plane region. 

205 for j, (region, e_psf_region) in enumerate(self.focal_plane_regions.items()): 

206 metadata = PropertyList() 

207 metadata.update({"REGION": region, "EXTNAME": "IMAGE"}) 

208 e_psf_region.extended_psf_image.image.writeFits(filename, metadata=metadata, mode="a") 

209 metadata.update({"REGION": region, "EXTNAME": "MASK"}) 

210 e_psf_region.extended_psf_image.mask.writeFits(filename, metadata=metadata, mode="a") 

211 

212 def writeFits(self, filename): 

213 """Alias for ``write_fits``; for compatibility with the Butler.""" 

214 self.write_fits(filename) 

215 

216 @classmethod 

217 def read_fits(cls, filename): 

218 """Build an instance of this class from a file. 

219 

220 Parameters 

221 ---------- 

222 filename : `str` 

223 Name of the file to read. 

224 """ 

225 # Extract info from metadata. 

226 global_metadata = readMetadata(filename, hdu=0) 

227 has_default = global_metadata.getBool("HAS_DEFAULT") 

228 if global_metadata.getBool("HAS_REGIONS"): 

229 focal_plane_region_names = global_metadata.getArray("REGION_NAMES") 

230 else: 

231 focal_plane_region_names = [] 

232 f = Fits(filename, "r") 

233 n_extensions = f.countHdus() 

234 extended_psf_parts = {} 

235 for j in range(1, n_extensions): 

236 md = readMetadata(filename, hdu=j) 

237 if has_default and md["REGION"] == "DEFAULT": 

238 if md["EXTNAME"] == "IMAGE": 

239 default_image = ImageF(filename, hdu=j) 

240 elif md["EXTNAME"] == "MASK": 

241 default_mask = MaskX(filename, hdu=j) 

242 continue 

243 if md["EXTNAME"] == "IMAGE": 

244 extended_psf_part = ImageF(filename, hdu=j) 

245 elif md["EXTNAME"] == "MASK": 

246 extended_psf_part = MaskX(filename, hdu=j) 

247 extended_psf_parts.setdefault(md["REGION"], {})[md["EXTNAME"].lower()] = extended_psf_part 

248 # Handle default if present. 

249 if has_default: 

250 extended_psf = cls(MaskedImageF(default_image, default_mask)) 

251 else: 

252 extended_psf = cls() 

253 # Ensure we recovered an extended PSF for all focal plane regions. 

254 if len(extended_psf_parts) != len(focal_plane_region_names): 

255 raise ValueError( 

256 f"Number of per-region extended PSFs read ({len(extended_psf_parts)}) does not " 

257 "match with the number of regions recorded in the metadata " 

258 f"({len(focal_plane_region_names)})." 

259 ) 

260 # Generate extended PSF regions mappings. 

261 for r_name in focal_plane_region_names: 

262 extended_psf_image = MaskedImageF(**extended_psf_parts[r_name]) 

263 detector_list = global_metadata.getArray(r_name) 

264 extended_psf.add_regional_extended_psf(extended_psf_image, r_name, detector_list) 

265 # Instantiate ExtendedPsf. 

266 return extended_psf 

267 

268 @classmethod 

269 def readFits(cls, filename): 

270 """Alias for ``readFits``; exists for compatibility with the Butler.""" 

271 return cls.read_fits(filename) 

272 

273 

274class StackBrightStarsConfig(Config): 

275 """Configuration parameters for StackBrightStarsTask.""" 

276 

277 subregion_size = ListField( 

278 dtype=int, 

279 doc="Size, in pixels, of the subregions over which the stacking will be " "iteratively performed.", 

280 default=(100, 100), 

281 ) 

282 stacking_statistic = ChoiceField( 

283 dtype=str, 

284 doc="Type of statistic to use for stacking.", 

285 default="MEANCLIP", 

286 allowed={ 

287 "MEAN": "mean", 

288 "MEDIAN": "median", 

289 "MEANCLIP": "clipped mean", 

290 }, 

291 ) 

292 num_sigma_clip = Field( 

293 dtype=float, 

294 doc="Sigma for outlier rejection; ignored if stacking_statistic != 'MEANCLIP'.", 

295 default=4, 

296 ) 

297 num_iter = Field( 

298 dtype=int, 

299 doc="Number of iterations of outlier rejection; ignored if stackingStatistic != 'MEANCLIP'.", 

300 default=3, 

301 ) 

302 bad_mask_planes = ListField( 

303 dtype=str, 

304 doc="Mask planes that define pixels to be excluded from the stacking of the bright star stamps.", 

305 default=("BAD", "CR", "CROSSTALK", "EDGE", "NO_DATA", "SAT", "SUSPECT", "UNMASKEDNAN"), 

306 ) 

307 do_mag_cut = Field( 

308 dtype=bool, 

309 doc="Apply magnitude cut before stacking?", 

310 default=False, 

311 ) 

312 mag_limit = Field( 

313 dtype=float, 

314 doc="Magnitude limit, in Gaia G; all stars brighter than this value will be stacked", 

315 default=18, 

316 ) 

317 

318 

319class StackBrightStarsTask(Task): 

320 """Stack bright stars together to build an extended PSF model.""" 

321 

322 ConfigClass = StackBrightStarsConfig 

323 _DefaultName = "stack_bright_stars" 

324 

325 def _set_up_stacking(self, example_stamp): 

326 """Configure stacking statistic and control from config fields.""" 

327 stats_control = StatisticsControl() 

328 stats_control.setNumSigmaClip(self.config.num_sigma_clip) 

329 stats_control.setNumIter(self.config.num_iter) 

330 if bad_masks := self.config.bad_mask_planes: 

331 and_mask = example_stamp.mask.getPlaneBitMask(bad_masks[0]) 

332 for bm in bad_masks[1:]: 

333 and_mask = and_mask | example_stamp.mask.getPlaneBitMask(bm) 

334 stats_control.setAndMask(and_mask) 

335 stats_flags = stringToStatisticsProperty(self.config.stacking_statistic) 

336 return stats_control, stats_flags 

337 

338 def run(self, bss_ref_list, region_name=None): 

339 """Read input bright star stamps and stack them together. 

340 

341 The stacking is done iteratively over smaller areas of the final model 

342 image to allow for a great number of bright star stamps to be used. 

343 

344 Parameters 

345 ---------- 

346 bss_ref_list : `list` of 

347 `lsst.daf.butler._deferredDatasetHandle.DeferredDatasetHandle` 

348 List of available bright star stamps data references. 

349 region_name : `str`, optional 

350 Name of the focal plane region, if applicable. Only used for 

351 logging purposes, when running over multiple such regions 

352 (typically from `MeasureExtendedPsfTask`) 

353 """ 

354 if region_name: 

355 region_message = f" for region '{region_name}'." 

356 else: 

357 region_message = "." 

358 self.log.info( 

359 "Building extended PSF from stamps extracted from %d detector images%s", 

360 len(bss_ref_list), 

361 region_message, 

362 ) 

363 # read in example set of full stamps 

364 example_bss = bss_ref_list[0].get() 

365 example_stamp = example_bss[0].stamp_im 

366 # create model image 

367 ext_psf = MaskedImageF(example_stamp.getBBox()) 

368 # divide model image into smaller subregions 

369 subregion_size = Extent2I(*self.config.subregion_size) 

370 sub_bboxes = AssembleCoaddTask._subBBoxIter(ext_psf.getBBox(), subregion_size) 

371 # compute approximate number of subregions 

372 n_subregions = ((ext_psf.getDimensions()[0]) // (subregion_size[0] + 1)) * ( 

373 (ext_psf.getDimensions()[1]) // (subregion_size[1] + 1) 

374 ) 

375 self.log.info( 

376 "Stacking performed iteratively over approximately %d smaller areas of the final model image.", 

377 n_subregions, 

378 ) 

379 # set up stacking statistic 

380 stats_control, stats_flags = self._set_up_stacking(example_stamp) 

381 # perform stacking 

382 for jbbox, bbox in enumerate(sub_bboxes): 

383 all_stars = None 

384 for bss_ref in bss_ref_list: 

385 read_stars = bss_ref.get(parameters={"bbox": bbox}) 

386 if self.config.do_mag_cut: 

387 read_stars = read_stars.selectByMag(magMax=self.config.mag_limit) 

388 if all_stars: 

389 all_stars.extend(read_stars) 

390 else: 

391 all_stars = read_stars 

392 # TODO: DM-27371 add weights to bright stars for stacking 

393 coadd_sub_bbox = statisticsStack(all_stars.getMaskedImages(), stats_flags, stats_control) 

394 ext_psf.assign(coadd_sub_bbox, bbox) 

395 return ext_psf 

396 

397 

398class MeasureExtendedPsfConnections(PipelineTaskConnections, dimensions=("band", "instrument")): 

399 input_brightStarStamps = Input( 

400 doc="Input list of bright star collections to be stacked.", 

401 name="brightStarStamps", 

402 storageClass="BrightStarStamps", 

403 dimensions=("visit", "detector"), 

404 deferLoad=True, 

405 multiple=True, 

406 ) 

407 extended_psf = Output( 

408 doc="Extended PSF model built by stacking bright stars.", 

409 name="extended_psf", 

410 storageClass="ExtendedPsf", 

411 dimensions=("band",), 

412 ) 

413 

414 

415class MeasureExtendedPsfConfig(PipelineTaskConfig, pipelineConnections=MeasureExtendedPsfConnections): 

416 """Configuration parameters for MeasureExtendedPsfTask.""" 

417 

418 stack_bright_stars = ConfigurableField( 

419 target=StackBrightStarsTask, 

420 doc="Stack selected bright stars", 

421 ) 

422 detectors_focal_plane_regions = DictField( 

423 keytype=int, 

424 itemtype=str, 

425 doc=( 

426 "Mapping from detector IDs to focal plane region names. If empty, a constant extended PSF model " 

427 "is built from all selected bright stars." 

428 ), 

429 default={}, 

430 ) 

431 

432 

433class MeasureExtendedPsfTask(Task): 

434 """Build and save extended PSF model. 

435 

436 The model is built by stacking bright star stamps, extracted and 

437 preprocessed by 

438 `lsst.pipe.tasks.processBrightStars.ProcessBrightStarsTask`. 

439 

440 If a mapping from detector IDs to focal plane regions is provided, a 

441 different extended PSF model will be built for each focal plane region. If 

442 not, a single constant extended PSF model is built with all available data. 

443 """ 

444 

445 ConfigClass = MeasureExtendedPsfConfig 

446 _DefaultName = "measureExtendedPsf" 

447 

448 def __init__(self, initInputs=None, *args, **kwargs): 

449 super().__init__(*args, **kwargs) 

450 self.makeSubtask("stack_bright_stars") 

451 self.focal_plane_regions = { 

452 region: [] for region in set(self.config.detectors_focal_plane_regions.values()) 

453 } 

454 for det, region in self.config.detectors_focal_plane_regions.items(): 

455 self.focal_plane_regions[region].append(det) 

456 # make no assumption on what detector IDs should be, but if we come 

457 # across one where there are processed bright stars, but no 

458 # corresponding focal plane region, make sure we keep track of 

459 # it (eg to raise a warning only once) 

460 self.regionless_dets = [] 

461 

462 def select_detector_refs(self, ref_list): 

463 """Split available sets of bright star stamps according to focal plane 

464 regions. 

465 

466 Parameters 

467 ---------- 

468 ref_list : `list` of 

469 `lsst.daf.butler._deferredDatasetHandle.DeferredDatasetHandle` 

470 List of available bright star stamps data references. 

471 """ 

472 region_ref_list = {region: [] for region in self.focal_plane_regions.keys()} 

473 for dataset_handle in ref_list: 

474 det_id = dataset_handle.ref.dataId["detector"] 

475 if det_id in self.regionless_dets: 

476 continue 

477 try: 

478 region_name = self.config.detectors_focal_plane_regions[det_id] 

479 except KeyError: 

480 self.log.warning( 

481 "Bright stars were available for detector %d, but it was missing from the %s config " 

482 "field, so they will not be used to build any of the extended PSF models.", 

483 det_id, 

484 "'detectors_focal_plane_regions'", 

485 ) 

486 self.regionless_dets.append(det_id) 

487 continue 

488 region_ref_list[region_name].append(dataset_handle) 

489 return region_ref_list 

490 

491 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

492 input_data = butlerQC.get(inputRefs) 

493 bss_ref_list = input_data["input_brightStarStamps"] 

494 # Handle default case of a single region with empty detector list 

495 if not self.config.detectors_focal_plane_regions: 

496 self.log.info( 

497 "No detector groups were provided to MeasureExtendedPsfTask; computing a single, " 

498 "constant extended PSF model over all available observations." 

499 ) 

500 output_e_psf = ExtendedPsf(self.stack_bright_stars.run(bss_ref_list)) 

501 else: 

502 output_e_psf = ExtendedPsf() 

503 region_ref_list = self.select_detector_refs(bss_ref_list) 

504 for region_name, ref_list in region_ref_list.items(): 

505 if not ref_list: 

506 # no valid references found 

507 self.log.warning( 

508 "No valid brightStarStamps reference found for region '%s'; skipping it.", 

509 region_name, 

510 ) 

511 continue 

512 ext_psf = self.stack_bright_stars.run(ref_list, region_name) 

513 output_e_psf.add_regional_extended_psf( 

514 ext_psf, region_name, self.focal_plane_regions[region_name] 

515 ) 

516 output = Struct(extended_psf=output_e_psf) 

517 butlerQC.put(output, outputRefs)