Coverage for python/lsst/daf/butler/dimensions/_record_set.py: 23%

106 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-25 10:50 +0000

1# This file is part of daf_butler. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30__all__ = ("DimensionRecordSet", "DimensionRecordFactory") 

31 

32from collections.abc import Collection, Iterable, Iterator 

33from typing import TYPE_CHECKING, Any, Protocol, final 

34 

35from ._coordinate import DataCoordinate, DataIdValue 

36from ._records import DimensionRecord 

37 

38if TYPE_CHECKING: 

39 from ._elements import DimensionElement 

40 from ._universe import DimensionUniverse 

41 

42 

43class DimensionRecordFactory(Protocol): 

44 """Protocol for a callback that can be used to create a dimension record 

45 to add to a `DimensionRecordSet` when a search for an existing one fails. 

46 """ 

47 

48 def __call__( 

49 self, record_class: type[DimensionRecord], required_values: tuple[DataIdValue, ...] 

50 ) -> DimensionRecord: 

51 """Make a new `DimensionRecord` instance. 

52 

53 Parameters 

54 ---------- 

55 record_class : `type` [ `DimensionRecord` ] 

56 A concrete `DimensionRecord` subclass. 

57 required_values : `tuple` 

58 Tuple of data ID values, corresponding to 

59 ``record_class.definition.required``. 

60 """ 

61 ... # pragma: no cover 

62 

63 

64def fail_record_lookup( 

65 record_class: type[DimensionRecord], required_values: tuple[DataIdValue, ...] 

66) -> DimensionRecord: 

67 """Raise `LookupError` to indicate that a `DimensionRecord` could not be 

68 found or created. 

69 

70 This is intended for use as the default value for arguments that take a 

71 `DimensionRecordFactory` callback. 

72 

73 Parameters 

74 ---------- 

75 record_class : `type` [ `DimensionRecord` ] 

76 Type of record to create. 

77 required_values : `tuple` 

78 Tuple of data ID required values that are sufficient to identify a 

79 record that exists in the data repository. 

80 

81 Returns 

82 ------- 

83 record : `DimensionRecord` 

84 Never returned; this function always raises `LookupError`. 

85 """ 

86 raise LookupError( 

87 f"No {record_class.definition.name!r} record with data ID " 

88 f"{DataCoordinate.from_required_values(record_class.definition.minimal_group, required_values)}." 

89 ) 

90 

91 

92@final 

93class DimensionRecordSet(Collection[DimensionRecord]): # numpydoc ignore=PR01 

94 """A mutable set-like container specialized for `DimensionRecord` objects. 

95 

96 Parameters 

97 ---------- 

98 element : `DimensionElement` or `str`, optional 

99 The dimension element that defines the records held by this set. If 

100 not a `DimensionElement` instance, ``universe`` must be provided. 

101 records : `~collections.abc.Iterable` [ `DimensionRecord` ], optional 

102 Dimension records to add to the set. 

103 universe : `DimensionUniverse`, optional 

104 Object that defines all dimensions. Ignored if ``element`` is a 

105 `DimensionElement` instance. 

106 

107 Notes 

108 ----- 

109 `DimensionRecordSet` maintains its insertion order (like `dict`, and unlike 

110 `set`). 

111 

112 `DimensionRecordSet` implements `collections.abc.Collection` but not 

113 `collections.abc.Set` because the latter would require interoperability 

114 with all other `~collections.abc.Set` implementations rather than just 

115 `DimensionRecordSet`, and that adds a lot of complexity without much clear 

116 value. To help make this clear to type checkers it implements only the 

117 named-method versions of these operations (e.g. `issubset`) rather than the 

118 operator special methods (e.g. ``__le__``). 

119 

120 `DimensionRecord` equality is defined in terms of a record's data ID fields 

121 only, and `DimensionRecordSet` does not generally specify which record 

122 "wins" when two records with the same data ID interact (e.g. in 

123 `intersection`). The `add` and `update` methods are notable exceptions: 

124 they always replace the existing record with the new one. 

125 

126 Dimension records can also be held by `DimensionRecordTable`, which 

127 provides column-oriented access and Arrow interoperability. 

128 """ 

129 

130 def __init__( 

131 self, 

132 element: DimensionElement | str, 

133 records: Iterable[DimensionRecord] = (), 

134 universe: DimensionUniverse | None = None, 

135 *, 

136 _by_required_values: dict[tuple[DataIdValue, ...], DimensionRecord] | None = None, 

137 ): 

138 if isinstance(element, str): 

139 if universe is None: 

140 raise TypeError("'universe' must be provided if 'element' is not a DimensionElement.") 

141 element = universe[element] 

142 else: 

143 universe = element.universe 

144 if _by_required_values is None: 

145 _by_required_values = {} 

146 self._record_type = element.RecordClass 

147 self._by_required_values = _by_required_values 

148 self._dimensions = element.minimal_group 

149 self.update(records) 

150 

151 @property 

152 def element(self) -> DimensionElement: 

153 """Name of the dimension element these records correspond to.""" 

154 return self._record_type.definition 

155 

156 def __contains__(self, key: object) -> bool: 

157 match key: 

158 case DimensionRecord() if key.definition == self.element: 

159 required_values = key.dataId.required_values 

160 case DataCoordinate() if key.dimensions == self.element.minimal_group: 

161 required_values = key.required_values 

162 case _: 

163 return False 

164 return required_values in self._by_required_values 

165 

166 def __len__(self) -> int: 

167 return len(self._by_required_values) 

168 

169 def __iter__(self) -> Iterator[DimensionRecord]: 

170 return iter(self._by_required_values.values()) 

171 

172 def __eq__(self, other: object) -> bool: 

173 if not isinstance(other, DimensionRecordSet): 

174 return False 

175 return ( 

176 self._record_type is other._record_type 

177 and self._by_required_values.keys() == other._by_required_values.keys() 

178 ) 

179 

180 def issubset(self, other: DimensionRecordSet) -> bool: 

181 """Test whether all elements in ``self`` are in ``other``. 

182 

183 Parameters 

184 ---------- 

185 other : `DimensionRecordSet` 

186 Another record set with the same record type. 

187 

188 Returns 

189 ------- 

190 issubset ; `bool` 

191 Whether all elements in ``self`` are in ``other``. 

192 """ 

193 if self._record_type is not other._record_type: 

194 raise ValueError( 

195 "Invalid comparison between dimension record sets for elements " 

196 f"{self.element.name!r} and {other.element.name!r}." 

197 ) 

198 return self._by_required_values.keys() <= other._by_required_values.keys() 

199 

200 def issuperset(self, other: DimensionRecordSet) -> bool: 

201 """Test whether all elements in ``other`` are in ``self``. 

202 

203 Parameters 

204 ---------- 

205 other : `DimensionRecordSet` 

206 Another record set with the same record type. 

207 

208 Returns 

209 ------- 

210 issuperset ; `bool` 

211 Whether all elements in ``other`` are in ``self``. 

212 """ 

213 if self._record_type is not other._record_type: 

214 raise ValueError( 

215 "Invalid comparison between dimension record sets for elements " 

216 f"{self.element.name!r} and {other.element.name!r}." 

217 ) 

218 return self._by_required_values.keys() >= other._by_required_values.keys() 

219 

220 def isdisjoint(self, other: DimensionRecordSet) -> bool: 

221 """Test whether the intersection of ``self`` and ``other`` is empty. 

222 

223 Parameters 

224 ---------- 

225 other : `DimensionRecordSet` 

226 Another record set with the same record type. 

227 

228 Returns 

229 ------- 

230 isdisjoint ; `bool` 

231 Whether the intersection of ``self`` and ``other`` is empty. 

232 """ 

233 if self._record_type is not other._record_type: 

234 raise ValueError( 

235 "Invalid comparison between dimension record sets for elements " 

236 f"{self.element.name!r} and {other.element.name!r}." 

237 ) 

238 return self._by_required_values.keys().isdisjoint(other._by_required_values.keys()) 

239 

240 def intersection(self, other: DimensionRecordSet) -> DimensionRecordSet: 

241 """Return a new set with only records that are in both ``self`` and 

242 ``other``. 

243 

244 Parameters 

245 ---------- 

246 other : `DimensionRecordSet` 

247 Another record set with the same record type. 

248 

249 Returns 

250 ------- 

251 intersection : `DimensionRecordSet` 

252 A new record set with all elements in both sets. 

253 """ 

254 if self._record_type is not other._record_type: 

255 raise ValueError( 

256 "Invalid intersection between dimension record sets for elements " 

257 f"{self.element.name!r} and {other.element.name!r}." 

258 ) 

259 return DimensionRecordSet( 

260 self.element, 

261 _by_required_values={ 

262 k: v for k, v in self._by_required_values.items() if k in other._by_required_values 

263 }, 

264 ) 

265 

266 def difference(self, other: DimensionRecordSet) -> DimensionRecordSet: 

267 """Return a new set with only records that are in ``self`` and not in 

268 ``other``. 

269 

270 Parameters 

271 ---------- 

272 other : `DimensionRecordSet` 

273 Another record set with the same record type. 

274 

275 Returns 

276 ------- 

277 difference : `DimensionRecordSet` 

278 A new record set with all elements ``self`` that are not in 

279 ``other``. 

280 """ 

281 if self._record_type is not other._record_type: 

282 raise ValueError( 

283 "Invalid difference between dimension record sets for elements " 

284 f"{self.element.name!r} and {other.element.name!r}." 

285 ) 

286 return DimensionRecordSet( 

287 self.element, 

288 _by_required_values={ 

289 k: v for k, v in self._by_required_values.items() if k not in other._by_required_values 

290 }, 

291 ) 

292 

293 def union(self, other: DimensionRecordSet) -> DimensionRecordSet: 

294 """Return a new set with all records that are either in ``self`` or 

295 ``other``. 

296 

297 Parameters 

298 ---------- 

299 other : `DimensionRecordSet` 

300 Another record set with the same record type. 

301 

302 Returns 

303 ------- 

304 intersection : `DimensionRecordSet` 

305 A new record set with all elements in either set. 

306 """ 

307 if self._record_type is not other._record_type: 

308 raise ValueError( 

309 "Invalid union between dimension record sets for elements " 

310 f"{self.element.name!r} and {other.element.name!r}." 

311 ) 

312 return DimensionRecordSet( 

313 self.element, 

314 _by_required_values=self._by_required_values | other._by_required_values, 

315 ) 

316 

317 def find( 

318 self, 

319 data_id: DataCoordinate, 

320 or_add: DimensionRecordFactory = fail_record_lookup, 

321 ) -> DimensionRecord: 

322 """Return the record with the given data ID. 

323 

324 Parameters 

325 ---------- 

326 data_id : `DataCoordinate` 

327 Data ID to match. 

328 or_add : `DimensionRecordFactory` 

329 Callback that is invoked if no existing record is found, to create 

330 a new record that is added to the set and returned. The return 

331 value of this callback is *not* checked to see if it is a valid 

332 dimension record with the right element and data ID. 

333 

334 Returns 

335 ------- 

336 record : `DimensionRecord` 

337 Matching record. 

338 

339 Raises 

340 ------ 

341 KeyError 

342 Raised if no record with this data ID was found. 

343 ValueError 

344 Raised if the data ID did not have the right dimensions. 

345 """ 

346 if data_id.dimensions != self._dimensions: 

347 raise ValueError( 

348 f"data ID {data_id} has incorrect dimensions for dimension records for {self.element!r}." 

349 ) 

350 return self.find_with_required_values(data_id.required_values, or_add) 

351 

352 def find_with_required_values( 

353 self, required_values: tuple[DataIdValue, ...], or_add: DimensionRecordFactory = fail_record_lookup 

354 ) -> DimensionRecord: 

355 """Return the record whose data ID has the given required values. 

356 

357 Parameters 

358 ---------- 

359 required_values : `tuple` [ `int` or `str` ] 

360 Data ID values to match. 

361 or_add : `DimensionRecordFactory` 

362 Callback that is invoked if no existing record is found, to create 

363 a new record that is added to the set and returned. The return 

364 value of this callback is *not* checked to see if it is a valid 

365 dimension record with the right element and data ID. 

366 

367 Returns 

368 ------- 

369 record : `DimensionRecord` 

370 Matching record. 

371 

372 Raises 

373 ------ 

374 ValueError 

375 Raised if the data ID did not have the right dimensions. 

376 """ 

377 if (result := self._by_required_values.get(required_values)) is None: 

378 result = or_add(self._record_type, required_values) 

379 self._by_required_values[required_values] = result 

380 return result 

381 

382 def add(self, value: DimensionRecord, replace: bool = True) -> None: 

383 """Add a new record to the set. 

384 

385 Parameters 

386 ---------- 

387 value : `DimensionRecord` 

388 Record to add. 

389 replace : `bool`, optional 

390 If `True` (default) replace any existing record with the same data 

391 ID. If `False` the existing record will be kept. 

392 

393 Raises 

394 ------ 

395 ValueError 

396 Raised if ``value.element != self.element``. 

397 """ 

398 if value.definition.name != self.element: 

399 raise ValueError( 

400 f"Cannot add record {value} for {value.definition.name!r} to set for {self.element!r}." 

401 ) 

402 if replace: 

403 self._by_required_values[value.dataId.required_values] = value 

404 else: 

405 self._by_required_values.setdefault(value.dataId.required_values, value) 

406 

407 def update(self, values: Iterable[DimensionRecord], replace: bool = True) -> None: 

408 """Add new records to the set. 

409 

410 Parameters 

411 ---------- 

412 values : `~collections.abc.Iterable` [ `DimensionRecord` ] 

413 Records to add. 

414 replace : `bool`, optional 

415 If `True` (default) replace any existing records with the same data 

416 IDs. If `False` the existing records will be kept. 

417 

418 Raises 

419 ------ 

420 ValueError 

421 Raised if ``value.element != self.element``. 

422 """ 

423 for value in values: 

424 self.add(value, replace=replace) 

425 

426 def update_from_data_coordinates(self, data_coordinates: Iterable[DataCoordinate]) -> None: 

427 """Add records to the set by extracting and deduplicating them from 

428 data coordinates. 

429 

430 Parameters 

431 ---------- 

432 data_coordinates : `~collections.abc.Iterable` [ `DataCoordinate` ] 

433 Data coordinates to extract from. `DataCoordinate.hasRecords` must 

434 be `True`. 

435 """ 

436 for data_coordinate in data_coordinates: 

437 if record := data_coordinate._record(self.element.name): 

438 self._by_required_values[record.dataId.required_values] = record 

439 

440 def discard(self, value: DimensionRecord | DataCoordinate) -> None: 

441 """Remove a record if it exists. 

442 

443 Parameters 

444 ---------- 

445 value : `DimensionRecord` or `DataCoordinate` 

446 Record to remove, or its data ID. 

447 """ 

448 if isinstance(value, DimensionRecord): 

449 value = value.dataId 

450 if value.dimensions != self._dimensions: 

451 raise ValueError(f"{value} has incorrect dimensions for dimension records for {self.element!r}.") 

452 self._by_required_values.pop(value.required_values, None) 

453 

454 def remove(self, value: DimensionRecord | DataCoordinate) -> None: 

455 """Remove a record. 

456 

457 Parameters 

458 ---------- 

459 value : `DimensionRecord` or `DataCoordinate` 

460 Record to remove, or its data ID. 

461 

462 Raises 

463 ------ 

464 KeyError 

465 Raised if there is no matching record. 

466 """ 

467 if isinstance(value, DimensionRecord): 

468 value = value.dataId 

469 if value.dimensions != self._dimensions: 

470 raise ValueError(f"{value} has incorrect dimensions for dimension records for {self.element!r}.") 

471 del self._by_required_values[value.required_values] 

472 

473 def pop(self) -> DimensionRecord: 

474 """Remove and return an arbitrary record.""" 

475 return self._by_required_values.popitem()[1] 

476 

477 def __deepcopy__(self, memo: dict[str, Any]) -> DimensionRecordSet: 

478 return DimensionRecordSet(self.element, _by_required_values=self._by_required_values.copy())