Coverage for python/lsst/scarlet/lite/io.py: 47%

146 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-16 02:46 -0700

1from __future__ import annotations 

2 

3import json 

4import logging 

5from dataclasses import dataclass 

6from typing import Any, Callable 

7 

8import numpy as np 

9from numpy.typing import DTypeLike 

10 

11from .bbox import Box 

12from .blend import Blend 

13from .component import Component, FactorizedComponent 

14from .image import Image 

15from .observation import Observation 

16from .parameters import FixedParameter 

17from .source import Source 

18 

19__all__ = [ 

20 "ScarletComponentData", 

21 "ScarletFactorizedComponentData", 

22 "ScarletSourceData", 

23 "ScarletBlendData", 

24 "ScarletModelData", 

25 "ComponentCube", 

26] 

27 

28logger = logging.getLogger(__name__) 

29 

30 

31@dataclass(kw_only=True) 

32class ScarletComponentData: 

33 """Data for a component expressed as a 3D data cube 

34 

35 This is used for scarlet component models that are not factorized, 

36 storing their entire model as a 3D data cube (bands, y, x). 

37 

38 Attributes 

39 ---------- 

40 origin: 

41 The lower bound of the components bounding box. 

42 peak: 

43 The peak of the component. 

44 model: 

45 The model for the component. 

46 """ 

47 

48 origin: tuple[int, int] 

49 peak: tuple[float, float] 

50 model: np.ndarray 

51 

52 @property 

53 def shape(self): 

54 return self.model.shape[-2:] 

55 

56 def as_dict(self) -> dict: 

57 """Return the object encoded into a dict for JSON serialization 

58 

59 Returns 

60 ------- 

61 result: 

62 The object encoded as a JSON compatible dict 

63 """ 

64 return { 

65 "origin": self.origin, 

66 "shape": self.model.shape, 

67 "peak": self.peak, 

68 "model": tuple(self.model.flatten().astype(float)), 

69 } 

70 

71 @classmethod 

72 def from_dict(cls, data: dict, dtype: DTypeLike = np.float32) -> ScarletComponentData: 

73 """Reconstruct `ScarletComponentData` from JSON compatible dict 

74 

75 Parameters 

76 ---------- 

77 data: 

78 Dictionary representation of the object 

79 dtype: 

80 Datatype of the resulting model. 

81 

82 Returns 

83 ------- 

84 result: 

85 The reconstructed object 

86 """ 

87 shape = tuple(data["shape"]) 

88 

89 return cls( 

90 origin=tuple(data["origin"]), # type: ignore 

91 peak=data["peak"], 

92 model=np.array(data["model"]).reshape(shape).astype(dtype), 

93 ) 

94 

95 

96@dataclass(kw_only=True) 

97class ScarletFactorizedComponentData: 

98 """Data for a factorized component 

99 

100 Attributes 

101 ---------- 

102 origin: 

103 The lower bound of the component's bounding box. 

104 peak: 

105 The ``(y, x)`` peak of the component. 

106 spectrum: 

107 The SED of the component. 

108 morph: 

109 The 2D morphology of the component. 

110 """ 

111 

112 origin: tuple[int, int] 

113 peak: tuple[float, float] 

114 spectrum: np.ndarray 

115 morph: np.ndarray 

116 

117 @property 

118 def shape(self): 

119 return self.morph.shape 

120 

121 def as_dict(self) -> dict: 

122 """Return the object encoded into a dict for JSON serialization 

123 

124 Returns 

125 ------- 

126 result: 

127 The object encoded as a JSON compatible dict 

128 """ 

129 return { 

130 "origin": tuple(int(o) for o in self.origin), 

131 "shape": tuple(int(s) for s in self.morph.shape), 

132 "peak": tuple(int(p) for p in self.peak), 

133 "spectrum": tuple(self.spectrum.astype(float)), 

134 "morph": tuple(self.morph.flatten().astype(float)), 

135 } 

136 

137 @classmethod 

138 def from_dict(cls, data: dict, dtype: DTypeLike = np.float32) -> ScarletFactorizedComponentData: 

139 """Reconstruct `ScarletFactorizedComponentData` from JSON compatible 

140 dict. 

141 

142 Parameters 

143 ---------- 

144 data: 

145 Dictionary representation of the object 

146 dtype: 

147 Datatype of the resulting model. 

148 

149 Returns 

150 ------- 

151 result: 

152 The reconstructed object 

153 """ 

154 shape = tuple(data["shape"]) 

155 

156 return cls( 

157 origin=tuple(data["origin"]), # type: ignore 

158 peak=data["peak"], 

159 spectrum=np.array(data["spectrum"]).astype(dtype), 

160 morph=np.array(data["morph"]).reshape(shape).astype(dtype), 

161 ) 

162 

163 

164@dataclass(kw_only=True) 

165class ScarletSourceData: 

166 """Data for a scarlet source 

167 

168 Attributes 

169 ---------- 

170 components: 

171 The components contained in the source that are not factorized. 

172 factorized_components: 

173 The components contained in the source that are factorized. 

174 peak_id: 

175 The peak ID of the source in it's parent's footprint peak catalog. 

176 """ 

177 

178 components: list[ScarletComponentData] 

179 factorized_components: list[ScarletFactorizedComponentData] 

180 peak_id: int 

181 

182 def as_dict(self) -> dict: 

183 """Return the object encoded into a dict for JSON serialization 

184 

185 Returns 

186 ------- 

187 result: 

188 The object encoded as a JSON compatible dict 

189 """ 

190 result = { 

191 "components": [component.as_dict() for component in self.components], 

192 "factorized": [component.as_dict() for component in self.factorized_components], 

193 "peak_id": self.peak_id, 

194 } 

195 return result 

196 

197 @classmethod 

198 def from_dict(cls, data: dict, dtype: DTypeLike = np.float32) -> ScarletSourceData: 

199 """Reconstruct `ScarletSourceData` from JSON compatible 

200 dict. 

201 

202 Parameters 

203 ---------- 

204 data: 

205 Dictionary representation of the object 

206 dtype: 

207 Datatype of the resulting model. 

208 

209 Returns 

210 ------- 

211 result: 

212 The reconstructed object 

213 """ 

214 components = [] 

215 for component in data["components"]: 

216 component = ScarletComponentData.from_dict(component, dtype=dtype) 

217 components.append(component) 

218 

219 factorized = [] 

220 for component in data["factorized"]: 

221 component = ScarletFactorizedComponentData.from_dict(component, dtype=dtype) 

222 factorized.append(component) 

223 

224 return cls(components=components, factorized_components=factorized, peak_id=int(data["peak_id"])) 

225 

226 

227@dataclass(kw_only=True) 

228class ScarletBlendData: 

229 """Data for an entire blend. 

230 

231 Attributes 

232 ---------- 

233 origin: 

234 The lower bound of the blend's bounding box. 

235 shape: 

236 The shape of the blend's bounding box. 

237 sources: 

238 Data for the sources contained in the blend, 

239 indexed by the source id. 

240 psf_center: 

241 The location used for the center of the PSF for 

242 the blend. 

243 psf: 

244 The PSF of the observation. 

245 bands : `list` of `str` 

246 The names of the bands. 

247 The order of the bands must be the same as the order of 

248 the multiband model arrays, and SEDs. 

249 """ 

250 

251 origin: tuple[int, int] 

252 shape: tuple[int, int] 

253 sources: dict[int, ScarletSourceData] 

254 psf_center: tuple[float, float] 

255 psf: np.ndarray 

256 bands: tuple[str] 

257 

258 def as_dict(self) -> dict: 

259 """Return the object encoded into a dict for JSON serialization 

260 

261 Returns 

262 ------- 

263 result: 

264 The object encoded as a JSON compatible dict 

265 """ 

266 result = { 

267 "origin": self.origin, 

268 "shape": self.shape, 

269 "psf_center": self.psf_center, 

270 "psf_shape": self.psf.shape, 

271 "psf": tuple(self.psf.flatten().astype(float)), 

272 "sources": {bid: source.as_dict() for bid, source in self.sources.items()}, 

273 "bands": self.bands, 

274 } 

275 return result 

276 

277 @classmethod 

278 def from_dict(cls, data: dict, dtype: DTypeLike = np.float32) -> ScarletBlendData: 

279 """Reconstruct `ScarletBlendData` from JSON compatible 

280 dict. 

281 

282 Parameters 

283 ---------- 

284 data: 

285 Dictionary representation of the object 

286 dtype: 

287 Datatype of the resulting model. 

288 

289 Returns 

290 ------- 

291 result: 

292 The reconstructed object 

293 """ 

294 psf_shape = data["psf_shape"] 

295 return cls( 

296 origin=tuple(data["origin"]), # type: ignore 

297 shape=tuple(data["shape"]), # type: ignore 

298 psf_center=tuple(data["psf_center"]), # type: ignore 

299 psf=np.array(data["psf"]).reshape(psf_shape).astype(dtype), 

300 sources={ 

301 int(bid): ScarletSourceData.from_dict(source, dtype=dtype) 

302 for bid, source in data["sources"].items() 

303 }, 

304 bands=tuple(data["bands"]), # type: ignore 

305 ) 

306 

307 def minimal_data_to_blend(self, model_psf: np.ndarray, dtype: DTypeLike) -> Blend: 

308 """Convert the storage data model into a scarlet lite blend 

309 

310 Parameters 

311 ---------- 

312 model_psf: 

313 PSF in model space (usually a nyquist sampled circular Gaussian). 

314 dtype: 

315 The data type of the model that is generated. 

316 

317 Returns 

318 ------- 

319 blend: 

320 A scarlet blend model extracted from persisted data. 

321 """ 

322 model_box = Box(self.shape, origin=(0, 0)) 

323 observation = Observation.empty( 

324 bands=self.bands, 

325 psfs=self.psf, 

326 model_psf=model_psf, 

327 bbox=model_box, 

328 dtype=dtype, 

329 ) 

330 return self.to_blend(observation) 

331 

332 def to_blend(self, observation: Observation) -> Blend: 

333 """Convert the storage data model into a scarlet lite blend 

334 

335 Parameters 

336 ---------- 

337 observation: 

338 The observation that contains the blend. 

339 If `observation` is ``None`` then an `Observation` containing 

340 no image data is initialized. 

341 

342 Returns 

343 ------- 

344 blend: 

345 A scarlet blend model extracted from persisted data. 

346 """ 

347 sources = [] 

348 for source_id, source_data in self.sources.items(): 

349 components: list[Component] = [] 

350 for component_data in source_data.components: 

351 bbox = Box(component_data.shape, origin=component_data.origin) 

352 model = component_data.model 

353 if component_data.peak is None: 

354 peak = None 

355 else: 

356 peak = (int(np.round(component_data.peak[0])), int(np.round(component_data.peak[0]))) 

357 component = ComponentCube( 

358 bands=observation.bands, 

359 bbox=bbox, 

360 model=Image(model, yx0=bbox.origin, bands=observation.bands), # type: ignore 

361 peak=peak, 

362 ) 

363 components.append(component) 

364 for factorized_data in source_data.factorized_components: 

365 bbox = Box(factorized_data.shape, origin=factorized_data.origin) 

366 # Add dummy values for properties only needed for 

367 # model fitting. 

368 spectrum = FixedParameter(factorized_data.spectrum) 

369 morph = FixedParameter(factorized_data.morph) 

370 # Note: since we aren't fitting a model, we don't need to 

371 # set the RMS of the background. 

372 # We set it to NaN just to be safe. 

373 factorized = FactorizedComponent( 

374 bands=observation.bands, 

375 spectrum=spectrum, 

376 morph=morph, 

377 peak=tuple(int(np.round(p)) for p in factorized_data.peak), # type: ignore 

378 bbox=bbox, 

379 bg_rms=np.full((len(observation.bands),), np.nan), 

380 ) 

381 components.append(factorized) 

382 

383 source = Source(components=components) 

384 # Store identifiers for the source 

385 source.record_id = source_id # type: ignore 

386 source.peak_id = source_data.peak_id # type: ignore 

387 sources.append(source) 

388 

389 return Blend(sources=sources, observation=observation) 

390 

391 @staticmethod 

392 def from_blend(blend: Blend, psf_center: tuple[int, int]) -> ScarletBlendData: 

393 """Convert a scarlet lite blend into a persistable data object 

394 

395 Parameters 

396 ---------- 

397 blend: 

398 The blend that is being persisted. 

399 psf_center: 

400 The center of the PSF. 

401 

402 Returns 

403 ------- 

404 blend_data: 

405 The data model for a single blend. 

406 """ 

407 sources = {} 

408 for source in blend.sources: 

409 components = [] 

410 factorized = [] 

411 for component in source.components: 

412 if type(component) is FactorizedComponent: 

413 factorized_data = ScarletFactorizedComponentData( 

414 origin=component.bbox.origin, # type: ignore 

415 peak=component.peak, # type: ignore 

416 spectrum=component.spectrum, 

417 morph=component.morph, 

418 ) 

419 factorized.append(factorized_data) 

420 else: 

421 component_data = ScarletComponentData( 

422 origin=component.bbox.origin, # type: ignore 

423 peak=component.peak, # type: ignore 

424 model=component.get_model().data, 

425 ) 

426 components.append(component_data) 

427 source_data = ScarletSourceData( 

428 components=components, 

429 factorized_components=factorized, 

430 peak_id=source.peak_id, # type: ignore 

431 ) 

432 sources[source.record_id] = source_data # type: ignore 

433 

434 blend_data = ScarletBlendData( 

435 origin=blend.bbox.origin, # type: ignore 

436 shape=blend.bbox.shape, # type: ignore 

437 sources=sources, 

438 psf_center=psf_center, 

439 psf=blend.observation.psfs, 

440 bands=blend.observation.bands, # type: ignore 

441 ) 

442 

443 return blend_data 

444 

445 

446class ScarletModelData: 

447 """A container that propagates scarlet models for an entire catalog.""" 

448 

449 def __init__(self, psf: np.ndarray, blends: dict[int, ScarletBlendData] | None = None): 

450 """Initialize an instance 

451 

452 Parameters 

453 ---------- 

454 bands: 

455 The names of the bands. 

456 The order of the bands must be the same as the order of 

457 the multiband model arrays, and SEDs. 

458 psf: 

459 The 2D array of the PSF in scarlet model space. 

460 This is typically a narrow Gaussian integrated over the 

461 pixels in the exposure. 

462 blends: 

463 Map from parent IDs in the source catalog 

464 to scarlet model data for each parent ID (blend). 

465 """ 

466 self.psf = psf 

467 if blends is None: 

468 blends = {} 

469 self.blends = blends 

470 

471 def json(self) -> str: 

472 """Serialize the data model to a JSON formatted string 

473 

474 Returns 

475 ------- 

476 result : `str` 

477 The result of the object converted into a JSON format 

478 """ 

479 result = { 

480 "psfShape": self.psf.shape, 

481 "psf": list(self.psf.flatten().astype(float)), 

482 "blends": {bid: blend.as_dict() for bid, blend in self.blends.items()}, 

483 } 

484 return json.dumps(result) 

485 

486 @classmethod 

487 def parse_obj(cls, data: dict) -> ScarletModelData: 

488 """Construct a ScarletModelData from python decoded JSON object. 

489 

490 Parameters 

491 ---------- 

492 data: 

493 The result of json.load(s) on a JSON persisted ScarletModelData 

494 

495 Returns 

496 ------- 

497 result: 

498 The `ScarletModelData` that was loaded the from the input object 

499 """ 

500 model_psf = np.array(data["psf"]).reshape(data["psfShape"]).astype(np.float32) 

501 return cls( 

502 psf=model_psf, 

503 blends={int(bid): ScarletBlendData.from_dict(blend) for bid, blend in data["blends"].items()}, 

504 ) 

505 

506 

507class ComponentCube(Component): 

508 """Dummy component for a component cube. 

509 

510 This is duck-typed to a `lsst.scarlet.lite.Component` in order to 

511 generate a model from the component. 

512 

513 If scarlet lite ever implements a component as a data cube, 

514 this class can be removed. 

515 """ 

516 

517 def __init__(self, bands: tuple[Any, ...], bbox: Box, model: Image, peak: tuple[int, int]): 

518 """Initialization 

519 

520 Parameters 

521 ---------- 

522 bands: 

523 model: 

524 The 3D (bands, y, x) model of the component. 

525 peak: 

526 The `(y, x)` peak of the component. 

527 bbox: 

528 The bounding box of the component. 

529 """ 

530 super().__init__(bands, bbox) 

531 self._model = model 

532 self.peak = peak 

533 

534 def get_model(self) -> Image: 

535 """Generate the model for the source 

536 

537 Returns 

538 ------- 

539 model: 

540 The model as a 3D `(band, y, x)` array. 

541 """ 

542 return self._model 

543 

544 def resize(self, model_box: Box) -> bool: 

545 """Test whether or not the component needs to be resized""" 

546 return False 

547 

548 def update(self, it: int, input_grad: np.ndarray) -> None: 

549 """Implementation of unused abstract method""" 

550 

551 def parameterize(self, parameterization: Callable) -> None: 

552 """Implementation of unused abstract method"""