Coverage for python/lsst/pipe/base/_task_metadata.py: 18%

219 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-03-23 03:26 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <https://www.gnu.org/licenses/>. 

27 

28__all__ = [ 

29 "TaskMetadata", 

30 "SetDictMetadata", 

31 "GetDictMetadata", 

32 "GetSetDictMetadata", 

33 "NestedMetadataDict", 

34] 

35 

36import itertools 

37import numbers 

38import sys 

39from collections.abc import Collection, Iterator, Mapping, Sequence 

40from typing import Any, Protocol, TypeAlias, Union 

41 

42from pydantic import BaseModel, Field, StrictBool, StrictFloat, StrictInt, StrictStr 

43 

44# The types allowed in a Task metadata field are restricted 

45# to allow predictable serialization. 

46_ALLOWED_PRIMITIVE_TYPES = (str, float, int, bool) 

47 

48# Note that '|' syntax for unions doesn't work when we have to use a string 

49# literal (and we do since it's recursive and not an annotation). 

50NestedMetadataDict: TypeAlias = Mapping[str, Union[str, float, int, bool, "NestedMetadataDict"]] 

51 

52 

53class PropertySetLike(Protocol): 

54 """Protocol that looks like a ``lsst.daf.base.PropertySet``. 

55 

56 Enough of the API is specified to support conversion of a 

57 ``PropertySet`` to a `TaskMetadata`. 

58 """ 

59 

60 def paramNames(self, topLevelOnly: bool = True) -> Collection[str]: ... 60 ↛ exitline 60 didn't jump to line 60, because

61 

62 def getArray(self, name: str) -> Any: ... 62 ↛ exitline 62 didn't jump to line 62, because

63 

64 

65def _isListLike(v: Any) -> bool: 

66 return isinstance(v, Sequence) and not isinstance(v, str) 

67 

68 

69class SetDictMetadata(Protocol): 

70 """Protocol for objects that can be assigned a possibly-nested `dict` of 

71 primitives. 

72 

73 This protocol is satisfied by `TaskMetadata`, `lsst.daf.base.PropertySet`, 

74 and `lsst.daf.base.PropertyList`, providing a consistent way to insert a 

75 dictionary into these objects that avoids their historical idiosyncrasies. 

76 

77 The form in which these entries appear in the object's native keys and 

78 values is implementation-defined. *Empty nested dictionaries may be 

79 dropped, and if the top-level dictionary is empty this method may do 

80 nothing.* 

81 

82 Neither the top-level key nor nested keys may contain ``.`` (period) 

83 characters. 

84 """ 

85 

86 def set_dict(self, key: str, nested: NestedMetadataDict) -> None: ... 86 ↛ exitline 86 didn't jump to line 86, because

87 

88 

89class GetDictMetadata(Protocol): 

90 """Protocol for objects that can extract a possibly-nested mapping of 

91 primitives. 

92 

93 This protocol is satisfied by `TaskMetadata`, `lsst.daf.base.PropertySet`, 

94 and `lsst.daf.base.PropertyList`, providing a consistent way to extract a 

95 dictionary from these objects that avoids their historical idiosyncrasies. 

96 

97 This is guaranteed to work for mappings inserted by 

98 `~SetMapping.set_dict`. It should not be expected to work for values 

99 inserted in other ways. If a value was never inserted with the given key 

100 at all, *an empty `dict` will be returned* (this is a concession to 

101 implementation constraints in `~lsst.daf.base.PropertyList`. 

102 """ 

103 

104 def get_dict(self, key: str) -> NestedMetadataDict: ... 104 ↛ exitline 104 didn't jump to line 104, because

105 

106 

107class GetSetDictMetadata(SetDictMetadata, GetDictMetadata, Protocol): 

108 """Protocol for objects that can assign and extract a possibly-nested 

109 mapping of primitives. 

110 """ 

111 

112 

113class TaskMetadata(BaseModel): 

114 """Dict-like object for storing task metadata. 

115 

116 Metadata can be stored at two levels: single task or task plus subtasks. 

117 The later is called full metadata of a task and has a form 

118 

119 topLevelTaskName:subtaskName:subsubtaskName.itemName 

120 

121 Metadata item key of a task (`itemName` above) must not contain `.`, 

122 which serves as a separator in full metadata keys and turns 

123 the value into sub-dictionary. Arbitrary hierarchies are supported. 

124 """ 

125 

126 scalars: dict[str, StrictFloat | StrictInt | StrictBool | StrictStr] = Field(default_factory=dict) 

127 arrays: dict[str, list[StrictFloat] | list[StrictInt] | list[StrictBool] | list[StrictStr]] = Field( 

128 default_factory=dict 

129 ) 

130 metadata: dict[str, "TaskMetadata"] = Field(default_factory=dict) 

131 

132 @classmethod 

133 def from_dict(cls, d: Mapping[str, Any]) -> "TaskMetadata": 

134 """Create a TaskMetadata from a dictionary. 

135 

136 Parameters 

137 ---------- 

138 d : `~collections.abc.Mapping` 

139 Mapping to convert. Can be hierarchical. Any dictionaries 

140 in the hierarchy are converted to `TaskMetadata`. 

141 

142 Returns 

143 ------- 

144 meta : `TaskMetadata` 

145 Newly-constructed metadata. 

146 """ 

147 metadata = cls() 

148 for k, v in d.items(): 

149 metadata[k] = v 

150 return metadata 

151 

152 @classmethod 

153 def from_metadata(cls, ps: PropertySetLike) -> "TaskMetadata": 

154 """Create a TaskMetadata from a PropertySet-like object. 

155 

156 Parameters 

157 ---------- 

158 ps : `PropertySetLike` or `TaskMetadata` 

159 A ``PropertySet``-like object to be transformed to a 

160 `TaskMetadata`. A `TaskMetadata` can be copied using this 

161 class method. 

162 

163 Returns 

164 ------- 

165 tm : `TaskMetadata` 

166 Newly-constructed metadata. 

167 

168 Notes 

169 ----- 

170 Items stored in single-element arrays in the supplied object 

171 will be converted to scalars in the newly-created object. 

172 """ 

173 # Use hierarchical names to assign values from input to output. 

174 # This API exists for both PropertySet and TaskMetadata. 

175 # from_dict() does not work because PropertySet is not declared 

176 # to be a Mapping. 

177 # PropertySet.toDict() is not present in TaskMetadata so is best 

178 # avoided. 

179 metadata = cls() 

180 for key in sorted(ps.paramNames(topLevelOnly=False)): 

181 value = ps.getArray(key) 

182 if len(value) == 1: 

183 value = value[0] 

184 metadata[key] = value 

185 return metadata 

186 

187 def to_dict(self) -> dict[str, Any]: 

188 """Convert the class to a simple dictionary. 

189 

190 Returns 

191 ------- 

192 d : `dict` 

193 Simple dictionary that can contain scalar values, array values 

194 or other dictionary values. 

195 

196 Notes 

197 ----- 

198 Unlike `dict()`, this method hides the model layout and combines 

199 scalars, arrays, and other metadata in the same dictionary. Can be 

200 used when a simple dictionary is needed. Use 

201 `TaskMetadata.from_dict()` to convert it back. 

202 """ 

203 d: dict[str, Any] = {} 

204 d.update(self.scalars) 

205 d.update(self.arrays) 

206 for k, v in self.metadata.items(): 

207 d[k] = v.to_dict() 

208 return d 

209 

210 def add(self, name: str, value: Any) -> None: 

211 """Store a new value, adding to a list if one already exists. 

212 

213 Parameters 

214 ---------- 

215 name : `str` 

216 Name of the metadata property. 

217 value : `~typing.Any` 

218 Metadata property value. 

219 """ 

220 keys = self._getKeys(name) 

221 key0 = keys.pop(0) 

222 if len(keys) == 0: 

223 # If add() is being used, always store the value in the arrays 

224 # property as a list. It's likely there will be another call. 

225 slot_type, value = self._validate_value(value) 

226 if slot_type == "array": 

227 pass 

228 elif slot_type == "scalar": 

229 value = [value] 

230 else: 

231 raise ValueError("add() can only be used for primitive types or sequences of those types.") 

232 

233 if key0 in self.metadata: 

234 raise ValueError(f"Can not add() to key '{name}' since that is a TaskMetadata") 

235 

236 if key0 in self.scalars: 

237 # Convert scalar to array. 

238 # MyPy should be able to figure out that List[Union[T1, T2]] is 

239 # compatible with Union[List[T1], List[T2]] if the list has 

240 # only one element, but it can't. 

241 self.arrays[key0] = [self.scalars.pop(key0)] # type: ignore 

242 

243 if key0 in self.arrays: 

244 # Check that the type is not changing. 

245 if (curtype := type(self.arrays[key0][0])) is not (newtype := type(value[0])): 

246 raise ValueError(f"Type mismatch in add() -- currently {curtype} but adding {newtype}") 

247 self.arrays[key0].extend(value) 

248 else: 

249 self.arrays[key0] = value 

250 

251 return 

252 

253 self.metadata[key0].add(".".join(keys), value) 

254 

255 def getScalar(self, key: str) -> str | int | float | bool: 

256 """Retrieve a scalar item even if the item is a list. 

257 

258 Parameters 

259 ---------- 

260 key : `str` 

261 Item to retrieve. 

262 

263 Returns 

264 ------- 

265 value : `str`, `int`, `float`, or `bool` 

266 Either the value associated with the key or, if the key 

267 corresponds to a list, the last item in the list. 

268 

269 Raises 

270 ------ 

271 KeyError 

272 Raised if the item is not found. 

273 """ 

274 # Used in pipe_tasks. 

275 # getScalar() is the default behavior for __getitem__. 

276 return self[key] 

277 

278 def getArray(self, key: str) -> list[Any]: 

279 """Retrieve an item as a list even if it is a scalar. 

280 

281 Parameters 

282 ---------- 

283 key : `str` 

284 Item to retrieve. 

285 

286 Returns 

287 ------- 

288 values : `list` of any 

289 A list containing the value or values associated with this item. 

290 

291 Raises 

292 ------ 

293 KeyError 

294 Raised if the item is not found. 

295 """ 

296 keys = self._getKeys(key) 

297 key0 = keys.pop(0) 

298 if len(keys) == 0: 

299 if key0 in self.arrays: 

300 return self.arrays[key0] 

301 elif key0 in self.scalars: 

302 return [self.scalars[key0]] 

303 elif key0 in self.metadata: 

304 return [self.metadata[key0]] 

305 raise KeyError(f"'{key}' not found") 

306 

307 try: 

308 return self.metadata[key0].getArray(".".join(keys)) 

309 except KeyError: 

310 # Report the correct key. 

311 raise KeyError(f"'{key}' not found") from None 

312 

313 def names(self) -> set[str]: 

314 """Return the hierarchical keys from the metadata. 

315 

316 Returns 

317 ------- 

318 names : `collections.abc.Set` 

319 A set of all keys, including those from the hierarchy and the 

320 top-level hierarchy. 

321 """ 

322 names = set() 

323 for k, v in self.items(): 

324 names.add(k) # Always include the current level 

325 if isinstance(v, TaskMetadata): 

326 names.update({k + "." + item for item in v.names()}) 

327 return names 

328 

329 def paramNames(self, topLevelOnly: bool) -> set[str]: 

330 """Return hierarchical names. 

331 

332 Parameters 

333 ---------- 

334 topLevelOnly : `bool` 

335 Control whether only top-level items are returned or items 

336 from the hierarchy. 

337 

338 Returns 

339 ------- 

340 paramNames : `set` of `str` 

341 If ``topLevelOnly`` is `True`, returns any keys that are not 

342 part of a hierarchy. If `False` also returns fully-qualified 

343 names from the hierarchy. Keys associated with the top 

344 of a hierarchy are never returned. 

345 """ 

346 # Currently used by the verify package. 

347 paramNames = set() 

348 for k, v in self.items(): 

349 if isinstance(v, TaskMetadata): 

350 if not topLevelOnly: 

351 paramNames.update({k + "." + item for item in v.paramNames(topLevelOnly=topLevelOnly)}) 

352 else: 

353 paramNames.add(k) 

354 return paramNames 

355 

356 @staticmethod 

357 def _getKeys(key: str) -> list[str]: 

358 """Return the key hierarchy. 

359 

360 Parameters 

361 ---------- 

362 key : `str` 

363 The key to analyze. Can be dot-separated. 

364 

365 Returns 

366 ------- 

367 keys : `list` of `str` 

368 The key hierarchy that has been split on ``.``. 

369 

370 Raises 

371 ------ 

372 KeyError 

373 Raised if the key is not a string. 

374 """ 

375 try: 

376 keys = key.split(".") 

377 except Exception: 

378 raise KeyError(f"Invalid key '{key}': only string keys are allowed") from None 

379 return keys 

380 

381 def keys(self) -> tuple[str, ...]: 

382 """Return the top-level keys.""" 

383 return tuple(k for k in self) 

384 

385 def items(self) -> Iterator[tuple[str, Any]]: 

386 """Yield the top-level keys and values.""" 

387 yield from itertools.chain(self.scalars.items(), self.arrays.items(), self.metadata.items()) 

388 

389 def __len__(self) -> int: 

390 """Return the number of items.""" 

391 return len(self.scalars) + len(self.arrays) + len(self.metadata) 

392 

393 # This is actually a Liskov substitution violation, because 

394 # pydantic.BaseModel says __iter__ should return something else. But the 

395 # pydantic docs say to do exactly this to in order to make a mapping-like 

396 # BaseModel, so that's what we do. 

397 def __iter__(self) -> Iterator[str]: # type: ignore 

398 """Return an iterator over each key.""" 

399 # The order of keys is not preserved since items can move 

400 # from scalar to array. 

401 return itertools.chain(iter(self.scalars), iter(self.arrays), iter(self.metadata)) 

402 

403 def __getitem__(self, key: str) -> Any: 

404 """Retrieve the item associated with the key. 

405 

406 Parameters 

407 ---------- 

408 key : `str` 

409 The key to retrieve. Can be dot-separated hierarchical. 

410 

411 Returns 

412 ------- 

413 value : `TaskMetadata`, `float`, `int`, `bool`, `str` 

414 A scalar value. For compatibility with ``PropertySet``, if the key 

415 refers to an array, the final element is returned and not the 

416 array itself. 

417 

418 Raises 

419 ------ 

420 KeyError 

421 Raised if the item is not found. 

422 """ 

423 keys = self._getKeys(key) 

424 key0 = keys.pop(0) 

425 if len(keys) == 0: 

426 if key0 in self.scalars: 

427 return self.scalars[key0] 

428 if key0 in self.metadata: 

429 return self.metadata[key0] 

430 if key0 in self.arrays: 

431 return self.arrays[key0][-1] 

432 raise KeyError(f"'{key}' not found") 

433 # Hierarchical lookup so the top key can only be in the metadata 

434 # property. Trap KeyError and reraise so that the correct key 

435 # in the hierarchy is reported. 

436 try: 

437 # And forward request to that metadata. 

438 return self.metadata[key0][".".join(keys)] 

439 except KeyError: 

440 raise KeyError(f"'{key}' not found") from None 

441 

442 def get(self, key: str, default: Any = None) -> Any: 

443 """Retrieve the item associated with the key or a default. 

444 

445 Parameters 

446 ---------- 

447 key : `str` 

448 The key to retrieve. Can be dot-separated hierarchical. 

449 default : `~typing.Any` 

450 The value to return if the key does not exist. 

451 

452 Returns 

453 ------- 

454 value : `TaskMetadata`, `float`, `int`, `bool`, `str` 

455 A scalar value. If the key refers to an array, the final element 

456 is returned and not the array itself; this is consistent with 

457 `__getitem__` and `PropertySet.get`, but not ``to_dict().get``. 

458 """ 

459 try: 

460 return self[key] 

461 except KeyError: 

462 return default 

463 

464 def __setitem__(self, key: str, item: Any) -> None: 

465 """Store the given item.""" 

466 keys = self._getKeys(key) 

467 key0 = keys.pop(0) 

468 if len(keys) == 0: 

469 slots: dict[str, dict[str, Any]] = { 

470 "array": self.arrays, 

471 "scalar": self.scalars, 

472 "metadata": self.metadata, 

473 } 

474 primary: dict[str, Any] | None = None 

475 slot_type, item = self._validate_value(item) 

476 primary = slots.pop(slot_type, None) 

477 if primary is None: 

478 raise AssertionError(f"Unknown slot type returned from validator: {slot_type}") 

479 

480 # Assign the value to the right place. 

481 primary[key0] = item 

482 for property in slots.values(): 

483 # Remove any other entries. 

484 property.pop(key0, None) 

485 return 

486 

487 # This must be hierarchical so forward to the child TaskMetadata. 

488 if key0 not in self.metadata: 

489 self.metadata[key0] = TaskMetadata() 

490 self.metadata[key0][".".join(keys)] = item 

491 

492 # Ensure we have cleared out anything with the same name elsewhere. 

493 self.scalars.pop(key0, None) 

494 self.arrays.pop(key0, None) 

495 

496 def __contains__(self, key: str) -> bool: 

497 """Determine if the key exists.""" 

498 keys = self._getKeys(key) 

499 key0 = keys.pop(0) 

500 if len(keys) == 0: 

501 return key0 in self.scalars or key0 in self.arrays or key0 in self.metadata 

502 

503 if key0 in self.metadata: 

504 return ".".join(keys) in self.metadata[key0] 

505 return False 

506 

507 def __delitem__(self, key: str) -> None: 

508 """Remove the specified item. 

509 

510 Raises 

511 ------ 

512 KeyError 

513 Raised if the item is not present. 

514 """ 

515 keys = self._getKeys(key) 

516 key0 = keys.pop(0) 

517 if len(keys) == 0: 

518 # MyPy can't figure out that this way to combine the types in the 

519 # tuple is the one that matters, and annotating a local variable 

520 # helps it out. 

521 properties: tuple[dict[str, Any], ...] = (self.scalars, self.arrays, self.metadata) 

522 for property in properties: 

523 if key0 in property: 

524 del property[key0] 

525 return 

526 raise KeyError(f"'{key}' not found'") 

527 

528 try: 

529 del self.metadata[key0][".".join(keys)] 

530 except KeyError: 

531 # Report the correct key. 

532 raise KeyError(f"'{key}' not found'") from None 

533 

534 def get_dict(self, key: str) -> NestedMetadataDict: 

535 """Return a possibly-hierarchical nested `dict`. 

536 

537 This implements the `GetDictMetadata` protocol for consistency with 

538 `lsst.daf.base.PropertySet` and `lsst.daf.base.PropertyList`. The 

539 returned `dict` is guaranteed to be a deep copy, not a view. 

540 

541 Parameters 

542 ---------- 

543 key : `str` 

544 String key associated with the mapping. May not have a ``.`` 

545 character. 

546 

547 Returns 

548 ------- 

549 value : `~collections.abc.Mapping` 

550 Possibly-nested mapping, with `str` keys and values that are `int`, 

551 `float`, `str`, `bool`, or another `dict` with the same key and 

552 value types. Will be empty if ``key`` does not exist. 

553 """ 

554 if value := self.get(key): 

555 return value.to_dict() 

556 else: 

557 return {} 

558 

559 def set_dict(self, key: str, value: NestedMetadataDict) -> None: 

560 """Assign a possibly-hierarchical nested `dict`. 

561 

562 This implements the `SetDictMetadata` protocol for consistency with 

563 `lsst.daf.base.PropertySet` and `lsst.daf.base.PropertyList`. 

564 

565 Parameters 

566 ---------- 

567 key : `str` 

568 String key associated with the mapping. May not have a ``.`` 

569 character. 

570 value : `~collections.abc.Mapping` 

571 Possibly-nested mapping, with `str` keys and values that are `int`, 

572 `float`, `str`, `bool`, or another `dict` with the same key and 

573 value types. Nested keys may not have a ``.`` character. 

574 """ 

575 self[key] = value 

576 

577 def _validate_value(self, value: Any) -> tuple[str, Any]: 

578 """Validate the given value. 

579 

580 Parameters 

581 ---------- 

582 value : Any 

583 Value to check. 

584 

585 Returns 

586 ------- 

587 slot_type : `str` 

588 The type of value given. Options are "scalar", "array", "metadata". 

589 item : Any 

590 The item that was given but possibly modified to conform to 

591 the slot type. 

592 

593 Raises 

594 ------ 

595 ValueError 

596 Raised if the value is not a recognized type. 

597 """ 

598 # Test the simplest option first. 

599 value_type = type(value) 

600 if value_type in _ALLOWED_PRIMITIVE_TYPES: 

601 return "scalar", value 

602 

603 if isinstance(value, TaskMetadata): 

604 return "metadata", value 

605 if isinstance(value, Mapping): 

606 return "metadata", self.from_dict(value) 

607 

608 if _isListLike(value): 

609 # For model consistency, need to check that every item in the 

610 # list has the same type. 

611 value = list(value) 

612 

613 type0 = type(value[0]) 

614 for i in value: 

615 if type(i) != type0: 

616 raise ValueError( 

617 "Type mismatch in supplied list. TaskMetadata requires all" 

618 f" elements have same type but see {type(i)} and {type0}." 

619 ) 

620 

621 if type0 not in _ALLOWED_PRIMITIVE_TYPES: 

622 # Must check to see if we got numpy floats or something. 

623 type_cast: type 

624 if isinstance(value[0], numbers.Integral): 

625 type_cast = int 

626 elif isinstance(value[0], numbers.Real): 

627 type_cast = float 

628 else: 

629 raise ValueError( 

630 f"Supplied list has element of type '{type0}'. " 

631 "TaskMetadata can only accept primitive types in lists." 

632 ) 

633 

634 value = [type_cast(v) for v in value] 

635 

636 return "array", value 

637 

638 # Sometimes a numpy number is given. 

639 if isinstance(value, numbers.Integral): 

640 value = int(value) 

641 return "scalar", value 

642 if isinstance(value, numbers.Real): 

643 value = float(value) 

644 return "scalar", value 

645 

646 raise ValueError(f"TaskMetadata does not support values of type {value!r}.") 

647 

648 # Work around the fact that Sphinx chokes on Pydantic docstring formatting, 

649 # when we inherit those docstrings in our public classes. 

650 if "sphinx" in sys.modules: 650 ↛ 652line 650 didn't jump to line 652, because the condition on line 650 was never true

651 

652 def copy(self, *args: Any, **kwargs: Any) -> Any: 

653 """See `pydantic.BaseModel.copy`.""" 

654 return super().copy(*args, **kwargs) 

655 

656 def model_dump(self, *args: Any, **kwargs: Any) -> Any: 

657 """See `pydantic.BaseModel.model_dump`.""" 

658 return super().model_dump(*args, **kwargs) 

659 

660 def model_copy(self, *args: Any, **kwargs: Any) -> Any: 

661 """See `pydantic.BaseModel.model_copy`.""" 

662 return super().model_copy(*args, **kwargs) 

663 

664 @classmethod 

665 def model_json_schema(cls, *args: Any, **kwargs: Any) -> Any: 

666 """See `pydantic.BaseModel.model_json_schema`.""" 

667 return super().model_json_schema(*args, **kwargs) 

668 

669 

670# Needed because a TaskMetadata can contain a TaskMetadata. 

671TaskMetadata.model_rebuild()