Coverage for python/lsst/pipe/base/_task_metadata.py: 16%

219 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2022-08-26 02:35 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ["TaskMetadata"] 

23 

24import itertools 

25import numbers 

26import warnings 

27from collections.abc import Sequence 

28from typing import Any, Collection, Dict, Iterator, List, Mapping, Optional, Protocol, Set, Tuple, Union 

29 

30from deprecated.sphinx import deprecated 

31from pydantic import BaseModel, Field, StrictBool, StrictFloat, StrictInt, StrictStr 

32 

33_DEPRECATION_REASON = "Will be removed after v25." 

34_DEPRECATION_VERSION = "v24" 

35 

36# The types allowed in a Task metadata field are restricted 

37# to allow predictable serialization. 

38_ALLOWED_PRIMITIVE_TYPES = (str, float, int, bool) 

39 

40 

41class PropertySetLike(Protocol): 

42 """Protocol that looks like a ``lsst.daf.base.PropertySet`` 

43 

44 Enough of the API is specified to support conversion of a 

45 ``PropertySet`` to a `TaskMetadata`. 

46 """ 

47 

48 def paramNames(self, topLevelOnly: bool = True) -> Collection[str]: 

49 ... 

50 

51 def getArray(self, name: str) -> Any: 

52 ... 

53 

54 

55def _isListLike(v: Any) -> bool: 

56 return isinstance(v, Sequence) and not isinstance(v, str) 

57 

58 

59class TaskMetadata(BaseModel): 

60 """Dict-like object for storing task metadata. 

61 

62 Metadata can be stored at two levels: single task or task plus subtasks. 

63 The later is called full metadata of a task and has a form 

64 

65 topLevelTaskName:subtaskName:subsubtaskName.itemName 

66 

67 Metadata item key of a task (`itemName` above) must not contain `.`, 

68 which serves as a separator in full metadata keys and turns 

69 the value into sub-dictionary. Arbitrary hierarchies are supported. 

70 

71 Deprecated methods are for compatibility with 

72 the predecessor containers. 

73 """ 

74 

75 scalars: Dict[str, Union[StrictFloat, StrictInt, StrictBool, StrictStr]] = Field(default_factory=dict) 

76 arrays: Dict[str, Union[List[StrictFloat], List[StrictInt], List[StrictBool], List[StrictStr]]] = Field( 

77 default_factory=dict 

78 ) 

79 metadata: Dict[str, "TaskMetadata"] = Field(default_factory=dict) 

80 

81 @classmethod 

82 def from_dict(cls, d: Mapping[str, Any]) -> "TaskMetadata": 

83 """Create a TaskMetadata from a dictionary. 

84 

85 Parameters 

86 ---------- 

87 d : `Mapping` 

88 Mapping to convert. Can be hierarchical. Any dictionaries 

89 in the hierarchy are converted to `TaskMetadata`. 

90 

91 Returns 

92 ------- 

93 meta : `TaskMetadata` 

94 Newly-constructed metadata. 

95 """ 

96 metadata = cls() 

97 for k, v in d.items(): 

98 metadata[k] = v 

99 return metadata 

100 

101 @classmethod 

102 def from_metadata(cls, ps: PropertySetLike) -> "TaskMetadata": 

103 """Create a TaskMetadata from a PropertySet-like object. 

104 

105 Parameters 

106 ---------- 

107 ps : `PropertySetLike` or `TaskMetadata` 

108 A ``PropertySet``-like object to be transformed to a 

109 `TaskMetadata`. A `TaskMetadata` can be copied using this 

110 class method. 

111 

112 Returns 

113 ------- 

114 tm : `TaskMetadata` 

115 Newly-constructed metadata. 

116 

117 Notes 

118 ----- 

119 Items stored in single-element arrays in the supplied object 

120 will be converted to scalars in the newly-created object. 

121 """ 

122 # Use hierarchical names to assign values from input to output. 

123 # This API exists for both PropertySet and TaskMetadata. 

124 # from_dict() does not work because PropertySet is not declared 

125 # to be a Mapping. 

126 # PropertySet.toDict() is not present in TaskMetadata so is best 

127 # avoided. 

128 metadata = cls() 

129 for key in sorted(ps.paramNames(topLevelOnly=False)): 

130 value = ps.getArray(key) 

131 if len(value) == 1: 

132 value = value[0] 

133 metadata[key] = value 

134 return metadata 

135 

136 def to_dict(self) -> Dict[str, Any]: 

137 """Convert the class to a simple dictionary. 

138 

139 Returns 

140 ------- 

141 d : `dict` 

142 Simple dictionary that can contain scalar values, array values 

143 or other dictionary values. 

144 

145 Notes 

146 ----- 

147 Unlike `dict()`, this method hides the model layout and combines 

148 scalars, arrays, and other metadata in the same dictionary. Can be 

149 used when a simple dictionary is needed. Use 

150 `TaskMetadata.from_dict()` to convert it back. 

151 """ 

152 d: Dict[str, Any] = {} 

153 d.update(self.scalars) 

154 d.update(self.arrays) 

155 for k, v in self.metadata.items(): 

156 d[k] = v.to_dict() 

157 return d 

158 

159 def add(self, name: str, value: Any) -> None: 

160 """Store a new value, adding to a list if one already exists. 

161 

162 Parameters 

163 ---------- 

164 name : `str` 

165 Name of the metadata property. 

166 value 

167 Metadata property value. 

168 """ 

169 keys = self._getKeys(name) 

170 key0 = keys.pop(0) 

171 if len(keys) == 0: 

172 

173 # If add() is being used, always store the value in the arrays 

174 # property as a list. It's likely there will be another call. 

175 slot_type, value = self._validate_value(value) 

176 if slot_type == "array": 

177 pass 

178 elif slot_type == "scalar": 

179 value = [value] 

180 else: 

181 raise ValueError("add() can only be used for primitive types or sequences of those types.") 

182 

183 if key0 in self.metadata: 

184 raise ValueError(f"Can not add() to key '{name}' since that is a TaskMetadata") 

185 

186 if key0 in self.scalars: 

187 # Convert scalar to array. 

188 # MyPy should be able to figure out that List[Union[T1, T2]] is 

189 # compatible with Union[List[T1], List[T2]] if the list has 

190 # only one element, but it can't. 

191 self.arrays[key0] = [self.scalars.pop(key0)] # type: ignore 

192 

193 if key0 in self.arrays: 

194 # Check that the type is not changing. 

195 if (curtype := type(self.arrays[key0][0])) is not (newtype := type(value[0])): 

196 raise ValueError(f"Type mismatch in add() -- currently {curtype} but adding {newtype}") 

197 self.arrays[key0].extend(value) 

198 else: 

199 self.arrays[key0] = value 

200 

201 return 

202 

203 self.metadata[key0].add(".".join(keys), value) 

204 

205 @deprecated( 

206 reason="Cast the return value to float explicitly. " + _DEPRECATION_REASON, 

207 version=_DEPRECATION_VERSION, 

208 category=FutureWarning, 

209 ) 

210 def getAsDouble(self, key: str) -> float: 

211 """Return the value cast to a `float`. 

212 

213 Parameters 

214 ---------- 

215 key : `str` 

216 Item to return. Can be dot-separated hierarchical. 

217 

218 Returns 

219 ------- 

220 value : `float` 

221 The value cast to a `float`. 

222 

223 Raises 

224 ------ 

225 KeyError 

226 Raised if the item is not found. 

227 """ 

228 return float(self.__getitem__(key)) 

229 

230 def getScalar(self, key: str) -> Union[str, int, float, bool]: 

231 """Retrieve a scalar item even if the item is a list. 

232 

233 Parameters 

234 ---------- 

235 key : `str` 

236 Item to retrieve. 

237 

238 Returns 

239 ------- 

240 value : `str`, `int`, `float`, or `bool` 

241 Either the value associated with the key or, if the key 

242 corresponds to a list, the last item in the list. 

243 

244 Raises 

245 ------ 

246 KeyError 

247 Raised if the item is not found. 

248 """ 

249 # Used in pipe_tasks. 

250 # getScalar() is the default behavior for __getitem__. 

251 return self[key] 

252 

253 def getArray(self, key: str) -> List[Any]: 

254 """Retrieve an item as a list even if it is a scalar. 

255 

256 Parameters 

257 ---------- 

258 key : `str` 

259 Item to retrieve. 

260 

261 Returns 

262 ------- 

263 values : `list` of any 

264 A list containing the value or values associated with this item. 

265 

266 Raises 

267 ------ 

268 KeyError 

269 Raised if the item is not found. 

270 """ 

271 keys = self._getKeys(key) 

272 key0 = keys.pop(0) 

273 if len(keys) == 0: 

274 if key0 in self.arrays: 

275 return self.arrays[key0] 

276 elif key0 in self.scalars: 

277 return [self.scalars[key0]] 

278 elif key0 in self.metadata: 

279 return [self.metadata[key0]] 

280 raise KeyError(f"'{key}' not found") 

281 

282 try: 

283 return self.metadata[key0].getArray(".".join(keys)) 

284 except KeyError: 

285 # Report the correct key. 

286 raise KeyError(f"'{key}' not found") from None 

287 

288 def names(self, topLevelOnly: bool = True) -> Set[str]: 

289 """Return the hierarchical keys from the metadata. 

290 

291 Parameters 

292 ---------- 

293 topLevelOnly : `bool` 

294 If true, return top-level keys, otherwise full metadata item keys. 

295 

296 Returns 

297 ------- 

298 names : `collection.abc.Set` 

299 A set of top-level keys or full metadata item keys, including 

300 the top-level keys. 

301 

302 Notes 

303 ----- 

304 Should never be called in new code with ``topLevelOnly`` set to `True` 

305 -- this is equivalent to asking for the keys and is the default 

306 when iterating through the task metadata. In this case a deprecation 

307 message will be issued and the ability will raise an exception 

308 in a future release. 

309 

310 When ``topLevelOnly`` is `False` all keys, including those from the 

311 hierarchy and the top-level hierarchy, are returned. 

312 """ 

313 if topLevelOnly: 

314 warnings.warn("Use keys() instead. " + _DEPRECATION_REASON, FutureWarning) 

315 return set(self.keys()) 

316 else: 

317 names = set() 

318 for k, v in self.items(): 

319 names.add(k) # Always include the current level 

320 if isinstance(v, TaskMetadata): 

321 names.update({k + "." + item for item in v.names(topLevelOnly=topLevelOnly)}) 

322 return names 

323 

324 def paramNames(self, topLevelOnly: bool) -> Set[str]: 

325 """Return hierarchical names. 

326 

327 Parameters 

328 ---------- 

329 topLevelOnly : `bool` 

330 Control whether only top-level items are returned or items 

331 from the hierarchy. 

332 

333 Returns 

334 ------- 

335 paramNames : `set` of `str` 

336 If ``topLevelOnly`` is `True`, returns any keys that are not 

337 part of a hierarchy. If `False` also returns fully-qualified 

338 names from the hierarchy. Keys associated with the top 

339 of a hierarchy are never returned. 

340 """ 

341 # Currently used by the verify package. 

342 paramNames = set() 

343 for k, v in self.items(): 

344 if isinstance(v, TaskMetadata): 

345 if not topLevelOnly: 

346 paramNames.update({k + "." + item for item in v.paramNames(topLevelOnly=topLevelOnly)}) 

347 else: 

348 paramNames.add(k) 

349 return paramNames 

350 

351 @deprecated( 

352 reason="Use standard assignment syntax. " + _DEPRECATION_REASON, 

353 version=_DEPRECATION_VERSION, 

354 category=FutureWarning, 

355 ) 

356 def set(self, key: str, item: Any) -> None: 

357 """Set the value of the supplied key.""" 

358 self.__setitem__(key, item) 

359 

360 @deprecated( 

361 reason="Use standard del dict syntax. " + _DEPRECATION_REASON, 

362 version=_DEPRECATION_VERSION, 

363 category=FutureWarning, 

364 ) 

365 def remove(self, key: str) -> None: 

366 """Remove the item without raising if absent.""" 

367 try: 

368 self.__delitem__(key) 

369 except KeyError: 

370 # The PropertySet.remove() should always work. 

371 pass 

372 

373 @staticmethod 

374 def _getKeys(key: str) -> List[str]: 

375 """Return the key hierarchy. 

376 

377 Parameters 

378 ---------- 

379 key : `str` 

380 The key to analyze. Can be dot-separated. 

381 

382 Returns 

383 ------- 

384 keys : `list` of `str` 

385 The key hierarchy that has been split on ``.``. 

386 

387 Raises 

388 ------ 

389 KeyError 

390 Raised if the key is not a string. 

391 """ 

392 try: 

393 keys = key.split(".") 

394 except Exception: 

395 raise KeyError(f"Invalid key '{key}': only string keys are allowed") from None 

396 return keys 

397 

398 def keys(self) -> Tuple[str, ...]: 

399 """Return the top-level keys.""" 

400 return tuple(k for k in self) 

401 

402 def items(self) -> Iterator[Tuple[str, Any]]: 

403 """Yield the top-level keys and values.""" 

404 for k, v in itertools.chain(self.scalars.items(), self.arrays.items(), self.metadata.items()): 

405 yield (k, v) 

406 

407 def __len__(self) -> int: 

408 """Return the number of items.""" 

409 return len(self.scalars) + len(self.arrays) + len(self.metadata) 

410 

411 # This is actually a Liskov substitution violation, because 

412 # pydantic.BaseModel says __iter__ should return something else. But the 

413 # pydantic docs say to do exactly this to in order to make a mapping-like 

414 # BaseModel, so that's what we do. 

415 def __iter__(self) -> Iterator[str]: # type: ignore 

416 """Return an iterator over each key.""" 

417 # The order of keys is not preserved since items can move 

418 # from scalar to array. 

419 return itertools.chain(iter(self.scalars), iter(self.arrays), iter(self.metadata)) 

420 

421 def __getitem__(self, key: str) -> Any: 

422 """Retrieve the item associated with the key. 

423 

424 Parameters 

425 ---------- 

426 key : `str` 

427 The key to retrieve. Can be dot-separated hierarchical. 

428 

429 Returns 

430 ------- 

431 value : `TaskMetadata`, `float`, `int`, `bool`, `str` 

432 A scalar value. For compatibility with ``PropertySet``, if the key 

433 refers to an array, the final element is returned and not the 

434 array itself. 

435 

436 Raises 

437 ------ 

438 KeyError 

439 Raised if the item is not found. 

440 """ 

441 keys = self._getKeys(key) 

442 key0 = keys.pop(0) 

443 if len(keys) == 0: 

444 if key0 in self.scalars: 

445 return self.scalars[key0] 

446 if key0 in self.metadata: 

447 return self.metadata[key0] 

448 if key0 in self.arrays: 

449 return self.arrays[key0][-1] 

450 raise KeyError(f"'{key}' not found") 

451 # Hierarchical lookup so the top key can only be in the metadata 

452 # property. Trap KeyError and reraise so that the correct key 

453 # in the hierarchy is reported. 

454 try: 

455 # And forward request to that metadata. 

456 return self.metadata[key0][".".join(keys)] 

457 except KeyError: 

458 raise KeyError(f"'{key}' not found") from None 

459 

460 def get(self, key: str, default: Any = None) -> Any: 

461 """Retrieve the item associated with the key or a default. 

462 

463 Parameters 

464 ---------- 

465 key : `str` 

466 The key to retrieve. Can be dot-separated hierarchical. 

467 default 

468 The value to return if the key doesnot exist. 

469 

470 Returns 

471 ------- 

472 value : `TaskMetadata`, `float`, `int`, `bool`, `str` 

473 A scalar value. If the key refers to an array, the final element 

474 is returned and not the array itself; this is consistent with 

475 `__getitem__` and `PropertySet.get`, but not ``to_dict().get``. 

476 """ 

477 try: 

478 return self[key] 

479 except KeyError: 

480 return default 

481 

482 def __setitem__(self, key: str, item: Any) -> None: 

483 """Store the given item.""" 

484 keys = self._getKeys(key) 

485 key0 = keys.pop(0) 

486 if len(keys) == 0: 

487 slots: Dict[str, Dict[str, Any]] = { 

488 "array": self.arrays, 

489 "scalar": self.scalars, 

490 "metadata": self.metadata, 

491 } 

492 primary: Optional[Dict[str, Any]] = None 

493 slot_type, item = self._validate_value(item) 

494 primary = slots.pop(slot_type, None) 

495 if primary is None: 

496 raise AssertionError(f"Unknown slot type returned from validator: {slot_type}") 

497 

498 # Assign the value to the right place. 

499 primary[key0] = item 

500 for property in slots.values(): 

501 # Remove any other entries. 

502 property.pop(key0, None) 

503 return 

504 

505 # This must be hierarchical so forward to the child TaskMetadata. 

506 if key0 not in self.metadata: 

507 self.metadata[key0] = TaskMetadata() 

508 self.metadata[key0][".".join(keys)] = item 

509 

510 # Ensure we have cleared out anything with the same name elsewhere. 

511 self.scalars.pop(key0, None) 

512 self.arrays.pop(key0, None) 

513 

514 def __contains__(self, key: str) -> bool: 

515 """Determine if the key exists.""" 

516 keys = self._getKeys(key) 

517 key0 = keys.pop(0) 

518 if len(keys) == 0: 

519 return key0 in self.scalars or key0 in self.arrays or key0 in self.metadata 

520 

521 if key0 in self.metadata: 

522 return ".".join(keys) in self.metadata[key0] 

523 return False 

524 

525 def __delitem__(self, key: str) -> None: 

526 """Remove the specified item. 

527 

528 Raises 

529 ------ 

530 KeyError 

531 Raised if the item is not present. 

532 """ 

533 keys = self._getKeys(key) 

534 key0 = keys.pop(0) 

535 if len(keys) == 0: 

536 # MyPy can't figure out that this way to combine the types in the 

537 # tuple is the one that matters, and annotating a local variable 

538 # helps it out. 

539 properties: Tuple[Dict[str, Any], ...] = (self.scalars, self.arrays, self.metadata) 

540 for property in properties: 

541 if key0 in property: 

542 del property[key0] 

543 return 

544 raise KeyError(f"'{key}' not found'") 

545 

546 try: 

547 del self.metadata[key0][".".join(keys)] 

548 except KeyError: 

549 # Report the correct key. 

550 raise KeyError(f"'{key}' not found'") from None 

551 

552 def _validate_value(self, value: Any) -> Tuple[str, Any]: 

553 """Validate the given value. 

554 

555 Parameters 

556 ---------- 

557 value : Any 

558 Value to check. 

559 

560 Returns 

561 ------- 

562 slot_type : `str` 

563 The type of value given. Options are "scalar", "array", "metadata". 

564 item : Any 

565 The item that was given but possibly modified to conform to 

566 the slot type. 

567 

568 Raises 

569 ------ 

570 ValueError 

571 Raised if the value is not a recognized type. 

572 """ 

573 # Test the simplest option first. 

574 value_type = type(value) 

575 if value_type in _ALLOWED_PRIMITIVE_TYPES: 

576 return "scalar", value 

577 

578 if isinstance(value, TaskMetadata): 

579 return "metadata", value 

580 if isinstance(value, Mapping): 

581 return "metadata", self.from_dict(value) 

582 

583 if _isListLike(value): 

584 # For model consistency, need to check that every item in the 

585 # list has the same type. 

586 value = list(value) 

587 

588 type0 = type(value[0]) 

589 for i in value: 

590 if type(i) != type0: 

591 raise ValueError( 

592 "Type mismatch in supplied list. TaskMetadata requires all" 

593 f" elements have same type but see {type(i)} and {type0}." 

594 ) 

595 

596 if type0 not in _ALLOWED_PRIMITIVE_TYPES: 

597 # Must check to see if we got numpy floats or something. 

598 type_cast: type 

599 if isinstance(value[0], numbers.Integral): 

600 type_cast = int 

601 elif isinstance(value[0], numbers.Real): 

602 type_cast = float 

603 else: 

604 raise ValueError( 

605 f"Supplied list has element of type '{type0}'. " 

606 "TaskMetadata can only accept primitive types in lists." 

607 ) 

608 

609 value = [type_cast(v) for v in value] 

610 

611 return "array", value 

612 

613 # Sometimes a numpy number is given. 

614 if isinstance(value, numbers.Integral): 

615 value = int(value) 

616 return "scalar", value 

617 if isinstance(value, numbers.Real): 

618 value = float(value) 

619 return "scalar", value 

620 

621 raise ValueError(f"TaskMetadata does not support values of type {value!r}.") 

622 

623 

624# Needed because a TaskMetadata can contain a TaskMetadata. 

625TaskMetadata.update_forward_refs()