Coverage for python/lsst/pipe/base/_task_metadata.py: 18%

222 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-05-02 03:31 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <https://www.gnu.org/licenses/>. 

27 

28__all__ = [ 

29 "TaskMetadata", 

30 "SetDictMetadata", 

31 "GetDictMetadata", 

32 "GetSetDictMetadata", 

33 "NestedMetadataDict", 

34] 

35 

36import itertools 

37import numbers 

38import sys 

39from collections.abc import Collection, Iterator, Mapping, Sequence 

40from typing import Any, Protocol, TypeAlias, Union 

41 

42from pydantic import BaseModel, ConfigDict, Field, StrictBool, StrictFloat, StrictInt, StrictStr 

43 

44# The types allowed in a Task metadata field are restricted 

45# to allow predictable serialization. 

46_ALLOWED_PRIMITIVE_TYPES = (str, float, int, bool) 

47 

48# Note that '|' syntax for unions doesn't work when we have to use a string 

49# literal (and we do since it's recursive and not an annotation). 

50NestedMetadataDict: TypeAlias = Mapping[str, Union[str, float, int, bool, "NestedMetadataDict"]] 

51 

52 

53class PropertySetLike(Protocol): 

54 """Protocol that looks like a ``lsst.daf.base.PropertySet``. 

55 

56 Enough of the API is specified to support conversion of a 

57 ``PropertySet`` to a `TaskMetadata`. 

58 """ 

59 

60 def paramNames(self, topLevelOnly: bool = True) -> Collection[str]: ... 60 ↛ exitline 60 didn't jump to line 60, because

61 

62 def getArray(self, name: str) -> Any: ... 62 ↛ exitline 62 didn't jump to line 62, because

63 

64 

65def _isListLike(v: Any) -> bool: 

66 return isinstance(v, Sequence) and not isinstance(v, str) 

67 

68 

69class SetDictMetadata(Protocol): 

70 """Protocol for objects that can be assigned a possibly-nested `dict` of 

71 primitives. 

72 

73 This protocol is satisfied by `TaskMetadata`, `lsst.daf.base.PropertySet`, 

74 and `lsst.daf.base.PropertyList`, providing a consistent way to insert a 

75 dictionary into these objects that avoids their historical idiosyncrasies. 

76 

77 The form in which these entries appear in the object's native keys and 

78 values is implementation-defined. *Empty nested dictionaries may be 

79 dropped, and if the top-level dictionary is empty this method may do 

80 nothing.* 

81 

82 Neither the top-level key nor nested keys may contain ``.`` (period) 

83 characters. 

84 """ 

85 

86 def set_dict(self, key: str, nested: NestedMetadataDict) -> None: ... 86 ↛ exitline 86 didn't jump to line 86, because

87 

88 

89class GetDictMetadata(Protocol): 

90 """Protocol for objects that can extract a possibly-nested mapping of 

91 primitives. 

92 

93 This protocol is satisfied by `TaskMetadata`, `lsst.daf.base.PropertySet`, 

94 and `lsst.daf.base.PropertyList`, providing a consistent way to extract a 

95 dictionary from these objects that avoids their historical idiosyncrasies. 

96 

97 This is guaranteed to work for mappings inserted by 

98 `~SetMapping.set_dict`. It should not be expected to work for values 

99 inserted in other ways. If a value was never inserted with the given key 

100 at all, *an empty `dict` will be returned* (this is a concession to 

101 implementation constraints in `~lsst.daf.base.PropertyList`. 

102 """ 

103 

104 def get_dict(self, key: str) -> NestedMetadataDict: ... 104 ↛ exitline 104 didn't jump to line 104, because

105 

106 

107class GetSetDictMetadata(SetDictMetadata, GetDictMetadata, Protocol): 

108 """Protocol for objects that can assign and extract a possibly-nested 

109 mapping of primitives. 

110 """ 

111 

112 

113class TaskMetadata(BaseModel): 

114 """Dict-like object for storing task metadata. 

115 

116 Metadata can be stored at two levels: single task or task plus subtasks. 

117 The later is called full metadata of a task and has a form 

118 

119 topLevelTaskName:subtaskName:subsubtaskName.itemName 

120 

121 Metadata item key of a task (`itemName` above) must not contain `.`, 

122 which serves as a separator in full metadata keys and turns 

123 the value into sub-dictionary. Arbitrary hierarchies are supported. 

124 """ 

125 

126 # Pipelines regularly generate NaN and Inf so these need to be 

127 # supported even though that's a JSON extension. 

128 model_config = ConfigDict(ser_json_inf_nan="constants") 

129 

130 scalars: dict[str, StrictFloat | StrictInt | StrictBool | StrictStr] = Field(default_factory=dict) 

131 arrays: dict[str, list[StrictFloat] | list[StrictInt] | list[StrictBool] | list[StrictStr]] = Field( 

132 default_factory=dict 

133 ) 

134 metadata: dict[str, "TaskMetadata"] = Field(default_factory=dict) 

135 

136 @classmethod 

137 def from_dict(cls, d: Mapping[str, Any]) -> "TaskMetadata": 

138 """Create a TaskMetadata from a dictionary. 

139 

140 Parameters 

141 ---------- 

142 d : `~collections.abc.Mapping` 

143 Mapping to convert. Can be hierarchical. Any dictionaries 

144 in the hierarchy are converted to `TaskMetadata`. 

145 

146 Returns 

147 ------- 

148 meta : `TaskMetadata` 

149 Newly-constructed metadata. 

150 """ 

151 metadata = cls() 

152 for k, v in d.items(): 

153 metadata[k] = v 

154 return metadata 

155 

156 @classmethod 

157 def from_metadata(cls, ps: PropertySetLike) -> "TaskMetadata": 

158 """Create a TaskMetadata from a PropertySet-like object. 

159 

160 Parameters 

161 ---------- 

162 ps : `PropertySetLike` or `TaskMetadata` 

163 A ``PropertySet``-like object to be transformed to a 

164 `TaskMetadata`. A `TaskMetadata` can be copied using this 

165 class method. 

166 

167 Returns 

168 ------- 

169 tm : `TaskMetadata` 

170 Newly-constructed metadata. 

171 

172 Notes 

173 ----- 

174 Items stored in single-element arrays in the supplied object 

175 will be converted to scalars in the newly-created object. 

176 """ 

177 # Use hierarchical names to assign values from input to output. 

178 # This API exists for both PropertySet and TaskMetadata. 

179 # from_dict() does not work because PropertySet is not declared 

180 # to be a Mapping. 

181 # PropertySet.toDict() is not present in TaskMetadata so is best 

182 # avoided. 

183 metadata = cls() 

184 for key in sorted(ps.paramNames(topLevelOnly=False)): 

185 value = ps.getArray(key) 

186 if len(value) == 1: 

187 value = value[0] 

188 metadata[key] = value 

189 return metadata 

190 

191 def to_dict(self) -> dict[str, Any]: 

192 """Convert the class to a simple dictionary. 

193 

194 Returns 

195 ------- 

196 d : `dict` 

197 Simple dictionary that can contain scalar values, array values 

198 or other dictionary values. 

199 

200 Notes 

201 ----- 

202 Unlike `dict()`, this method hides the model layout and combines 

203 scalars, arrays, and other metadata in the same dictionary. Can be 

204 used when a simple dictionary is needed. Use 

205 `TaskMetadata.from_dict()` to convert it back. 

206 """ 

207 d: dict[str, Any] = {} 

208 d.update(self.scalars) 

209 d.update(self.arrays) 

210 for k, v in self.metadata.items(): 

211 d[k] = v.to_dict() 

212 return d 

213 

214 def add(self, name: str, value: Any) -> None: 

215 """Store a new value, adding to a list if one already exists. 

216 

217 Parameters 

218 ---------- 

219 name : `str` 

220 Name of the metadata property. 

221 value : `~typing.Any` 

222 Metadata property value. 

223 """ 

224 keys = self._getKeys(name) 

225 key0 = keys.pop(0) 

226 if len(keys) == 0: 

227 # If add() is being used, always store the value in the arrays 

228 # property as a list. It's likely there will be another call. 

229 slot_type, value = self._validate_value(value) 

230 if slot_type == "array": 

231 pass 

232 elif slot_type == "scalar": 

233 value = [value] 

234 else: 

235 raise ValueError("add() can only be used for primitive types or sequences of those types.") 

236 

237 if key0 in self.metadata: 

238 raise ValueError(f"Can not add() to key '{name}' since that is a TaskMetadata") 

239 

240 if key0 in self.scalars: 

241 # Convert scalar to array. 

242 # MyPy should be able to figure out that List[Union[T1, T2]] is 

243 # compatible with Union[List[T1], List[T2]] if the list has 

244 # only one element, but it can't. 

245 self.arrays[key0] = [self.scalars.pop(key0)] # type: ignore 

246 

247 if key0 in self.arrays: 

248 # Check that the type is not changing. 

249 if (curtype := type(self.arrays[key0][0])) is not (newtype := type(value[0])): 

250 raise ValueError(f"Type mismatch in add() -- currently {curtype} but adding {newtype}") 

251 self.arrays[key0].extend(value) 

252 else: 

253 self.arrays[key0] = value 

254 

255 return 

256 

257 self.metadata[key0].add(".".join(keys), value) 

258 

259 def getScalar(self, key: str) -> str | int | float | bool: 

260 """Retrieve a scalar item even if the item is a list. 

261 

262 Parameters 

263 ---------- 

264 key : `str` 

265 Item to retrieve. 

266 

267 Returns 

268 ------- 

269 value : `str`, `int`, `float`, or `bool` 

270 Either the value associated with the key or, if the key 

271 corresponds to a list, the last item in the list. 

272 

273 Raises 

274 ------ 

275 KeyError 

276 Raised if the item is not found. 

277 """ 

278 # Used in pipe_tasks. 

279 # getScalar() is the default behavior for __getitem__. 

280 return self[key] 

281 

282 def getArray(self, key: str) -> list[Any]: 

283 """Retrieve an item as a list even if it is a scalar. 

284 

285 Parameters 

286 ---------- 

287 key : `str` 

288 Item to retrieve. 

289 

290 Returns 

291 ------- 

292 values : `list` of any 

293 A list containing the value or values associated with this item. 

294 

295 Raises 

296 ------ 

297 KeyError 

298 Raised if the item is not found. 

299 """ 

300 keys = self._getKeys(key) 

301 key0 = keys.pop(0) 

302 if len(keys) == 0: 

303 if key0 in self.arrays: 

304 return self.arrays[key0] 

305 elif key0 in self.scalars: 

306 return [self.scalars[key0]] 

307 elif key0 in self.metadata: 

308 return [self.metadata[key0]] 

309 raise KeyError(f"'{key}' not found") 

310 

311 try: 

312 return self.metadata[key0].getArray(".".join(keys)) 

313 except KeyError: 

314 # Report the correct key. 

315 raise KeyError(f"'{key}' not found") from None 

316 

317 def names(self) -> set[str]: 

318 """Return the hierarchical keys from the metadata. 

319 

320 Returns 

321 ------- 

322 names : `collections.abc.Set` 

323 A set of all keys, including those from the hierarchy and the 

324 top-level hierarchy. 

325 """ 

326 names = set() 

327 for k, v in self.items(): 

328 names.add(k) # Always include the current level 

329 if isinstance(v, TaskMetadata): 

330 names.update({k + "." + item for item in v.names()}) 

331 return names 

332 

333 def paramNames(self, topLevelOnly: bool) -> set[str]: 

334 """Return hierarchical names. 

335 

336 Parameters 

337 ---------- 

338 topLevelOnly : `bool` 

339 Control whether only top-level items are returned or items 

340 from the hierarchy. 

341 

342 Returns 

343 ------- 

344 paramNames : `set` of `str` 

345 If ``topLevelOnly`` is `True`, returns any keys that are not 

346 part of a hierarchy. If `False` also returns fully-qualified 

347 names from the hierarchy. Keys associated with the top 

348 of a hierarchy are never returned. 

349 """ 

350 # Currently used by the verify package. 

351 paramNames = set() 

352 for k, v in self.items(): 

353 if isinstance(v, TaskMetadata): 

354 if not topLevelOnly: 

355 paramNames.update({k + "." + item for item in v.paramNames(topLevelOnly=topLevelOnly)}) 

356 else: 

357 paramNames.add(k) 

358 return paramNames 

359 

360 @staticmethod 

361 def _getKeys(key: str) -> list[str]: 

362 """Return the key hierarchy. 

363 

364 Parameters 

365 ---------- 

366 key : `str` 

367 The key to analyze. Can be dot-separated. 

368 

369 Returns 

370 ------- 

371 keys : `list` of `str` 

372 The key hierarchy that has been split on ``.``. 

373 

374 Raises 

375 ------ 

376 KeyError 

377 Raised if the key is not a string. 

378 """ 

379 try: 

380 keys = key.split(".") 

381 except Exception: 

382 raise KeyError(f"Invalid key '{key}': only string keys are allowed") from None 

383 return keys 

384 

385 def keys(self) -> tuple[str, ...]: 

386 """Return the top-level keys.""" 

387 return tuple(k for k in self) 

388 

389 def items(self) -> Iterator[tuple[str, Any]]: 

390 """Yield the top-level keys and values.""" 

391 yield from itertools.chain(self.scalars.items(), self.arrays.items(), self.metadata.items()) 

392 

393 def __len__(self) -> int: 

394 """Return the number of items.""" 

395 return len(self.scalars) + len(self.arrays) + len(self.metadata) 

396 

397 # This is actually a Liskov substitution violation, because 

398 # pydantic.BaseModel says __iter__ should return something else. But the 

399 # pydantic docs say to do exactly this to in order to make a mapping-like 

400 # BaseModel, so that's what we do. 

401 def __iter__(self) -> Iterator[str]: # type: ignore 

402 """Return an iterator over each key.""" 

403 # The order of keys is not preserved since items can move 

404 # from scalar to array. 

405 return itertools.chain(iter(self.scalars), iter(self.arrays), iter(self.metadata)) 

406 

407 def __getitem__(self, key: str) -> Any: 

408 """Retrieve the item associated with the key. 

409 

410 Parameters 

411 ---------- 

412 key : `str` 

413 The key to retrieve. Can be dot-separated hierarchical. 

414 

415 Returns 

416 ------- 

417 value : `TaskMetadata`, `float`, `int`, `bool`, `str` 

418 A scalar value. For compatibility with ``PropertySet``, if the key 

419 refers to an array, the final element is returned and not the 

420 array itself. 

421 

422 Raises 

423 ------ 

424 KeyError 

425 Raised if the item is not found. 

426 """ 

427 keys = self._getKeys(key) 

428 key0 = keys.pop(0) 

429 if len(keys) == 0: 

430 if key0 in self.scalars: 

431 return self.scalars[key0] 

432 if key0 in self.metadata: 

433 return self.metadata[key0] 

434 if key0 in self.arrays: 

435 return self.arrays[key0][-1] 

436 raise KeyError(f"'{key}' not found") 

437 # Hierarchical lookup so the top key can only be in the metadata 

438 # property. Trap KeyError and reraise so that the correct key 

439 # in the hierarchy is reported. 

440 try: 

441 # And forward request to that metadata. 

442 return self.metadata[key0][".".join(keys)] 

443 except KeyError: 

444 raise KeyError(f"'{key}' not found") from None 

445 

446 def get(self, key: str, default: Any = None) -> Any: 

447 """Retrieve the item associated with the key or a default. 

448 

449 Parameters 

450 ---------- 

451 key : `str` 

452 The key to retrieve. Can be dot-separated hierarchical. 

453 default : `~typing.Any` 

454 The value to return if the key does not exist. 

455 

456 Returns 

457 ------- 

458 value : `TaskMetadata`, `float`, `int`, `bool`, `str` 

459 A scalar value. If the key refers to an array, the final element 

460 is returned and not the array itself; this is consistent with 

461 `__getitem__` and `PropertySet.get`, but not ``to_dict().get``. 

462 """ 

463 try: 

464 return self[key] 

465 except KeyError: 

466 return default 

467 

468 def __setitem__(self, key: str, item: Any) -> None: 

469 """Store the given item.""" 

470 keys = self._getKeys(key) 

471 key0 = keys.pop(0) 

472 if len(keys) == 0: 

473 slots: dict[str, dict[str, Any]] = { 

474 "array": self.arrays, 

475 "scalar": self.scalars, 

476 "metadata": self.metadata, 

477 } 

478 primary: dict[str, Any] | None = None 

479 slot_type, item = self._validate_value(item) 

480 primary = slots.pop(slot_type, None) 

481 if primary is None: 

482 raise AssertionError(f"Unknown slot type returned from validator: {slot_type}") 

483 

484 # Assign the value to the right place. 

485 primary[key0] = item 

486 for property in slots.values(): 

487 # Remove any other entries. 

488 property.pop(key0, None) 

489 return 

490 

491 # This must be hierarchical so forward to the child TaskMetadata. 

492 if key0 not in self.metadata: 

493 self.metadata[key0] = TaskMetadata() 

494 self.metadata[key0][".".join(keys)] = item 

495 

496 # Ensure we have cleared out anything with the same name elsewhere. 

497 self.scalars.pop(key0, None) 

498 self.arrays.pop(key0, None) 

499 

500 def __contains__(self, key: str) -> bool: 

501 """Determine if the key exists.""" 

502 keys = self._getKeys(key) 

503 key0 = keys.pop(0) 

504 if len(keys) == 0: 

505 return key0 in self.scalars or key0 in self.arrays or key0 in self.metadata 

506 

507 if key0 in self.metadata: 

508 return ".".join(keys) in self.metadata[key0] 

509 return False 

510 

511 def __delitem__(self, key: str) -> None: 

512 """Remove the specified item. 

513 

514 Raises 

515 ------ 

516 KeyError 

517 Raised if the item is not present. 

518 """ 

519 keys = self._getKeys(key) 

520 key0 = keys.pop(0) 

521 if len(keys) == 0: 

522 # MyPy can't figure out that this way to combine the types in the 

523 # tuple is the one that matters, and annotating a local variable 

524 # helps it out. 

525 properties: tuple[dict[str, Any], ...] = (self.scalars, self.arrays, self.metadata) 

526 for property in properties: 

527 if key0 in property: 

528 del property[key0] 

529 return 

530 raise KeyError(f"'{key}' not found'") 

531 

532 try: 

533 del self.metadata[key0][".".join(keys)] 

534 except KeyError: 

535 # Report the correct key. 

536 raise KeyError(f"'{key}' not found'") from None 

537 

538 def get_dict(self, key: str) -> NestedMetadataDict: 

539 """Return a possibly-hierarchical nested `dict`. 

540 

541 This implements the `GetDictMetadata` protocol for consistency with 

542 `lsst.daf.base.PropertySet` and `lsst.daf.base.PropertyList`. The 

543 returned `dict` is guaranteed to be a deep copy, not a view. 

544 

545 Parameters 

546 ---------- 

547 key : `str` 

548 String key associated with the mapping. May not have a ``.`` 

549 character. 

550 

551 Returns 

552 ------- 

553 value : `~collections.abc.Mapping` 

554 Possibly-nested mapping, with `str` keys and values that are `int`, 

555 `float`, `str`, `bool`, or another `dict` with the same key and 

556 value types. Will be empty if ``key`` does not exist. 

557 """ 

558 if value := self.get(key): 

559 return value.to_dict() 

560 else: 

561 return {} 

562 

563 def set_dict(self, key: str, value: NestedMetadataDict) -> None: 

564 """Assign a possibly-hierarchical nested `dict`. 

565 

566 This implements the `SetDictMetadata` protocol for consistency with 

567 `lsst.daf.base.PropertySet` and `lsst.daf.base.PropertyList`. 

568 

569 Parameters 

570 ---------- 

571 key : `str` 

572 String key associated with the mapping. May not have a ``.`` 

573 character. 

574 value : `~collections.abc.Mapping` 

575 Possibly-nested mapping, with `str` keys and values that are `int`, 

576 `float`, `str`, `bool`, or another `dict` with the same key and 

577 value types. Nested keys may not have a ``.`` character. 

578 """ 

579 self[key] = value 

580 

581 def _validate_value(self, value: Any) -> tuple[str, Any]: 

582 """Validate the given value. 

583 

584 Parameters 

585 ---------- 

586 value : Any 

587 Value to check. 

588 

589 Returns 

590 ------- 

591 slot_type : `str` 

592 The type of value given. Options are "scalar", "array", "metadata". 

593 item : Any 

594 The item that was given but possibly modified to conform to 

595 the slot type. 

596 

597 Raises 

598 ------ 

599 ValueError 

600 Raised if the value is not a recognized type. 

601 """ 

602 # Test the simplest option first. 

603 value_type = type(value) 

604 if value_type in _ALLOWED_PRIMITIVE_TYPES: 

605 return "scalar", value 

606 

607 if isinstance(value, TaskMetadata): 

608 return "metadata", value 

609 if isinstance(value, Mapping): 

610 return "metadata", self.from_dict(value) 

611 

612 if _isListLike(value): 

613 # For model consistency, need to check that every item in the 

614 # list has the same type. 

615 value = list(value) 

616 

617 type0 = type(value[0]) 

618 for i in value: 

619 if type(i) != type0: 

620 raise ValueError( 

621 "Type mismatch in supplied list. TaskMetadata requires all" 

622 f" elements have same type but see {type(i)} and {type0}." 

623 ) 

624 

625 if type0 not in _ALLOWED_PRIMITIVE_TYPES: 

626 # Must check to see if we got numpy floats or something. 

627 type_cast: type 

628 if isinstance(value[0], numbers.Integral): 

629 type_cast = int 

630 elif isinstance(value[0], numbers.Real): 

631 type_cast = float 

632 else: 

633 raise ValueError( 

634 f"Supplied list has element of type '{type0}'. " 

635 "TaskMetadata can only accept primitive types in lists." 

636 ) 

637 

638 value = [type_cast(v) for v in value] 

639 

640 return "array", value 

641 

642 # Sometimes a numpy number is given. 

643 if isinstance(value, numbers.Integral): 

644 value = int(value) 

645 return "scalar", value 

646 if isinstance(value, numbers.Real): 

647 value = float(value) 

648 return "scalar", value 

649 

650 raise ValueError(f"TaskMetadata does not support values of type {value!r}.") 

651 

652 # Work around the fact that Sphinx chokes on Pydantic docstring formatting, 

653 # when we inherit those docstrings in our public classes. 

654 if "sphinx" in sys.modules: 654 ↛ 656line 654 didn't jump to line 656, because the condition on line 654 was never true

655 

656 def copy(self, *args: Any, **kwargs: Any) -> Any: 

657 """See `pydantic.BaseModel.copy`.""" 

658 return super().copy(*args, **kwargs) 

659 

660 def model_dump(self, *args: Any, **kwargs: Any) -> Any: 

661 """See `pydantic.BaseModel.model_dump`.""" 

662 return super().model_dump(*args, **kwargs) 

663 

664 def model_dump_json(self, *args: Any, **kwargs: Any) -> Any: 

665 """See `pydantic.BaseModel.model_dump_json`.""" 

666 return super().model_dump(*args, **kwargs) 

667 

668 def model_copy(self, *args: Any, **kwargs: Any) -> Any: 

669 """See `pydantic.BaseModel.model_copy`.""" 

670 return super().model_copy(*args, **kwargs) 

671 

672 @classmethod 

673 def model_json_schema(cls, *args: Any, **kwargs: Any) -> Any: 

674 """See `pydantic.BaseModel.model_json_schema`.""" 

675 return super().model_json_schema(*args, **kwargs) 

676 

677 

678# Needed because a TaskMetadata can contain a TaskMetadata. 

679TaskMetadata.model_rebuild()