Coverage for python/lsst/afw/table/_base.py: 12%

172 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-11-18 02:24 -0800

1# This file is part of afw. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21import numpy as np 

22 

23from lsst.utils import continueClass, TemplateMeta 

24from ._table import BaseRecord, BaseCatalog 

25from ._schema import Key 

26 

27 

28__all__ = ["Catalog"] 

29 

30 

31@continueClass 

32class BaseRecord: # noqa: F811 

33 

34 def extract(self, *patterns, **kwargs): 

35 """Extract a dictionary of {<name>: <field-value>} in which the field 

36 names match the given shell-style glob pattern(s). 

37 

38 Any number of glob patterns may be passed; the result will be the union 

39 of all the result of each glob considered separately. 

40 

41 Parameters 

42 ---------- 

43 items : `dict` 

44 The result of a call to self.schema.extract(); this will be used 

45 instead of doing any new matching, and allows the pattern matching 

46 to be reused to extract values from multiple records. This 

47 keyword is incompatible with any position arguments and the regex, 

48 sub, and ordered keyword arguments. 

49 regex : `str` or `re` pattern object 

50 A regular expression to be used in addition to any glob patterns 

51 passed as positional arguments. Note that this will be compared 

52 with re.match, not re.search. 

53 sub : `str` 

54 A replacement string (see `re.MatchObject.expand`) used to set the 

55 dictionary keys of any fields matched by regex. 

56 ordered : `bool` 

57 If `True`, a `collections.OrderedDict` will be returned instead of 

58 a standard dict, with the order corresponding to the definition 

59 order of the `Schema`. Default is `False`. 

60 """ 

61 d = kwargs.pop("items", None) 

62 if d is None: 

63 d = self.schema.extract(*patterns, **kwargs).copy() 

64 elif kwargs: 

65 kwargsStr = ", ".join(kwargs.keys()) 

66 raise ValueError(f"Unrecognized keyword arguments for extract: {kwargsStr}") 

67 return {name: self.get(schemaItem.key) for name, schemaItem in d.items()} 

68 

69 def __repr__(self): 

70 return f"{type(self)}\n{self}" 

71 

72 

73class Catalog(metaclass=TemplateMeta): 

74 

75 def getColumnView(self): 

76 self._columns = self._getColumnView() 

77 return self._columns 

78 

79 def __getColumns(self): 

80 if not hasattr(self, "_columns") or self._columns is None: 

81 self._columns = self._getColumnView() 

82 return self._columns 

83 columns = property(__getColumns, doc="a column view of the catalog") 

84 

85 def __getitem__(self, key): 

86 """Return the record at index key if key is an integer, 

87 return a column if `key` is a string field name or Key, 

88 or return a subset of the catalog if key is a slice 

89 or boolean NumPy array. 

90 """ 

91 if type(key) is slice: 

92 (start, stop, step) = (key.start, key.stop, key.step) 

93 if step is None: 

94 step = 1 

95 if start is None: 

96 start = 0 

97 if stop is None: 

98 stop = len(self) 

99 return self.subset(start, stop, step) 

100 elif isinstance(key, np.ndarray): 

101 if key.dtype == bool: 

102 return self.subset(key) 

103 raise RuntimeError(f"Unsupported array type for indexing non-contiguous Catalog: {key.dtype}") 

104 elif isinstance(key, Key) or isinstance(key, str): 

105 if not self.isContiguous(): 

106 if isinstance(key, str): 

107 key = self.schema[key].asKey() 

108 array = self._getitem_(key) 

109 # This array doesn't share memory with the Catalog, so don't let it be modified by 

110 # the user who thinks that the Catalog itself is being modified. 

111 # Just be aware that this array can only be passed down to C++ as an ndarray::Array<T const> 

112 # instead of an ordinary ndarray::Array<T>. If pybind isn't letting it down into C++, 

113 # you may have left off the 'const' in the definition. 

114 array.flags.writeable = False 

115 return array 

116 return self.columns[key] 

117 else: 

118 return self._getitem_(key) 

119 

120 def __setitem__(self, key, value): 

121 """If ``key`` is an integer, set ``catalog[key]`` to 

122 ``value``. Otherwise select column ``key`` and set it to 

123 ``value``. 

124 """ 

125 self._columns = None 

126 if isinstance(key, str): 

127 key = self.schema[key].asKey() 

128 if isinstance(key, Key): 

129 if isinstance(key, Key["Flag"]): 

130 self._set_flag(key, value) 

131 else: 

132 self.columns[key] = value 

133 else: 

134 return self.set(key, value) 

135 

136 def __delitem__(self, key): 

137 self._columns = None 

138 if isinstance(key, slice): 

139 self._delslice_(key) 

140 else: 

141 self._delitem_(key) 

142 

143 def append(self, record): 

144 self._columns = None 

145 self._append(record) 

146 

147 def insert(self, key, value): 

148 self._columns = None 

149 self._insert(key, value) 

150 

151 def clear(self): 

152 self._columns = None 

153 self._clear() 

154 

155 def addNew(self): 

156 self._columns = None 

157 return self._addNew() 

158 

159 def cast(self, type_, deep=False): 

160 """Return a copy of the catalog with the given type. 

161 

162 Parameters 

163 ---------- 

164 type_ : 

165 Type of catalog to return. 

166 deep : `bool`, optional 

167 If `True`, clone the table and deep copy all records. 

168 

169 Returns 

170 ------- 

171 copy : 

172 Copy of catalog with the requested type. 

173 """ 

174 if deep: 

175 table = self.table.clone() 

176 table.preallocate(len(self)) 

177 else: 

178 table = self.table 

179 copy = type_(table) 

180 copy.extend(self, deep=deep) 

181 return copy 

182 

183 def copy(self, deep=False): 

184 """ 

185 Copy a catalog (default is not a deep copy). 

186 """ 

187 return self.cast(type(self), deep) 

188 

189 def extend(self, iterable, deep=False, mapper=None): 

190 """Append all records in the given iterable to the catalog. 

191 

192 Parameters 

193 ---------- 

194 iterable : 

195 Any Python iterable containing records. 

196 deep : `bool`, optional 

197 If `True`, the records will be deep-copied; ignored if 

198 mapper is not `None` (that always implies `True`). 

199 mapper : `lsst.afw.table.schemaMapper.SchemaMapper`, optional 

200 Used to translate records. 

201 """ 

202 self._columns = None 

203 # We can't use isinstance here, because the SchemaMapper symbol isn't available 

204 # when this code is part of a subclass of Catalog in another package. 

205 if type(deep).__name__ == "SchemaMapper": 

206 mapper = deep 

207 deep = None 

208 if isinstance(iterable, type(self)): 

209 if mapper is not None: 

210 self._extend(iterable, mapper) 

211 else: 

212 self._extend(iterable, deep) 

213 else: 

214 for record in iterable: 

215 if mapper is not None: 

216 self._append(self.table.copyRecord(record, mapper)) 

217 elif deep: 

218 self._append(self.table.copyRecord(record)) 

219 else: 

220 self._append(record) 

221 

222 def __reduce__(self): 

223 import lsst.afw.fits 

224 return lsst.afw.fits.reduceToFits(self) 

225 

226 def asAstropy(self, cls=None, copy=False, unviewable="copy"): 

227 """Return an astropy.table.Table (or subclass thereof) view into this catalog. 

228 

229 Parameters 

230 ---------- 

231 cls : 

232 Table subclass to use; `None` implies `astropy.table.Table` 

233 itself. Use `astropy.table.QTable` to get Quantity columns. 

234 copy : bool, optional 

235 If `True`, copy data from the LSST catalog to the astropy 

236 table. Not copying is usually faster, but can keep memory 

237 from being freed if columns are later removed from the 

238 Astropy view. 

239 unviewable : `str`, optional 

240 One of the following options (which is ignored if 

241 copy=`True` ), indicating how to handle field types (`str` 

242 and `Flag`) for which views cannot be constructed: 

243 

244 - 'copy' (default): copy only the unviewable fields. 

245 - 'raise': raise ValueError if unviewable fields are present. 

246 - 'skip': do not include unviewable fields in the Astropy Table. 

247 

248 Returns 

249 ------- 

250 cls : `astropy.table.Table` 

251 Astropy view into the catalog. 

252 

253 Raises 

254 ------ 

255 ValueError 

256 Raised if the `unviewable` option is not a known value, or 

257 if the option is 'raise' and an uncopyable field is found. 

258 

259 """ 

260 import astropy.table 

261 if cls is None: 

262 cls = astropy.table.Table 

263 if unviewable not in ("copy", "raise", "skip"): 

264 raise ValueError( 

265 f"'unviewable'={unviewable!r} must be one of 'copy', 'raise', or 'skip'") 

266 ps = self.getMetadata() 

267 meta = ps.toOrderedDict() if ps is not None else None 

268 columns = [] 

269 items = self.schema.extract("*", ordered=True) 

270 for name, item in items.items(): 

271 key = item.key 

272 unit = item.field.getUnits() or None # use None instead of "" when empty 

273 if key.getTypeString() == "String": 

274 if not copy: 

275 if unviewable == "raise": 

276 raise ValueError("Cannot extract string " 

277 "unless copy=True or unviewable='copy' or 'skip'.") 

278 elif unviewable == "skip": 

279 continue 

280 data = np.zeros( 

281 len(self), dtype=np.dtype((str, key.getSize()))) 

282 for i, record in enumerate(self): 

283 data[i] = record.get(key) 

284 elif key.getTypeString() == "Flag": 

285 if not copy: 

286 if unviewable == "raise": 

287 raise ValueError("Cannot extract packed bit columns " 

288 "unless copy=True or unviewable='copy' or 'skip'.") 

289 elif unviewable == "skip": 

290 continue 

291 data = self.columns.get_bool_array(key) 

292 elif key.getTypeString() == "Angle": 

293 data = self.columns.get(key) 

294 unit = "radian" 

295 if copy: 

296 data = data.copy() 

297 elif "Array" in key.getTypeString() and key.isVariableLength(): 

298 # Can't get columns for variable-length array fields. 

299 if unviewable == "raise": 

300 raise ValueError("Cannot extract variable-length array fields unless unviewable='skip'.") 

301 elif unviewable == "skip" or unviewable == "copy": 

302 continue 

303 else: 

304 data = self.columns.get(key) 

305 if copy: 

306 data = data.copy() 

307 columns.append( 

308 astropy.table.Column( 

309 data, 

310 name=name, 

311 unit=unit, 

312 description=item.field.getDoc() 

313 ) 

314 ) 

315 return cls(columns, meta=meta, copy=False) 

316 

317 def __dir__(self): 

318 """ 

319 This custom dir is necessary due to the custom getattr below. 

320 Without it, not all of the methods available are returned with dir. 

321 See DM-7199. 

322 """ 

323 def recursive_get_class_dir(cls): 

324 """ 

325 Return a set containing the names of all methods 

326 for a given class *and* all of its subclasses. 

327 """ 

328 result = set() 

329 if cls.__bases__: 

330 for subcls in cls.__bases__: 

331 result |= recursive_get_class_dir(subcls) 

332 result |= set(cls.__dict__.keys()) 

333 return result 

334 return sorted(set(dir(self.columns)) | set(dir(self.table)) 

335 | recursive_get_class_dir(type(self)) | set(self.__dict__.keys())) 

336 

337 def __getattr__(self, name): 

338 # Catalog forwards unknown method calls to its table and column view 

339 # for convenience. (Feature requested by RHL; complaints about magic 

340 # should be directed to him.) 

341 if name == "_columns": 

342 self._columns = None 

343 return None 

344 try: 

345 return getattr(self.table, name) 

346 except AttributeError: 

347 return getattr(self.columns, name) 

348 

349 def __str__(self): 

350 if self.isContiguous(): 

351 return str(self.asAstropy()) 

352 else: 

353 fields = ' '.join(x.field.getName() for x in self.schema) 

354 return f"Non-contiguous afw.Catalog of {len(self)} rows.\ncolumns: {fields}" 

355 

356 def __repr__(self): 

357 return "%s\n%s" % (type(self), self) 

358 

359 

360Catalog.register("Base", BaseCatalog)