Coverage for python/lsst/pipe/base/_task_metadata.py: 16%
207 statements
« prev ^ index » next coverage.py v7.4.2, created at 2024-02-21 10:57 +0000
« prev ^ index » next coverage.py v7.4.2, created at 2024-02-21 10:57 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <https://www.gnu.org/licenses/>.
28__all__ = ["TaskMetadata"]
30import itertools
31import numbers
32import sys
33from collections.abc import Collection, Iterator, Mapping, Sequence
34from typing import Any, Protocol
36from pydantic import BaseModel, Field, StrictBool, StrictFloat, StrictInt, StrictStr
38# The types allowed in a Task metadata field are restricted
39# to allow predictable serialization.
40_ALLOWED_PRIMITIVE_TYPES = (str, float, int, bool)
43class PropertySetLike(Protocol):
44 """Protocol that looks like a ``lsst.daf.base.PropertySet``.
46 Enough of the API is specified to support conversion of a
47 ``PropertySet`` to a `TaskMetadata`.
48 """
50 def paramNames(self, topLevelOnly: bool = True) -> Collection[str]: ... 50 ↛ exitline 50 didn't return from function 'paramNames'
52 def getArray(self, name: str) -> Any: ... 52 ↛ exitline 52 didn't return from function 'getArray'
55def _isListLike(v: Any) -> bool:
56 return isinstance(v, Sequence) and not isinstance(v, str)
59class TaskMetadata(BaseModel):
60 """Dict-like object for storing task metadata.
62 Metadata can be stored at two levels: single task or task plus subtasks.
63 The later is called full metadata of a task and has a form
65 topLevelTaskName:subtaskName:subsubtaskName.itemName
67 Metadata item key of a task (`itemName` above) must not contain `.`,
68 which serves as a separator in full metadata keys and turns
69 the value into sub-dictionary. Arbitrary hierarchies are supported.
70 """
72 scalars: dict[str, StrictFloat | StrictInt | StrictBool | StrictStr] = Field(default_factory=dict)
73 arrays: dict[str, list[StrictFloat] | list[StrictInt] | list[StrictBool] | list[StrictStr]] = Field(
74 default_factory=dict
75 )
76 metadata: dict[str, "TaskMetadata"] = Field(default_factory=dict)
78 @classmethod
79 def from_dict(cls, d: Mapping[str, Any]) -> "TaskMetadata":
80 """Create a TaskMetadata from a dictionary.
82 Parameters
83 ----------
84 d : `~collections.abc.Mapping`
85 Mapping to convert. Can be hierarchical. Any dictionaries
86 in the hierarchy are converted to `TaskMetadata`.
88 Returns
89 -------
90 meta : `TaskMetadata`
91 Newly-constructed metadata.
92 """
93 metadata = cls()
94 for k, v in d.items():
95 metadata[k] = v
96 return metadata
98 @classmethod
99 def from_metadata(cls, ps: PropertySetLike) -> "TaskMetadata":
100 """Create a TaskMetadata from a PropertySet-like object.
102 Parameters
103 ----------
104 ps : `PropertySetLike` or `TaskMetadata`
105 A ``PropertySet``-like object to be transformed to a
106 `TaskMetadata`. A `TaskMetadata` can be copied using this
107 class method.
109 Returns
110 -------
111 tm : `TaskMetadata`
112 Newly-constructed metadata.
114 Notes
115 -----
116 Items stored in single-element arrays in the supplied object
117 will be converted to scalars in the newly-created object.
118 """
119 # Use hierarchical names to assign values from input to output.
120 # This API exists for both PropertySet and TaskMetadata.
121 # from_dict() does not work because PropertySet is not declared
122 # to be a Mapping.
123 # PropertySet.toDict() is not present in TaskMetadata so is best
124 # avoided.
125 metadata = cls()
126 for key in sorted(ps.paramNames(topLevelOnly=False)):
127 value = ps.getArray(key)
128 if len(value) == 1:
129 value = value[0]
130 metadata[key] = value
131 return metadata
133 def to_dict(self) -> dict[str, Any]:
134 """Convert the class to a simple dictionary.
136 Returns
137 -------
138 d : `dict`
139 Simple dictionary that can contain scalar values, array values
140 or other dictionary values.
142 Notes
143 -----
144 Unlike `dict()`, this method hides the model layout and combines
145 scalars, arrays, and other metadata in the same dictionary. Can be
146 used when a simple dictionary is needed. Use
147 `TaskMetadata.from_dict()` to convert it back.
148 """
149 d: dict[str, Any] = {}
150 d.update(self.scalars)
151 d.update(self.arrays)
152 for k, v in self.metadata.items():
153 d[k] = v.to_dict()
154 return d
156 def add(self, name: str, value: Any) -> None:
157 """Store a new value, adding to a list if one already exists.
159 Parameters
160 ----------
161 name : `str`
162 Name of the metadata property.
163 value : `~typing.Any`
164 Metadata property value.
165 """
166 keys = self._getKeys(name)
167 key0 = keys.pop(0)
168 if len(keys) == 0:
169 # If add() is being used, always store the value in the arrays
170 # property as a list. It's likely there will be another call.
171 slot_type, value = self._validate_value(value)
172 if slot_type == "array":
173 pass
174 elif slot_type == "scalar":
175 value = [value]
176 else:
177 raise ValueError("add() can only be used for primitive types or sequences of those types.")
179 if key0 in self.metadata:
180 raise ValueError(f"Can not add() to key '{name}' since that is a TaskMetadata")
182 if key0 in self.scalars:
183 # Convert scalar to array.
184 # MyPy should be able to figure out that List[Union[T1, T2]] is
185 # compatible with Union[List[T1], List[T2]] if the list has
186 # only one element, but it can't.
187 self.arrays[key0] = [self.scalars.pop(key0)] # type: ignore
189 if key0 in self.arrays:
190 # Check that the type is not changing.
191 if (curtype := type(self.arrays[key0][0])) is not (newtype := type(value[0])):
192 raise ValueError(f"Type mismatch in add() -- currently {curtype} but adding {newtype}")
193 self.arrays[key0].extend(value)
194 else:
195 self.arrays[key0] = value
197 return
199 self.metadata[key0].add(".".join(keys), value)
201 def getScalar(self, key: str) -> str | int | float | bool:
202 """Retrieve a scalar item even if the item is a list.
204 Parameters
205 ----------
206 key : `str`
207 Item to retrieve.
209 Returns
210 -------
211 value : `str`, `int`, `float`, or `bool`
212 Either the value associated with the key or, if the key
213 corresponds to a list, the last item in the list.
215 Raises
216 ------
217 KeyError
218 Raised if the item is not found.
219 """
220 # Used in pipe_tasks.
221 # getScalar() is the default behavior for __getitem__.
222 return self[key]
224 def getArray(self, key: str) -> list[Any]:
225 """Retrieve an item as a list even if it is a scalar.
227 Parameters
228 ----------
229 key : `str`
230 Item to retrieve.
232 Returns
233 -------
234 values : `list` of any
235 A list containing the value or values associated with this item.
237 Raises
238 ------
239 KeyError
240 Raised if the item is not found.
241 """
242 keys = self._getKeys(key)
243 key0 = keys.pop(0)
244 if len(keys) == 0:
245 if key0 in self.arrays:
246 return self.arrays[key0]
247 elif key0 in self.scalars:
248 return [self.scalars[key0]]
249 elif key0 in self.metadata:
250 return [self.metadata[key0]]
251 raise KeyError(f"'{key}' not found")
253 try:
254 return self.metadata[key0].getArray(".".join(keys))
255 except KeyError:
256 # Report the correct key.
257 raise KeyError(f"'{key}' not found") from None
259 def names(self) -> set[str]:
260 """Return the hierarchical keys from the metadata.
262 Returns
263 -------
264 names : `collections.abc.Set`
265 A set of all keys, including those from the hierarchy and the
266 top-level hierarchy.
267 """
268 names = set()
269 for k, v in self.items():
270 names.add(k) # Always include the current level
271 if isinstance(v, TaskMetadata):
272 names.update({k + "." + item for item in v.names()})
273 return names
275 def paramNames(self, topLevelOnly: bool) -> set[str]:
276 """Return hierarchical names.
278 Parameters
279 ----------
280 topLevelOnly : `bool`
281 Control whether only top-level items are returned or items
282 from the hierarchy.
284 Returns
285 -------
286 paramNames : `set` of `str`
287 If ``topLevelOnly`` is `True`, returns any keys that are not
288 part of a hierarchy. If `False` also returns fully-qualified
289 names from the hierarchy. Keys associated with the top
290 of a hierarchy are never returned.
291 """
292 # Currently used by the verify package.
293 paramNames = set()
294 for k, v in self.items():
295 if isinstance(v, TaskMetadata):
296 if not topLevelOnly:
297 paramNames.update({k + "." + item for item in v.paramNames(topLevelOnly=topLevelOnly)})
298 else:
299 paramNames.add(k)
300 return paramNames
302 @staticmethod
303 def _getKeys(key: str) -> list[str]:
304 """Return the key hierarchy.
306 Parameters
307 ----------
308 key : `str`
309 The key to analyze. Can be dot-separated.
311 Returns
312 -------
313 keys : `list` of `str`
314 The key hierarchy that has been split on ``.``.
316 Raises
317 ------
318 KeyError
319 Raised if the key is not a string.
320 """
321 try:
322 keys = key.split(".")
323 except Exception:
324 raise KeyError(f"Invalid key '{key}': only string keys are allowed") from None
325 return keys
327 def keys(self) -> tuple[str, ...]:
328 """Return the top-level keys."""
329 return tuple(k for k in self)
331 def items(self) -> Iterator[tuple[str, Any]]:
332 """Yield the top-level keys and values."""
333 yield from itertools.chain(self.scalars.items(), self.arrays.items(), self.metadata.items())
335 def __len__(self) -> int:
336 """Return the number of items."""
337 return len(self.scalars) + len(self.arrays) + len(self.metadata)
339 # This is actually a Liskov substitution violation, because
340 # pydantic.BaseModel says __iter__ should return something else. But the
341 # pydantic docs say to do exactly this to in order to make a mapping-like
342 # BaseModel, so that's what we do.
343 def __iter__(self) -> Iterator[str]: # type: ignore
344 """Return an iterator over each key."""
345 # The order of keys is not preserved since items can move
346 # from scalar to array.
347 return itertools.chain(iter(self.scalars), iter(self.arrays), iter(self.metadata))
349 def __getitem__(self, key: str) -> Any:
350 """Retrieve the item associated with the key.
352 Parameters
353 ----------
354 key : `str`
355 The key to retrieve. Can be dot-separated hierarchical.
357 Returns
358 -------
359 value : `TaskMetadata`, `float`, `int`, `bool`, `str`
360 A scalar value. For compatibility with ``PropertySet``, if the key
361 refers to an array, the final element is returned and not the
362 array itself.
364 Raises
365 ------
366 KeyError
367 Raised if the item is not found.
368 """
369 keys = self._getKeys(key)
370 key0 = keys.pop(0)
371 if len(keys) == 0:
372 if key0 in self.scalars:
373 return self.scalars[key0]
374 if key0 in self.metadata:
375 return self.metadata[key0]
376 if key0 in self.arrays:
377 return self.arrays[key0][-1]
378 raise KeyError(f"'{key}' not found")
379 # Hierarchical lookup so the top key can only be in the metadata
380 # property. Trap KeyError and reraise so that the correct key
381 # in the hierarchy is reported.
382 try:
383 # And forward request to that metadata.
384 return self.metadata[key0][".".join(keys)]
385 except KeyError:
386 raise KeyError(f"'{key}' not found") from None
388 def get(self, key: str, default: Any = None) -> Any:
389 """Retrieve the item associated with the key or a default.
391 Parameters
392 ----------
393 key : `str`
394 The key to retrieve. Can be dot-separated hierarchical.
395 default : `~typing.Any`
396 The value to return if the key does not exist.
398 Returns
399 -------
400 value : `TaskMetadata`, `float`, `int`, `bool`, `str`
401 A scalar value. If the key refers to an array, the final element
402 is returned and not the array itself; this is consistent with
403 `__getitem__` and `PropertySet.get`, but not ``to_dict().get``.
404 """
405 try:
406 return self[key]
407 except KeyError:
408 return default
410 def __setitem__(self, key: str, item: Any) -> None:
411 """Store the given item."""
412 keys = self._getKeys(key)
413 key0 = keys.pop(0)
414 if len(keys) == 0:
415 slots: dict[str, dict[str, Any]] = {
416 "array": self.arrays,
417 "scalar": self.scalars,
418 "metadata": self.metadata,
419 }
420 primary: dict[str, Any] | None = None
421 slot_type, item = self._validate_value(item)
422 primary = slots.pop(slot_type, None)
423 if primary is None:
424 raise AssertionError(f"Unknown slot type returned from validator: {slot_type}")
426 # Assign the value to the right place.
427 primary[key0] = item
428 for property in slots.values():
429 # Remove any other entries.
430 property.pop(key0, None)
431 return
433 # This must be hierarchical so forward to the child TaskMetadata.
434 if key0 not in self.metadata:
435 self.metadata[key0] = TaskMetadata()
436 self.metadata[key0][".".join(keys)] = item
438 # Ensure we have cleared out anything with the same name elsewhere.
439 self.scalars.pop(key0, None)
440 self.arrays.pop(key0, None)
442 def __contains__(self, key: str) -> bool:
443 """Determine if the key exists."""
444 keys = self._getKeys(key)
445 key0 = keys.pop(0)
446 if len(keys) == 0:
447 return key0 in self.scalars or key0 in self.arrays or key0 in self.metadata
449 if key0 in self.metadata:
450 return ".".join(keys) in self.metadata[key0]
451 return False
453 def __delitem__(self, key: str) -> None:
454 """Remove the specified item.
456 Raises
457 ------
458 KeyError
459 Raised if the item is not present.
460 """
461 keys = self._getKeys(key)
462 key0 = keys.pop(0)
463 if len(keys) == 0:
464 # MyPy can't figure out that this way to combine the types in the
465 # tuple is the one that matters, and annotating a local variable
466 # helps it out.
467 properties: tuple[dict[str, Any], ...] = (self.scalars, self.arrays, self.metadata)
468 for property in properties:
469 if key0 in property:
470 del property[key0]
471 return
472 raise KeyError(f"'{key}' not found'")
474 try:
475 del self.metadata[key0][".".join(keys)]
476 except KeyError:
477 # Report the correct key.
478 raise KeyError(f"'{key}' not found'") from None
480 def _validate_value(self, value: Any) -> tuple[str, Any]:
481 """Validate the given value.
483 Parameters
484 ----------
485 value : Any
486 Value to check.
488 Returns
489 -------
490 slot_type : `str`
491 The type of value given. Options are "scalar", "array", "metadata".
492 item : Any
493 The item that was given but possibly modified to conform to
494 the slot type.
496 Raises
497 ------
498 ValueError
499 Raised if the value is not a recognized type.
500 """
501 # Test the simplest option first.
502 value_type = type(value)
503 if value_type in _ALLOWED_PRIMITIVE_TYPES:
504 return "scalar", value
506 if isinstance(value, TaskMetadata):
507 return "metadata", value
508 if isinstance(value, Mapping):
509 return "metadata", self.from_dict(value)
511 if _isListLike(value):
512 # For model consistency, need to check that every item in the
513 # list has the same type.
514 value = list(value)
516 type0 = type(value[0])
517 for i in value:
518 if type(i) != type0:
519 raise ValueError(
520 "Type mismatch in supplied list. TaskMetadata requires all"
521 f" elements have same type but see {type(i)} and {type0}."
522 )
524 if type0 not in _ALLOWED_PRIMITIVE_TYPES:
525 # Must check to see if we got numpy floats or something.
526 type_cast: type
527 if isinstance(value[0], numbers.Integral):
528 type_cast = int
529 elif isinstance(value[0], numbers.Real):
530 type_cast = float
531 else:
532 raise ValueError(
533 f"Supplied list has element of type '{type0}'. "
534 "TaskMetadata can only accept primitive types in lists."
535 )
537 value = [type_cast(v) for v in value]
539 return "array", value
541 # Sometimes a numpy number is given.
542 if isinstance(value, numbers.Integral):
543 value = int(value)
544 return "scalar", value
545 if isinstance(value, numbers.Real):
546 value = float(value)
547 return "scalar", value
549 raise ValueError(f"TaskMetadata does not support values of type {value!r}.")
551 # Work around the fact that Sphinx chokes on Pydantic docstring formatting,
552 # when we inherit those docstrings in our public classes.
553 if "sphinx" in sys.modules: 553 ↛ 555line 553 didn't jump to line 555, because the condition on line 553 was never true
555 def copy(self, *args: Any, **kwargs: Any) -> Any:
556 """See `pydantic.BaseModel.copy`."""
557 return super().copy(*args, **kwargs)
559 def model_dump(self, *args: Any, **kwargs: Any) -> Any:
560 """See `pydantic.BaseModel.model_dump`."""
561 return super().model_dump(*args, **kwargs)
563 def model_copy(self, *args: Any, **kwargs: Any) -> Any:
564 """See `pydantic.BaseModel.model_copy`."""
565 return super().model_copy(*args, **kwargs)
567 @classmethod
568 def model_json_schema(cls, *args: Any, **kwargs: Any) -> Any:
569 """See `pydantic.BaseModel.model_json_schema`."""
570 return super().model_json_schema(*args, **kwargs)
573# Needed because a TaskMetadata can contain a TaskMetadata.
574TaskMetadata.model_rebuild()