Coverage for python/lsst/pipe/base/_task_metadata.py: 15%

219 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-10-26 15:47 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ["TaskMetadata"] 

23 

24import itertools 

25import numbers 

26import warnings 

27from collections.abc import Sequence 

28from typing import Any, Collection, Dict, Iterator, List, Mapping, Optional, Protocol, Set, Tuple, Union 

29 

30from deprecated.sphinx import deprecated 

31from pydantic import BaseModel, Field, StrictBool, StrictFloat, StrictInt, StrictStr 

32 

33_DEPRECATION_REASON = "Will be removed after v25." 

34_DEPRECATION_VERSION = "v24" 

35 

36# The types allowed in a Task metadata field are restricted 

37# to allow predictable serialization. 

38_ALLOWED_PRIMITIVE_TYPES = (str, float, int, bool) 

39 

40 

41class PropertySetLike(Protocol): 

42 """Protocol that looks like a ``lsst.daf.base.PropertySet`` 

43 

44 Enough of the API is specified to support conversion of a 

45 ``PropertySet`` to a `TaskMetadata`. 

46 """ 

47 

48 def paramNames(self, topLevelOnly: bool = True) -> Collection[str]: 

49 ... 

50 

51 def getArray(self, name: str) -> Any: 

52 ... 

53 

54 

55def _isListLike(v: Any) -> bool: 

56 return isinstance(v, Sequence) and not isinstance(v, str) 

57 

58 

59class TaskMetadata(BaseModel): 

60 """Dict-like object for storing task metadata. 

61 

62 Metadata can be stored at two levels: single task or task plus subtasks. 

63 The later is called full metadata of a task and has a form 

64 

65 topLevelTaskName:subtaskName:subsubtaskName.itemName 

66 

67 Metadata item key of a task (`itemName` above) must not contain `.`, 

68 which serves as a separator in full metadata keys and turns 

69 the value into sub-dictionary. Arbitrary hierarchies are supported. 

70 

71 Deprecated methods are for compatibility with 

72 the predecessor containers. 

73 """ 

74 

75 scalars: Dict[str, Union[StrictFloat, StrictInt, StrictBool, StrictStr]] = Field(default_factory=dict) 

76 arrays: Dict[str, Union[List[StrictFloat], List[StrictInt], List[StrictBool], List[StrictStr]]] = Field( 

77 default_factory=dict 

78 ) 

79 metadata: Dict[str, "TaskMetadata"] = Field(default_factory=dict) 

80 

81 @classmethod 

82 def from_dict(cls, d: Mapping[str, Any]) -> "TaskMetadata": 

83 """Create a TaskMetadata from a dictionary. 

84 

85 Parameters 

86 ---------- 

87 d : `Mapping` 

88 Mapping to convert. Can be hierarchical. Any dictionaries 

89 in the hierarchy are converted to `TaskMetadata`. 

90 

91 Returns 

92 ------- 

93 meta : `TaskMetadata` 

94 Newly-constructed metadata. 

95 """ 

96 metadata = cls() 

97 for k, v in d.items(): 

98 metadata[k] = v 

99 return metadata 

100 

101 @classmethod 

102 def from_metadata(cls, ps: PropertySetLike) -> "TaskMetadata": 

103 """Create a TaskMetadata from a PropertySet-like object. 

104 

105 Parameters 

106 ---------- 

107 ps : `PropertySetLike` or `TaskMetadata` 

108 A ``PropertySet``-like object to be transformed to a 

109 `TaskMetadata`. A `TaskMetadata` can be copied using this 

110 class method. 

111 

112 Returns 

113 ------- 

114 tm : `TaskMetadata` 

115 Newly-constructed metadata. 

116 

117 Notes 

118 ----- 

119 Items stored in single-element arrays in the supplied object 

120 will be converted to scalars in the newly-created object. 

121 """ 

122 # Use hierarchical names to assign values from input to output. 

123 # This API exists for both PropertySet and TaskMetadata. 

124 # from_dict() does not work because PropertySet is not declared 

125 # to be a Mapping. 

126 # PropertySet.toDict() is not present in TaskMetadata so is best 

127 # avoided. 

128 metadata = cls() 

129 for key in sorted(ps.paramNames(topLevelOnly=False)): 

130 value = ps.getArray(key) 

131 if len(value) == 1: 

132 value = value[0] 

133 metadata[key] = value 

134 return metadata 

135 

136 def to_dict(self) -> Dict[str, Any]: 

137 """Convert the class to a simple dictionary. 

138 

139 Returns 

140 ------- 

141 d : `dict` 

142 Simple dictionary that can contain scalar values, array values 

143 or other dictionary values. 

144 

145 Notes 

146 ----- 

147 Unlike `dict()`, this method hides the model layout and combines 

148 scalars, arrays, and other metadata in the same dictionary. Can be 

149 used when a simple dictionary is needed. Use 

150 `TaskMetadata.from_dict()` to convert it back. 

151 """ 

152 d: Dict[str, Any] = {} 

153 d.update(self.scalars) 

154 d.update(self.arrays) 

155 for k, v in self.metadata.items(): 

156 d[k] = v.to_dict() 

157 return d 

158 

159 def add(self, name: str, value: Any) -> None: 

160 """Store a new value, adding to a list if one already exists. 

161 

162 Parameters 

163 ---------- 

164 name : `str` 

165 Name of the metadata property. 

166 value 

167 Metadata property value. 

168 """ 

169 keys = self._getKeys(name) 

170 key0 = keys.pop(0) 

171 if len(keys) == 0: 

172 # If add() is being used, always store the value in the arrays 

173 # property as a list. It's likely there will be another call. 

174 slot_type, value = self._validate_value(value) 

175 if slot_type == "array": 

176 pass 

177 elif slot_type == "scalar": 

178 value = [value] 

179 else: 

180 raise ValueError("add() can only be used for primitive types or sequences of those types.") 

181 

182 if key0 in self.metadata: 

183 raise ValueError(f"Can not add() to key '{name}' since that is a TaskMetadata") 

184 

185 if key0 in self.scalars: 

186 # Convert scalar to array. 

187 # MyPy should be able to figure out that List[Union[T1, T2]] is 

188 # compatible with Union[List[T1], List[T2]] if the list has 

189 # only one element, but it can't. 

190 self.arrays[key0] = [self.scalars.pop(key0)] # type: ignore 

191 

192 if key0 in self.arrays: 

193 # Check that the type is not changing. 

194 if (curtype := type(self.arrays[key0][0])) is not (newtype := type(value[0])): 

195 raise ValueError(f"Type mismatch in add() -- currently {curtype} but adding {newtype}") 

196 self.arrays[key0].extend(value) 

197 else: 

198 self.arrays[key0] = value 

199 

200 return 

201 

202 self.metadata[key0].add(".".join(keys), value) 

203 

204 @deprecated( 

205 reason="Cast the return value to float explicitly. " + _DEPRECATION_REASON, 

206 version=_DEPRECATION_VERSION, 

207 category=FutureWarning, 

208 ) 

209 def getAsDouble(self, key: str) -> float: 

210 """Return the value cast to a `float`. 

211 

212 Parameters 

213 ---------- 

214 key : `str` 

215 Item to return. Can be dot-separated hierarchical. 

216 

217 Returns 

218 ------- 

219 value : `float` 

220 The value cast to a `float`. 

221 

222 Raises 

223 ------ 

224 KeyError 

225 Raised if the item is not found. 

226 """ 

227 return float(self.__getitem__(key)) 

228 

229 def getScalar(self, key: str) -> Union[str, int, float, bool]: 

230 """Retrieve a scalar item even if the item is a list. 

231 

232 Parameters 

233 ---------- 

234 key : `str` 

235 Item to retrieve. 

236 

237 Returns 

238 ------- 

239 value : `str`, `int`, `float`, or `bool` 

240 Either the value associated with the key or, if the key 

241 corresponds to a list, the last item in the list. 

242 

243 Raises 

244 ------ 

245 KeyError 

246 Raised if the item is not found. 

247 """ 

248 # Used in pipe_tasks. 

249 # getScalar() is the default behavior for __getitem__. 

250 return self[key] 

251 

252 def getArray(self, key: str) -> List[Any]: 

253 """Retrieve an item as a list even if it is a scalar. 

254 

255 Parameters 

256 ---------- 

257 key : `str` 

258 Item to retrieve. 

259 

260 Returns 

261 ------- 

262 values : `list` of any 

263 A list containing the value or values associated with this item. 

264 

265 Raises 

266 ------ 

267 KeyError 

268 Raised if the item is not found. 

269 """ 

270 keys = self._getKeys(key) 

271 key0 = keys.pop(0) 

272 if len(keys) == 0: 

273 if key0 in self.arrays: 

274 return self.arrays[key0] 

275 elif key0 in self.scalars: 

276 return [self.scalars[key0]] 

277 elif key0 in self.metadata: 

278 return [self.metadata[key0]] 

279 raise KeyError(f"'{key}' not found") 

280 

281 try: 

282 return self.metadata[key0].getArray(".".join(keys)) 

283 except KeyError: 

284 # Report the correct key. 

285 raise KeyError(f"'{key}' not found") from None 

286 

287 def names(self, topLevelOnly: bool = True) -> Set[str]: 

288 """Return the hierarchical keys from the metadata. 

289 

290 Parameters 

291 ---------- 

292 topLevelOnly : `bool` 

293 If true, return top-level keys, otherwise full metadata item keys. 

294 

295 Returns 

296 ------- 

297 names : `collection.abc.Set` 

298 A set of top-level keys or full metadata item keys, including 

299 the top-level keys. 

300 

301 Notes 

302 ----- 

303 Should never be called in new code with ``topLevelOnly`` set to `True` 

304 -- this is equivalent to asking for the keys and is the default 

305 when iterating through the task metadata. In this case a deprecation 

306 message will be issued and the ability will raise an exception 

307 in a future release. 

308 

309 When ``topLevelOnly`` is `False` all keys, including those from the 

310 hierarchy and the top-level hierarchy, are returned. 

311 """ 

312 if topLevelOnly: 

313 warnings.warn("Use keys() instead. " + _DEPRECATION_REASON, FutureWarning) 

314 return set(self.keys()) 

315 else: 

316 names = set() 

317 for k, v in self.items(): 

318 names.add(k) # Always include the current level 

319 if isinstance(v, TaskMetadata): 

320 names.update({k + "." + item for item in v.names(topLevelOnly=topLevelOnly)}) 

321 return names 

322 

323 def paramNames(self, topLevelOnly: bool) -> Set[str]: 

324 """Return hierarchical names. 

325 

326 Parameters 

327 ---------- 

328 topLevelOnly : `bool` 

329 Control whether only top-level items are returned or items 

330 from the hierarchy. 

331 

332 Returns 

333 ------- 

334 paramNames : `set` of `str` 

335 If ``topLevelOnly`` is `True`, returns any keys that are not 

336 part of a hierarchy. If `False` also returns fully-qualified 

337 names from the hierarchy. Keys associated with the top 

338 of a hierarchy are never returned. 

339 """ 

340 # Currently used by the verify package. 

341 paramNames = set() 

342 for k, v in self.items(): 

343 if isinstance(v, TaskMetadata): 

344 if not topLevelOnly: 

345 paramNames.update({k + "." + item for item in v.paramNames(topLevelOnly=topLevelOnly)}) 

346 else: 

347 paramNames.add(k) 

348 return paramNames 

349 

350 @deprecated( 

351 reason="Use standard assignment syntax. " + _DEPRECATION_REASON, 

352 version=_DEPRECATION_VERSION, 

353 category=FutureWarning, 

354 ) 

355 def set(self, key: str, item: Any) -> None: 

356 """Set the value of the supplied key.""" 

357 self.__setitem__(key, item) 

358 

359 @deprecated( 

360 reason="Use standard del dict syntax. " + _DEPRECATION_REASON, 

361 version=_DEPRECATION_VERSION, 

362 category=FutureWarning, 

363 ) 

364 def remove(self, key: str) -> None: 

365 """Remove the item without raising if absent.""" 

366 try: 

367 self.__delitem__(key) 

368 except KeyError: 

369 # The PropertySet.remove() should always work. 

370 pass 

371 

372 @staticmethod 

373 def _getKeys(key: str) -> List[str]: 

374 """Return the key hierarchy. 

375 

376 Parameters 

377 ---------- 

378 key : `str` 

379 The key to analyze. Can be dot-separated. 

380 

381 Returns 

382 ------- 

383 keys : `list` of `str` 

384 The key hierarchy that has been split on ``.``. 

385 

386 Raises 

387 ------ 

388 KeyError 

389 Raised if the key is not a string. 

390 """ 

391 try: 

392 keys = key.split(".") 

393 except Exception: 

394 raise KeyError(f"Invalid key '{key}': only string keys are allowed") from None 

395 return keys 

396 

397 def keys(self) -> Tuple[str, ...]: 

398 """Return the top-level keys.""" 

399 return tuple(k for k in self) 

400 

401 def items(self) -> Iterator[Tuple[str, Any]]: 

402 """Yield the top-level keys and values.""" 

403 for k, v in itertools.chain(self.scalars.items(), self.arrays.items(), self.metadata.items()): 

404 yield (k, v) 

405 

406 def __len__(self) -> int: 

407 """Return the number of items.""" 

408 return len(self.scalars) + len(self.arrays) + len(self.metadata) 

409 

410 # This is actually a Liskov substitution violation, because 

411 # pydantic.BaseModel says __iter__ should return something else. But the 

412 # pydantic docs say to do exactly this to in order to make a mapping-like 

413 # BaseModel, so that's what we do. 

414 def __iter__(self) -> Iterator[str]: # type: ignore 

415 """Return an iterator over each key.""" 

416 # The order of keys is not preserved since items can move 

417 # from scalar to array. 

418 return itertools.chain(iter(self.scalars), iter(self.arrays), iter(self.metadata)) 

419 

420 def __getitem__(self, key: str) -> Any: 

421 """Retrieve the item associated with the key. 

422 

423 Parameters 

424 ---------- 

425 key : `str` 

426 The key to retrieve. Can be dot-separated hierarchical. 

427 

428 Returns 

429 ------- 

430 value : `TaskMetadata`, `float`, `int`, `bool`, `str` 

431 A scalar value. For compatibility with ``PropertySet``, if the key 

432 refers to an array, the final element is returned and not the 

433 array itself. 

434 

435 Raises 

436 ------ 

437 KeyError 

438 Raised if the item is not found. 

439 """ 

440 keys = self._getKeys(key) 

441 key0 = keys.pop(0) 

442 if len(keys) == 0: 

443 if key0 in self.scalars: 

444 return self.scalars[key0] 

445 if key0 in self.metadata: 

446 return self.metadata[key0] 

447 if key0 in self.arrays: 

448 return self.arrays[key0][-1] 

449 raise KeyError(f"'{key}' not found") 

450 # Hierarchical lookup so the top key can only be in the metadata 

451 # property. Trap KeyError and reraise so that the correct key 

452 # in the hierarchy is reported. 

453 try: 

454 # And forward request to that metadata. 

455 return self.metadata[key0][".".join(keys)] 

456 except KeyError: 

457 raise KeyError(f"'{key}' not found") from None 

458 

459 def get(self, key: str, default: Any = None) -> Any: 

460 """Retrieve the item associated with the key or a default. 

461 

462 Parameters 

463 ---------- 

464 key : `str` 

465 The key to retrieve. Can be dot-separated hierarchical. 

466 default 

467 The value to return if the key doesnot exist. 

468 

469 Returns 

470 ------- 

471 value : `TaskMetadata`, `float`, `int`, `bool`, `str` 

472 A scalar value. If the key refers to an array, the final element 

473 is returned and not the array itself; this is consistent with 

474 `__getitem__` and `PropertySet.get`, but not ``to_dict().get``. 

475 """ 

476 try: 

477 return self[key] 

478 except KeyError: 

479 return default 

480 

481 def __setitem__(self, key: str, item: Any) -> None: 

482 """Store the given item.""" 

483 keys = self._getKeys(key) 

484 key0 = keys.pop(0) 

485 if len(keys) == 0: 

486 slots: Dict[str, Dict[str, Any]] = { 

487 "array": self.arrays, 

488 "scalar": self.scalars, 

489 "metadata": self.metadata, 

490 } 

491 primary: Optional[Dict[str, Any]] = None 

492 slot_type, item = self._validate_value(item) 

493 primary = slots.pop(slot_type, None) 

494 if primary is None: 

495 raise AssertionError(f"Unknown slot type returned from validator: {slot_type}") 

496 

497 # Assign the value to the right place. 

498 primary[key0] = item 

499 for property in slots.values(): 

500 # Remove any other entries. 

501 property.pop(key0, None) 

502 return 

503 

504 # This must be hierarchical so forward to the child TaskMetadata. 

505 if key0 not in self.metadata: 

506 self.metadata[key0] = TaskMetadata() 

507 self.metadata[key0][".".join(keys)] = item 

508 

509 # Ensure we have cleared out anything with the same name elsewhere. 

510 self.scalars.pop(key0, None) 

511 self.arrays.pop(key0, None) 

512 

513 def __contains__(self, key: str) -> bool: 

514 """Determine if the key exists.""" 

515 keys = self._getKeys(key) 

516 key0 = keys.pop(0) 

517 if len(keys) == 0: 

518 return key0 in self.scalars or key0 in self.arrays or key0 in self.metadata 

519 

520 if key0 in self.metadata: 

521 return ".".join(keys) in self.metadata[key0] 

522 return False 

523 

524 def __delitem__(self, key: str) -> None: 

525 """Remove the specified item. 

526 

527 Raises 

528 ------ 

529 KeyError 

530 Raised if the item is not present. 

531 """ 

532 keys = self._getKeys(key) 

533 key0 = keys.pop(0) 

534 if len(keys) == 0: 

535 # MyPy can't figure out that this way to combine the types in the 

536 # tuple is the one that matters, and annotating a local variable 

537 # helps it out. 

538 properties: Tuple[Dict[str, Any], ...] = (self.scalars, self.arrays, self.metadata) 

539 for property in properties: 

540 if key0 in property: 

541 del property[key0] 

542 return 

543 raise KeyError(f"'{key}' not found'") 

544 

545 try: 

546 del self.metadata[key0][".".join(keys)] 

547 except KeyError: 

548 # Report the correct key. 

549 raise KeyError(f"'{key}' not found'") from None 

550 

551 def _validate_value(self, value: Any) -> Tuple[str, Any]: 

552 """Validate the given value. 

553 

554 Parameters 

555 ---------- 

556 value : Any 

557 Value to check. 

558 

559 Returns 

560 ------- 

561 slot_type : `str` 

562 The type of value given. Options are "scalar", "array", "metadata". 

563 item : Any 

564 The item that was given but possibly modified to conform to 

565 the slot type. 

566 

567 Raises 

568 ------ 

569 ValueError 

570 Raised if the value is not a recognized type. 

571 """ 

572 # Test the simplest option first. 

573 value_type = type(value) 

574 if value_type in _ALLOWED_PRIMITIVE_TYPES: 

575 return "scalar", value 

576 

577 if isinstance(value, TaskMetadata): 

578 return "metadata", value 

579 if isinstance(value, Mapping): 

580 return "metadata", self.from_dict(value) 

581 

582 if _isListLike(value): 

583 # For model consistency, need to check that every item in the 

584 # list has the same type. 

585 value = list(value) 

586 

587 type0 = type(value[0]) 

588 for i in value: 

589 if type(i) != type0: 

590 raise ValueError( 

591 "Type mismatch in supplied list. TaskMetadata requires all" 

592 f" elements have same type but see {type(i)} and {type0}." 

593 ) 

594 

595 if type0 not in _ALLOWED_PRIMITIVE_TYPES: 

596 # Must check to see if we got numpy floats or something. 

597 type_cast: type 

598 if isinstance(value[0], numbers.Integral): 

599 type_cast = int 

600 elif isinstance(value[0], numbers.Real): 

601 type_cast = float 

602 else: 

603 raise ValueError( 

604 f"Supplied list has element of type '{type0}'. " 

605 "TaskMetadata can only accept primitive types in lists." 

606 ) 

607 

608 value = [type_cast(v) for v in value] 

609 

610 return "array", value 

611 

612 # Sometimes a numpy number is given. 

613 if isinstance(value, numbers.Integral): 

614 value = int(value) 

615 return "scalar", value 

616 if isinstance(value, numbers.Real): 

617 value = float(value) 

618 return "scalar", value 

619 

620 raise ValueError(f"TaskMetadata does not support values of type {value!r}.") 

621 

622 

623# Needed because a TaskMetadata can contain a TaskMetadata. 

624TaskMetadata.update_forward_refs()