Coverage for python/lsst/meas/extensions/scarlet/scarletDeblendTask.py: 16%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

391 statements  

1# This file is part of meas_extensions_scarlet. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import logging 

23import numpy as np 

24import scarlet 

25from scarlet.psf import ImagePSF, GaussianPSF 

26from scarlet import Blend, Frame, Observation 

27from scarlet.renderer import ConvolutionRenderer 

28from scarlet.initialization import init_all_sources 

29 

30import lsst.pex.config as pexConfig 

31from lsst.pex.exceptions import InvalidParameterError 

32import lsst.pipe.base as pipeBase 

33from lsst.geom import Point2I, Box2I, Point2D 

34import lsst.afw.geom.ellipses as afwEll 

35import lsst.afw.image as afwImage 

36import lsst.afw.detection as afwDet 

37import lsst.afw.table as afwTable 

38from lsst.utils.timer import timeMethod 

39 

40from .source import modelToHeavy 

41 

42# Scarlet and proxmin have a different definition of log levels than the stack, 

43# so even "warnings" occur far more often than we would like. 

44# So for now we only display scarlet and proxmin errors, as all other 

45# scarlet outputs would be considered "TRACE" by our standards. 

46scarletLogger = logging.getLogger("scarlet") 

47scarletLogger.setLevel(logging.ERROR) 

48proxminLogger = logging.getLogger("proxmin") 

49proxminLogger.setLevel(logging.ERROR) 

50 

51__all__ = ["deblend", "ScarletDeblendConfig", "ScarletDeblendTask"] 

52 

53logger = logging.getLogger(__name__) 

54 

55 

56class IncompleteDataError(Exception): 

57 """The PSF could not be computed due to incomplete data 

58 """ 

59 pass 

60 

61 

62class ScarletGradientError(Exception): 

63 """An error occurred during optimization 

64 

65 This error occurs when the optimizer encounters 

66 a NaN value while calculating the gradient. 

67 """ 

68 def __init__(self, iterations, sources): 

69 self.iterations = iterations 

70 self.sources = sources 

71 msg = ("ScalarGradientError in iteration {0}. " 

72 "NaN values introduced in sources {1}") 

73 self.message = msg.format(iterations, sources) 

74 

75 def __str__(self): 

76 return self.message 

77 

78 

79def _checkBlendConvergence(blend, f_rel): 

80 """Check whether or not a blend has converged 

81 """ 

82 deltaLoss = np.abs(blend.loss[-2] - blend.loss[-1]) 

83 convergence = f_rel * np.abs(blend.loss[-1]) 

84 return deltaLoss < convergence 

85 

86 

87def _getPsfFwhm(psf): 

88 """Calculate the FWHM of the `psf` 

89 """ 

90 return psf.computeShape().getDeterminantRadius() * 2.35 

91 

92 

93def _computePsfImage(self, position=None): 

94 """Get a multiband PSF image 

95 The PSF Kernel Image is computed for each band 

96 and combined into a (filter, y, x) array and stored 

97 as `self._psfImage`. 

98 The result is not cached, so if the same PSF is expected 

99 to be used multiple times it is a good idea to store the 

100 result in another variable. 

101 Note: this is a temporary fix during the deblender sprint. 

102 In the future this function will replace the current method 

103 in `afw.MultibandExposure.computePsfImage` (DM-19789). 

104 Parameters 

105 ---------- 

106 position : `Point2D` or `tuple` 

107 Coordinates to evaluate the PSF. If `position` is `None` 

108 then `Psf.getAveragePosition()` is used. 

109 Returns 

110 ------- 

111 self._psfImage: array 

112 The multiband PSF image. 

113 """ 

114 psfs = [] 

115 # Make the coordinates into a Point2D (if necessary) 

116 if not isinstance(position, Point2D) and position is not None: 

117 position = Point2D(position[0], position[1]) 

118 

119 for bidx, single in enumerate(self.singles): 

120 try: 

121 if position is None: 

122 psf = single.getPsf().computeImage() 

123 psfs.append(psf) 

124 else: 

125 psf = single.getPsf().computeKernelImage(position) 

126 psfs.append(psf) 

127 except InvalidParameterError: 

128 # This band failed to compute the PSF due to incomplete data 

129 # at that location. This is unlikely to be a problem for Rubin, 

130 # however the edges of some HSC COSMOS fields contain incomplete 

131 # data in some bands, so we track this error to distinguish it 

132 # from unknown errors. 

133 msg = "Failed to compute PSF at {} in band {}" 

134 raise IncompleteDataError(msg.format(position, self.filters[bidx])) 

135 

136 left = np.min([psf.getBBox().getMinX() for psf in psfs]) 

137 bottom = np.min([psf.getBBox().getMinY() for psf in psfs]) 

138 right = np.max([psf.getBBox().getMaxX() for psf in psfs]) 

139 top = np.max([psf.getBBox().getMaxY() for psf in psfs]) 

140 bbox = Box2I(Point2I(left, bottom), Point2I(right, top)) 

141 psfs = [afwImage.utils.projectImage(psf, bbox) for psf in psfs] 

142 psfImage = afwImage.MultibandImage.fromImages(self.filters, psfs) 

143 return psfImage 

144 

145 

146def getFootprintMask(footprint, mExposure): 

147 """Mask pixels outside the footprint 

148 

149 Parameters 

150 ---------- 

151 mExposure : `lsst.image.MultibandExposure` 

152 - The multiband exposure containing the image, 

153 mask, and variance data 

154 footprint : `lsst.detection.Footprint` 

155 - The footprint of the parent to deblend 

156 

157 Returns 

158 ------- 

159 footprintMask : array 

160 Boolean array with pixels not in the footprint set to one. 

161 """ 

162 bbox = footprint.getBBox() 

163 fpMask = afwImage.Mask(bbox) 

164 footprint.spans.setMask(fpMask, 1) 

165 fpMask = ~fpMask.getArray().astype(bool) 

166 return fpMask 

167 

168 

169def isPseudoSource(source, pseudoColumns): 

170 """Check if a source is a pseudo source. 

171 

172 This is mostly for skipping sky objects, 

173 but any other column can also be added to disable 

174 deblending on a parent or individual source when 

175 set to `True`. 

176 

177 Parameters 

178 ---------- 

179 source : `lsst.afw.table.source.source.SourceRecord` 

180 The source to check for the pseudo bit. 

181 pseudoColumns : `list` of `str` 

182 A list of columns to check for pseudo sources. 

183 """ 

184 isPseudo = False 

185 for col in pseudoColumns: 

186 try: 

187 isPseudo |= source[col] 

188 except KeyError: 

189 pass 

190 return isPseudo 

191 

192 

193def deblend(mExposure, footprint, config): 

194 """Deblend a parent footprint 

195 

196 Parameters 

197 ---------- 

198 mExposure : `lsst.image.MultibandExposure` 

199 - The multiband exposure containing the image, 

200 mask, and variance data 

201 footprint : `lsst.detection.Footprint` 

202 - The footprint of the parent to deblend 

203 config : `ScarletDeblendConfig` 

204 - Configuration of the deblending task 

205 

206 Returns 

207 ------- 

208 blend : `scarlet.Blend` 

209 The scarlet blend class that contains all of the information 

210 about the parameters and results from scarlet 

211 skipped : `list` of `int` 

212 The indices of any children that failed to initialize 

213 and were skipped. 

214 spectrumInit : `bool` 

215 Whether or not all of the sources were initialized by jointly 

216 fitting their SED's. This provides a better initialization 

217 but created memory issues when a blend is too large or 

218 contains too many sources. 

219 """ 

220 # Extract coordinates from each MultiColorPeak 

221 bbox = footprint.getBBox() 

222 

223 # Create the data array from the masked images 

224 images = mExposure.image[:, bbox].array 

225 

226 # Use the inverse variance as the weights 

227 if config.useWeights: 

228 weights = 1/mExposure.variance[:, bbox].array 

229 else: 

230 weights = np.ones_like(images) 

231 badPixels = mExposure.mask.getPlaneBitMask(config.badMask) 

232 mask = mExposure.mask[:, bbox].array & badPixels 

233 weights[mask > 0] = 0 

234 

235 # Mask out the pixels outside the footprint 

236 mask = getFootprintMask(footprint, mExposure) 

237 weights *= ~mask 

238 

239 psfs = _computePsfImage(mExposure, footprint.getCentroid()).array.astype(np.float32) 

240 psfs = ImagePSF(psfs) 

241 model_psf = GaussianPSF(sigma=(config.modelPsfSigma,)*len(mExposure.filters)) 

242 

243 frame = Frame(images.shape, psf=model_psf, channels=mExposure.filters) 

244 observation = Observation(images, psf=psfs, weights=weights, channels=mExposure.filters) 

245 if config.convolutionType == "fft": 

246 observation.match(frame) 

247 elif config.convolutionType == "real": 

248 renderer = ConvolutionRenderer(observation, frame, convolution_type="real") 

249 observation.match(frame, renderer=renderer) 

250 else: 

251 raise ValueError("Unrecognized convolution type {}".format(config.convolutionType)) 

252 

253 assert(config.sourceModel in ["single", "double", "compact", "fit"]) 

254 

255 # Set the appropriate number of components 

256 if config.sourceModel == "single": 

257 maxComponents = 1 

258 elif config.sourceModel == "double": 

259 maxComponents = 2 

260 elif config.sourceModel == "compact": 

261 maxComponents = 0 

262 elif config.sourceModel == "point": 

263 raise NotImplementedError("Point source photometry is currently not implemented") 

264 elif config.sourceModel == "fit": 

265 # It is likely in the future that there will be some heuristic 

266 # used to determine what type of model to use for each source, 

267 # but that has not yet been implemented (see DM-22551) 

268 raise NotImplementedError("sourceModel 'fit' has not been implemented yet") 

269 

270 # Convert the centers to pixel coordinates 

271 xmin = bbox.getMinX() 

272 ymin = bbox.getMinY() 

273 centers = [ 

274 np.array([peak.getIy() - ymin, peak.getIx() - xmin], dtype=int) 

275 for peak in footprint.peaks 

276 if not isPseudoSource(peak, config.pseudoColumns) 

277 ] 

278 

279 # Choose whether or not to use the improved spectral initialization 

280 if config.setSpectra: 

281 if config.maxSpectrumCutoff <= 0: 

282 spectrumInit = True 

283 else: 

284 spectrumInit = len(centers) * bbox.getArea() < config.maxSpectrumCutoff 

285 else: 

286 spectrumInit = False 

287 

288 # Only deblend sources that can be initialized 

289 sources, skipped = init_all_sources( 

290 frame=frame, 

291 centers=centers, 

292 observations=observation, 

293 thresh=config.morphThresh, 

294 max_components=maxComponents, 

295 min_snr=config.minSNR, 

296 shifting=False, 

297 fallback=config.fallback, 

298 silent=config.catchFailures, 

299 set_spectra=spectrumInit, 

300 ) 

301 

302 # Attach the peak to all of the initialized sources 

303 srcIndex = 0 

304 for k, center in enumerate(centers): 

305 if k not in skipped: 

306 # This is just to make sure that there isn't a coding bug 

307 assert np.all(sources[srcIndex].center == center) 

308 # Store the record for the peak with the appropriate source 

309 sources[srcIndex].detectedPeak = footprint.peaks[k] 

310 srcIndex += 1 

311 

312 # Create the blend and attempt to optimize it 

313 blend = Blend(sources, observation) 

314 try: 

315 blend.fit(max_iter=config.maxIter, e_rel=config.relativeError) 

316 except ArithmeticError: 

317 # This occurs when a gradient update produces a NaN value 

318 # This is usually due to a source initialized with a 

319 # negative SED or no flux, often because the peak 

320 # is a noise fluctuation in one band and not a real source. 

321 iterations = len(blend.loss) 

322 failedSources = [] 

323 for k, src in enumerate(sources): 

324 if np.any(~np.isfinite(src.get_model())): 

325 failedSources.append(k) 

326 raise ScarletGradientError(iterations, failedSources) 

327 

328 return blend, skipped, spectrumInit 

329 

330 

331class ScarletDeblendConfig(pexConfig.Config): 

332 """MultibandDeblendConfig 

333 

334 Configuration for the multiband deblender. 

335 The parameters are organized by the parameter types, which are 

336 - Stopping Criteria: Used to determine if the fit has converged 

337 - Position Fitting Criteria: Used to fit the positions of the peaks 

338 - Constraints: Used to apply constraints to the peaks and their components 

339 - Other: Parameters that don't fit into the above categories 

340 """ 

341 # Stopping Criteria 

342 maxIter = pexConfig.Field(dtype=int, default=300, 

343 doc=("Maximum number of iterations to deblend a single parent")) 

344 relativeError = pexConfig.Field(dtype=float, default=1e-4, 

345 doc=("Change in the loss function between" 

346 "iterations to exit fitter")) 

347 

348 # Constraints 

349 morphThresh = pexConfig.Field(dtype=float, default=1, 

350 doc="Fraction of background RMS a pixel must have" 

351 "to be included in the initial morphology") 

352 # Other scarlet paremeters 

353 useWeights = pexConfig.Field( 

354 dtype=bool, default=True, 

355 doc=("Whether or not use use inverse variance weighting." 

356 "If `useWeights` is `False` then flat weights are used")) 

357 modelPsfSize = pexConfig.Field( 

358 dtype=int, default=11, 

359 doc="Model PSF side length in pixels") 

360 modelPsfSigma = pexConfig.Field( 

361 dtype=float, default=0.8, 

362 doc="Define sigma for the model frame PSF") 

363 minSNR = pexConfig.Field( 

364 dtype=float, default=50, 

365 doc="Minimum Signal to noise to accept the source." 

366 "Sources with lower flux will be initialized with the PSF but updated " 

367 "like an ordinary ExtendedSource (known in scarlet as a `CompactSource`).") 

368 saveTemplates = pexConfig.Field( 

369 dtype=bool, default=True, 

370 doc="Whether or not to save the SEDs and templates") 

371 processSingles = pexConfig.Field( 

372 dtype=bool, default=True, 

373 doc="Whether or not to process isolated sources in the deblender") 

374 convolutionType = pexConfig.Field( 

375 dtype=str, default="fft", 

376 doc="Type of convolution to render the model to the observations.\n" 

377 "- 'fft': perform convolutions in Fourier space\n" 

378 "- 'real': peform convolutions in real space.") 

379 sourceModel = pexConfig.Field( 

380 dtype=str, default="double", 

381 doc=("How to determine which model to use for sources, from\n" 

382 "- 'single': use a single component for all sources\n" 

383 "- 'double': use a bulge disk model for all sources\n" 

384 "- 'compact': use a single component model, initialzed with a point source morphology, " 

385 " for all sources\n" 

386 "- 'point': use a point-source model for all sources\n" 

387 "- 'fit: use a PSF fitting model to determine the number of components (not yet implemented)") 

388 ) 

389 setSpectra = pexConfig.Field( 

390 dtype=bool, default=True, 

391 doc="Whether or not to solve for the best-fit spectra during initialization. " 

392 "This makes initialization slightly longer, as it requires a convolution " 

393 "to set the optimal spectra, but results in a much better initial log-likelihood " 

394 "and reduced total runtime, with convergence in fewer iterations." 

395 "This option is only used when " 

396 "peaks*area < `maxSpectrumCutoff` will use the improved initialization.") 

397 

398 # Mask-plane restrictions 

399 badMask = pexConfig.ListField( 

400 dtype=str, default=["BAD", "CR", "NO_DATA", "SAT", "SUSPECT", "EDGE"], 

401 doc="Whether or not to process isolated sources in the deblender") 

402 statsMask = pexConfig.ListField(dtype=str, default=["SAT", "INTRP", "NO_DATA"], 

403 doc="Mask planes to ignore when performing statistics") 

404 maskLimits = pexConfig.DictField( 

405 keytype=str, 

406 itemtype=float, 

407 default={}, 

408 doc=("Mask planes with the corresponding limit on the fraction of masked pixels. " 

409 "Sources violating this limit will not be deblended."), 

410 ) 

411 

412 # Size restrictions 

413 maxNumberOfPeaks = pexConfig.Field( 

414 dtype=int, default=0, 

415 doc=("Only deblend the brightest maxNumberOfPeaks peaks in the parent" 

416 " (<= 0: unlimited)")) 

417 maxFootprintArea = pexConfig.Field( 

418 dtype=int, default=1000000, 

419 doc=("Maximum area for footprints before they are ignored as large; " 

420 "non-positive means no threshold applied")) 

421 maxFootprintSize = pexConfig.Field( 

422 dtype=int, default=0, 

423 doc=("Maximum linear dimension for footprints before they are ignored " 

424 "as large; non-positive means no threshold applied")) 

425 minFootprintAxisRatio = pexConfig.Field( 

426 dtype=float, default=0.0, 

427 doc=("Minimum axis ratio for footprints before they are ignored " 

428 "as large; non-positive means no threshold applied")) 

429 maxSpectrumCutoff = pexConfig.Field( 

430 dtype=int, default=1000000, 

431 doc=("Maximum number of pixels * number of sources in a blend. " 

432 "This is different than `maxFootprintArea` because this isn't " 

433 "the footprint area but the area of the bounding box that " 

434 "contains the footprint, and is also multiplied by the number of" 

435 "sources in the footprint. This prevents large skinny blends with " 

436 "a high density of sources from running out of memory. " 

437 "If `maxSpectrumCutoff == -1` then there is no cutoff.") 

438 ) 

439 

440 # Failure modes 

441 fallback = pexConfig.Field( 

442 dtype=bool, default=True, 

443 doc="Whether or not to fallback to a smaller number of components if a source does not initialize" 

444 ) 

445 notDeblendedMask = pexConfig.Field( 

446 dtype=str, default="NOT_DEBLENDED", optional=True, 

447 doc="Mask name for footprints not deblended, or None") 

448 catchFailures = pexConfig.Field( 

449 dtype=bool, default=True, 

450 doc=("If True, catch exceptions thrown by the deblender, log them, " 

451 "and set a flag on the parent, instead of letting them propagate up")) 

452 

453 # Other options 

454 columnInheritance = pexConfig.DictField( 

455 keytype=str, itemtype=str, default={ 

456 "deblend_nChild": "deblend_parentNChild", 

457 "deblend_nPeaks": "deblend_parentNPeaks", 

458 "deblend_spectrumInitFlag": "deblend_spectrumInitFlag", 

459 "deblend_blendConvergenceFailedFlag": "deblend_blendConvergenceFailedFlag", 

460 }, 

461 doc="Columns to pass from the parent to the child. " 

462 "The key is the name of the column for the parent record, " 

463 "the value is the name of the column to use for the child." 

464 ) 

465 pseudoColumns = pexConfig.ListField( 

466 dtype=str, default=['merge_peak_sky', 'sky_source'], 

467 doc="Names of flags which should never be deblended." 

468 ) 

469 

470 # Logging option(s) 

471 loggingInterval = pexConfig.Field( 

472 dtype=int, default=600, 

473 doc="Interval (in seconds) to log messages (at VERBOSE level) while deblending sources." 

474 ) 

475 # Testing options 

476 # Some obs packages and ci packages run the full pipeline on a small 

477 # subset of data to test that the pipeline is functioning properly. 

478 # This is not meant as scientific validation, so it can be useful 

479 # to only run on a small subset of the data that is large enough to 

480 # test the desired pipeline features but not so long that the deblender 

481 # is the tall pole in terms of execution times. 

482 useCiLimits = pexConfig.Field( 

483 dtype=bool, default=False, 

484 doc="Limit the number of sources deblended for CI to prevent long build times") 

485 ciDeblendChildRange = pexConfig.ListField( 

486 dtype=int, default=[5, 10], 

487 doc="Only deblend parent Footprints with a number of peaks in the (inclusive) range indicated." 

488 "If `useCiLimits==False` then this parameter is ignored.") 

489 ciNumParentsToDeblend = pexConfig.Field( 

490 dtype=int, default=10, 

491 doc="Only use the first `ciNumParentsToDeblend` parent footprints with a total peak count " 

492 "within `ciDebledChildRange`. " 

493 "If `useCiLimits==False` then this parameter is ignored.") 

494 

495 

496class ScarletDeblendTask(pipeBase.Task): 

497 """ScarletDeblendTask 

498 

499 Split blended sources into individual sources. 

500 

501 This task has no return value; it only modifies the SourceCatalog in-place. 

502 """ 

503 ConfigClass = ScarletDeblendConfig 

504 _DefaultName = "scarletDeblend" 

505 

506 def __init__(self, schema, peakSchema=None, **kwargs): 

507 """Create the task, adding necessary fields to the given schema. 

508 

509 Parameters 

510 ---------- 

511 schema : `lsst.afw.table.schema.schema.Schema` 

512 Schema object for measurement fields; will be modified in-place. 

513 peakSchema : `lsst.afw.table.schema.schema.Schema` 

514 Schema of Footprint Peaks that will be passed to the deblender. 

515 Any fields beyond the PeakTable minimal schema will be transferred 

516 to the main source Schema. If None, no fields will be transferred 

517 from the Peaks. 

518 filters : list of str 

519 Names of the filters used for the eposures. This is needed to store 

520 the SED as a field 

521 **kwargs 

522 Passed to Task.__init__. 

523 """ 

524 pipeBase.Task.__init__(self, **kwargs) 

525 

526 peakMinimalSchema = afwDet.PeakTable.makeMinimalSchema() 

527 if peakSchema is None: 

528 # In this case, the peakSchemaMapper will transfer nothing, but 

529 # we'll still have one 

530 # to simplify downstream code 

531 self.peakSchemaMapper = afwTable.SchemaMapper(peakMinimalSchema, schema) 

532 else: 

533 self.peakSchemaMapper = afwTable.SchemaMapper(peakSchema, schema) 

534 for item in peakSchema: 

535 if item.key not in peakMinimalSchema: 

536 self.peakSchemaMapper.addMapping(item.key, item.field) 

537 # Because SchemaMapper makes a copy of the output schema 

538 # you give its ctor, it isn't updating this Schema in 

539 # place. That's probably a design flaw, but in the 

540 # meantime, we'll keep that schema in sync with the 

541 # peakSchemaMapper.getOutputSchema() manually, by adding 

542 # the same fields to both. 

543 schema.addField(item.field) 

544 assert schema == self.peakSchemaMapper.getOutputSchema(), "Logic bug mapping schemas" 

545 self._addSchemaKeys(schema) 

546 self.schema = schema 

547 self.toCopyFromParent = [item.key for item in self.schema 

548 if item.field.getName().startswith("merge_footprint")] 

549 

550 def _addSchemaKeys(self, schema): 

551 """Add deblender specific keys to the schema 

552 """ 

553 self.runtimeKey = schema.addField('deblend_runtime', type=np.float32, doc='runtime in ms') 

554 

555 self.iterKey = schema.addField('deblend_iterations', type=np.int32, doc='iterations to converge') 

556 

557 self.nChildKey = schema.addField('deblend_nChild', type=np.int32, 

558 doc='Number of children this object has (defaults to 0)') 

559 self.psfKey = schema.addField('deblend_deblendedAsPsf', type='Flag', 

560 doc='Deblender thought this source looked like a PSF') 

561 self.tooManyPeaksKey = schema.addField('deblend_tooManyPeaks', type='Flag', 

562 doc='Source had too many peaks; ' 

563 'only the brightest were included') 

564 self.tooBigKey = schema.addField('deblend_parentTooBig', type='Flag', 

565 doc='Parent footprint covered too many pixels') 

566 self.maskedKey = schema.addField('deblend_masked', type='Flag', 

567 doc='Parent footprint was predominantly masked') 

568 self.sedNotConvergedKey = schema.addField('deblend_sedConvergenceFailed', type='Flag', 

569 doc='scarlet sed optimization did not converge before' 

570 'config.maxIter') 

571 self.morphNotConvergedKey = schema.addField('deblend_morphConvergenceFailed', type='Flag', 

572 doc='scarlet morph optimization did not converge before' 

573 'config.maxIter') 

574 self.blendConvergenceFailedFlagKey = schema.addField('deblend_blendConvergenceFailedFlag', 

575 type='Flag', 

576 doc='at least one source in the blend' 

577 'failed to converge') 

578 self.edgePixelsKey = schema.addField('deblend_edgePixels', type='Flag', 

579 doc='Source had flux on the edge of the parent footprint') 

580 self.deblendFailedKey = schema.addField('deblend_failed', type='Flag', 

581 doc="Deblending failed on source") 

582 self.deblendErrorKey = schema.addField('deblend_error', type="String", size=25, 

583 doc='Name of error if the blend failed') 

584 self.deblendSkippedKey = schema.addField('deblend_skipped', type='Flag', 

585 doc="Deblender skipped this source") 

586 self.peakCenter = afwTable.Point2IKey.addFields(schema, name="deblend_peak_center", 

587 doc="Center used to apply constraints in scarlet", 

588 unit="pixel") 

589 self.peakIdKey = schema.addField("deblend_peakId", type=np.int32, 

590 doc="ID of the peak in the parent footprint. " 

591 "This is not unique, but the combination of 'parent'" 

592 "and 'peakId' should be for all child sources. " 

593 "Top level blends with no parents have 'peakId=0'") 

594 self.modelCenterFlux = schema.addField('deblend_peak_instFlux', type=float, units='count', 

595 doc="The instFlux at the peak position of deblended mode") 

596 self.modelTypeKey = schema.addField("deblend_modelType", type="String", size=25, 

597 doc="The type of model used, for example " 

598 "MultiExtendedSource, SingleExtendedSource, PointSource") 

599 self.nPeaksKey = schema.addField("deblend_nPeaks", type=np.int32, 

600 doc="Number of initial peaks in the blend. " 

601 "This includes peaks that may have been culled " 

602 "during deblending or failed to deblend") 

603 self.parentNPeaksKey = schema.addField("deblend_parentNPeaks", type=np.int32, 

604 doc="deblend_nPeaks from this records parent.") 

605 self.parentNChildKey = schema.addField("deblend_parentNChild", type=np.int32, 

606 doc="deblend_nChild from this records parent.") 

607 self.scarletFluxKey = schema.addField("deblend_scarletFlux", type=np.float32, 

608 doc="Flux measurement from scarlet") 

609 self.scarletLogLKey = schema.addField("deblend_logL", type=np.float32, 

610 doc="Final logL, used to identify regressions in scarlet.") 

611 self.scarletSpectrumInitKey = schema.addField("deblend_spectrumInitFlag", type='Flag', 

612 doc="True when scarlet initializes sources " 

613 "in the blend with a more accurate spectrum. " 

614 "The algorithm uses a lot of memory, " 

615 "so large dense blends will use " 

616 "a less accurate initialization.") 

617 

618 # self.log.trace('Added keys to schema: %s', ", ".join(str(x) for x in 

619 # (self.nChildKey, self.tooManyPeaksKey, self.tooBigKey)) 

620 # ) 

621 

622 @timeMethod 

623 def run(self, mExposure, mergedSources): 

624 """Get the psf from each exposure and then run deblend(). 

625 

626 Parameters 

627 ---------- 

628 mExposure : `MultibandExposure` 

629 The exposures should be co-added images of the same 

630 shape and region of the sky. 

631 mergedSources : `SourceCatalog` 

632 The merged `SourceCatalog` that contains parent footprints 

633 to (potentially) deblend. 

634 

635 Returns 

636 ------- 

637 templateCatalogs: dict 

638 Keys are the names of the filters and the values are 

639 `lsst.afw.table.source.source.SourceCatalog`'s. 

640 These are catalogs with heavy footprints that are the templates 

641 created by the multiband templates. 

642 """ 

643 return self.deblend(mExposure, mergedSources) 

644 

645 @timeMethod 

646 def deblend(self, mExposure, catalog): 

647 """Deblend a data cube of multiband images 

648 

649 Parameters 

650 ---------- 

651 mExposure : `MultibandExposure` 

652 The exposures should be co-added images of the same 

653 shape and region of the sky. 

654 catalog : `SourceCatalog` 

655 The merged `SourceCatalog` that contains parent footprints 

656 to (potentially) deblend. The new deblended sources are 

657 appended to this catalog in place. 

658 

659 Returns 

660 ------- 

661 catalogs : `dict` or `None` 

662 Keys are the names of the filters and the values are 

663 `lsst.afw.table.source.source.SourceCatalog`'s. 

664 These are catalogs with heavy footprints that are the templates 

665 created by the multiband templates. 

666 """ 

667 import time 

668 

669 # Cull footprints if required by ci 

670 if self.config.useCiLimits: 

671 self.log.info("Using CI catalog limits, the original number of sources to deblend was %d.", 

672 len(catalog)) 

673 # Select parents with a number of children in the range 

674 # config.ciDeblendChildRange 

675 minChildren, maxChildren = self.config.ciDeblendChildRange 

676 nPeaks = np.array([len(src.getFootprint().peaks) for src in catalog]) 

677 childrenInRange = np.where((nPeaks >= minChildren) & (nPeaks <= maxChildren))[0] 

678 if len(childrenInRange) < self.config.ciNumParentsToDeblend: 

679 raise ValueError("Fewer than ciNumParentsToDeblend children were contained in the range " 

680 "indicated by ciDeblendChildRange. Adjust this range to include more " 

681 "parents.") 

682 # Keep all of the isolated parents and the first 

683 # `ciNumParentsToDeblend` children 

684 parents = nPeaks == 1 

685 children = np.zeros((len(catalog),), dtype=bool) 

686 children[childrenInRange[:self.config.ciNumParentsToDeblend]] = True 

687 catalog = catalog[parents | children] 

688 # We need to update the IdFactory, otherwise the the source ids 

689 # will not be sequential 

690 idFactory = catalog.getIdFactory() 

691 maxId = np.max(catalog["id"]) 

692 idFactory.notify(maxId) 

693 

694 filters = mExposure.filters 

695 self.log.info("Deblending %d sources in %d exposure bands", len(catalog), len(mExposure)) 

696 nextLogTime = time.time() + self.config.loggingInterval 

697 

698 # Add the NOT_DEBLENDED mask to the mask plane in each band 

699 if self.config.notDeblendedMask: 

700 for mask in mExposure.mask: 

701 mask.addMaskPlane(self.config.notDeblendedMask) 

702 

703 nParents = len(catalog) 

704 nDeblendedParents = 0 

705 skippedParents = [] 

706 multibandColumns = { 

707 "heavies": [], 

708 "fluxes": [], 

709 "centerFluxes": [], 

710 } 

711 for parentIndex in range(nParents): 

712 parent = catalog[parentIndex] 

713 foot = parent.getFootprint() 

714 bbox = foot.getBBox() 

715 peaks = foot.getPeaks() 

716 

717 # Since we use the first peak for the parent object, we should 

718 # propagate its flags to the parent source. 

719 parent.assign(peaks[0], self.peakSchemaMapper) 

720 

721 # Skip isolated sources unless processSingles is turned on. 

722 # Note: this does not flag isolated sources as skipped or 

723 # set the NOT_DEBLENDED mask in the exposure, 

724 # since these aren't really a skipped blends. 

725 # We also skip pseudo sources, like sky objects, which 

726 # are intended to be skipped 

727 if ((len(peaks) < 2 and not self.config.processSingles) 

728 or isPseudoSource(parent, self.config.pseudoColumns)): 

729 self._updateParentRecord( 

730 parent=parent, 

731 nPeaks=len(peaks), 

732 nChild=0, 

733 runtime=np.nan, 

734 iterations=0, 

735 logL=np.nan, 

736 spectrumInit=False, 

737 converged=False, 

738 ) 

739 continue 

740 

741 # Block of conditions for skipping a parent with multiple children 

742 skipKey = None 

743 if self._isLargeFootprint(foot): 

744 # The footprint is above the maximum footprint size limit 

745 skipKey = self.tooBigKey 

746 skipMessage = f"Parent {parent.getId()}: skipping large footprint" 

747 elif self._isMasked(foot, mExposure): 

748 # The footprint exceeds the maximum number of masked pixels 

749 skipKey = self.maskedKey 

750 skipMessage = f"Parent {parent.getId()}: skipping masked footprint" 

751 elif self.config.maxNumberOfPeaks > 0 and len(peaks) > self.config.maxNumberOfPeaks: 

752 # Unlike meas_deblender, in scarlet we skip the entire blend 

753 # if the number of peaks exceeds max peaks, since neglecting 

754 # to model any peaks often results in catastrophic failure 

755 # of scarlet to generate models for the brighter sources. 

756 skipKey = self.tooManyPeaksKey 

757 skipMessage = f"Parent {parent.getId()}: Too many peaks, skipping blend" 

758 if skipKey is not None: 

759 self._skipParent( 

760 parent=parent, 

761 skipKey=skipKey, 

762 logMessage=skipMessage, 

763 ) 

764 skippedParents.append(parentIndex) 

765 continue 

766 

767 nDeblendedParents += 1 

768 self.log.trace("Parent %d: deblending %d peaks", parent.getId(), len(peaks)) 

769 # Run the deblender 

770 blendError = None 

771 try: 

772 t0 = time.time() 

773 # Build the parameter lists with the same ordering 

774 blend, skipped, spectrumInit = deblend(mExposure, foot, self.config) 

775 tf = time.time() 

776 runtime = (tf-t0)*1000 

777 converged = _checkBlendConvergence(blend, self.config.relativeError) 

778 scarletSources = [src for src in blend.sources] 

779 nChild = len(scarletSources) 

780 # Re-insert place holders for skipped sources 

781 # to propagate them in the catalog so 

782 # that the peaks stay consistent 

783 for k in skipped: 

784 scarletSources.insert(k, None) 

785 # Catch all errors and filter out the ones that we know about 

786 except Exception as e: 

787 blendError = type(e).__name__ 

788 if isinstance(e, ScarletGradientError): 

789 parent.set(self.iterKey, e.iterations) 

790 elif not isinstance(e, IncompleteDataError): 

791 blendError = "UnknownError" 

792 if self.config.catchFailures: 

793 # Make it easy to find UnknownErrors in the log file 

794 self.log.warn("UnknownError") 

795 import traceback 

796 traceback.print_exc() 

797 else: 

798 raise 

799 

800 self._skipParent( 

801 parent=parent, 

802 skipKey=self.deblendFailedKey, 

803 logMessage=f"Unable to deblend source {parent.getId}: {blendError}", 

804 ) 

805 parent.set(self.deblendErrorKey, blendError) 

806 skippedParents.append(parentIndex) 

807 continue 

808 

809 # Update the parent record with the deblending results 

810 logL = blend.loss[-1]-blend.observations[0].log_norm 

811 self._updateParentRecord( 

812 parent=parent, 

813 nPeaks=len(peaks), 

814 nChild=nChild, 

815 runtime=runtime, 

816 iterations=len(blend.loss), 

817 logL=logL, 

818 spectrumInit=spectrumInit, 

819 converged=converged, 

820 ) 

821 

822 # Add each deblended source to the catalog 

823 for k, scarletSource in enumerate(scarletSources): 

824 # Skip any sources with no flux or that scarlet skipped because 

825 # it could not initialize 

826 if k in skipped: 

827 # No need to propagate anything 

828 continue 

829 parent.set(self.deblendSkippedKey, False) 

830 mHeavy = modelToHeavy(scarletSource, mExposure, blend, xy0=bbox.getMin()) 

831 multibandColumns["heavies"].append(mHeavy) 

832 flux = scarlet.measure.flux(scarletSource) 

833 multibandColumns["fluxes"].append({ 

834 filters[fidx]: _flux 

835 for fidx, _flux in enumerate(flux) 

836 }) 

837 centerFlux = self._getCenterFlux(mHeavy, scarletSource, xy0=bbox.getMin()) 

838 multibandColumns["centerFluxes"].append(centerFlux) 

839 

840 # Add all fields except the HeavyFootprint to the 

841 # source record 

842 self._addChild( 

843 parent=parent, 

844 mHeavy=mHeavy, 

845 catalog=catalog, 

846 scarletSource=scarletSource, 

847 ) 

848 

849 # Log a message if it has been a while since the last log. 

850 if (currentTime := time.time()) > nextLogTime: 

851 nextLogTime = currentTime + self.config.loggingInterval 

852 self.log.verbose("Deblended %d parent sources out of %d", parentIndex + 1, nParents) 

853 

854 # Clear the cached values in scarlet to clear out memory 

855 scarlet.cache.Cache._cache = {} 

856 

857 # Make sure that the number of new sources matches the number of 

858 # entries in each of the band dependent columns. 

859 # This should never trigger and is just a sanity check. 

860 nChildren = len(catalog) - nParents 

861 if np.any([len(meas) != nChildren for meas in multibandColumns.values()]): 

862 msg = f"Added {len(catalog)-nParents} new sources, but have " 

863 msg += ", ".join([ 

864 f"{len(value)} {key}" 

865 for key, value in multibandColumns 

866 ]) 

867 raise RuntimeError(msg) 

868 # Make a copy of the catlog in each band and update the footprints 

869 catalogs = {} 

870 for f in filters: 

871 _catalog = afwTable.SourceCatalog(catalog.table.clone()) 

872 _catalog.extend(catalog, deep=True) 

873 # Update the footprints and columns that are different 

874 # for each filter 

875 for sourceIndex, source in enumerate(_catalog[nParents:]): 

876 source.setFootprint(multibandColumns["heavies"][sourceIndex][f]) 

877 source.set(self.scarletFluxKey, multibandColumns["fluxes"][sourceIndex][f]) 

878 source.set(self.modelCenterFlux, multibandColumns["centerFluxes"][sourceIndex][f]) 

879 catalogs[f] = _catalog 

880 

881 # Update the mExposure mask with the footprint of skipped parents 

882 if self.config.notDeblendedMask: 

883 for mask in mExposure.mask: 

884 for parentIndex in skippedParents: 

885 fp = _catalog[parentIndex].getFootprint() 

886 fp.spans.setMask(mask, mask.getPlaneBitMask(self.config.notDeblendedMask)) 

887 

888 self.log.info("Deblender results: of %d parent sources, %d were deblended, " 

889 "creating %d children, for a total of %d sources", 

890 nParents, nDeblendedParents, nChildren, len(catalog)) 

891 return catalogs 

892 

893 def _isLargeFootprint(self, footprint): 

894 """Returns whether a Footprint is large 

895 

896 'Large' is defined by thresholds on the area, size and axis ratio. 

897 These may be disabled independently by configuring them to be 

898 non-positive. 

899 

900 This is principally intended to get rid of satellite streaks, which the 

901 deblender or other downstream processing can have trouble dealing with 

902 (e.g., multiple large HeavyFootprints can chew up memory). 

903 """ 

904 if self.config.maxFootprintArea > 0 and footprint.getArea() > self.config.maxFootprintArea: 

905 return True 

906 if self.config.maxFootprintSize > 0: 

907 bbox = footprint.getBBox() 

908 if max(bbox.getWidth(), bbox.getHeight()) > self.config.maxFootprintSize: 

909 return True 

910 if self.config.minFootprintAxisRatio > 0: 

911 axes = afwEll.Axes(footprint.getShape()) 

912 if axes.getB() < self.config.minFootprintAxisRatio*axes.getA(): 

913 return True 

914 return False 

915 

916 def _isMasked(self, footprint, mExposure): 

917 """Returns whether the footprint violates the mask limits""" 

918 bbox = footprint.getBBox() 

919 mask = np.bitwise_or.reduce(mExposure.mask[:, bbox].array, axis=0) 

920 size = float(footprint.getArea()) 

921 for maskName, limit in self.config.maskLimits.items(): 

922 maskVal = mExposure.mask.getPlaneBitMask(maskName) 

923 _mask = afwImage.MaskX(mask & maskVal, xy0=bbox.getMin()) 

924 unmaskedSpan = footprint.spans.intersectNot(_mask) # spanset of unmasked pixels 

925 if (size - unmaskedSpan.getArea())/size > limit: 

926 return True 

927 return False 

928 

929 def _skipParent(self, parent, skipKey, logMessage): 

930 """Update a parent record that is not being deblended. 

931 

932 This is a fairly trivial function but is implemented to ensure 

933 that a skipped parent updates the appropriate columns 

934 consistently, and always has a flag to mark the reason that 

935 it is being skipped. 

936 

937 Parameters 

938 ---------- 

939 parent : `lsst.afw.table.source.source.SourceRecord` 

940 The parent record to flag as skipped. 

941 skipKey : `bool` 

942 The name of the flag to mark the reason for skipping. 

943 logMessage : `str` 

944 The message to display in a log.trace when a source 

945 is skipped. 

946 """ 

947 if logMessage is not None: 

948 self.log.trace(logMessage) 

949 self._updateParentRecord( 

950 parent=parent, 

951 nPeaks=len(parent.getFootprint().peaks), 

952 nChild=0, 

953 runtime=np.nan, 

954 iterations=0, 

955 logL=np.nan, 

956 spectrumInit=False, 

957 converged=False, 

958 ) 

959 

960 # Mark the source as skipped by the deblender and 

961 # flag the reason why. 

962 parent.set(self.deblendSkippedKey, True) 

963 parent.set(skipKey, True) 

964 

965 def _updateParentRecord(self, parent, nPeaks, nChild, 

966 runtime, iterations, logL, spectrumInit, converged): 

967 """Update a parent record in all of the single band catalogs. 

968 

969 Ensure that all locations that update a parent record, 

970 whether it is skipped or updated after deblending, 

971 update all of the appropriate columns. 

972 

973 Parameters 

974 ---------- 

975 parent : `lsst.afw.table.source.source.SourceRecord` 

976 The parent record to update. 

977 nPeaks : `int` 

978 Number of peaks in the parent footprint. 

979 nChild : `int` 

980 Number of children deblended from the parent. 

981 This may differ from `nPeaks` if some of the peaks 

982 were culled and have no deblended model. 

983 runtime : `float` 

984 Total runtime for deblending. 

985 iterations : `int` 

986 Total number of iterations in scarlet before convergence. 

987 logL : `float` 

988 Final log likelihood of the blend. 

989 spectrumInit : `bool` 

990 True when scarlet used `set_spectra` to initialize all 

991 sources with better initial intensities. 

992 converged : `bool` 

993 True when the optimizer reached convergence before 

994 reaching the maximum number of iterations. 

995 """ 

996 parent.set(self.nPeaksKey, nPeaks) 

997 parent.set(self.nChildKey, nChild) 

998 parent.set(self.runtimeKey, runtime) 

999 parent.set(self.iterKey, iterations) 

1000 parent.set(self.scarletLogLKey, logL) 

1001 parent.set(self.scarletSpectrumInitKey, spectrumInit) 

1002 parent.set(self.blendConvergenceFailedFlagKey, converged) 

1003 

1004 def _addChild(self, parent, mHeavy, catalog, scarletSource): 

1005 """Add a child to a catalog. 

1006 

1007 This creates a new child in the source catalog, 

1008 assigning it a parent id, and adding all columns 

1009 that are independent across all filter bands. 

1010 

1011 Parameters 

1012 ---------- 

1013 parent : `lsst.afw.table.source.source.SourceRecord` 

1014 The parent of the new child record. 

1015 mHeavy : `lsst.detection.MultibandFootprint` 

1016 The multi-band footprint containing the model and 

1017 peak catalog for the new child record. 

1018 catalog : `lsst.afw.table.source.source.SourceCatalog` 

1019 The merged `SourceCatalog` that contains parent footprints 

1020 to (potentially) deblend. 

1021 scarletSource : `scarlet.Component` 

1022 The scarlet model for the new source record. 

1023 """ 

1024 src = catalog.addNew() 

1025 for key in self.toCopyFromParent: 

1026 src.set(key, parent.get(key)) 

1027 # The peak catalog is the same for all bands, 

1028 # so we just use the first peak catalog 

1029 peaks = mHeavy[mHeavy.filters[0]].peaks 

1030 src.assign(peaks[0], self.peakSchemaMapper) 

1031 src.setParent(parent.getId()) 

1032 # Currently all children only have a single peak, 

1033 # but it's possible in the future that there will be hierarchical 

1034 # deblending, so we use the footprint to set the number of peaks 

1035 # for each child. 

1036 src.set(self.nPeaksKey, len(peaks)) 

1037 # Set the psf key based on whether or not the source was 

1038 # deblended using the PointSource model. 

1039 # This key is not that useful anymore since we now keep track of 

1040 # `modelType`, but we continue to propagate it in case code downstream 

1041 # is expecting it. 

1042 src.set(self.psfKey, scarletSource.__class__.__name__ == "PointSource") 

1043 src.set(self.modelTypeKey, scarletSource.__class__.__name__) 

1044 # We set the runtime to zero so that summing up the 

1045 # runtime column will give the total time spent 

1046 # running the deblender for the catalog. 

1047 src.set(self.runtimeKey, 0) 

1048 

1049 # Set the position of the peak from the parent footprint 

1050 # This will make it easier to match the same source across 

1051 # deblenders and across observations, where the peak 

1052 # position is unlikely to change unless enough time passes 

1053 # for a source to move on the sky. 

1054 peak = scarletSource.detectedPeak 

1055 src.set(self.peakCenter, Point2I(peak["i_x"], peak["i_y"])) 

1056 src.set(self.peakIdKey, peak["id"]) 

1057 

1058 # Propagate columns from the parent to the child 

1059 for parentColumn, childColumn in self.config.columnInheritance.items(): 

1060 src.set(childColumn, parent.get(parentColumn)) 

1061 

1062 def _getCenterFlux(self, mHeavy, scarletSource, xy0): 

1063 """Get the flux at the center of a HeavyFootprint 

1064 

1065 Parameters 

1066 ---------- 

1067 mHeavy : `lsst.detection.MultibandFootprint` 

1068 The multi-band footprint containing the model for the source. 

1069 scarletSource : `scarlet.Component` 

1070 The scarlet model for the heavy footprint 

1071 """ 

1072 # Store the flux at the center of the model and the total 

1073 # scarlet flux measurement. 

1074 mImage = mHeavy.getImage(fill=0.0).image 

1075 

1076 # Set the flux at the center of the model (for SNR) 

1077 try: 

1078 cy, cx = scarletSource.center 

1079 cy += xy0.y 

1080 cx += xy0.x 

1081 return mImage[:, cx, cy] 

1082 except AttributeError: 

1083 msg = "Did not recognize coordinates for source type of `{0}`, " 

1084 msg += "could not write coordinates or center flux. " 

1085 msg += "Add `{0}` to meas_extensions_scarlet to properly persist this information." 

1086 logger.warning(msg.format(type(scarletSource))) 

1087 return {f: np.nan for f in mImage.filters}