Coverage for python/lsst/pipe/base/_task_metadata.py: 15%
205 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-17 10:52 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-17 10:52 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <https://www.gnu.org/licenses/>.
28__all__ = ["TaskMetadata"]
30import itertools
31import numbers
32import warnings
33from collections.abc import Collection, Iterator, Mapping, Sequence
34from typing import Any, Protocol
36from lsst.daf.butler._compat import _BaseModelCompat
37from lsst.utils.introspection import find_outside_stacklevel
38from pydantic import Field, StrictBool, StrictFloat, StrictInt, StrictStr
40# The types allowed in a Task metadata field are restricted
41# to allow predictable serialization.
42_ALLOWED_PRIMITIVE_TYPES = (str, float, int, bool)
45class PropertySetLike(Protocol):
46 """Protocol that looks like a ``lsst.daf.base.PropertySet``
48 Enough of the API is specified to support conversion of a
49 ``PropertySet`` to a `TaskMetadata`.
50 """
52 def paramNames(self, topLevelOnly: bool = True) -> Collection[str]:
53 ...
55 def getArray(self, name: str) -> Any:
56 ...
59def _isListLike(v: Any) -> bool:
60 return isinstance(v, Sequence) and not isinstance(v, str)
63class TaskMetadata(_BaseModelCompat):
64 """Dict-like object for storing task metadata.
66 Metadata can be stored at two levels: single task or task plus subtasks.
67 The later is called full metadata of a task and has a form
69 topLevelTaskName:subtaskName:subsubtaskName.itemName
71 Metadata item key of a task (`itemName` above) must not contain `.`,
72 which serves as a separator in full metadata keys and turns
73 the value into sub-dictionary. Arbitrary hierarchies are supported.
74 """
76 scalars: dict[str, StrictFloat | StrictInt | StrictBool | StrictStr] = Field(default_factory=dict)
77 arrays: dict[str, list[StrictFloat] | list[StrictInt] | list[StrictBool] | list[StrictStr]] = Field(
78 default_factory=dict
79 )
80 metadata: dict[str, "TaskMetadata"] = Field(default_factory=dict)
82 @classmethod
83 def from_dict(cls, d: Mapping[str, Any]) -> "TaskMetadata":
84 """Create a TaskMetadata from a dictionary.
86 Parameters
87 ----------
88 d : `~collections.abc.Mapping`
89 Mapping to convert. Can be hierarchical. Any dictionaries
90 in the hierarchy are converted to `TaskMetadata`.
92 Returns
93 -------
94 meta : `TaskMetadata`
95 Newly-constructed metadata.
96 """
97 metadata = cls()
98 for k, v in d.items():
99 metadata[k] = v
100 return metadata
102 @classmethod
103 def from_metadata(cls, ps: PropertySetLike) -> "TaskMetadata":
104 """Create a TaskMetadata from a PropertySet-like object.
106 Parameters
107 ----------
108 ps : `PropertySetLike` or `TaskMetadata`
109 A ``PropertySet``-like object to be transformed to a
110 `TaskMetadata`. A `TaskMetadata` can be copied using this
111 class method.
113 Returns
114 -------
115 tm : `TaskMetadata`
116 Newly-constructed metadata.
118 Notes
119 -----
120 Items stored in single-element arrays in the supplied object
121 will be converted to scalars in the newly-created object.
122 """
123 # Use hierarchical names to assign values from input to output.
124 # This API exists for both PropertySet and TaskMetadata.
125 # from_dict() does not work because PropertySet is not declared
126 # to be a Mapping.
127 # PropertySet.toDict() is not present in TaskMetadata so is best
128 # avoided.
129 metadata = cls()
130 for key in sorted(ps.paramNames(topLevelOnly=False)):
131 value = ps.getArray(key)
132 if len(value) == 1:
133 value = value[0]
134 metadata[key] = value
135 return metadata
137 def to_dict(self) -> dict[str, Any]:
138 """Convert the class to a simple dictionary.
140 Returns
141 -------
142 d : `dict`
143 Simple dictionary that can contain scalar values, array values
144 or other dictionary values.
146 Notes
147 -----
148 Unlike `dict()`, this method hides the model layout and combines
149 scalars, arrays, and other metadata in the same dictionary. Can be
150 used when a simple dictionary is needed. Use
151 `TaskMetadata.from_dict()` to convert it back.
152 """
153 d: dict[str, Any] = {}
154 d.update(self.scalars)
155 d.update(self.arrays)
156 for k, v in self.metadata.items():
157 d[k] = v.to_dict()
158 return d
160 def add(self, name: str, value: Any) -> None:
161 """Store a new value, adding to a list if one already exists.
163 Parameters
164 ----------
165 name : `str`
166 Name of the metadata property.
167 value
168 Metadata property value.
169 """
170 keys = self._getKeys(name)
171 key0 = keys.pop(0)
172 if len(keys) == 0:
173 # If add() is being used, always store the value in the arrays
174 # property as a list. It's likely there will be another call.
175 slot_type, value = self._validate_value(value)
176 if slot_type == "array":
177 pass
178 elif slot_type == "scalar":
179 value = [value]
180 else:
181 raise ValueError("add() can only be used for primitive types or sequences of those types.")
183 if key0 in self.metadata:
184 raise ValueError(f"Can not add() to key '{name}' since that is a TaskMetadata")
186 if key0 in self.scalars:
187 # Convert scalar to array.
188 # MyPy should be able to figure out that List[Union[T1, T2]] is
189 # compatible with Union[List[T1], List[T2]] if the list has
190 # only one element, but it can't.
191 self.arrays[key0] = [self.scalars.pop(key0)] # type: ignore
193 if key0 in self.arrays:
194 # Check that the type is not changing.
195 if (curtype := type(self.arrays[key0][0])) is not (newtype := type(value[0])):
196 raise ValueError(f"Type mismatch in add() -- currently {curtype} but adding {newtype}")
197 self.arrays[key0].extend(value)
198 else:
199 self.arrays[key0] = value
201 return
203 self.metadata[key0].add(".".join(keys), value)
205 def getScalar(self, key: str) -> str | int | float | bool:
206 """Retrieve a scalar item even if the item is a list.
208 Parameters
209 ----------
210 key : `str`
211 Item to retrieve.
213 Returns
214 -------
215 value : `str`, `int`, `float`, or `bool`
216 Either the value associated with the key or, if the key
217 corresponds to a list, the last item in the list.
219 Raises
220 ------
221 KeyError
222 Raised if the item is not found.
223 """
224 # Used in pipe_tasks.
225 # getScalar() is the default behavior for __getitem__.
226 return self[key]
228 def getArray(self, key: str) -> list[Any]:
229 """Retrieve an item as a list even if it is a scalar.
231 Parameters
232 ----------
233 key : `str`
234 Item to retrieve.
236 Returns
237 -------
238 values : `list` of any
239 A list containing the value or values associated with this item.
241 Raises
242 ------
243 KeyError
244 Raised if the item is not found.
245 """
246 keys = self._getKeys(key)
247 key0 = keys.pop(0)
248 if len(keys) == 0:
249 if key0 in self.arrays:
250 return self.arrays[key0]
251 elif key0 in self.scalars:
252 return [self.scalars[key0]]
253 elif key0 in self.metadata:
254 return [self.metadata[key0]]
255 raise KeyError(f"'{key}' not found")
257 try:
258 return self.metadata[key0].getArray(".".join(keys))
259 except KeyError:
260 # Report the correct key.
261 raise KeyError(f"'{key}' not found") from None
263 def names(self, topLevelOnly: bool | None = None) -> set[str]:
264 """Return the hierarchical keys from the metadata.
266 Parameters
267 ----------
268 topLevelOnly : `bool` or `None`, optional
269 This parameter is deprecated and will be removed in the future.
270 If given it can only be `False`. All names in the hierarchy are
271 always returned.
273 Returns
274 -------
275 names : `collections.abc.Set`
276 A set of all keys, including those from the hierarchy and the
277 top-level hierarchy.
278 """
279 if topLevelOnly:
280 raise RuntimeError(
281 "The topLevelOnly parameter is no longer supported and can not have a True value."
282 )
284 if topLevelOnly is False:
285 warnings.warn(
286 "The topLevelOnly parameter is deprecated and is always assumed to be False."
287 " It will be removed completely after v26.",
288 category=FutureWarning,
289 stacklevel=find_outside_stacklevel("lsst.pipe.base"),
290 )
292 names = set()
293 for k, v in self.items():
294 names.add(k) # Always include the current level
295 if isinstance(v, TaskMetadata):
296 names.update({k + "." + item for item in v.names()})
297 return names
299 def paramNames(self, topLevelOnly: bool) -> set[str]:
300 """Return hierarchical names.
302 Parameters
303 ----------
304 topLevelOnly : `bool`
305 Control whether only top-level items are returned or items
306 from the hierarchy.
308 Returns
309 -------
310 paramNames : `set` of `str`
311 If ``topLevelOnly`` is `True`, returns any keys that are not
312 part of a hierarchy. If `False` also returns fully-qualified
313 names from the hierarchy. Keys associated with the top
314 of a hierarchy are never returned.
315 """
316 # Currently used by the verify package.
317 paramNames = set()
318 for k, v in self.items():
319 if isinstance(v, TaskMetadata):
320 if not topLevelOnly:
321 paramNames.update({k + "." + item for item in v.paramNames(topLevelOnly=topLevelOnly)})
322 else:
323 paramNames.add(k)
324 return paramNames
326 @staticmethod
327 def _getKeys(key: str) -> list[str]:
328 """Return the key hierarchy.
330 Parameters
331 ----------
332 key : `str`
333 The key to analyze. Can be dot-separated.
335 Returns
336 -------
337 keys : `list` of `str`
338 The key hierarchy that has been split on ``.``.
340 Raises
341 ------
342 KeyError
343 Raised if the key is not a string.
344 """
345 try:
346 keys = key.split(".")
347 except Exception:
348 raise KeyError(f"Invalid key '{key}': only string keys are allowed") from None
349 return keys
351 def keys(self) -> tuple[str, ...]:
352 """Return the top-level keys."""
353 return tuple(k for k in self)
355 def items(self) -> Iterator[tuple[str, Any]]:
356 """Yield the top-level keys and values."""
357 yield from itertools.chain(self.scalars.items(), self.arrays.items(), self.metadata.items())
359 def __len__(self) -> int:
360 """Return the number of items."""
361 return len(self.scalars) + len(self.arrays) + len(self.metadata)
363 # This is actually a Liskov substitution violation, because
364 # pydantic.BaseModel says __iter__ should return something else. But the
365 # pydantic docs say to do exactly this to in order to make a mapping-like
366 # BaseModel, so that's what we do.
367 def __iter__(self) -> Iterator[str]: # type: ignore
368 """Return an iterator over each key."""
369 # The order of keys is not preserved since items can move
370 # from scalar to array.
371 return itertools.chain(iter(self.scalars), iter(self.arrays), iter(self.metadata))
373 def __getitem__(self, key: str) -> Any:
374 """Retrieve the item associated with the key.
376 Parameters
377 ----------
378 key : `str`
379 The key to retrieve. Can be dot-separated hierarchical.
381 Returns
382 -------
383 value : `TaskMetadata`, `float`, `int`, `bool`, `str`
384 A scalar value. For compatibility with ``PropertySet``, if the key
385 refers to an array, the final element is returned and not the
386 array itself.
388 Raises
389 ------
390 KeyError
391 Raised if the item is not found.
392 """
393 keys = self._getKeys(key)
394 key0 = keys.pop(0)
395 if len(keys) == 0:
396 if key0 in self.scalars:
397 return self.scalars[key0]
398 if key0 in self.metadata:
399 return self.metadata[key0]
400 if key0 in self.arrays:
401 return self.arrays[key0][-1]
402 raise KeyError(f"'{key}' not found")
403 # Hierarchical lookup so the top key can only be in the metadata
404 # property. Trap KeyError and reraise so that the correct key
405 # in the hierarchy is reported.
406 try:
407 # And forward request to that metadata.
408 return self.metadata[key0][".".join(keys)]
409 except KeyError:
410 raise KeyError(f"'{key}' not found") from None
412 def get(self, key: str, default: Any = None) -> Any:
413 """Retrieve the item associated with the key or a default.
415 Parameters
416 ----------
417 key : `str`
418 The key to retrieve. Can be dot-separated hierarchical.
419 default
420 The value to return if the key does not exist.
422 Returns
423 -------
424 value : `TaskMetadata`, `float`, `int`, `bool`, `str`
425 A scalar value. If the key refers to an array, the final element
426 is returned and not the array itself; this is consistent with
427 `__getitem__` and `PropertySet.get`, but not ``to_dict().get``.
428 """
429 try:
430 return self[key]
431 except KeyError:
432 return default
434 def __setitem__(self, key: str, item: Any) -> None:
435 """Store the given item."""
436 keys = self._getKeys(key)
437 key0 = keys.pop(0)
438 if len(keys) == 0:
439 slots: dict[str, dict[str, Any]] = {
440 "array": self.arrays,
441 "scalar": self.scalars,
442 "metadata": self.metadata,
443 }
444 primary: dict[str, Any] | None = None
445 slot_type, item = self._validate_value(item)
446 primary = slots.pop(slot_type, None)
447 if primary is None:
448 raise AssertionError(f"Unknown slot type returned from validator: {slot_type}")
450 # Assign the value to the right place.
451 primary[key0] = item
452 for property in slots.values():
453 # Remove any other entries.
454 property.pop(key0, None)
455 return
457 # This must be hierarchical so forward to the child TaskMetadata.
458 if key0 not in self.metadata:
459 self.metadata[key0] = TaskMetadata()
460 self.metadata[key0][".".join(keys)] = item
462 # Ensure we have cleared out anything with the same name elsewhere.
463 self.scalars.pop(key0, None)
464 self.arrays.pop(key0, None)
466 def __contains__(self, key: str) -> bool:
467 """Determine if the key exists."""
468 keys = self._getKeys(key)
469 key0 = keys.pop(0)
470 if len(keys) == 0:
471 return key0 in self.scalars or key0 in self.arrays or key0 in self.metadata
473 if key0 in self.metadata:
474 return ".".join(keys) in self.metadata[key0]
475 return False
477 def __delitem__(self, key: str) -> None:
478 """Remove the specified item.
480 Raises
481 ------
482 KeyError
483 Raised if the item is not present.
484 """
485 keys = self._getKeys(key)
486 key0 = keys.pop(0)
487 if len(keys) == 0:
488 # MyPy can't figure out that this way to combine the types in the
489 # tuple is the one that matters, and annotating a local variable
490 # helps it out.
491 properties: tuple[dict[str, Any], ...] = (self.scalars, self.arrays, self.metadata)
492 for property in properties:
493 if key0 in property:
494 del property[key0]
495 return
496 raise KeyError(f"'{key}' not found'")
498 try:
499 del self.metadata[key0][".".join(keys)]
500 except KeyError:
501 # Report the correct key.
502 raise KeyError(f"'{key}' not found'") from None
504 def _validate_value(self, value: Any) -> tuple[str, Any]:
505 """Validate the given value.
507 Parameters
508 ----------
509 value : Any
510 Value to check.
512 Returns
513 -------
514 slot_type : `str`
515 The type of value given. Options are "scalar", "array", "metadata".
516 item : Any
517 The item that was given but possibly modified to conform to
518 the slot type.
520 Raises
521 ------
522 ValueError
523 Raised if the value is not a recognized type.
524 """
525 # Test the simplest option first.
526 value_type = type(value)
527 if value_type in _ALLOWED_PRIMITIVE_TYPES:
528 return "scalar", value
530 if isinstance(value, TaskMetadata):
531 return "metadata", value
532 if isinstance(value, Mapping):
533 return "metadata", self.from_dict(value)
535 if _isListLike(value):
536 # For model consistency, need to check that every item in the
537 # list has the same type.
538 value = list(value)
540 type0 = type(value[0])
541 for i in value:
542 if type(i) != type0:
543 raise ValueError(
544 "Type mismatch in supplied list. TaskMetadata requires all"
545 f" elements have same type but see {type(i)} and {type0}."
546 )
548 if type0 not in _ALLOWED_PRIMITIVE_TYPES:
549 # Must check to see if we got numpy floats or something.
550 type_cast: type
551 if isinstance(value[0], numbers.Integral):
552 type_cast = int
553 elif isinstance(value[0], numbers.Real):
554 type_cast = float
555 else:
556 raise ValueError(
557 f"Supplied list has element of type '{type0}'. "
558 "TaskMetadata can only accept primitive types in lists."
559 )
561 value = [type_cast(v) for v in value]
563 return "array", value
565 # Sometimes a numpy number is given.
566 if isinstance(value, numbers.Integral):
567 value = int(value)
568 return "scalar", value
569 if isinstance(value, numbers.Real):
570 value = float(value)
571 return "scalar", value
573 raise ValueError(f"TaskMetadata does not support values of type {value!r}.")
576# Needed because a TaskMetadata can contain a TaskMetadata.
577TaskMetadata.model_rebuild()