lsst.pex.config  13.0-2-g483026c+4
 All Classes Namespaces Files Functions Variables Properties Macros Pages
listField.py
Go to the documentation of this file.
1 #
2 # LSST Data Management System
3 # Copyright 2008, 2009, 2010 LSST Corporation.
4 #
5 # This product includes software developed by the
6 # LSST Project (http://www.lsst.org/).
7 #
8 # This program is free software: you can redistribute it and/or modify
9 # it under the terms of the GNU General Public License as published by
10 # the Free Software Foundation, either version 3 of the License, or
11 # (at your option) any later version.
12 #
13 # This program is distributed in the hope that it will be useful,
14 # but WITHOUT ANY WARRANTY; without even the implied warranty of
15 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 # GNU General Public License for more details.
17 #
18 # You should have received a copy of the LSST License Statement and
19 # the GNU General Public License along with this program. If not,
20 # see <http://www.lsstcorp.org/LegalNotices/>.
21 #
22 from builtins import zip
23 from builtins import str
24 from builtins import range
25 
26 import collections
27 
28 from .config import Field, FieldValidationError, _typeStr, _autocast, _joinNamePath
29 from .comparison import compareScalars, getComparisonName
30 from .callStack import getCallStack, getStackFrame
31 
32 __all__ = ["ListField"]
33 
34 
35 class List(collections.MutableSequence):
36  def __init__(self, config, field, value, at, label, setHistory=True):
37  self._field = field
38  self._config = config
39  self._history = self._config._history.setdefault(self._field.name, [])
40  self._list = []
41  self.__doc__ = field.doc
42  if value is not None:
43  try:
44  for i, x in enumerate(value):
45  self.insert(i, x, setHistory=False)
46  except TypeError:
47  msg = "Value %s is of incorrect type %s. Sequence type expected" % (value, _typeStr(value))
48  raise FieldValidationError(self._field, self._config, msg)
49  if setHistory:
50  self.history.append((list(self._list), at, label))
51 
52  def validateItem(self, i, x):
53 
54  if not isinstance(x, self._field.itemtype) and x is not None:
55  msg = "Item at position %d with value %s is of incorrect type %s. Expected %s" % \
56  (i, x, _typeStr(x), _typeStr(self._field.itemtype))
57  raise FieldValidationError(self._field, self._config, msg)
58 
59  if self._field.itemCheck is not None and not self._field.itemCheck(x):
60  msg = "Item at position %d is not a valid value: %s" % (i, x)
61  raise FieldValidationError(self._field, self._config, msg)
62 
63  def list(self):
64  return self._list
65 
66 
67  """
68  Read-only history
69  """
70  history = property(lambda x: x._history)
71 
72  def __contains__(self, x):
73  return x in self._list
74 
75  def __len__(self):
76  return len(self._list)
77 
78  def __setitem__(self, i, x, at=None, label="setitem", setHistory=True):
79  if self._config._frozen:
80  raise FieldValidationError(self._field, self._config,
81  "Cannot modify a frozen Config")
82  if isinstance(i, slice):
83  k, stop, step = i.indices(len(self))
84  for j, xj in enumerate(x):
85  xj = _autocast(xj, self._field.itemtype)
86  self.validateItem(k, xj)
87  x[j] = xj
88  k += step
89  else:
90  x = _autocast(x, self._field.itemtype)
91  self.validateItem(i, x)
92 
93  self._list[i] = x
94  if setHistory:
95  if at is None:
96  at = getCallStack()
97  self.history.append((list(self._list), at, label))
98 
99  def __getitem__(self, i):
100  return self._list[i]
101 
102  def __delitem__(self, i, at=None, label="delitem", setHistory=True):
103  if self._config._frozen:
104  raise FieldValidationError(self._field, self._config,
105  "Cannot modify a frozen Config")
106  del self._list[i]
107  if setHistory:
108  if at is None:
109  at = getCallStack()
110  self.history.append((list(self._list), at, label))
111 
112  def __iter__(self):
113  return iter(self._list)
114 
115  def insert(self, i, x, at=None, label="insert", setHistory=True):
116  if at is None:
117  at = getCallStack()
118  self.__setitem__(slice(i, i), [x], at=at, label=label, setHistory=setHistory)
119 
120  def __repr__(self):
121  return repr(self._list)
122 
123  def __str__(self):
124  return str(self._list)
125 
126  def __eq__(self, other):
127  try:
128  if len(self) != len(other):
129  return False
130 
131  for i, j in zip(self, other):
132  if i != j:
133  return False
134  return True
135  except AttributeError:
136  # other is not a sequence type
137  return False
138 
139  def __ne__(self, other):
140  return not self.__eq__(other)
141 
142  def __setattr__(self, attr, value, at=None, label="assignment"):
143  if hasattr(getattr(self.__class__, attr, None), '__set__'):
144  # This allows properties to work.
145  object.__setattr__(self, attr, value)
146  elif attr in self.__dict__ or attr in ["_field", "_config", "_history", "_list", "__doc__"]:
147  # This allows specific private attributes to work.
148  object.__setattr__(self, attr, value)
149  else:
150  # We throw everything else.
151  msg = "%s has no attribute %s" % (_typeStr(self._field), attr)
152  raise FieldValidationError(self._field, self._config, msg)
153 
154 
155 class ListField(Field):
156  """
157  Defines a field which is a container of values of type dtype
158 
159  If length is not None, then instances of this field must match this length
160  exactly.
161  If minLength is not None, then instances of the field must be no shorter
162  then minLength
163  If maxLength is not None, then instances of the field must be no longer
164  than maxLength
165 
166  Additionally users can provide two check functions:
167  listCheck - used to validate the list as a whole, and
168  itemCheck - used to validate each item individually
169  """
170  def __init__(self, doc, dtype, default=None, optional=False,
171  listCheck=None, itemCheck=None,
172  length=None, minLength=None, maxLength=None):
173  if dtype not in Field.supportedTypes:
174  raise ValueError("Unsupported dtype %s" % _typeStr(dtype))
175  if length is not None:
176  if length <= 0:
177  raise ValueError("'length' (%d) must be positive" % length)
178  minLength = None
179  maxLength = None
180  else:
181  if maxLength is not None and maxLength <= 0:
182  raise ValueError("'maxLength' (%d) must be positive" % maxLength)
183  if minLength is not None and maxLength is not None \
184  and minLength > maxLength:
185  raise ValueError("'maxLength' (%d) must be at least"
186  " as large as 'minLength' (%d)" % (maxLength, minLength))
187 
188  if listCheck is not None and not hasattr(listCheck, "__call__"):
189  raise ValueError("'listCheck' must be callable")
190  if itemCheck is not None and not hasattr(itemCheck, "__call__"):
191  raise ValueError("'itemCheck' must be callable")
192 
193  source = getStackFrame()
194  self._setup(doc=doc, dtype=List, default=default, check=None, optional=optional, source=source)
195  self.listCheck = listCheck
196  self.itemCheck = itemCheck
197  self.itemtype = dtype
198  self.length = length
199  self.minLength = minLength
200  self.maxLength = maxLength
201 
202  def validate(self, instance):
203  """
204  ListField validation ensures that non-optional fields are not None,
205  and that non-None values comply with length requirements and
206  that the list passes listCheck if supplied by the user.
207  Individual Item checks are applied at set time and are not re-checked.
208  """
209  Field.validate(self, instance)
210  value = self.__get__(instance)
211  if value is not None:
212  lenValue = len(value)
213  if self.length is not None and not lenValue == self.length:
214  msg = "Required list length=%d, got length=%d" % (self.length, lenValue)
215  raise FieldValidationError(self, instance, msg)
216  elif self.minLength is not None and lenValue < self.minLength:
217  msg = "Minimum allowed list length=%d, got length=%d" % (self.minLength, lenValue)
218  raise FieldValidationError(self, instance, msg)
219  elif self.maxLength is not None and lenValue > self.maxLength:
220  msg = "Maximum allowed list length=%d, got length=%d" % (self.maxLength, lenValue)
221  raise FieldValidationError(self, instance, msg)
222  elif self.listCheck is not None and not self.listCheck(value):
223  msg = "%s is not a valid value" % str(value)
224  raise FieldValidationError(self, instance, msg)
225 
226  def __set__(self, instance, value, at=None, label="assignment"):
227  if instance._frozen:
228  raise FieldValidationError(self, instance, "Cannot modify a frozen Config")
229 
230  if at is None:
231  at = getCallStack()
232 
233  if value is not None:
234  value = List(instance, self, value, at, label)
235  else:
236  history = instance._history.setdefault(self.name, [])
237  history.append((value, at, label))
238 
239  instance._storage[self.name] = value
240 
241  def toDict(self, instance):
242  value = self.__get__(instance)
243  return list(value) if value is not None else None
244 
245  def _compare(self, instance1, instance2, shortcut, rtol, atol, output):
246  """Helper function for Config.compare; used to compare two fields for equality.
247 
248  @param[in] instance1 LHS Config instance to compare.
249  @param[in] instance2 RHS Config instance to compare.
250  @param[in] shortcut If True, return as soon as an inequality is found.
251  @param[in] rtol Relative tolerance for floating point comparisons.
252  @param[in] atol Absolute tolerance for floating point comparisons.
253  @param[in] output If not None, a callable that takes a string, used (possibly repeatedly)
254  to report inequalities.
255 
256  Floating point comparisons are performed by numpy.allclose; refer to that for details.
257  """
258  l1 = getattr(instance1, self.name)
259  l2 = getattr(instance2, self.name)
260  name = getComparisonName(
261  _joinNamePath(instance1._name, self.name),
262  _joinNamePath(instance2._name, self.name)
263  )
264  if not compareScalars("isnone for %s" % name, l1 is None, l2 is None, output=output):
265  return False
266  if l1 is None and l2 is None:
267  return True
268  if not compareScalars("size for %s" % name, len(l1), len(l2), output=output):
269  return False
270  equal = True
271  for n, v1, v2 in zip(range(len(l1)), l1, l2):
272  result = compareScalars("%s[%d]" % (name, n), v1, v2, dtype=self.dtype,
273  rtol=rtol, atol=atol, output=output)
274  if not result and shortcut:
275  return False
276  equal = equal and result
277  return equal