Coverage for python/lsst/pipe/base/_task_metadata.py: 15%

209 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-23 10:54 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <https://www.gnu.org/licenses/>. 

27 

28__all__ = ["TaskMetadata"] 

29 

30import itertools 

31import numbers 

32import sys 

33from collections.abc import Collection, Iterator, Mapping, Sequence 

34from typing import Any, Protocol 

35 

36from pydantic import BaseModel, Field, StrictBool, StrictFloat, StrictInt, StrictStr 

37 

38# The types allowed in a Task metadata field are restricted 

39# to allow predictable serialization. 

40_ALLOWED_PRIMITIVE_TYPES = (str, float, int, bool) 

41 

42 

43class PropertySetLike(Protocol): 

44 """Protocol that looks like a ``lsst.daf.base.PropertySet``. 

45 

46 Enough of the API is specified to support conversion of a 

47 ``PropertySet`` to a `TaskMetadata`. 

48 """ 

49 

50 def paramNames(self, topLevelOnly: bool = True) -> Collection[str]: 

51 ... 

52 

53 def getArray(self, name: str) -> Any: 

54 ... 

55 

56 

57def _isListLike(v: Any) -> bool: 

58 return isinstance(v, Sequence) and not isinstance(v, str) 

59 

60 

61class TaskMetadata(BaseModel): 

62 """Dict-like object for storing task metadata. 

63 

64 Metadata can be stored at two levels: single task or task plus subtasks. 

65 The later is called full metadata of a task and has a form 

66 

67 topLevelTaskName:subtaskName:subsubtaskName.itemName 

68 

69 Metadata item key of a task (`itemName` above) must not contain `.`, 

70 which serves as a separator in full metadata keys and turns 

71 the value into sub-dictionary. Arbitrary hierarchies are supported. 

72 """ 

73 

74 scalars: dict[str, StrictFloat | StrictInt | StrictBool | StrictStr] = Field(default_factory=dict) 

75 arrays: dict[str, list[StrictFloat] | list[StrictInt] | list[StrictBool] | list[StrictStr]] = Field( 

76 default_factory=dict 

77 ) 

78 metadata: dict[str, "TaskMetadata"] = Field(default_factory=dict) 

79 

80 @classmethod 

81 def from_dict(cls, d: Mapping[str, Any]) -> "TaskMetadata": 

82 """Create a TaskMetadata from a dictionary. 

83 

84 Parameters 

85 ---------- 

86 d : `~collections.abc.Mapping` 

87 Mapping to convert. Can be hierarchical. Any dictionaries 

88 in the hierarchy are converted to `TaskMetadata`. 

89 

90 Returns 

91 ------- 

92 meta : `TaskMetadata` 

93 Newly-constructed metadata. 

94 """ 

95 metadata = cls() 

96 for k, v in d.items(): 

97 metadata[k] = v 

98 return metadata 

99 

100 @classmethod 

101 def from_metadata(cls, ps: PropertySetLike) -> "TaskMetadata": 

102 """Create a TaskMetadata from a PropertySet-like object. 

103 

104 Parameters 

105 ---------- 

106 ps : `PropertySetLike` or `TaskMetadata` 

107 A ``PropertySet``-like object to be transformed to a 

108 `TaskMetadata`. A `TaskMetadata` can be copied using this 

109 class method. 

110 

111 Returns 

112 ------- 

113 tm : `TaskMetadata` 

114 Newly-constructed metadata. 

115 

116 Notes 

117 ----- 

118 Items stored in single-element arrays in the supplied object 

119 will be converted to scalars in the newly-created object. 

120 """ 

121 # Use hierarchical names to assign values from input to output. 

122 # This API exists for both PropertySet and TaskMetadata. 

123 # from_dict() does not work because PropertySet is not declared 

124 # to be a Mapping. 

125 # PropertySet.toDict() is not present in TaskMetadata so is best 

126 # avoided. 

127 metadata = cls() 

128 for key in sorted(ps.paramNames(topLevelOnly=False)): 

129 value = ps.getArray(key) 

130 if len(value) == 1: 

131 value = value[0] 

132 metadata[key] = value 

133 return metadata 

134 

135 def to_dict(self) -> dict[str, Any]: 

136 """Convert the class to a simple dictionary. 

137 

138 Returns 

139 ------- 

140 d : `dict` 

141 Simple dictionary that can contain scalar values, array values 

142 or other dictionary values. 

143 

144 Notes 

145 ----- 

146 Unlike `dict()`, this method hides the model layout and combines 

147 scalars, arrays, and other metadata in the same dictionary. Can be 

148 used when a simple dictionary is needed. Use 

149 `TaskMetadata.from_dict()` to convert it back. 

150 """ 

151 d: dict[str, Any] = {} 

152 d.update(self.scalars) 

153 d.update(self.arrays) 

154 for k, v in self.metadata.items(): 

155 d[k] = v.to_dict() 

156 return d 

157 

158 def add(self, name: str, value: Any) -> None: 

159 """Store a new value, adding to a list if one already exists. 

160 

161 Parameters 

162 ---------- 

163 name : `str` 

164 Name of the metadata property. 

165 value : `~typing.Any` 

166 Metadata property value. 

167 """ 

168 keys = self._getKeys(name) 

169 key0 = keys.pop(0) 

170 if len(keys) == 0: 

171 # If add() is being used, always store the value in the arrays 

172 # property as a list. It's likely there will be another call. 

173 slot_type, value = self._validate_value(value) 

174 if slot_type == "array": 

175 pass 

176 elif slot_type == "scalar": 

177 value = [value] 

178 else: 

179 raise ValueError("add() can only be used for primitive types or sequences of those types.") 

180 

181 if key0 in self.metadata: 

182 raise ValueError(f"Can not add() to key '{name}' since that is a TaskMetadata") 

183 

184 if key0 in self.scalars: 

185 # Convert scalar to array. 

186 # MyPy should be able to figure out that List[Union[T1, T2]] is 

187 # compatible with Union[List[T1], List[T2]] if the list has 

188 # only one element, but it can't. 

189 self.arrays[key0] = [self.scalars.pop(key0)] # type: ignore 

190 

191 if key0 in self.arrays: 

192 # Check that the type is not changing. 

193 if (curtype := type(self.arrays[key0][0])) is not (newtype := type(value[0])): 

194 raise ValueError(f"Type mismatch in add() -- currently {curtype} but adding {newtype}") 

195 self.arrays[key0].extend(value) 

196 else: 

197 self.arrays[key0] = value 

198 

199 return 

200 

201 self.metadata[key0].add(".".join(keys), value) 

202 

203 def getScalar(self, key: str) -> str | int | float | bool: 

204 """Retrieve a scalar item even if the item is a list. 

205 

206 Parameters 

207 ---------- 

208 key : `str` 

209 Item to retrieve. 

210 

211 Returns 

212 ------- 

213 value : `str`, `int`, `float`, or `bool` 

214 Either the value associated with the key or, if the key 

215 corresponds to a list, the last item in the list. 

216 

217 Raises 

218 ------ 

219 KeyError 

220 Raised if the item is not found. 

221 """ 

222 # Used in pipe_tasks. 

223 # getScalar() is the default behavior for __getitem__. 

224 return self[key] 

225 

226 def getArray(self, key: str) -> list[Any]: 

227 """Retrieve an item as a list even if it is a scalar. 

228 

229 Parameters 

230 ---------- 

231 key : `str` 

232 Item to retrieve. 

233 

234 Returns 

235 ------- 

236 values : `list` of any 

237 A list containing the value or values associated with this item. 

238 

239 Raises 

240 ------ 

241 KeyError 

242 Raised if the item is not found. 

243 """ 

244 keys = self._getKeys(key) 

245 key0 = keys.pop(0) 

246 if len(keys) == 0: 

247 if key0 in self.arrays: 

248 return self.arrays[key0] 

249 elif key0 in self.scalars: 

250 return [self.scalars[key0]] 

251 elif key0 in self.metadata: 

252 return [self.metadata[key0]] 

253 raise KeyError(f"'{key}' not found") 

254 

255 try: 

256 return self.metadata[key0].getArray(".".join(keys)) 

257 except KeyError: 

258 # Report the correct key. 

259 raise KeyError(f"'{key}' not found") from None 

260 

261 def names(self) -> set[str]: 

262 """Return the hierarchical keys from the metadata. 

263 

264 Returns 

265 ------- 

266 names : `collections.abc.Set` 

267 A set of all keys, including those from the hierarchy and the 

268 top-level hierarchy. 

269 """ 

270 names = set() 

271 for k, v in self.items(): 

272 names.add(k) # Always include the current level 

273 if isinstance(v, TaskMetadata): 

274 names.update({k + "." + item for item in v.names()}) 

275 return names 

276 

277 def paramNames(self, topLevelOnly: bool) -> set[str]: 

278 """Return hierarchical names. 

279 

280 Parameters 

281 ---------- 

282 topLevelOnly : `bool` 

283 Control whether only top-level items are returned or items 

284 from the hierarchy. 

285 

286 Returns 

287 ------- 

288 paramNames : `set` of `str` 

289 If ``topLevelOnly`` is `True`, returns any keys that are not 

290 part of a hierarchy. If `False` also returns fully-qualified 

291 names from the hierarchy. Keys associated with the top 

292 of a hierarchy are never returned. 

293 """ 

294 # Currently used by the verify package. 

295 paramNames = set() 

296 for k, v in self.items(): 

297 if isinstance(v, TaskMetadata): 

298 if not topLevelOnly: 

299 paramNames.update({k + "." + item for item in v.paramNames(topLevelOnly=topLevelOnly)}) 

300 else: 

301 paramNames.add(k) 

302 return paramNames 

303 

304 @staticmethod 

305 def _getKeys(key: str) -> list[str]: 

306 """Return the key hierarchy. 

307 

308 Parameters 

309 ---------- 

310 key : `str` 

311 The key to analyze. Can be dot-separated. 

312 

313 Returns 

314 ------- 

315 keys : `list` of `str` 

316 The key hierarchy that has been split on ``.``. 

317 

318 Raises 

319 ------ 

320 KeyError 

321 Raised if the key is not a string. 

322 """ 

323 try: 

324 keys = key.split(".") 

325 except Exception: 

326 raise KeyError(f"Invalid key '{key}': only string keys are allowed") from None 

327 return keys 

328 

329 def keys(self) -> tuple[str, ...]: 

330 """Return the top-level keys.""" 

331 return tuple(k for k in self) 

332 

333 def items(self) -> Iterator[tuple[str, Any]]: 

334 """Yield the top-level keys and values.""" 

335 yield from itertools.chain(self.scalars.items(), self.arrays.items(), self.metadata.items()) 

336 

337 def __len__(self) -> int: 

338 """Return the number of items.""" 

339 return len(self.scalars) + len(self.arrays) + len(self.metadata) 

340 

341 # This is actually a Liskov substitution violation, because 

342 # pydantic.BaseModel says __iter__ should return something else. But the 

343 # pydantic docs say to do exactly this to in order to make a mapping-like 

344 # BaseModel, so that's what we do. 

345 def __iter__(self) -> Iterator[str]: # type: ignore 

346 """Return an iterator over each key.""" 

347 # The order of keys is not preserved since items can move 

348 # from scalar to array. 

349 return itertools.chain(iter(self.scalars), iter(self.arrays), iter(self.metadata)) 

350 

351 def __getitem__(self, key: str) -> Any: 

352 """Retrieve the item associated with the key. 

353 

354 Parameters 

355 ---------- 

356 key : `str` 

357 The key to retrieve. Can be dot-separated hierarchical. 

358 

359 Returns 

360 ------- 

361 value : `TaskMetadata`, `float`, `int`, `bool`, `str` 

362 A scalar value. For compatibility with ``PropertySet``, if the key 

363 refers to an array, the final element is returned and not the 

364 array itself. 

365 

366 Raises 

367 ------ 

368 KeyError 

369 Raised if the item is not found. 

370 """ 

371 keys = self._getKeys(key) 

372 key0 = keys.pop(0) 

373 if len(keys) == 0: 

374 if key0 in self.scalars: 

375 return self.scalars[key0] 

376 if key0 in self.metadata: 

377 return self.metadata[key0] 

378 if key0 in self.arrays: 

379 return self.arrays[key0][-1] 

380 raise KeyError(f"'{key}' not found") 

381 # Hierarchical lookup so the top key can only be in the metadata 

382 # property. Trap KeyError and reraise so that the correct key 

383 # in the hierarchy is reported. 

384 try: 

385 # And forward request to that metadata. 

386 return self.metadata[key0][".".join(keys)] 

387 except KeyError: 

388 raise KeyError(f"'{key}' not found") from None 

389 

390 def get(self, key: str, default: Any = None) -> Any: 

391 """Retrieve the item associated with the key or a default. 

392 

393 Parameters 

394 ---------- 

395 key : `str` 

396 The key to retrieve. Can be dot-separated hierarchical. 

397 default : `~typing.Any` 

398 The value to return if the key does not exist. 

399 

400 Returns 

401 ------- 

402 value : `TaskMetadata`, `float`, `int`, `bool`, `str` 

403 A scalar value. If the key refers to an array, the final element 

404 is returned and not the array itself; this is consistent with 

405 `__getitem__` and `PropertySet.get`, but not ``to_dict().get``. 

406 """ 

407 try: 

408 return self[key] 

409 except KeyError: 

410 return default 

411 

412 def __setitem__(self, key: str, item: Any) -> None: 

413 """Store the given item.""" 

414 keys = self._getKeys(key) 

415 key0 = keys.pop(0) 

416 if len(keys) == 0: 

417 slots: dict[str, dict[str, Any]] = { 

418 "array": self.arrays, 

419 "scalar": self.scalars, 

420 "metadata": self.metadata, 

421 } 

422 primary: dict[str, Any] | None = None 

423 slot_type, item = self._validate_value(item) 

424 primary = slots.pop(slot_type, None) 

425 if primary is None: 

426 raise AssertionError(f"Unknown slot type returned from validator: {slot_type}") 

427 

428 # Assign the value to the right place. 

429 primary[key0] = item 

430 for property in slots.values(): 

431 # Remove any other entries. 

432 property.pop(key0, None) 

433 return 

434 

435 # This must be hierarchical so forward to the child TaskMetadata. 

436 if key0 not in self.metadata: 

437 self.metadata[key0] = TaskMetadata() 

438 self.metadata[key0][".".join(keys)] = item 

439 

440 # Ensure we have cleared out anything with the same name elsewhere. 

441 self.scalars.pop(key0, None) 

442 self.arrays.pop(key0, None) 

443 

444 def __contains__(self, key: str) -> bool: 

445 """Determine if the key exists.""" 

446 keys = self._getKeys(key) 

447 key0 = keys.pop(0) 

448 if len(keys) == 0: 

449 return key0 in self.scalars or key0 in self.arrays or key0 in self.metadata 

450 

451 if key0 in self.metadata: 

452 return ".".join(keys) in self.metadata[key0] 

453 return False 

454 

455 def __delitem__(self, key: str) -> None: 

456 """Remove the specified item. 

457 

458 Raises 

459 ------ 

460 KeyError 

461 Raised if the item is not present. 

462 """ 

463 keys = self._getKeys(key) 

464 key0 = keys.pop(0) 

465 if len(keys) == 0: 

466 # MyPy can't figure out that this way to combine the types in the 

467 # tuple is the one that matters, and annotating a local variable 

468 # helps it out. 

469 properties: tuple[dict[str, Any], ...] = (self.scalars, self.arrays, self.metadata) 

470 for property in properties: 

471 if key0 in property: 

472 del property[key0] 

473 return 

474 raise KeyError(f"'{key}' not found'") 

475 

476 try: 

477 del self.metadata[key0][".".join(keys)] 

478 except KeyError: 

479 # Report the correct key. 

480 raise KeyError(f"'{key}' not found'") from None 

481 

482 def _validate_value(self, value: Any) -> tuple[str, Any]: 

483 """Validate the given value. 

484 

485 Parameters 

486 ---------- 

487 value : Any 

488 Value to check. 

489 

490 Returns 

491 ------- 

492 slot_type : `str` 

493 The type of value given. Options are "scalar", "array", "metadata". 

494 item : Any 

495 The item that was given but possibly modified to conform to 

496 the slot type. 

497 

498 Raises 

499 ------ 

500 ValueError 

501 Raised if the value is not a recognized type. 

502 """ 

503 # Test the simplest option first. 

504 value_type = type(value) 

505 if value_type in _ALLOWED_PRIMITIVE_TYPES: 

506 return "scalar", value 

507 

508 if isinstance(value, TaskMetadata): 

509 return "metadata", value 

510 if isinstance(value, Mapping): 

511 return "metadata", self.from_dict(value) 

512 

513 if _isListLike(value): 

514 # For model consistency, need to check that every item in the 

515 # list has the same type. 

516 value = list(value) 

517 

518 type0 = type(value[0]) 

519 for i in value: 

520 if type(i) != type0: 

521 raise ValueError( 

522 "Type mismatch in supplied list. TaskMetadata requires all" 

523 f" elements have same type but see {type(i)} and {type0}." 

524 ) 

525 

526 if type0 not in _ALLOWED_PRIMITIVE_TYPES: 

527 # Must check to see if we got numpy floats or something. 

528 type_cast: type 

529 if isinstance(value[0], numbers.Integral): 

530 type_cast = int 

531 elif isinstance(value[0], numbers.Real): 

532 type_cast = float 

533 else: 

534 raise ValueError( 

535 f"Supplied list has element of type '{type0}'. " 

536 "TaskMetadata can only accept primitive types in lists." 

537 ) 

538 

539 value = [type_cast(v) for v in value] 

540 

541 return "array", value 

542 

543 # Sometimes a numpy number is given. 

544 if isinstance(value, numbers.Integral): 

545 value = int(value) 

546 return "scalar", value 

547 if isinstance(value, numbers.Real): 

548 value = float(value) 

549 return "scalar", value 

550 

551 raise ValueError(f"TaskMetadata does not support values of type {value!r}.") 

552 

553 # Work around the fact that Sphinx chokes on Pydantic docstring formatting, 

554 # when we inherit those docstrings in our public classes. 

555 if "sphinx" in sys.modules: 555 ↛ 557line 555 didn't jump to line 557, because the condition on line 555 was never true

556 

557 def copy(self, *args: Any, **kwargs: Any) -> Any: 

558 """See `pydantic.BaseModel.copy`.""" 

559 return super().copy(*args, **kwargs) 

560 

561 def model_dump(self, *args: Any, **kwargs: Any) -> Any: 

562 """See `pydantic.BaseModel.model_dump`.""" 

563 return super().model_dump(*args, **kwargs) 

564 

565 def model_copy(self, *args: Any, **kwargs: Any) -> Any: 

566 """See `pydantic.BaseModel.model_copy`.""" 

567 return super().model_copy(*args, **kwargs) 

568 

569 @classmethod 

570 def model_json_schema(cls, *args: Any, **kwargs: Any) -> Any: 

571 """See `pydantic.BaseModel.model_json_schema`.""" 

572 return super().model_json_schema(*args, **kwargs) 

573 

574 

575# Needed because a TaskMetadata can contain a TaskMetadata. 

576TaskMetadata.model_rebuild()