Coverage for python/lsst/meas/extensions/scarlet/scarletDeblendTask.py: 15%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

391 statements  

1# This file is part of meas_extensions_scarlet. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import logging 

23import numpy as np 

24import scarlet 

25from scarlet.psf import ImagePSF, GaussianPSF 

26from scarlet import Blend, Frame, Observation 

27from scarlet.renderer import ConvolutionRenderer 

28from scarlet.initialization import init_all_sources 

29 

30import lsst.log 

31import lsst.pex.config as pexConfig 

32from lsst.pex.exceptions import InvalidParameterError 

33import lsst.pipe.base as pipeBase 

34from lsst.geom import Point2I, Box2I, Point2D 

35import lsst.afw.geom.ellipses as afwEll 

36import lsst.afw.image.utils 

37import lsst.afw.image as afwImage 

38import lsst.afw.detection as afwDet 

39import lsst.afw.table as afwTable 

40 

41from .source import modelToHeavy 

42 

43# Scarlet and proxmin have a different definition of log levels than the stack, 

44# so even "warnings" occur far more often than we would like. 

45# So for now we only display scarlet and proxmin errors, as all other 

46# scarlet outputs would be considered "TRACE" by our standards. 

47scarletLogger = logging.getLogger("scarlet") 

48scarletLogger.setLevel(logging.ERROR) 

49proxminLogger = logging.getLogger("proxmin") 

50proxminLogger.setLevel(logging.ERROR) 

51 

52__all__ = ["deblend", "ScarletDeblendConfig", "ScarletDeblendTask"] 

53 

54logger = lsst.log.Log.getLogger("meas.deblender.deblend") 

55 

56 

57class IncompleteDataError(Exception): 

58 """The PSF could not be computed due to incomplete data 

59 """ 

60 pass 

61 

62 

63class ScarletGradientError(Exception): 

64 """An error occurred during optimization 

65 

66 This error occurs when the optimizer encounters 

67 a NaN value while calculating the gradient. 

68 """ 

69 def __init__(self, iterations, sources): 

70 self.iterations = iterations 

71 self.sources = sources 

72 msg = ("ScalarGradientError in iteration {0}. " 

73 "NaN values introduced in sources {1}") 

74 self.message = msg.format(iterations, sources) 

75 

76 def __str__(self): 

77 return self.message 

78 

79 

80def _checkBlendConvergence(blend, f_rel): 

81 """Check whether or not a blend has converged 

82 """ 

83 deltaLoss = np.abs(blend.loss[-2] - blend.loss[-1]) 

84 convergence = f_rel * np.abs(blend.loss[-1]) 

85 return deltaLoss < convergence 

86 

87 

88def _getPsfFwhm(psf): 

89 """Calculate the FWHM of the `psf` 

90 """ 

91 return psf.computeShape().getDeterminantRadius() * 2.35 

92 

93 

94def _computePsfImage(self, position=None): 

95 """Get a multiband PSF image 

96 The PSF Kernel Image is computed for each band 

97 and combined into a (filter, y, x) array and stored 

98 as `self._psfImage`. 

99 The result is not cached, so if the same PSF is expected 

100 to be used multiple times it is a good idea to store the 

101 result in another variable. 

102 Note: this is a temporary fix during the deblender sprint. 

103 In the future this function will replace the current method 

104 in `afw.MultibandExposure.computePsfImage` (DM-19789). 

105 Parameters 

106 ---------- 

107 position : `Point2D` or `tuple` 

108 Coordinates to evaluate the PSF. If `position` is `None` 

109 then `Psf.getAveragePosition()` is used. 

110 Returns 

111 ------- 

112 self._psfImage: array 

113 The multiband PSF image. 

114 """ 

115 psfs = [] 

116 # Make the coordinates into a Point2D (if necessary) 

117 if not isinstance(position, Point2D) and position is not None: 

118 position = Point2D(position[0], position[1]) 

119 

120 for bidx, single in enumerate(self.singles): 

121 try: 

122 if position is None: 

123 psf = single.getPsf().computeImage() 

124 psfs.append(psf) 

125 else: 

126 psf = single.getPsf().computeKernelImage(position) 

127 psfs.append(psf) 

128 except InvalidParameterError: 

129 # This band failed to compute the PSF due to incomplete data 

130 # at that location. This is unlikely to be a problem for Rubin, 

131 # however the edges of some HSC COSMOS fields contain incomplete 

132 # data in some bands, so we track this error to distinguish it 

133 # from unknown errors. 

134 msg = "Failed to compute PSF at {} in band {}" 

135 raise IncompleteDataError(msg.format(position, self.filters[bidx])) 

136 

137 left = np.min([psf.getBBox().getMinX() for psf in psfs]) 

138 bottom = np.min([psf.getBBox().getMinY() for psf in psfs]) 

139 right = np.max([psf.getBBox().getMaxX() for psf in psfs]) 

140 top = np.max([psf.getBBox().getMaxY() for psf in psfs]) 

141 bbox = Box2I(Point2I(left, bottom), Point2I(right, top)) 

142 psfs = [afwImage.utils.projectImage(psf, bbox) for psf in psfs] 

143 psfImage = afwImage.MultibandImage.fromImages(self.filters, psfs) 

144 return psfImage 

145 

146 

147def getFootprintMask(footprint, mExposure): 

148 """Mask pixels outside the footprint 

149 

150 Parameters 

151 ---------- 

152 mExposure : `lsst.image.MultibandExposure` 

153 - The multiband exposure containing the image, 

154 mask, and variance data 

155 footprint : `lsst.detection.Footprint` 

156 - The footprint of the parent to deblend 

157 

158 Returns 

159 ------- 

160 footprintMask : array 

161 Boolean array with pixels not in the footprint set to one. 

162 """ 

163 bbox = footprint.getBBox() 

164 fpMask = afwImage.Mask(bbox) 

165 footprint.spans.setMask(fpMask, 1) 

166 fpMask = ~fpMask.getArray().astype(bool) 

167 return fpMask 

168 

169 

170def isPseudoSource(source, pseudoColumns): 

171 """Check if a source is a pseudo source. 

172 

173 This is mostly for skipping sky objects, 

174 but any other column can also be added to disable 

175 deblending on a parent or individual source when 

176 set to `True`. 

177 

178 Parameters 

179 ---------- 

180 source : `lsst.afw.table.source.source.SourceRecord` 

181 The source to check for the pseudo bit. 

182 pseudoColumns : `list` of `str` 

183 A list of columns to check for pseudo sources. 

184 """ 

185 isPseudo = False 

186 for col in pseudoColumns: 

187 try: 

188 isPseudo |= source[col] 

189 except KeyError: 

190 pass 

191 return isPseudo 

192 

193 

194def deblend(mExposure, footprint, config): 

195 """Deblend a parent footprint 

196 

197 Parameters 

198 ---------- 

199 mExposure : `lsst.image.MultibandExposure` 

200 - The multiband exposure containing the image, 

201 mask, and variance data 

202 footprint : `lsst.detection.Footprint` 

203 - The footprint of the parent to deblend 

204 config : `ScarletDeblendConfig` 

205 - Configuration of the deblending task 

206 """ 

207 # Extract coordinates from each MultiColorPeak 

208 bbox = footprint.getBBox() 

209 

210 # Create the data array from the masked images 

211 images = mExposure.image[:, bbox].array 

212 

213 # Use the inverse variance as the weights 

214 if config.useWeights: 

215 weights = 1/mExposure.variance[:, bbox].array 

216 else: 

217 weights = np.ones_like(images) 

218 badPixels = mExposure.mask.getPlaneBitMask(config.badMask) 

219 mask = mExposure.mask[:, bbox].array & badPixels 

220 weights[mask > 0] = 0 

221 

222 # Mask out the pixels outside the footprint 

223 mask = getFootprintMask(footprint, mExposure) 

224 weights *= ~mask 

225 

226 psfs = _computePsfImage(mExposure, footprint.getCentroid()).array.astype(np.float32) 

227 psfs = ImagePSF(psfs) 

228 model_psf = GaussianPSF(sigma=(config.modelPsfSigma,)*len(mExposure.filters)) 

229 

230 frame = Frame(images.shape, psf=model_psf, channels=mExposure.filters) 

231 observation = Observation(images, psf=psfs, weights=weights, channels=mExposure.filters) 

232 if config.convolutionType == "fft": 

233 observation.match(frame) 

234 elif config.convolutionType == "real": 

235 renderer = ConvolutionRenderer(observation, frame, convolution_type="real") 

236 observation.match(frame, renderer=renderer) 

237 else: 

238 raise ValueError("Unrecognized convolution type {}".format(config.convolutionType)) 

239 

240 assert(config.sourceModel in ["single", "double", "compact", "fit"]) 

241 

242 # Set the appropriate number of components 

243 if config.sourceModel == "single": 

244 maxComponents = 1 

245 elif config.sourceModel == "double": 

246 maxComponents = 2 

247 elif config.sourceModel == "compact": 

248 maxComponents = 0 

249 elif config.sourceModel == "point": 

250 raise NotImplementedError("Point source photometry is currently not implemented") 

251 elif config.sourceModel == "fit": 

252 # It is likely in the future that there will be some heuristic 

253 # used to determine what type of model to use for each source, 

254 # but that has not yet been implemented (see DM-22551) 

255 raise NotImplementedError("sourceModel 'fit' has not been implemented yet") 

256 

257 # Convert the centers to pixel coordinates 

258 xmin = bbox.getMinX() 

259 ymin = bbox.getMinY() 

260 centers = [ 

261 np.array([peak.getIy() - ymin, peak.getIx() - xmin], dtype=int) 

262 for peak in footprint.peaks 

263 if not isPseudoSource(peak, config.pseudoColumns) 

264 ] 

265 

266 # Choose whether or not to use the improved spectral initialization 

267 if config.setSpectra: 

268 if config.maxSpectrumCutoff <= 0: 

269 spectrumInit = True 

270 else: 

271 spectrumInit = len(centers) * bbox.getArea() < config.maxSpectrumCutoff 

272 else: 

273 spectrumInit = False 

274 

275 # Only deblend sources that can be initialized 

276 sources, skipped = init_all_sources( 

277 frame=frame, 

278 centers=centers, 

279 observations=observation, 

280 thresh=config.morphThresh, 

281 max_components=maxComponents, 

282 min_snr=config.minSNR, 

283 shifting=False, 

284 fallback=config.fallback, 

285 silent=config.catchFailures, 

286 set_spectra=spectrumInit, 

287 ) 

288 

289 # Attach the peak to all of the initialized sources 

290 srcIndex = 0 

291 for k, center in enumerate(centers): 

292 if k not in skipped: 

293 # This is just to make sure that there isn't a coding bug 

294 assert np.all(sources[srcIndex].center == center) 

295 # Store the record for the peak with the appropriate source 

296 sources[srcIndex].detectedPeak = footprint.peaks[k] 

297 srcIndex += 1 

298 

299 # Create the blend and attempt to optimize it 

300 blend = Blend(sources, observation) 

301 try: 

302 blend.fit(max_iter=config.maxIter, e_rel=config.relativeError) 

303 except ArithmeticError: 

304 # This occurs when a gradient update produces a NaN value 

305 # This is usually due to a source initialized with a 

306 # negative SED or no flux, often because the peak 

307 # is a noise fluctuation in one band and not a real source. 

308 iterations = len(blend.loss) 

309 failedSources = [] 

310 for k, src in enumerate(sources): 

311 if np.any(~np.isfinite(src.get_model())): 

312 failedSources.append(k) 

313 raise ScarletGradientError(iterations, failedSources) 

314 

315 return blend, skipped, spectrumInit 

316 

317 

318class ScarletDeblendConfig(pexConfig.Config): 

319 """MultibandDeblendConfig 

320 

321 Configuration for the multiband deblender. 

322 The parameters are organized by the parameter types, which are 

323 - Stopping Criteria: Used to determine if the fit has converged 

324 - Position Fitting Criteria: Used to fit the positions of the peaks 

325 - Constraints: Used to apply constraints to the peaks and their components 

326 - Other: Parameters that don't fit into the above categories 

327 """ 

328 # Stopping Criteria 

329 maxIter = pexConfig.Field(dtype=int, default=300, 

330 doc=("Maximum number of iterations to deblend a single parent")) 

331 relativeError = pexConfig.Field(dtype=float, default=1e-4, 

332 doc=("Change in the loss function between" 

333 "iterations to exit fitter")) 

334 

335 # Constraints 

336 morphThresh = pexConfig.Field(dtype=float, default=1, 

337 doc="Fraction of background RMS a pixel must have" 

338 "to be included in the initial morphology") 

339 # Other scarlet paremeters 

340 useWeights = pexConfig.Field( 

341 dtype=bool, default=True, 

342 doc=("Whether or not use use inverse variance weighting." 

343 "If `useWeights` is `False` then flat weights are used")) 

344 modelPsfSize = pexConfig.Field( 

345 dtype=int, default=11, 

346 doc="Model PSF side length in pixels") 

347 modelPsfSigma = pexConfig.Field( 

348 dtype=float, default=0.8, 

349 doc="Define sigma for the model frame PSF") 

350 minSNR = pexConfig.Field( 

351 dtype=float, default=50, 

352 doc="Minimum Signal to noise to accept the source." 

353 "Sources with lower flux will be initialized with the PSF but updated " 

354 "like an ordinary ExtendedSource (known in scarlet as a `CompactSource`).") 

355 saveTemplates = pexConfig.Field( 

356 dtype=bool, default=True, 

357 doc="Whether or not to save the SEDs and templates") 

358 processSingles = pexConfig.Field( 

359 dtype=bool, default=True, 

360 doc="Whether or not to process isolated sources in the deblender") 

361 convolutionType = pexConfig.Field( 

362 dtype=str, default="fft", 

363 doc="Type of convolution to render the model to the observations.\n" 

364 "- 'fft': perform convolutions in Fourier space\n" 

365 "- 'real': peform convolutions in real space.") 

366 sourceModel = pexConfig.Field( 

367 dtype=str, default="double", 

368 doc=("How to determine which model to use for sources, from\n" 

369 "- 'single': use a single component for all sources\n" 

370 "- 'double': use a bulge disk model for all sources\n" 

371 "- 'compact': use a single component model, initialzed with a point source morphology, " 

372 " for all sources\n" 

373 "- 'point': use a point-source model for all sources\n" 

374 "- 'fit: use a PSF fitting model to determine the number of components (not yet implemented)") 

375 ) 

376 setSpectra = pexConfig.Field( 

377 dtype=bool, default=True, 

378 doc="Whether or not to solve for the best-fit spectra during initialization. " 

379 "This makes initialization slightly longer, as it requires a convolution " 

380 "to set the optimal spectra, but results in a much better initial log-likelihood " 

381 "and reduced total runtime, with convergence in fewer iterations." 

382 "This option is only used when " 

383 "peaks*area < `maxSpectrumCutoff` will use the improved initialization.") 

384 

385 # Mask-plane restrictions 

386 badMask = pexConfig.ListField( 

387 dtype=str, default=["BAD", "CR", "NO_DATA", "SAT", "SUSPECT", "EDGE"], 

388 doc="Whether or not to process isolated sources in the deblender") 

389 statsMask = pexConfig.ListField(dtype=str, default=["SAT", "INTRP", "NO_DATA"], 

390 doc="Mask planes to ignore when performing statistics") 

391 maskLimits = pexConfig.DictField( 

392 keytype=str, 

393 itemtype=float, 

394 default={}, 

395 doc=("Mask planes with the corresponding limit on the fraction of masked pixels. " 

396 "Sources violating this limit will not be deblended."), 

397 ) 

398 

399 # Size restrictions 

400 maxNumberOfPeaks = pexConfig.Field( 

401 dtype=int, default=0, 

402 doc=("Only deblend the brightest maxNumberOfPeaks peaks in the parent" 

403 " (<= 0: unlimited)")) 

404 maxFootprintArea = pexConfig.Field( 

405 dtype=int, default=1000000, 

406 doc=("Maximum area for footprints before they are ignored as large; " 

407 "non-positive means no threshold applied")) 

408 maxFootprintSize = pexConfig.Field( 

409 dtype=int, default=0, 

410 doc=("Maximum linear dimension for footprints before they are ignored " 

411 "as large; non-positive means no threshold applied")) 

412 minFootprintAxisRatio = pexConfig.Field( 

413 dtype=float, default=0.0, 

414 doc=("Minimum axis ratio for footprints before they are ignored " 

415 "as large; non-positive means no threshold applied")) 

416 maxSpectrumCutoff = pexConfig.Field( 

417 dtype=int, default=1000000, 

418 doc=("Maximum number of pixels * number of sources in a blend. " 

419 "This is different than `maxFootprintArea` because this isn't " 

420 "the footprint area but the area of the bounding box that " 

421 "contains the footprint, and is also multiplied by the number of" 

422 "sources in the footprint. This prevents large skinny blends with " 

423 "a high density of sources from running out of memory. " 

424 "If `maxSpectrumCutoff == -1` then there is no cutoff.") 

425 ) 

426 

427 # Failure modes 

428 fallback = pexConfig.Field( 

429 dtype=bool, default=True, 

430 doc="Whether or not to fallback to a smaller number of components if a source does not initialize" 

431 ) 

432 notDeblendedMask = pexConfig.Field( 

433 dtype=str, default="NOT_DEBLENDED", optional=True, 

434 doc="Mask name for footprints not deblended, or None") 

435 catchFailures = pexConfig.Field( 

436 dtype=bool, default=True, 

437 doc=("If True, catch exceptions thrown by the deblender, log them, " 

438 "and set a flag on the parent, instead of letting them propagate up")) 

439 

440 # Other options 

441 columnInheritance = pexConfig.DictField( 

442 keytype=str, itemtype=str, default={ 

443 "deblend_nChild": "deblend_parentNChild", 

444 "deblend_nPeaks": "deblend_parentNPeaks", 

445 "deblend_spectrumInitFlag": "deblend_spectrumInitFlag", 

446 "deblend_blendConvergenceFailedFlag": "deblend_blendConvergenceFailedFlag", 

447 }, 

448 doc="Columns to pass from the parent to the child. " 

449 "The key is the name of the column for the parent record, " 

450 "the value is the name of the column to use for the child." 

451 ) 

452 pseudoColumns = pexConfig.ListField( 

453 dtype=str, default=['merge_peak_sky', 'sky_source'], 

454 doc="Names of flags which should never be deblended." 

455 ) 

456 

457 # Logging option(s) 

458 loggingInterval = pexConfig.Field( 

459 dtype=int, default=600, 

460 doc="Interval (in seconds) to log messages (at VERBOSE level) while deblending sources." 

461 ) 

462 # Testing options 

463 # Some obs packages and ci packages run the full pipeline on a small 

464 # subset of data to test that the pipeline is functioning properly. 

465 # This is not meant as scientific validation, so it can be useful 

466 # to only run on a small subset of the data that is large enough to 

467 # test the desired pipeline features but not so long that the deblender 

468 # is the tall pole in terms of execution times. 

469 useCiLimits = pexConfig.Field( 

470 dtype=bool, default=False, 

471 doc="Limit the number of sources deblended for CI to prevent long build times") 

472 ciDeblendChildRange = pexConfig.ListField( 

473 dtype=int, default=[5, 10], 

474 doc="Only deblend parent Footprints with a number of peaks in the (inclusive) range indicated." 

475 "If `useCiLimits==False` then this parameter is ignored.") 

476 ciNumParentsToDeblend = pexConfig.Field( 

477 dtype=int, default=10, 

478 doc="Only use the first `ciNumParentsToDeblend` parent footprints with a total peak count " 

479 "within `ciDebledChildRange`. " 

480 "If `useCiLimits==False` then this parameter is ignored.") 

481 

482 

483class ScarletDeblendTask(pipeBase.Task): 

484 """ScarletDeblendTask 

485 

486 Split blended sources into individual sources. 

487 

488 This task has no return value; it only modifies the SourceCatalog in-place. 

489 """ 

490 ConfigClass = ScarletDeblendConfig 

491 _DefaultName = "scarletDeblend" 

492 

493 def __init__(self, schema, peakSchema=None, **kwargs): 

494 """Create the task, adding necessary fields to the given schema. 

495 

496 Parameters 

497 ---------- 

498 schema : `lsst.afw.table.schema.schema.Schema` 

499 Schema object for measurement fields; will be modified in-place. 

500 peakSchema : `lsst.afw.table.schema.schema.Schema` 

501 Schema of Footprint Peaks that will be passed to the deblender. 

502 Any fields beyond the PeakTable minimal schema will be transferred 

503 to the main source Schema. If None, no fields will be transferred 

504 from the Peaks. 

505 filters : list of str 

506 Names of the filters used for the eposures. This is needed to store 

507 the SED as a field 

508 **kwargs 

509 Passed to Task.__init__. 

510 """ 

511 pipeBase.Task.__init__(self, **kwargs) 

512 

513 peakMinimalSchema = afwDet.PeakTable.makeMinimalSchema() 

514 if peakSchema is None: 

515 # In this case, the peakSchemaMapper will transfer nothing, but 

516 # we'll still have one 

517 # to simplify downstream code 

518 self.peakSchemaMapper = afwTable.SchemaMapper(peakMinimalSchema, schema) 

519 else: 

520 self.peakSchemaMapper = afwTable.SchemaMapper(peakSchema, schema) 

521 for item in peakSchema: 

522 if item.key not in peakMinimalSchema: 

523 self.peakSchemaMapper.addMapping(item.key, item.field) 

524 # Because SchemaMapper makes a copy of the output schema 

525 # you give its ctor, it isn't updating this Schema in 

526 # place. That's probably a design flaw, but in the 

527 # meantime, we'll keep that schema in sync with the 

528 # peakSchemaMapper.getOutputSchema() manually, by adding 

529 # the same fields to both. 

530 schema.addField(item.field) 

531 assert schema == self.peakSchemaMapper.getOutputSchema(), "Logic bug mapping schemas" 

532 self._addSchemaKeys(schema) 

533 self.schema = schema 

534 self.toCopyFromParent = [item.key for item in self.schema 

535 if item.field.getName().startswith("merge_footprint")] 

536 

537 def _addSchemaKeys(self, schema): 

538 """Add deblender specific keys to the schema 

539 """ 

540 self.runtimeKey = schema.addField('deblend_runtime', type=np.float32, doc='runtime in ms') 

541 

542 self.iterKey = schema.addField('deblend_iterations', type=np.int32, doc='iterations to converge') 

543 

544 self.nChildKey = schema.addField('deblend_nChild', type=np.int32, 

545 doc='Number of children this object has (defaults to 0)') 

546 self.psfKey = schema.addField('deblend_deblendedAsPsf', type='Flag', 

547 doc='Deblender thought this source looked like a PSF') 

548 self.tooManyPeaksKey = schema.addField('deblend_tooManyPeaks', type='Flag', 

549 doc='Source had too many peaks; ' 

550 'only the brightest were included') 

551 self.tooBigKey = schema.addField('deblend_parentTooBig', type='Flag', 

552 doc='Parent footprint covered too many pixels') 

553 self.maskedKey = schema.addField('deblend_masked', type='Flag', 

554 doc='Parent footprint was predominantly masked') 

555 self.sedNotConvergedKey = schema.addField('deblend_sedConvergenceFailed', type='Flag', 

556 doc='scarlet sed optimization did not converge before' 

557 'config.maxIter') 

558 self.morphNotConvergedKey = schema.addField('deblend_morphConvergenceFailed', type='Flag', 

559 doc='scarlet morph optimization did not converge before' 

560 'config.maxIter') 

561 self.blendConvergenceFailedFlagKey = schema.addField('deblend_blendConvergenceFailedFlag', 

562 type='Flag', 

563 doc='at least one source in the blend' 

564 'failed to converge') 

565 self.edgePixelsKey = schema.addField('deblend_edgePixels', type='Flag', 

566 doc='Source had flux on the edge of the parent footprint') 

567 self.deblendFailedKey = schema.addField('deblend_failed', type='Flag', 

568 doc="Deblending failed on source") 

569 self.deblendErrorKey = schema.addField('deblend_error', type="String", size=25, 

570 doc='Name of error if the blend failed') 

571 self.deblendSkippedKey = schema.addField('deblend_skipped', type='Flag', 

572 doc="Deblender skipped this source") 

573 self.peakCenter = afwTable.Point2IKey.addFields(schema, name="deblend_peak_center", 

574 doc="Center used to apply constraints in scarlet", 

575 unit="pixel") 

576 self.peakIdKey = schema.addField("deblend_peakId", type=np.int32, 

577 doc="ID of the peak in the parent footprint. " 

578 "This is not unique, but the combination of 'parent'" 

579 "and 'peakId' should be for all child sources. " 

580 "Top level blends with no parents have 'peakId=0'") 

581 self.modelCenterFlux = schema.addField('deblend_peak_instFlux', type=float, units='count', 

582 doc="The instFlux at the peak position of deblended mode") 

583 self.modelTypeKey = schema.addField("deblend_modelType", type="String", size=25, 

584 doc="The type of model used, for example " 

585 "MultiExtendedSource, SingleExtendedSource, PointSource") 

586 self.nPeaksKey = schema.addField("deblend_nPeaks", type=np.int32, 

587 doc="Number of initial peaks in the blend. " 

588 "This includes peaks that may have been culled " 

589 "during deblending or failed to deblend") 

590 self.parentNPeaksKey = schema.addField("deblend_parentNPeaks", type=np.int32, 

591 doc="deblend_nPeaks from this records parent.") 

592 self.parentNChildKey = schema.addField("deblend_parentNChild", type=np.int32, 

593 doc="deblend_nChild from this records parent.") 

594 self.scarletFluxKey = schema.addField("deblend_scarletFlux", type=np.float32, 

595 doc="Flux measurement from scarlet") 

596 self.scarletLogLKey = schema.addField("deblend_logL", type=np.float32, 

597 doc="Final logL, used to identify regressions in scarlet.") 

598 self.scarletSpectrumInitKey = schema.addField("deblend_spectrumInitFlag", type='Flag', 

599 doc="True when scarlet initializes sources " 

600 "in the blend with a more accurate spectrum. " 

601 "The algorithm uses a lot of memory, " 

602 "so large dense blends will use " 

603 "a less accurate initialization.") 

604 

605 # self.log.trace('Added keys to schema: %s', ", ".join(str(x) for x in 

606 # (self.nChildKey, self.tooManyPeaksKey, self.tooBigKey)) 

607 # ) 

608 

609 @pipeBase.timeMethod 

610 def run(self, mExposure, mergedSources): 

611 """Get the psf from each exposure and then run deblend(). 

612 

613 Parameters 

614 ---------- 

615 mExposure : `MultibandExposure` 

616 The exposures should be co-added images of the same 

617 shape and region of the sky. 

618 mergedSources : `SourceCatalog` 

619 The merged `SourceCatalog` that contains parent footprints 

620 to (potentially) deblend. 

621 

622 Returns 

623 ------- 

624 templateCatalogs: dict 

625 Keys are the names of the filters and the values are 

626 `lsst.afw.table.source.source.SourceCatalog`'s. 

627 These are catalogs with heavy footprints that are the templates 

628 created by the multiband templates. 

629 """ 

630 return self.deblend(mExposure, mergedSources) 

631 

632 @pipeBase.timeMethod 

633 def deblend(self, mExposure, catalog): 

634 """Deblend a data cube of multiband images 

635 

636 Parameters 

637 ---------- 

638 mExposure : `MultibandExposure` 

639 The exposures should be co-added images of the same 

640 shape and region of the sky. 

641 catalog : `SourceCatalog` 

642 The merged `SourceCatalog` that contains parent footprints 

643 to (potentially) deblend. The new deblended sources are 

644 appended to this catalog in place. 

645 

646 Returns 

647 ------- 

648 catalogs : `dict` or `None` 

649 Keys are the names of the filters and the values are 

650 `lsst.afw.table.source.source.SourceCatalog`'s. 

651 These are catalogs with heavy footprints that are the templates 

652 created by the multiband templates. 

653 """ 

654 import time 

655 

656 # Cull footprints if required by ci 

657 if self.config.useCiLimits: 

658 self.log.info("Using CI catalog limits, the original number of sources to deblend was %d.", 

659 len(catalog)) 

660 # Select parents with a number of children in the range 

661 # config.ciDeblendChildRange 

662 minChildren, maxChildren = self.config.ciDeblendChildRange 

663 nPeaks = np.array([len(src.getFootprint().peaks) for src in catalog]) 

664 childrenInRange = np.where((nPeaks >= minChildren) & (nPeaks <= maxChildren))[0] 

665 if len(childrenInRange) < self.config.ciNumParentsToDeblend: 

666 raise ValueError("Fewer than ciNumParentsToDeblend children were contained in the range " 

667 "indicated by ciDeblendChildRange. Adjust this range to include more " 

668 "parents.") 

669 # Keep all of the isolated parents and the first 

670 # `ciNumParentsToDeblend` children 

671 parents = nPeaks == 1 

672 children = np.zeros((len(catalog),), dtype=bool) 

673 children[childrenInRange[:self.config.ciNumParentsToDeblend]] = True 

674 catalog = catalog[parents | children] 

675 # We need to update the IdFactory, otherwise the the source ids 

676 # will not be sequential 

677 idFactory = catalog.getIdFactory() 

678 maxId = np.max(catalog["id"]) 

679 idFactory.notify(maxId) 

680 

681 filters = mExposure.filters 

682 self.log.info("Deblending %d sources in %d exposure bands", len(catalog), len(mExposure)) 

683 nextLogTime = time.time() + self.config.loggingInterval 

684 

685 # Add the NOT_DEBLENDED mask to the mask plane in each band 

686 if self.config.notDeblendedMask: 

687 for mask in mExposure.mask: 

688 mask.addMaskPlane(self.config.notDeblendedMask) 

689 

690 nParents = len(catalog) 

691 nDeblendedParents = 0 

692 skippedParents = [] 

693 multibandColumns = { 

694 "heavies": [], 

695 "fluxes": [], 

696 "centerFluxes": [], 

697 } 

698 for parentIndex in range(nParents): 

699 parent = catalog[parentIndex] 

700 foot = parent.getFootprint() 

701 bbox = foot.getBBox() 

702 peaks = foot.getPeaks() 

703 

704 # Since we use the first peak for the parent object, we should 

705 # propagate its flags to the parent source. 

706 parent.assign(peaks[0], self.peakSchemaMapper) 

707 

708 # Skip isolated sources unless processSingles is turned on. 

709 # Note: this does not flag isolated sources as skipped or 

710 # set the NOT_DEBLENDED mask in the exposure, 

711 # since these aren't really a skipped blends. 

712 # We also skip pseudo sources, like sky objects, which 

713 # are intended to be skipped 

714 if ((len(peaks) < 2 and not self.config.processSingles) 

715 or isPseudoSource(parent, self.config.pseudoColumns)): 

716 self._updateParentRecord( 

717 parent=parent, 

718 nPeaks=len(peaks), 

719 nChild=0, 

720 runtime=np.nan, 

721 iterations=0, 

722 logL=np.nan, 

723 spectrumInit=False, 

724 converged=False, 

725 ) 

726 continue 

727 

728 # Block of conditions for skipping a parent with multiple children 

729 skipKey = None 

730 if self._isLargeFootprint(foot): 

731 # The footprint is above the maximum footprint size limit 

732 skipKey = self.tooBigKey 

733 skipMessage = f"Parent {parent.getId()}: skipping large footprint" 

734 elif self._isMasked(foot, mExposure): 

735 # The footprint exceeds the maximum number of masked pixels 

736 skipKey = self.maskedKey 

737 skipMessage = f"Parent {parent.getId()}: skipping masked footprint" 

738 elif self.config.maxNumberOfPeaks > 0 and len(peaks) > self.config.maxNumberOfPeaks: 

739 # Unlike meas_deblender, in scarlet we skip the entire blend 

740 # if the number of peaks exceeds max peaks, since neglecting 

741 # to model any peaks often results in catastrophic failure 

742 # of scarlet to generate models for the brighter sources. 

743 skipKey = self.tooManyPeaksKey 

744 skipMessage = f"Parent {parent.getId()}: Too many peaks, skipping blend" 

745 if skipKey is not None: 

746 self._skipParent( 

747 parent=parent, 

748 skipKey=skipKey, 

749 logMessage=skipMessage, 

750 ) 

751 skippedParents.append(parentIndex) 

752 continue 

753 

754 nDeblendedParents += 1 

755 self.log.trace("Parent %d: deblending %d peaks", parent.getId(), len(peaks)) 

756 # Run the deblender 

757 blendError = None 

758 try: 

759 t0 = time.time() 

760 # Build the parameter lists with the same ordering 

761 blend, skipped, spectrumInit = deblend(mExposure, foot, self.config) 

762 tf = time.time() 

763 runtime = (tf-t0)*1000 

764 converged = _checkBlendConvergence(blend, self.config.relativeError) 

765 scarletSources = [src for src in blend.sources] 

766 nChild = len(scarletSources) 

767 # Re-insert place holders for skipped sources 

768 # to propagate them in the catalog so 

769 # that the peaks stay consistent 

770 for k in skipped: 

771 scarletSources.insert(k, None) 

772 # Catch all errors and filter out the ones that we know about 

773 except Exception as e: 

774 blendError = type(e).__name__ 

775 if isinstance(e, ScarletGradientError): 

776 parent.set(self.iterKey, e.iterations) 

777 elif not isinstance(e, IncompleteDataError): 

778 blendError = "UnknownError" 

779 if self.config.catchFailures: 

780 # Make it easy to find UnknownErrors in the log file 

781 self.log.warn("UnknownError") 

782 import traceback 

783 traceback.print_exc() 

784 else: 

785 raise 

786 

787 self._skipParent( 

788 parent=parent, 

789 skipKey=self.deblendFailedKey, 

790 logMessage=f"Unable to deblend source {parent.getId}: {blendError}", 

791 ) 

792 parent.set(self.deblendErrorKey, blendError) 

793 skippedParents.append(parentIndex) 

794 continue 

795 

796 # Update the parent record with the deblending results 

797 logL = blend.loss[-1]-blend.observations[0].log_norm 

798 self._updateParentRecord( 

799 parent=parent, 

800 nPeaks=len(peaks), 

801 nChild=nChild, 

802 runtime=runtime, 

803 iterations=len(blend.loss), 

804 logL=logL, 

805 spectrumInit=spectrumInit, 

806 converged=converged, 

807 ) 

808 

809 # Add each deblended source to the catalog 

810 for k, scarletSource in enumerate(scarletSources): 

811 # Skip any sources with no flux or that scarlet skipped because 

812 # it could not initialize 

813 if k in skipped: 

814 # No need to propagate anything 

815 continue 

816 parent.set(self.deblendSkippedKey, False) 

817 mHeavy = modelToHeavy(scarletSource, filters, xy0=bbox.getMin(), 

818 observation=blend.observations[0]) 

819 multibandColumns["heavies"].append(mHeavy) 

820 flux = scarlet.measure.flux(scarletSource) 

821 multibandColumns["fluxes"].append({ 

822 filters[fidx]: _flux 

823 for fidx, _flux in enumerate(flux) 

824 }) 

825 centerFlux = self._getCenterFlux(mHeavy, scarletSource, xy0=bbox.getMin()) 

826 multibandColumns["centerFluxes"].append(centerFlux) 

827 

828 # Add all fields except the HeavyFootprint to the 

829 # source record 

830 self._addChild( 

831 parent=parent, 

832 mHeavy=mHeavy, 

833 catalog=catalog, 

834 scarletSource=scarletSource, 

835 ) 

836 # Log a message if it has been a while since the last log. 

837 if (currentTime := time.time()) > nextLogTime: 

838 nextLogTime = currentTime + self.config.loggingInterval 

839 self.log.verbose("Deblended %d parent sources out of %d", parentIndex + 1, nParents) 

840 

841 # Make sure that the number of new sources matches the number of 

842 # entries in each of the band dependent columns. 

843 # This should never trigger and is just a sanity check. 

844 nChildren = len(catalog) - nParents 

845 if np.any([len(meas) != nChildren for meas in multibandColumns.values()]): 

846 msg = f"Added {len(catalog)-nParents} new sources, but have " 

847 msg += ", ".join([ 

848 f"{len(value)} {key}" 

849 for key, value in multibandColumns 

850 ]) 

851 raise RuntimeError(msg) 

852 # Make a copy of the catlog in each band and update the footprints 

853 catalogs = {} 

854 for f in filters: 

855 _catalog = afwTable.SourceCatalog(catalog.table.clone()) 

856 _catalog.extend(catalog, deep=True) 

857 # Update the footprints and columns that are different 

858 # for each filter 

859 for sourceIndex, source in enumerate(_catalog[nParents:]): 

860 source.setFootprint(multibandColumns["heavies"][sourceIndex][f]) 

861 source.set(self.scarletFluxKey, multibandColumns["fluxes"][sourceIndex][f]) 

862 source.set(self.modelCenterFlux, multibandColumns["centerFluxes"][sourceIndex][f]) 

863 catalogs[f] = _catalog 

864 

865 # Update the mExposure mask with the footprint of skipped parents 

866 if self.config.notDeblendedMask: 

867 for mask in mExposure.mask: 

868 for parentIndex in skippedParents: 

869 fp = _catalog[parentIndex].getFootprint() 

870 fp.spans.setMask(mask, mask.getPlaneBitMask(self.config.notDeblendedMask)) 

871 

872 self.log.info("Deblender results: of %d parent sources, %d were deblended, " 

873 "creating %d children, for a total of %d sources", 

874 nParents, nDeblendedParents, nChildren, len(catalog)) 

875 return catalogs 

876 

877 def _isLargeFootprint(self, footprint): 

878 """Returns whether a Footprint is large 

879 

880 'Large' is defined by thresholds on the area, size and axis ratio. 

881 These may be disabled independently by configuring them to be 

882 non-positive. 

883 

884 This is principally intended to get rid of satellite streaks, which the 

885 deblender or other downstream processing can have trouble dealing with 

886 (e.g., multiple large HeavyFootprints can chew up memory). 

887 """ 

888 if self.config.maxFootprintArea > 0 and footprint.getArea() > self.config.maxFootprintArea: 

889 return True 

890 if self.config.maxFootprintSize > 0: 

891 bbox = footprint.getBBox() 

892 if max(bbox.getWidth(), bbox.getHeight()) > self.config.maxFootprintSize: 

893 return True 

894 if self.config.minFootprintAxisRatio > 0: 

895 axes = afwEll.Axes(footprint.getShape()) 

896 if axes.getB() < self.config.minFootprintAxisRatio*axes.getA(): 

897 return True 

898 return False 

899 

900 def _isMasked(self, footprint, mExposure): 

901 """Returns whether the footprint violates the mask limits""" 

902 bbox = footprint.getBBox() 

903 mask = np.bitwise_or.reduce(mExposure.mask[:, bbox].array, axis=0) 

904 size = float(footprint.getArea()) 

905 for maskName, limit in self.config.maskLimits.items(): 

906 maskVal = mExposure.mask.getPlaneBitMask(maskName) 

907 _mask = afwImage.MaskX(mask & maskVal, xy0=bbox.getMin()) 

908 unmaskedSpan = footprint.spans.intersectNot(_mask) # spanset of unmasked pixels 

909 if (size - unmaskedSpan.getArea())/size > limit: 

910 return True 

911 return False 

912 

913 def _skipParent(self, parent, skipKey, logMessage): 

914 """Update a parent record that is not being deblended. 

915 

916 This is a fairly trivial function but is implemented to ensure 

917 that a skipped parent updates the appropriate columns 

918 consistently, and always has a flag to mark the reason that 

919 it is being skipped. 

920 

921 Parameters 

922 ---------- 

923 parent : `lsst.afw.table.source.source.SourceRecord` 

924 The parent record to flag as skipped. 

925 skipKey : `bool` 

926 The name of the flag to mark the reason for skipping. 

927 logMessage : `str` 

928 The message to display in a log.trace when a source 

929 is skipped. 

930 """ 

931 if logMessage is not None: 

932 self.log.trace(logMessage) 

933 self._updateParentRecord( 

934 parent=parent, 

935 nPeaks=len(parent.getFootprint().peaks), 

936 nChild=0, 

937 runtime=np.nan, 

938 iterations=0, 

939 logL=np.nan, 

940 spectrumInit=False, 

941 converged=False, 

942 ) 

943 

944 # Mark the source as skipped by the deblender and 

945 # flag the reason why. 

946 parent.set(self.deblendSkippedKey, True) 

947 parent.set(skipKey, True) 

948 

949 def _updateParentRecord(self, parent, nPeaks, nChild, 

950 runtime, iterations, logL, spectrumInit, converged): 

951 """Update a parent record in all of the single band catalogs. 

952 

953 Ensure that all locations that update a parent record, 

954 whether it is skipped or updated after deblending, 

955 update all of the appropriate columns. 

956 

957 Parameters 

958 ---------- 

959 parent : `lsst.afw.table.source.source.SourceRecord` 

960 The parent record to update. 

961 nPeaks : `int` 

962 Number of peaks in the parent footprint. 

963 nChild : `int` 

964 Number of children deblended from the parent. 

965 This may differ from `nPeaks` if some of the peaks 

966 were culled and have no deblended model. 

967 runtime : `float` 

968 Total runtime for deblending. 

969 iterations : `int` 

970 Total number of iterations in scarlet before convergence. 

971 logL : `float` 

972 Final log likelihood of the blend. 

973 spectrumInit : `bool` 

974 True when scarlet used `set_spectra` to initialize all 

975 sources with better initial intensities. 

976 converged : `bool` 

977 True when the optimizer reached convergence before 

978 reaching the maximum number of iterations. 

979 """ 

980 parent.set(self.nPeaksKey, nPeaks) 

981 parent.set(self.nChildKey, nChild) 

982 parent.set(self.runtimeKey, runtime) 

983 parent.set(self.iterKey, iterations) 

984 parent.set(self.scarletLogLKey, logL) 

985 parent.set(self.scarletSpectrumInitKey, spectrumInit) 

986 parent.set(self.blendConvergenceFailedFlagKey, converged) 

987 

988 def _addChild(self, parent, mHeavy, catalog, scarletSource): 

989 """Add a child to a catalog. 

990 

991 This creates a new child in the source catalog, 

992 assigning it a parent id, and adding all columns 

993 that are independent across all filter bands. 

994 

995 Parameters 

996 ---------- 

997 parent : `lsst.afw.table.source.source.SourceRecord` 

998 The parent of the new child record. 

999 mHeavy : `lsst.detection.MultibandFootprint` 

1000 The multi-band footprint containing the model and 

1001 peak catalog for the new child record. 

1002 catalog : `lsst.afw.table.source.source.SourceCatalog` 

1003 The merged `SourceCatalog` that contains parent footprints 

1004 to (potentially) deblend. 

1005 scarletSource : `scarlet.Component` 

1006 The scarlet model for the new source record. 

1007 """ 

1008 src = catalog.addNew() 

1009 for key in self.toCopyFromParent: 

1010 src.set(key, parent.get(key)) 

1011 # The peak catalog is the same for all bands, 

1012 # so we just use the first peak catalog 

1013 peaks = mHeavy[mHeavy.filters[0]].peaks 

1014 src.assign(peaks[0], self.peakSchemaMapper) 

1015 src.setParent(parent.getId()) 

1016 # Currently all children only have a single peak, 

1017 # but it's possible in the future that there will be hierarchical 

1018 # deblending, so we use the footprint to set the number of peaks 

1019 # for each child. 

1020 src.set(self.nPeaksKey, len(peaks)) 

1021 # Set the psf key based on whether or not the source was 

1022 # deblended using the PointSource model. 

1023 # This key is not that useful anymore since we now keep track of 

1024 # `modelType`, but we continue to propagate it in case code downstream 

1025 # is expecting it. 

1026 src.set(self.psfKey, scarletSource.__class__.__name__ == "PointSource") 

1027 src.set(self.modelTypeKey, scarletSource.__class__.__name__) 

1028 # We set the runtime to zero so that summing up the 

1029 # runtime column will give the total time spent 

1030 # running the deblender for the catalog. 

1031 src.set(self.runtimeKey, 0) 

1032 

1033 # Set the position of the peak from the parent footprint 

1034 # This will make it easier to match the same source across 

1035 # deblenders and across observations, where the peak 

1036 # position is unlikely to change unless enough time passes 

1037 # for a source to move on the sky. 

1038 peak = scarletSource.detectedPeak 

1039 src.set(self.peakCenter, Point2I(peak["i_x"], peak["i_y"])) 

1040 src.set(self.peakIdKey, peak["id"]) 

1041 

1042 # Propagate columns from the parent to the child 

1043 for parentColumn, childColumn in self.config.columnInheritance.items(): 

1044 src.set(childColumn, parent.get(parentColumn)) 

1045 

1046 def _getCenterFlux(self, mHeavy, scarletSource, xy0): 

1047 """Get the flux at the center of a HeavyFootprint 

1048 

1049 Parameters 

1050 ---------- 

1051 mHeavy : `lsst.detection.MultibandFootprint` 

1052 The multi-band footprint containing the model for the source. 

1053 scarletSource : `scarlet.Component` 

1054 The scarlet model for the heavy footprint 

1055 """ 

1056 # Store the flux at the center of the model and the total 

1057 # scarlet flux measurement. 

1058 mImage = mHeavy.getImage(fill=0.0).image 

1059 

1060 # Set the flux at the center of the model (for SNR) 

1061 try: 

1062 cy, cx = scarletSource.center 

1063 cy += xy0.y 

1064 cx += xy0.x 

1065 return mImage[:, cx, cy] 

1066 except AttributeError: 

1067 msg = "Did not recognize coordinates for source type of `{0}`, " 

1068 msg += "could not write coordinates or center flux. " 

1069 msg += "Add `{0}` to meas_extensions_scarlet to properly persist this information." 

1070 logger.warning(msg.format(type(scarletSource))) 

1071 return {f: np.nan for f in mImage.filters}