Coverage for python/lsst/afw/table/_base.py : 11%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of afw.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21import numpy as np
23from lsst.utils import continueClass, TemplateMeta
24from ._table import BaseRecord, BaseCatalog
25from ._schema import Key
28__all__ = ["Catalog"]
31@continueClass # noqa: F811
32class BaseRecord:
34 def extract(self, *patterns, **kwargs):
35 """Extract a dictionary of {<name>: <field-value>} in which the field names
36 match the given shell-style glob pattern(s).
38 Any number of glob patterns may be passed; the result will be the union of all
39 the result of each glob considered separately.
41 Parameters
42 ----------
43 items : `dict`
44 The result of a call to self.schema.extract(); this will be used
45 instead of doing any new matching, and allows the pattern matching
46 to be reused to extract values from multiple records. This
47 keyword is incompatible with any position arguments and the regex,
48 sub, and ordered keyword arguments.
49 split : `bool`
50 If `True`, fields with named subfields (e.g. points) will be split
51 into separate items in the dict; instead of {"point":
52 lsst.geom.Point2I(2,3)}, for instance, you'd get {"point.x":
53 2, "point.y": 3}. Default is `False`.
54 regex : `str` or `re` pattern object
55 A regular expression to be used in addition to any glob patterns
56 passed as positional arguments. Note that this will be compared
57 with re.match, not re.search.
58 sub : `str`
59 A replacement string (see `re.MatchObject.expand`) used to set the
60 dictionary keys of any fields matched by regex.
61 ordered : `bool`
62 If `True`, a `collections.OrderedDict` will be returned instead of
63 a standard dict, with the order corresponding to the definition
64 order of the `Schema`. Default is `False`.
65 """
66 d = kwargs.pop("items", None)
67 split = kwargs.pop("split", False)
68 if d is None:
69 d = self.schema.extract(*patterns, **kwargs).copy()
70 elif kwargs:
71 kwargsStr = ", ".join(kwargs.keys())
72 raise ValueError(f"Unrecognized keyword arguments for extract: {kwargsStr}")
73 # must use list because we might be adding/deleting elements
74 for name, schemaItem in list(d.items()):
75 key = schemaItem.key
76 if split and key.HAS_NAMED_SUBFIELDS:
77 for subname, subkey in zip(key.subfields, key.subkeys):
78 d[f"{name}.{subname}"] = self.get(subkey)
79 del d[name]
80 else:
81 d[name] = self.get(schemaItem.key)
82 return d
84 def __repr__(self):
85 return f"{type(self)}\n{self}"
88class Catalog(metaclass=TemplateMeta):
90 def getColumnView(self):
91 self._columns = self._getColumnView()
92 return self._columns
94 def __getColumns(self):
95 if not hasattr(self, "_columns") or self._columns is None:
96 self._columns = self._getColumnView()
97 return self._columns
98 columns = property(__getColumns, doc="a column view of the catalog")
100 def __getitem__(self, key):
101 """Return the record at index key if key is an integer,
102 return a column if `key` is a string field name or Key,
103 or return a subset of the catalog if key is a slice
104 or boolean NumPy array.
105 """
106 if type(key) is slice:
107 (start, stop, step) = (key.start, key.stop, key.step)
108 if step is None:
109 step = 1
110 if start is None:
111 start = 0
112 if stop is None:
113 stop = len(self)
114 return self.subset(start, stop, step)
115 elif isinstance(key, np.ndarray):
116 if key.dtype == bool:
117 return self.subset(key)
118 raise RuntimeError(f"Unsupported array type for indexing non-contiguous Catalog: {key.dtype}")
119 elif isinstance(key, Key) or isinstance(key, str):
120 if not self.isContiguous():
121 if isinstance(key, str):
122 key = self.schema[key].asKey()
123 array = self._getitem_(key)
124 # This array doesn't share memory with the Catalog, so don't let it be modified by
125 # the user who thinks that the Catalog itself is being modified.
126 # Just be aware that this array can only be passed down to C++ as an ndarray::Array<T const>
127 # instead of an ordinary ndarray::Array<T>. If pybind isn't letting it down into C++,
128 # you may have left off the 'const' in the definition.
129 array.flags.writeable = False
130 return array
131 return self.columns[key]
132 else:
133 return self._getitem_(key)
135 def __setitem__(self, key, value):
136 """If ``key`` is an integer, set ``catalog[key]`` to
137 ``value``. Otherwise select column ``key`` and set it to
138 ``value``.
139 """
140 self._columns = None
141 if isinstance(key, str):
142 key = self.schema[key].asKey()
143 if isinstance(key, Key):
144 if isinstance(key, Key["Flag"]):
145 self._set_flag(key, value)
146 else:
147 self.columns[key] = value
148 else:
149 return self.set(key, value)
151 def __delitem__(self, key):
152 self._columns = None
153 if isinstance(key, slice):
154 self._delslice_(key)
155 else:
156 self._delitem_(key)
158 def append(self, record):
159 self._columns = None
160 self._append(record)
162 def insert(self, key, value):
163 self._columns = None
164 self._insert(key, value)
166 def clear(self):
167 self._columns = None
168 self._clear()
170 def addNew(self):
171 self._columns = None
172 return self._addNew()
174 def cast(self, type_, deep=False):
175 """Return a copy of the catalog with the given type.
177 Parameters
178 ----------
179 type_ :
180 Type of catalog to return.
181 deep : `bool`, optional
182 If `True`, clone the table and deep copy all records.
184 Returns
185 -------
186 copy :
187 Copy of catalog with the requested type.
188 """
189 if deep:
190 table = self.table.clone()
191 table.preallocate(len(self))
192 else:
193 table = self.table
194 copy = type_(table)
195 copy.extend(self, deep=deep)
196 return copy
198 def copy(self, deep=False):
199 """
200 Copy a catalog (default is not a deep copy).
201 """
202 return self.cast(type(self), deep)
204 def extend(self, iterable, deep=False, mapper=None):
205 """Append all records in the given iterable to the catalog.
207 Parameters
208 ----------
209 iterable :
210 Any Python iterable containing records.
211 deep : `bool`, optional
212 If `True`, the records will be deep-copied; ignored if
213 mapper is not `None` (that always implies `True`).
214 mapper : `lsst.afw.table.schemaMapper.SchemaMapper`, optional
215 Used to translate records.
216 """
217 self._columns = None
218 # We can't use isinstance here, because the SchemaMapper symbol isn't available
219 # when this code is part of a subclass of Catalog in another package.
220 if type(deep).__name__ == "SchemaMapper":
221 mapper = deep
222 deep = None
223 if isinstance(iterable, type(self)):
224 if mapper is not None:
225 self._extend(iterable, mapper)
226 else:
227 self._extend(iterable, deep)
228 else:
229 for record in iterable:
230 if mapper is not None:
231 self._append(self.table.copyRecord(record, mapper))
232 elif deep:
233 self._append(self.table.copyRecord(record))
234 else:
235 self._append(record)
237 def __reduce__(self):
238 import lsst.afw.fits
239 return lsst.afw.fits.reduceToFits(self)
241 def asAstropy(self, cls=None, copy=False, unviewable="copy"):
242 """Return an astropy.table.Table (or subclass thereof) view into this catalog.
244 Parameters
245 ----------
246 cls :
247 Table subclass to use; `None` implies `astropy.table.Table`
248 itself. Use `astropy.table.QTable` to get Quantity columns.
249 copy : bool, optional
250 If `True`, copy data from the LSST catalog to the astropy
251 table. Not copying is usually faster, but can keep memory
252 from being freed if columns are later removed from the
253 Astropy view.
254 unviewable : `str`, optional
255 One of the following options (which is ignored if
256 copy=`True` ), indicating how to handle field types (`str`
257 and `Flag`) for which views cannot be constructed:
258 - 'copy' (default): copy only the unviewable fields.
259 - 'raise': raise ValueError if unviewable fields are present.
260 - 'skip': do not include unviewable fields in the Astropy Table.
262 Returns
263 -------
264 cls : `astropy.table.Table`
265 Astropy view into the catalog.
267 Raises
268 ------
269 ValueError
270 Raised if the `unviewable` option is not a known value, or
271 if the option is 'raise' and an uncopyable field is found.
273 """
274 import astropy.table
275 if cls is None:
276 cls = astropy.table.Table
277 if unviewable not in ("copy", "raise", "skip"):
278 raise ValueError(
279 f"'unviewable'={unviewable!r} must be one of 'copy', 'raise', or 'skip'")
280 ps = self.getMetadata()
281 meta = ps.toOrderedDict() if ps is not None else None
282 columns = []
283 items = self.schema.extract("*", ordered=True)
284 for name, item in items.items():
285 key = item.key
286 unit = item.field.getUnits() or None # use None instead of "" when empty
287 if key.getTypeString() == "String":
288 if not copy:
289 if unviewable == "raise":
290 raise ValueError("Cannot extract string "
291 "unless copy=True or unviewable='copy' or 'skip'.")
292 elif unviewable == "skip":
293 continue
294 data = np.zeros(
295 len(self), dtype=np.dtype((str, key.getSize())))
296 for i, record in enumerate(self):
297 data[i] = record.get(key)
298 elif key.getTypeString() == "Flag":
299 if not copy:
300 if unviewable == "raise":
301 raise ValueError("Cannot extract packed bit columns "
302 "unless copy=True or unviewable='copy' or 'skip'.")
303 elif unviewable == "skip":
304 continue
305 data = self.columns.get_bool_array(key)
306 elif key.getTypeString() == "Angle":
307 data = self.columns.get(key)
308 unit = "radian"
309 if copy:
310 data = data.copy()
311 elif "Array" in key.getTypeString() and key.isVariableLength():
312 # Can't get columns for variable-length array fields.
313 if unviewable == "raise":
314 raise ValueError("Cannot extract variable-length array fields unless unviewable='skip'.")
315 elif unviewable == "skip" or unviewable == "copy":
316 continue
317 else:
318 data = self.columns.get(key)
319 if copy:
320 data = data.copy()
321 columns.append(
322 astropy.table.Column(
323 data,
324 name=name,
325 unit=unit,
326 description=item.field.getDoc()
327 )
328 )
329 return cls(columns, meta=meta, copy=False)
331 def __dir__(self):
332 """
333 This custom dir is necessary due to the custom getattr below.
334 Without it, not all of the methods available are returned with dir.
335 See DM-7199.
336 """
337 def recursive_get_class_dir(cls):
338 """
339 Return a set containing the names of all methods
340 for a given class *and* all of its subclasses.
341 """
342 result = set()
343 if cls.__bases__:
344 for subcls in cls.__bases__:
345 result |= recursive_get_class_dir(subcls)
346 result |= set(cls.__dict__.keys())
347 return result
348 return sorted(set(dir(self.columns)) | set(dir(self.table))
349 | recursive_get_class_dir(type(self)) | set(self.__dict__.keys()))
351 def __getattr__(self, name):
352 # Catalog forwards unknown method calls to its table and column view
353 # for convenience. (Feature requested by RHL; complaints about magic
354 # should be directed to him.)
355 if name == "_columns":
356 self._columns = None
357 return None
358 try:
359 return getattr(self.table, name)
360 except AttributeError:
361 return getattr(self.columns, name)
363 def __str__(self):
364 if self.isContiguous():
365 return str(self.asAstropy())
366 else:
367 fields = ' '.join(x.field.getName() for x in self.schema)
368 return f"Non-contiguous afw.Catalog of {len(self)} rows.\ncolumns: {fields}"
370 def __repr__(self):
371 return "%s\n%s" % (type(self), self)
374Catalog.register("Base", BaseCatalog)