Coverage for python/lsst/cp/pipe/linearity.py: 12%

262 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-09 10:48 +0000

1# This file is part of cp_pipe. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21# 

22 

23__all__ = ["LinearitySolveTask", "LinearitySolveConfig"] 

24 

25import numpy as np 

26import lsst.afw.image as afwImage 

27import lsst.pipe.base as pipeBase 

28import lsst.pipe.base.connectionTypes as cT 

29import lsst.pex.config as pexConfig 

30 

31from lsstDebug import getDebugFrame 

32from lsst.ip.isr import (Linearizer, IsrProvenance) 

33 

34from .utils import (funcPolynomial, irlsFit, AstierSplineLinearityFitter, 

35 extractCalibDate) 

36 

37 

38def ptcLookup(datasetType, registry, quantumDataId, collections): 

39 """Butler lookup function to allow PTC to be found. 

40 

41 Parameters 

42 ---------- 

43 datasetType : `lsst.daf.butler.DatasetType` 

44 Dataset type to look up. 

45 registry : `lsst.daf.butler.Registry` 

46 Registry for the data repository being searched. 

47 quantumDataId : `lsst.daf.butler.DataCoordinate` 

48 Data ID for the quantum of the task this dataset will be passed to. 

49 This must include an "instrument" key, and should also include any 

50 keys that are present in ``datasetType.dimensions``. If it has an 

51 ``exposure`` or ``visit`` key, that's a sign that this function is 

52 not actually needed, as those come with the temporal information that 

53 would allow a real validity-range lookup. 

54 collections : `lsst.daf.butler.registry.CollectionSearch` 

55 Collections passed by the user when generating a QuantumGraph. Ignored 

56 by this function (see notes below). 

57 

58 Returns 

59 ------- 

60 refs : `list` [ `DatasetRef` ] 

61 A zero- or single-element list containing the matching 

62 dataset, if one was found. 

63 

64 Raises 

65 ------ 

66 RuntimeError 

67 Raised if more than one PTC reference is found. 

68 """ 

69 refs = list(registry.queryDatasets(datasetType, dataId=quantumDataId, collections=collections, 

70 findFirst=False)) 

71 if len(refs) >= 2: 

72 RuntimeError("Too many PTC connections found. Incorrect collections supplied?") 

73 

74 return refs 

75 

76 

77class LinearitySolveConnections(pipeBase.PipelineTaskConnections, 

78 dimensions=("instrument", "detector")): 

79 dummy = cT.Input( 

80 name="raw", 

81 doc="Dummy exposure.", 

82 storageClass='Exposure', 

83 dimensions=("instrument", "exposure", "detector"), 

84 multiple=True, 

85 deferLoad=True, 

86 ) 

87 

88 camera = cT.PrerequisiteInput( 

89 name="camera", 

90 doc="Camera Geometry definition.", 

91 storageClass="Camera", 

92 dimensions=("instrument", ), 

93 isCalibration=True, 

94 ) 

95 

96 inputPtc = cT.PrerequisiteInput( 

97 name="ptc", 

98 doc="Input PTC dataset.", 

99 storageClass="PhotonTransferCurveDataset", 

100 dimensions=("instrument", "detector"), 

101 isCalibration=True, 

102 lookupFunction=ptcLookup, 

103 ) 

104 

105 inputPhotodiodeCorrection = cT.Input( 

106 name="pdCorrection", 

107 doc="Input photodiode correction.", 

108 storageClass="IsrCalib", 

109 dimensions=("instrument", ), 

110 isCalibration=True, 

111 ) 

112 

113 outputLinearizer = cT.Output( 

114 name="linearity", 

115 doc="Output linearity measurements.", 

116 storageClass="Linearizer", 

117 dimensions=("instrument", "detector"), 

118 isCalibration=True, 

119 ) 

120 

121 def __init__(self, *, config=None): 

122 if not config.applyPhotodiodeCorrection: 

123 del self.inputPhotodiodeCorrection 

124 

125 

126class LinearitySolveConfig(pipeBase.PipelineTaskConfig, 

127 pipelineConnections=LinearitySolveConnections): 

128 """Configuration for solving the linearity from PTC dataset. 

129 """ 

130 linearityType = pexConfig.ChoiceField( 

131 dtype=str, 

132 doc="Type of linearizer to construct.", 

133 default="Squared", 

134 allowed={ 

135 "LookupTable": "Create a lookup table solution.", 

136 "Polynomial": "Create an arbitrary polynomial solution.", 

137 "Squared": "Create a single order squared solution.", 

138 "Spline": "Create a spline based solution.", 

139 "None": "Create a dummy solution.", 

140 } 

141 ) 

142 polynomialOrder = pexConfig.RangeField( 

143 dtype=int, 

144 doc="Degree of polynomial to fit. Must be at least 2.", 

145 default=3, 

146 min=2, 

147 ) 

148 splineKnots = pexConfig.Field( 

149 dtype=int, 

150 doc="Number of spline knots to use in fit.", 

151 default=10, 

152 ) 

153 maxLookupTableAdu = pexConfig.Field( 

154 dtype=int, 

155 doc="Maximum DN value for a LookupTable linearizer.", 

156 default=2**18, 

157 ) 

158 maxLinearAdu = pexConfig.Field( 

159 dtype=float, 

160 doc="Maximum DN value to use to estimate linear term.", 

161 default=20000.0, 

162 ) 

163 minLinearAdu = pexConfig.Field( 

164 dtype=float, 

165 doc="Minimum DN value to use to estimate linear term.", 

166 default=30.0, 

167 ) 

168 nSigmaClipLinear = pexConfig.Field( 

169 dtype=float, 

170 doc="Maximum deviation from linear solution for Poissonian noise.", 

171 default=5.0, 

172 ) 

173 ignorePtcMask = pexConfig.Field( 

174 dtype=bool, 

175 doc="Ignore the expIdMask set by the PTC solver?", 

176 default=False, 

177 ) 

178 usePhotodiode = pexConfig.Field( 

179 dtype=bool, 

180 doc="Use the photodiode info instead of the raw expTimes?", 

181 default=False, 

182 ) 

183 photodiodeIntegrationMethod = pexConfig.ChoiceField( 

184 dtype=str, 

185 doc="Integration method for photodiode monitoring data.", 

186 default="DIRECT_SUM", 

187 allowed={ 

188 "DIRECT_SUM": ("Use numpy's trapz integrator on all photodiode " 

189 "readout entries"), 

190 "TRIMMED_SUM": ("Use numpy's trapz integrator, clipping the " 

191 "leading and trailing entries, which are " 

192 "nominally at zero baseline level."), 

193 "CHARGE_SUM": ("Treat the current values as integrated charge " 

194 "over the sampling interval and simply sum " 

195 "the values, after subtracting a baseline level."), 

196 }, 

197 # TODO: remove on DM-40065. 

198 deprecated="This config has been moved to cpExtractPtcTask, and will be removed after v26.", 

199 ) 

200 photodiodeCurrentScale = pexConfig.Field( 

201 dtype=float, 

202 doc="Scale factor to apply to photodiode current values for the " 

203 "``CHARGE_SUM`` integration method.", 

204 default=-1.0, 

205 # TODO: remove on DM-40065. 

206 deprecated="This config has been moved to cpExtractPtcTask, and will be removed after v26.", 

207 ) 

208 applyPhotodiodeCorrection = pexConfig.Field( 

209 dtype=bool, 

210 doc="Calculate and apply a correction to the photodiode readings?", 

211 default=False, 

212 ) 

213 splineGroupingColumn = pexConfig.Field( 

214 dtype=str, 

215 doc="Column to use for grouping together points for Spline mode, to allow " 

216 "for different proportionality constants. If not set, no grouping " 

217 "will be done.", 

218 default=None, 

219 optional=True, 

220 ) 

221 splineGroupingMinPoints = pexConfig.Field( 

222 dtype=int, 

223 doc="Minimum number of linearity points to allow grouping together points " 

224 "for Spline mode with splineGroupingColumn. This configuration is here " 

225 "to prevent misuse of the Spline code to avoid over-fitting.", 

226 default=100, 

227 ) 

228 splineFitMinIter = pexConfig.Field( 

229 dtype=int, 

230 doc="Minimum number of iterations for spline fit.", 

231 default=3, 

232 ) 

233 splineFitMaxIter = pexConfig.Field( 

234 dtype=int, 

235 doc="Maximum number of iterations for spline fit.", 

236 default=20, 

237 ) 

238 splineFitMaxRejectionPerIteration = pexConfig.Field( 

239 dtype=int, 

240 doc="Maximum number of rejections per iteration for spline fit.", 

241 default=5, 

242 ) 

243 

244 

245class LinearitySolveTask(pipeBase.PipelineTask): 

246 """Fit the linearity from the PTC dataset. 

247 """ 

248 

249 ConfigClass = LinearitySolveConfig 

250 _DefaultName = 'cpLinearitySolve' 

251 

252 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

253 """Ensure that the input and output dimensions are passed along. 

254 

255 Parameters 

256 ---------- 

257 butlerQC : `lsst.daf.butler.QuantumContext` 

258 Butler to operate on. 

259 inputRefs : `lsst.pipe.base.InputQuantizedConnection` 

260 Input data refs to load. 

261 ouptutRefs : `lsst.pipe.base.OutputQuantizedConnection` 

262 Output data refs to persist. 

263 """ 

264 inputs = butlerQC.get(inputRefs) 

265 

266 # Use the dimensions to set calib/provenance information. 

267 inputs['inputDims'] = dict(inputRefs.inputPtc.dataId.required) 

268 

269 # Add calibration provenance info to header. 

270 kwargs = dict() 

271 reference = getattr(inputRefs, "inputPtc", None) 

272 

273 if reference is not None and hasattr(reference, "run"): 

274 runKey = "PTC_RUN" 

275 runValue = reference.run 

276 idKey = "PTC_UUID" 

277 idValue = str(reference.id) 

278 dateKey = "PTC_DATE" 

279 calib = inputs.get("inputPtc", None) 

280 dateValue = extractCalibDate(calib) 

281 

282 kwargs[runKey] = runValue 

283 kwargs[idKey] = idValue 

284 kwargs[dateKey] = dateValue 

285 

286 self.log.info("Using " + str(reference.run)) 

287 

288 outputs = self.run(**inputs) 

289 outputs.outputLinearizer.updateMetadata(setDate=False, **kwargs) 

290 

291 butlerQC.put(outputs, outputRefs) 

292 

293 def run(self, inputPtc, dummy, camera, inputDims, 

294 inputPhotodiodeCorrection=None): 

295 """Fit non-linearity to PTC data, returning the correct Linearizer 

296 object. 

297 

298 Parameters 

299 ---------- 

300 inputPtc : `lsst.ip.isr.PtcDataset` 

301 Pre-measured PTC dataset. 

302 dummy : `lsst.afw.image.Exposure` 

303 The exposure used to select the appropriate PTC dataset. 

304 In almost all circumstances, one of the input exposures 

305 used to generate the PTC dataset is the best option. 

306 inputPhotodiodeCorrection : `lsst.ip.isr.PhotodiodeCorrection` 

307 Pre-measured photodiode correction used in the case when 

308 applyPhotodiodeCorrection=True. 

309 camera : `lsst.afw.cameraGeom.Camera` 

310 Camera geometry. 

311 inputDims : `lsst.daf.butler.DataCoordinate` or `dict` 

312 DataIds to use to populate the output calibration. 

313 

314 Returns 

315 ------- 

316 results : `lsst.pipe.base.Struct` 

317 The results struct containing: 

318 

319 ``outputLinearizer`` 

320 Final linearizer calibration (`lsst.ip.isr.Linearizer`). 

321 ``outputProvenance`` 

322 Provenance data for the new calibration 

323 (`lsst.ip.isr.IsrProvenance`). 

324 

325 Notes 

326 ----- 

327 This task currently fits only polynomial-defined corrections, 

328 where the correction coefficients are defined such that: 

329 :math:`corrImage = uncorrImage + \\sum_i c_i uncorrImage^(2 + i)` 

330 These :math:`c_i` are defined in terms of the direct polynomial fit: 

331 :math:`meanVector ~ P(x=timeVector) = \\sum_j k_j x^j` 

332 such that :math:`c_(j-2) = -k_j/(k_1^j)` in units of DN^(1-j) (c.f., 

333 Eq. 37 of 2003.05978). The `config.polynomialOrder` or 

334 `config.splineKnots` define the maximum order of :math:`x^j` to fit. 

335 As :math:`k_0` and :math:`k_1` are degenerate with bias level and gain, 

336 they are not included in the non-linearity correction. 

337 """ 

338 if len(dummy) == 0: 

339 self.log.warning("No dummy exposure found.") 

340 

341 detector = camera[inputDims['detector']] 

342 if self.config.linearityType == 'LookupTable': 

343 table = np.zeros((len(detector), self.config.maxLookupTableAdu), dtype=np.float32) 

344 tableIndex = 0 

345 else: 

346 table = None 

347 tableIndex = None # This will fail if we increment it. 

348 

349 # Initialize the linearizer. 

350 linearizer = Linearizer(detector=detector, table=table, log=self.log) 

351 linearizer.updateMetadataFromExposures([inputPtc]) 

352 if self.config.usePhotodiode and self.config.applyPhotodiodeCorrection: 

353 abscissaCorrections = inputPhotodiodeCorrection.abscissaCorrections 

354 

355 if self.config.linearityType == 'Spline': 

356 if self.config.splineGroupingColumn is not None: 

357 if self.config.splineGroupingColumn not in inputPtc.auxValues: 

358 raise ValueError(f"Config requests grouping by {self.config.splineGroupingColumn}, " 

359 "but this column is not available in inputPtc.auxValues.") 

360 groupingValue = inputPtc.auxValues[self.config.splineGroupingColumn] 

361 else: 

362 groupingValue = np.ones(len(inputPtc.rawMeans[inputPtc.ampNames[0]]), dtype=int) 

363 # We set this to have a value to fill the bad amps. 

364 fitOrder = self.config.splineKnots 

365 else: 

366 fitOrder = self.config.polynomialOrder 

367 

368 for i, amp in enumerate(detector): 

369 ampName = amp.getName() 

370 if ampName in inputPtc.badAmps: 

371 linearizer = self.fillBadAmp(linearizer, fitOrder, inputPtc, amp) 

372 self.log.warning("Amp %s in detector %s has no usable PTC information. Skipping!", 

373 ampName, detector.getName()) 

374 continue 

375 

376 # Check for too few points. 

377 if self.config.linearityType == "Spline" \ 

378 and self.config.splineGroupingColumn is not None \ 

379 and len(inputPtc.inputExpIdPairs[ampName]) < self.config.splineGroupingMinPoints: 

380 raise RuntimeError( 

381 "The input PTC has too few points to reliably run with PD grouping. " 

382 "The recommended course of action is to set splineGroupingColumn to None. " 

383 "If you really know what you are doing, you may reduce " 

384 "config.splineGroupingMinPoints.") 

385 

386 if (len(inputPtc.expIdMask[ampName]) == 0) or self.config.ignorePtcMask: 

387 self.log.warning("Mask not found for %s in detector %s in fit. Using all points.", 

388 ampName, detector.getName()) 

389 mask = np.ones(len(inputPtc.expIdMask[ampName]), dtype=bool) 

390 else: 

391 mask = inputPtc.expIdMask[ampName].copy() 

392 

393 if self.config.usePhotodiode: 

394 modExpTimes = inputPtc.photoCharges[ampName].copy() 

395 # Make sure any exposure pairs that do not have photodiode data 

396 # are masked. 

397 mask[~np.isfinite(modExpTimes)] = False 

398 

399 # Get the photodiode correction. 

400 if self.config.applyPhotodiodeCorrection: 

401 for j, pair in enumerate(inputPtc.inputExpIdPairs[ampName]): 

402 try: 

403 correction = abscissaCorrections[str(pair)] 

404 except KeyError: 

405 correction = 0.0 

406 modExpTimes[j] += correction 

407 

408 inputAbscissa = modExpTimes 

409 else: 

410 inputAbscissa = inputPtc.rawExpTimes[ampName].copy() 

411 

412 inputOrdinate = inputPtc.rawMeans[ampName].copy() 

413 

414 mask &= (inputOrdinate < self.config.maxLinearAdu) 

415 mask &= (inputOrdinate > self.config.minLinearAdu) 

416 

417 if mask.sum() < 2: 

418 linearizer = self.fillBadAmp(linearizer, fitOrder, inputPtc, amp) 

419 self.log.warning("Amp %s in detector %s has not enough points for fit. Skipping!", 

420 ampName, detector.getName()) 

421 continue 

422 

423 if self.config.linearityType != 'Spline': 

424 linearFit, linearFitErr, chiSq, weights = irlsFit([0.0, 100.0], inputAbscissa[mask], 

425 inputOrdinate[mask], funcPolynomial) 

426 

427 # Convert this proxy-to-flux fit into an expected linear flux 

428 linearOrdinate = linearFit[0] + linearFit[1] * inputAbscissa 

429 # Exclude low end outliers. 

430 # This is compared to the original values. 

431 threshold = self.config.nSigmaClipLinear * np.sqrt(abs(inputOrdinate)) 

432 

433 mask[np.abs(inputOrdinate - linearOrdinate) >= threshold] = False 

434 

435 if mask.sum() < 2: 

436 linearizer = self.fillBadAmp(linearizer, fitOrder, inputPtc, amp) 

437 self.log.warning("Amp %s in detector %s has not enough points in linear ordinate. " 

438 "Skipping!", ampName, detector.getName()) 

439 continue 

440 

441 self.debugFit('linearFit', inputAbscissa, inputOrdinate, linearOrdinate, mask, ampName) 

442 

443 # Do fits 

444 if self.config.linearityType in ['Polynomial', 'Squared', 'LookupTable']: 

445 polyFit = np.zeros(fitOrder + 1) 

446 polyFit[1] = 1.0 

447 polyFit, polyFitErr, chiSq, weights = irlsFit(polyFit, linearOrdinate[mask], 

448 inputOrdinate[mask], funcPolynomial) 

449 

450 # Truncate the polynomial fit to the squared term. 

451 k1 = polyFit[1] 

452 linearityCoeffs = np.array( 

453 [-coeff/(k1**order) for order, coeff in enumerate(polyFit)] 

454 )[2:] 

455 significant = np.where(np.abs(linearityCoeffs) > 1e-10) 

456 self.log.info("Significant polynomial fits: %s", significant) 

457 

458 modelOrdinate = funcPolynomial(polyFit, linearOrdinate) 

459 

460 self.debugFit( 

461 'polyFit', 

462 inputAbscissa[mask], 

463 inputOrdinate[mask], 

464 modelOrdinate[mask], 

465 None, 

466 ampName, 

467 ) 

468 

469 if self.config.linearityType == 'Squared': 

470 # The first term is the squared term. 

471 linearityCoeffs = linearityCoeffs[0: 1] 

472 elif self.config.linearityType == 'LookupTable': 

473 # Use linear part to get time at which signal is 

474 # maxAduForLookupTableLinearizer DN 

475 tMax = (self.config.maxLookupTableAdu - polyFit[0])/polyFit[1] 

476 timeRange = np.linspace(0, tMax, self.config.maxLookupTableAdu) 

477 signalIdeal = polyFit[0] + polyFit[1]*timeRange 

478 signalUncorrected = funcPolynomial(polyFit, timeRange) 

479 lookupTableRow = signalIdeal - signalUncorrected # LinearizerLookupTable has correction 

480 

481 linearizer.tableData[tableIndex, :] = lookupTableRow 

482 linearityCoeffs = np.array([tableIndex, 0]) 

483 tableIndex += 1 

484 elif self.config.linearityType in ['Spline']: 

485 # This is a spline fit with photodiode data based on a model 

486 # from Pierre Astier. 

487 # This model fits a spline with (optional) nuisance parameters 

488 # to allow for different linearity coefficients with different 

489 # photodiode settings. The minimization is a least-squares 

490 # fit with the residual of 

491 # Sum[(S(mu_i) + mu_i)/(k_j * D_i) - 1]**2, where S(mu_i) is 

492 # an Akima Spline function of mu_i, the observed flat-pair 

493 # mean; D_j is the photo-diode measurement corresponding to 

494 # that flat-pair; and k_j is a constant of proportionality 

495 # which is over index j as it is allowed to 

496 # be different based on different photodiode settings (e.g. 

497 # CCOBCURR). 

498 

499 # The fit has additional constraints to ensure that the spline 

500 # goes through the (0, 0) point, as well as a normalization 

501 # condition so that the average of the spline over the full 

502 # range is 0. The normalization ensures that the spline only 

503 # fits deviations from linearity, rather than the linear 

504 # function itself which is degenerate with the gain. 

505 

506 nodes = np.linspace(0.0, np.max(inputOrdinate[mask]), self.config.splineKnots) 

507 

508 fitter = AstierSplineLinearityFitter( 

509 nodes, 

510 groupingValue, 

511 inputAbscissa, 

512 inputOrdinate, 

513 mask=mask, 

514 log=self.log, 

515 ) 

516 p0 = fitter.estimate_p0() 

517 pars = fitter.fit( 

518 p0, 

519 min_iter=self.config.splineFitMinIter, 

520 max_iter=self.config.splineFitMaxIter, 

521 max_rejection_per_iteration=self.config.splineFitMaxRejectionPerIteration, 

522 n_sigma_clip=self.config.nSigmaClipLinear, 

523 ) 

524 

525 # Confirm that the first parameter is 0, and set it to 

526 # exactly zero. 

527 if not np.isclose(pars[0], 0): 

528 raise RuntimeError("Programmer error! First spline parameter must " 

529 "be consistent with zero.") 

530 pars[0] = 0.0 

531 

532 linearityCoeffs = np.concatenate([nodes, pars[0: len(nodes)]]) 

533 linearFit = np.array([0.0, np.mean(pars[len(nodes):])]) 

534 

535 # We modify the inputAbscissa according to the linearity fits 

536 # here, for proper residual computation. 

537 for j, group_index in enumerate(fitter.group_indices): 

538 inputOrdinate[group_index] /= (pars[len(nodes) + j] / linearFit[1]) 

539 

540 linearOrdinate = linearFit[1] * inputOrdinate 

541 # For the spline fit, reuse the "polyFit -> fitParams" 

542 # field to record the linear coefficients for the groups. 

543 polyFit = pars[len(nodes):] 

544 polyFitErr = np.zeros_like(polyFit) 

545 chiSq = np.nan 

546 

547 # Update mask based on what the fitter rejected. 

548 mask = fitter.mask 

549 else: 

550 polyFit = np.zeros(1) 

551 polyFitErr = np.zeros(1) 

552 chiSq = np.nan 

553 linearityCoeffs = np.zeros(1) 

554 

555 linearizer.linearityType[ampName] = self.config.linearityType 

556 linearizer.linearityCoeffs[ampName] = linearityCoeffs 

557 linearizer.linearityBBox[ampName] = amp.getBBox() 

558 linearizer.fitParams[ampName] = polyFit 

559 linearizer.fitParamsErr[ampName] = polyFitErr 

560 linearizer.fitChiSq[ampName] = chiSq 

561 linearizer.linearFit[ampName] = linearFit 

562 

563 image = afwImage.ImageF(len(inputOrdinate), 1) 

564 image.array[:, :] = inputOrdinate 

565 linearizeFunction = linearizer.getLinearityTypeByName(linearizer.linearityType[ampName]) 

566 linearizeFunction()( 

567 image, 

568 **{'coeffs': linearizer.linearityCoeffs[ampName], 

569 'table': linearizer.tableData, 

570 'log': linearizer.log} 

571 ) 

572 linearizeModel = image.array[0, :] 

573 

574 # The residuals that we record are the final residuals compared to 

575 # a linear model, after everything has been (properly?) linearized. 

576 if mask.sum() < 2: 

577 self.log.warning("Amp %s in detector %s has not enough points in linear ordinate " 

578 "for residuals. Skipping!", ampName, detector.getName()) 

579 residuals = np.full_like(linearizeModel, np.nan) 

580 else: 

581 postLinearFit, _, _, _ = irlsFit( 

582 [0.0, 100.0], 

583 inputAbscissa[mask], 

584 linearizeModel[mask], 

585 funcPolynomial, 

586 ) 

587 residuals = linearizeModel - (postLinearFit[0] + postLinearFit[1] * inputAbscissa) 

588 # We set masked residuals to nan. 

589 residuals[~mask] = np.nan 

590 

591 linearizer.fitResiduals[ampName] = residuals 

592 

593 self.debugFit( 

594 'solution', 

595 inputOrdinate[mask], 

596 linearOrdinate[mask], 

597 linearizeModel[mask], 

598 None, 

599 ampName, 

600 ) 

601 

602 linearizer.hasLinearity = True 

603 linearizer.validate() 

604 linearizer.updateMetadata(camera=camera, detector=detector, filterName='NONE') 

605 linearizer.updateMetadata(setDate=True, setCalibId=True) 

606 provenance = IsrProvenance(calibType='linearizer') 

607 

608 return pipeBase.Struct( 

609 outputLinearizer=linearizer, 

610 outputProvenance=provenance, 

611 ) 

612 

613 def fillBadAmp(self, linearizer, fitOrder, inputPtc, amp): 

614 # Need to fill linearizer with empty values 

615 # if the amp is non-functional 

616 ampName = amp.getName() 

617 nEntries = 1 

618 pEntries = 1 

619 if self.config.linearityType in ['Polynomial']: 

620 nEntries = fitOrder + 1 

621 pEntries = fitOrder + 1 

622 elif self.config.linearityType in ['Spline']: 

623 nEntries = fitOrder * 2 

624 elif self.config.linearityType in ['Squared', 'None']: 

625 nEntries = 1 

626 pEntries = fitOrder + 1 

627 elif self.config.linearityType in ['LookupTable']: 

628 nEntries = 2 

629 pEntries = fitOrder + 1 

630 

631 linearizer.linearityType[ampName] = "None" 

632 linearizer.linearityCoeffs[ampName] = np.zeros(nEntries) 

633 linearizer.linearityBBox[ampName] = amp.getBBox() 

634 linearizer.fitParams[ampName] = np.zeros(pEntries) 

635 linearizer.fitParamsErr[ampName] = np.zeros(pEntries) 

636 linearizer.fitChiSq[ampName] = np.nan 

637 linearizer.fitResiduals[ampName] = np.zeros(len(inputPtc.expIdMask[ampName])) 

638 linearizer.linearFit[ampName] = np.zeros(2) 

639 return linearizer 

640 

641 def debugFit(self, stepname, xVector, yVector, yModel, mask, ampName): 

642 """Debug method for linearity fitting. 

643 

644 Parameters 

645 ---------- 

646 stepname : `str` 

647 A label to use to check if we care to debug at a given 

648 line of code. 

649 xVector : `numpy.array`, (N,) 

650 The values to use as the independent variable in the 

651 linearity fit. 

652 yVector : `numpy.array`, (N,) 

653 The values to use as the dependent variable in the 

654 linearity fit. 

655 yModel : `numpy.array`, (N,) 

656 The values to use as the linearized result. 

657 mask : `numpy.array` [`bool`], (N,) , optional 

658 A mask to indicate which entries of ``xVector`` and 

659 ``yVector`` to keep. 

660 ampName : `str` 

661 Amplifier name to lookup linearity correction values. 

662 """ 

663 frame = getDebugFrame(self._display, stepname) 

664 if frame: 

665 import matplotlib.pyplot as plt 

666 fig, axs = plt.subplots(2) 

667 

668 if mask is None: 

669 mask = np.ones_like(xVector, dtype=bool) 

670 

671 fig.suptitle(f"{stepname} {ampName} {self.config.linearityType}") 

672 if stepname == 'linearFit': 

673 axs[0].set_xlabel("Input Abscissa (time or mondiode)") 

674 axs[0].set_ylabel("Input Ordinate (flux)") 

675 axs[1].set_xlabel("Linear Ordinate (linear flux)") 

676 axs[1].set_ylabel("Flux Difference: (input - linear)") 

677 elif stepname in ('polyFit', 'splineFit'): 

678 axs[0].set_xlabel("Linear Abscissa (linear flux)") 

679 axs[0].set_ylabel("Input Ordinate (flux)") 

680 axs[1].set_xlabel("Linear Ordinate (linear flux)") 

681 axs[1].set_ylabel("Flux Difference: (input - full model fit)") 

682 elif stepname == 'solution': 

683 axs[0].set_xlabel("Input Abscissa (time or mondiode)") 

684 axs[0].set_ylabel("Linear Ordinate (linear flux)") 

685 axs[1].set_xlabel("Model flux (linear flux)") 

686 axs[1].set_ylabel("Flux Difference: (linear - model)") 

687 

688 axs[0].set_yscale('log') 

689 axs[0].set_xscale('log') 

690 axs[0].scatter(xVector, yVector) 

691 axs[0].scatter(xVector[~mask], yVector[~mask], c='red', marker='x') 

692 axs[1].set_xscale('log') 

693 

694 axs[1].scatter(yModel, yVector[mask] - yModel) 

695 fig.tight_layout() 

696 fig.show() 

697 

698 prompt = "Press Enter or c to continue [chpx]..." 

699 while True: 

700 ans = input(prompt).lower() 

701 if ans in ("", " ", "c",): 

702 break 

703 elif ans in ("p", ): 

704 import pdb 

705 pdb.set_trace() 

706 elif ans in ("h", ): 

707 print("[h]elp [c]ontinue [p]db") 

708 elif ans in ('x', ): 

709 exit() 

710 plt.close()