Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from __future__ import annotations 

2 

3__all__ = ("ByDimensionsDatasetRecordStorage",) 

4 

5from typing import ( 

6 Any, 

7 Dict, 

8 Iterable, 

9 Iterator, 

10 Optional, 

11 TYPE_CHECKING, 

12) 

13 

14import sqlalchemy 

15 

16from lsst.daf.butler import ( 

17 CollectionType, 

18 DataCoordinate, 

19 DatasetRef, 

20 DatasetType, 

21 SimpleQuery, 

22) 

23from lsst.daf.butler.registry.interfaces import DatasetRecordStorage 

24 

25if TYPE_CHECKING: 25 ↛ 26line 25 didn't jump to line 26, because the condition on line 25 was never true

26 from ...interfaces import CollectionManager, CollectionRecord, Database, RunRecord 

27 from .tables import StaticDatasetTablesTuple 

28 

29 

30class ByDimensionsDatasetRecordStorage(DatasetRecordStorage): 

31 """Dataset record storage implementation paired with 

32 `ByDimensionsDatasetRecordStorageManager`; see that class for more 

33 information. 

34 

35 Instances of this class should never be constructed directly; use 

36 `DatasetRecordStorageManager.register` instead. 

37 """ 

38 def __init__(self, *, datasetType: DatasetType, 

39 db: Database, 

40 dataset_type_id: int, 

41 collections: CollectionManager, 

42 static: StaticDatasetTablesTuple, 

43 dynamic: sqlalchemy.sql.Table): 

44 super().__init__(datasetType=datasetType) 

45 self._dataset_type_id = dataset_type_id 

46 self._db = db 

47 self._collections = collections 

48 self._static = static 

49 self._dynamic = dynamic 

50 self._runKeyColumn = collections.getRunForeignKeyName() 

51 

52 def insert(self, run: RunRecord, dataIds: Iterable[DataCoordinate]) -> Iterator[DatasetRef]: 

53 # Docstring inherited from DatasetRecordStorage. 

54 staticRow = { 

55 "dataset_type_id": self._dataset_type_id, 

56 self._runKeyColumn: run.key, 

57 } 

58 dataIds = list(dataIds) 

59 # Insert into the static dataset table, generating autoincrement 

60 # dataset_id values. 

61 with self._db.transaction(): 

62 datasetIds = self._db.insert(self._static.dataset, *([staticRow]*len(dataIds)), 

63 returnIds=True) 

64 assert datasetIds is not None 

65 # Combine the generated dataset_id values and data ID fields to 

66 # form rows to be inserted into the dynamic table. 

67 protoDynamicRow = { 

68 "dataset_type_id": self._dataset_type_id, 

69 self._collections.getCollectionForeignKeyName(): run.key, 

70 } 

71 dynamicRows = [ 

72 dict(protoDynamicRow, dataset_id=dataset_id, **dataId.byName()) 

73 for dataId, dataset_id in zip(dataIds, datasetIds) 

74 ] 

75 # Insert those rows into the dynamic table. This is where we'll 

76 # get any unique constraint violations. 

77 self._db.insert(self._dynamic, *dynamicRows) 

78 for dataId, datasetId in zip(dataIds, datasetIds): 

79 yield DatasetRef( 

80 datasetType=self.datasetType, 

81 dataId=dataId, 

82 id=datasetId, 

83 run=run.name, 

84 ) 

85 

86 def find(self, collection: CollectionRecord, dataId: DataCoordinate) -> Optional[DatasetRef]: 

87 # Docstring inherited from DatasetRecordStorage. 

88 assert dataId.graph == self.datasetType.dimensions 

89 sql = self.select(collection=collection, dataId=dataId, id=SimpleQuery.Select, 

90 run=SimpleQuery.Select).combine() 

91 row = self._db.query(sql).fetchone() 

92 if row is None: 

93 return None 

94 return DatasetRef( 

95 datasetType=self.datasetType, 

96 dataId=dataId, 

97 id=row["id"], 

98 run=self._collections[row[self._runKeyColumn]].name 

99 ) 

100 

101 def delete(self, datasets: Iterable[DatasetRef]) -> None: 

102 # Docstring inherited from DatasetRecordStorage. 

103 # Only delete from common dataset table; ON DELETE foreign key clauses 

104 # will handle the rest. 

105 self._db.delete( 

106 self._static.dataset, 

107 ["id"], 

108 *[{"id": dataset.getCheckedId()} for dataset in datasets], 

109 ) 

110 

111 def associate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None: 

112 # Docstring inherited from DatasetRecordStorage. 

113 if collection.type is not CollectionType.TAGGED: 113 ↛ 114line 113 didn't jump to line 114, because the condition on line 113 was never true

114 raise TypeError(f"Cannot associate into collection '{collection}' " 

115 f"of type {collection.type.name}; must be TAGGED.") 

116 protoRow = { 

117 self._collections.getCollectionForeignKeyName(): collection.key, 

118 "dataset_type_id": self._dataset_type_id, 

119 } 

120 rows = [] 

121 for dataset in datasets: 

122 row = dict(protoRow, dataset_id=dataset.getCheckedId()) 

123 for dimension, value in dataset.dataId.items(): 

124 row[dimension.name] = value 

125 rows.append(row) 

126 self._db.replace(self._dynamic, *rows) 

127 

128 def disassociate(self, collection: CollectionRecord, datasets: Iterable[DatasetRef]) -> None: 

129 # Docstring inherited from DatasetRecordStorage. 

130 if collection.type is not CollectionType.TAGGED: 130 ↛ 131line 130 didn't jump to line 131, because the condition on line 130 was never true

131 raise TypeError(f"Cannot disassociate from collection '{collection}' " 

132 f"of type {collection.type.name}; must be TAGGED.") 

133 rows = [ 

134 { 

135 "dataset_id": dataset.getCheckedId(), 

136 self._collections.getCollectionForeignKeyName(): collection.key 

137 } 

138 for dataset in datasets 

139 ] 

140 self._db.delete(self._dynamic, ["dataset_id", self._collections.getCollectionForeignKeyName()], 

141 *rows) 

142 

143 def select(self, collection: CollectionRecord, 

144 dataId: SimpleQuery.Select.Or[DataCoordinate] = SimpleQuery.Select, 

145 id: SimpleQuery.Select.Or[Optional[int]] = SimpleQuery.Select, 

146 run: SimpleQuery.Select.Or[None] = SimpleQuery.Select, 

147 ) -> SimpleQuery: 

148 # Docstring inherited from DatasetRecordStorage. 

149 assert collection.type is not CollectionType.CHAINED 

150 query = SimpleQuery() 

151 # We always include the _static.dataset table, and we can always get 

152 # the id and run fields from that; passing them as kwargs here tells 

153 # SimpleQuery to handle them whether they're constraints or results. 

154 # We always constraint the dataset_type_id here as well. 

155 query.join( 

156 self._static.dataset, 

157 id=id, 

158 dataset_type_id=self._dataset_type_id, 

159 **{self._runKeyColumn: run} 

160 ) 

161 # If and only if the collection is a RUN, we constrain it in the static 

162 # table (and also the dynamic table below) 

163 if collection.type is CollectionType.RUN: 

164 query.where.append(self._static.dataset.columns[self._runKeyColumn] 

165 == collection.key) 

166 # We get or constrain the data ID from the dynamic table, but that's 

167 # multiple columns, not one, so we need to transform the one Select.Or 

168 # argument into a dictionary of them. 

169 kwargs: Dict[str, Any] 

170 if dataId is SimpleQuery.Select: 

171 kwargs = {dim.name: SimpleQuery.Select for dim in self.datasetType.dimensions.required} 

172 else: 

173 kwargs = dict(dataId.byName()) 

174 # We always constrain (never retrieve) the collection from the dynamic 

175 # table. 

176 kwargs[self._collections.getCollectionForeignKeyName()] = collection.key 

177 # And now we finally join in the dynamic table. 

178 query.join( 

179 self._dynamic, 

180 onclause=(self._static.dataset.columns.id == self._dynamic.columns.dataset_id), 

181 **kwargs 

182 ) 

183 return query 

184 

185 def getDataId(self, id: int) -> DataCoordinate: 

186 # Docstring inherited from DatasetRecordStorage. 

187 # This query could return multiple rows (one for each tagged collection 

188 # the dataset is in, plus one for its run collection), and we don't 

189 # care which of those we get. 

190 sql = self._dynamic.select().where( 

191 sqlalchemy.sql.and_( 

192 self._dynamic.columns.dataset_id == id, 

193 self._dynamic.columns.dataset_type_id == self._dataset_type_id 

194 ) 

195 ).limit(1) 

196 row = self._db.query(sql).fetchone() 

197 assert row is not None, "Should be guaranteed by caller and foreign key constraints." 

198 return DataCoordinate.standardize( 

199 {dimension.name: row[dimension.name] for dimension in self.datasetType.dimensions.required}, 

200 graph=self.datasetType.dimensions 

201 )