Coverage for python/lsst/meas/algorithms/simple_curve.py: 32%

154 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-11-17 02:15 -0800

1# 

2# LSST Data Management System 

3# 

4# Copyright 2019 AURA/LSST. 

5# 

6# This product includes software developed by the 

7# LSST Project (http://www.lsst.org/). 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the LSST License Statement and 

20# the GNU General Public License along with this program. If not, 

21# see <https://www.lsstcorp.org/LegalNotices/>. 

22# 

23 

24__all__ = ["Curve", "AmpCurve", "DetectorCurve", "ImageCurve"] 

25 

26from scipy.interpolate import interp1d 

27from astropy.table import QTable 

28import astropy.units as u 

29from abc import ABC, abstractmethod 

30import datetime 

31import os 

32import numpy 

33 

34import lsst.afw.cameraGeom.utils as cgUtils 

35from lsst.geom import Point2I 

36 

37 

38class Curve(ABC): 

39 """ An abstract class to represent an arbitrary curve with 

40 interpolation. 

41 """ 

42 mode = '' 

43 subclasses = dict() 

44 

45 def __init__(self, wavelength, efficiency, metadata): 

46 if not (isinstance(wavelength, u.Quantity) and wavelength.unit.physical_type == 'length'): 

47 raise ValueError('The wavelength must be a quantity with a length sense.') 

48 if not isinstance(efficiency, u.Quantity) or efficiency.unit != u.percent: 

49 raise ValueError('The efficiency must be a quantity with units of percent.') 

50 self.wavelength = wavelength 

51 self.efficiency = efficiency 

52 # make sure needed metadata is set if built directly from ctor. 

53 metadata.update({'MODE': self.mode, 'TYPE': 'QE'}) 

54 self.metadata = metadata 

55 

56 @classmethod 

57 @abstractmethod 

58 def fromTable(cls, table): 

59 """Class method for constructing a `Curve` object. 

60 

61 Parameters 

62 ---------- 

63 table : `astropy.table.QTable` 

64 Table containing metadata and columns necessary 

65 for constructing a `Curve` object. 

66 

67 Returns 

68 ------- 

69 curve : `Curve` 

70 A `Curve` subclass of the appropriate type according 

71 to the table metadata 

72 """ 

73 pass 

74 

75 @abstractmethod 

76 def toTable(self): 

77 """Convert this `Curve` object to an `astropy.table.QTable`. 

78 

79 Returns 

80 ------- 

81 table : `astropy.table.QTable` 

82 A table object containing the data from this `Curve`. 

83 """ 

84 pass 

85 

86 @abstractmethod 

87 def evaluate(self, detector, position, wavelength, kind='linear', bounds_error=False, fill_value=0): 

88 """Interpolate the curve at the specified position and wavelength. 

89 

90 Parameters 

91 ---------- 

92 detector : `lsst.afw.cameraGeom.Detector` 

93 Is used to find the appropriate curve given the position for 

94 curves that vary over the detector. Ignored in the case where 

95 there is only a single curve per detector. 

96 position : `lsst.geom.Point2D` 

97 The position on the detector at which to evaluate the curve. 

98 wavelength : `astropy.units.Quantity` 

99 The wavelength(s) at which to make the interpolation. 

100 kind : `str`, optional 

101 The type of interpolation to do (default is 'linear'). 

102 See documentation for `scipy.interpolate.interp1d` for 

103 accepted values. 

104 bounds_error : `bool`, optional 

105 Raise error if interpolating outside the range of x? 

106 (default is False) 

107 fill_value : `float`, optional 

108 Fill values outside the range of x with this value 

109 (default is 0). 

110 

111 Returns 

112 ------- 

113 value : `astropy.units.Quantity` 

114 Interpolated value(s). Number of values returned will match the 

115 length of `wavelength`. 

116 

117 Raises 

118 ------ 

119 ValueError 

120 If the ``bounds_error`` is changed from the default, it will raise 

121 a `ValueError` if evaluating outside the bounds of the curve. 

122 """ 

123 pass 

124 

125 @classmethod 

126 def __init_subclass__(cls, **kwargs): 

127 """Register subclasses with the abstract base class""" 

128 super().__init_subclass__(**kwargs) 

129 if cls.mode in Curve.subclasses: 129 ↛ 130line 129 didn't jump to line 130, because the condition on line 129 was never true

130 raise ValueError(f'Class for mode, {cls.mode}, already defined') 

131 Curve.subclasses[cls.mode] = cls 

132 

133 @abstractmethod 

134 def __eq__(self, other): 

135 """Define equality for this class""" 

136 pass 

137 

138 def compare_metadata(self, other, 

139 keys_to_compare=['MODE', 'TYPE', 'CALIBDATE', 'INSTRUME', 'OBSTYPE', 'DETECTOR']): 

140 """Compare metadata in this object to another. 

141 

142 Parameters 

143 ---------- 

144 other : `Curve` 

145 The object with which to compare metadata. 

146 keys_to_compare : `list` 

147 List of metadata keys to compare. 

148 

149 Returns 

150 ------- 

151 same : `bool` 

152 Are the metadata the same? 

153 """ 

154 for k in keys_to_compare: 

155 if self.metadata[k] != other.metadata[k]: 

156 return False 

157 return True 

158 

159 def interpolate(self, wavelengths, values, wavelength, kind, bounds_error, fill_value): 

160 """Interplate the curve at the specified wavelength(s). 

161 

162 Parameters 

163 ---------- 

164 wavelengths : `astropy.units.Quantity` 

165 The wavelength values for the curve. 

166 values : `astropy.units.Quantity` 

167 The y-values for the curve. 

168 wavelength : `astropy.units.Quantity` 

169 The wavelength(s) at which to make the interpolation. 

170 kind : `str` 

171 The type of interpolation to do. See documentation for 

172 `scipy.interpolate.interp1d` for accepted values. 

173 

174 Returns 

175 ------- 

176 value : `astropy.units.Quantity` 

177 Interpolated value(s) 

178 """ 

179 if not isinstance(wavelength, u.Quantity): 

180 raise ValueError("Wavelengths at which to interpolate must be astropy quantities") 

181 if not (isinstance(wavelengths, u.Quantity) and isinstance(values, u.Quantity)): 

182 raise ValueError("Model to be interpreted must be astropy quantities") 

183 interp_wavelength = wavelength.to(wavelengths.unit) 

184 f = interp1d(wavelengths, values, kind=kind, bounds_error=bounds_error, fill_value=fill_value) 

185 return f(interp_wavelength.value)*values.unit 

186 

187 def getMetadata(self): 

188 """Return metadata 

189 

190 Returns 

191 ------- 

192 metadata : `dict` 

193 Dictionary of metadata for this curve. 

194 """ 

195 # Needed to duck type as an object that can be ingested 

196 return self.metadata 

197 

198 @classmethod 

199 def readText(cls, filename): 

200 """Class method for constructing a `Curve` object from 

201 the standardized text format. 

202 

203 Parameters 

204 ---------- 

205 filename : `str` 

206 Path to the text file to read. 

207 

208 Returns 

209 ------- 

210 curve : `Curve` 

211 A `Curve` subclass of the appropriate type according 

212 to the table metadata 

213 """ 

214 table = QTable.read(filename, format='ascii.ecsv') 

215 return cls.subclasses[table.meta['MODE']].fromTable(table) 

216 

217 @classmethod 

218 def readFits(cls, filename): 

219 """Class method for constructing a `Curve` object from 

220 the standardized FITS format. 

221 

222 Parameters 

223 ---------- 

224 filename : `str` 

225 Path to the FITS file to read. 

226 

227 Returns 

228 ------- 

229 curve : `Curve` 

230 A `Curve` subclass of the appropriate type according 

231 to the table metadata 

232 """ 

233 table = QTable.read(filename, format='fits') 

234 return cls.subclasses[table.meta['MODE']].fromTable(table) 

235 

236 @staticmethod 

237 def _check_cols(cols, table): 

238 """Check that the columns are in the table""" 

239 for col in cols: 

240 if col not in table.columns: 

241 raise ValueError(f'The table must include a column named "{col}".') 

242 

243 def _to_table_with_meta(self): 

244 """Compute standard metadata before writing file out""" 

245 now = datetime.datetime.utcnow() 

246 table = self.toTable() 

247 metadata = table.meta 

248 metadata["DATE"] = now.isoformat() 

249 metadata["CALIB_CREATION_DATE"] = now.strftime("%Y-%m-%d") 

250 metadata["CALIB_CREATION_TIME"] = now.strftime("%T %Z").strip() 

251 return table 

252 

253 def writeText(self, filename): 

254 """ Write the `Curve` out to a text file. 

255 

256 Parameters 

257 ---------- 

258 filename : `str` 

259 Path to the text file to write. 

260 

261 Returns 

262 ------- 

263 filename : `str` 

264 Because this method forces a particular extension return 

265 the name of the file actually written. 

266 """ 

267 table = self._to_table_with_meta() 

268 # Force file extension to .ecsv 

269 path, ext = os.path.splitext(filename) 

270 filename = path + ".ecsv" 

271 table.write(filename, format="ascii.ecsv") 

272 return filename 

273 

274 def writeFits(self, filename): 

275 """ Write the `Curve` out to a FITS file. 

276 

277 Parameters 

278 ---------- 

279 filename : `str` 

280 Path to the FITS file to write. 

281 

282 Returns 

283 ------- 

284 filename : `str` 

285 Because this method forces a particular extension return 

286 the name of the file actually written. 

287 """ 

288 table = self._to_table_with_meta() 

289 # Force file extension to .ecsv 

290 path, ext = os.path.splitext(filename) 

291 filename = path + ".fits" 

292 table.write(filename, format="fits") 

293 return filename 

294 

295 

296class DetectorCurve(Curve): 

297 """Subclass of `Curve` that represents a single curve per detector. 

298 

299 Parameters 

300 ---------- 

301 wavelength : `astropy.units.Quantity` 

302 Wavelength values for this curve 

303 efficiency : `astropy.units.Quantity` 

304 Quantum efficiency values for this curve 

305 metadata : `dict` 

306 Dictionary of metadata for this curve 

307 """ 

308 mode = 'DETECTOR' 

309 

310 def __eq__(self, other): 

311 return (self.compare_metadata(other) 

312 and numpy.array_equal(self.wavelength, other.wavelength) 

313 and numpy.array_equal(self.wavelength, other.wavelength)) 

314 

315 @classmethod 

316 def fromTable(cls, table): 

317 # Docstring inherited from base classs 

318 cls._check_cols(['wavelength', 'efficiency'], table) 

319 return cls(table['wavelength'], table['efficiency'], table.meta) 

320 

321 def toTable(self): 

322 # Docstring inherited from base classs 

323 return QTable({'wavelength': self.wavelength, 'efficiency': self.efficiency}, meta=self.metadata) 

324 

325 def evaluate(self, detector, position, wavelength, kind='linear', bounds_error=False, fill_value=0): 

326 # Docstring inherited from base classs 

327 return self.interpolate(self.wavelength, self.efficiency, wavelength, 

328 kind=kind, bounds_error=bounds_error, fill_value=fill_value) 

329 

330 

331class AmpCurve(Curve): 

332 """Subclass of `Curve` that represents a curve per amp. 

333 

334 Parameters 

335 ---------- 

336 amp_name_list : iterable of `str` 

337 The name of the amp for each entry 

338 wavelength : `astropy.units.Quantity` 

339 Wavelength values for this curve 

340 efficiency : `astropy.units.Quantity` 

341 Quantum efficiency values for this curve 

342 metadata : `dict` 

343 Dictionary of metadata for this curve 

344 """ 

345 mode = 'AMP' 

346 

347 def __init__(self, amp_name_list, wavelength, efficiency, metadata): 

348 super().__init__(wavelength, efficiency, metadata) 

349 amp_names = set(amp_name_list) 

350 self.data = {} 

351 for amp_name in amp_names: 

352 idx = numpy.where(amp_name_list == amp_name)[0] 

353 # Deal with the case where the keys are bytes from FITS 

354 name = amp_name 

355 if isinstance(name, bytes): 

356 name = name.decode() 

357 self.data[name] = (wavelength[idx], efficiency[idx]) 

358 

359 def __eq__(self, other): 

360 if not self.compare_metadata(other): 

361 return False 

362 for k in self.data: 

363 if not numpy.array_equal(self.data[k][0], other.data[k][0]): 

364 return False 

365 if not numpy.array_equal(self.data[k][1], other.data[k][1]): 

366 return False 

367 return True 

368 

369 @classmethod 

370 def fromTable(cls, table): 

371 # Docstring inherited from base classs 

372 cls._check_cols(['amp_name', 'wavelength', 'efficiency'], table) 

373 return cls(table['amp_name'], table['wavelength'], 

374 table['efficiency'], table.meta) 

375 

376 def toTable(self): 

377 # Docstring inherited from base classs 

378 wavelength = None 

379 efficiency = None 

380 names = numpy.array([]) 

381 # Loop over the amps and concatenate into three same length columns to feed 

382 # to the Table constructor. 

383 for amp_name, val in self.data.items(): 

384 # This will preserve the quantity 

385 if wavelength is None: 

386 wunit = val[0].unit 

387 wavelength = val[0].value 

388 else: 

389 wavelength = numpy.concatenate([wavelength, val[0].value]) 

390 if efficiency is None: 

391 eunit = val[1].unit 

392 efficiency = val[1].value 

393 else: 

394 efficiency = numpy.concatenate([efficiency, val[1].value]) 

395 names = numpy.concatenate([names, numpy.full(val[0].shape, amp_name)]) 

396 names = numpy.array(names) 

397 # Note that in future, the astropy.unit should make it through concatenation 

398 return QTable({'amp_name': names, 'wavelength': wavelength*wunit, 'efficiency': efficiency*eunit}, 

399 meta=self.metadata) 

400 

401 def evaluate(self, detector, position, wavelength, kind='linear', bounds_error=False, fill_value=0): 

402 # Docstring inherited from base classs 

403 amp = cgUtils.findAmp(detector, Point2I(position)) # cast to Point2I if Point2D passed 

404 w, e = self.data[amp.getName()] 

405 return self.interpolate(w, e, wavelength, kind=kind, bounds_error=bounds_error, 

406 fill_value=fill_value) 

407 

408 

409class ImageCurve(Curve): 

410 mode = 'IMAGE' 

411 

412 def fromTable(self, table): 

413 # Docstring inherited from base classs 

414 raise NotImplementedError() 

415 

416 def toTable(self): 

417 # Docstring inherited from base classs 

418 raise NotImplementedError() 

419 

420 def evaluate(self, detector, position, wavelength, kind='linear', bounds_error=False, fill_value=0): 

421 # Docstring inherited from base classs 

422 raise NotImplementedError()