Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from itertools import chain 

23import os 

24import pickle 

25import tempfile 

26import unittest 

27from lsst.daf.butler import DimensionUniverse 

28 

29from lsst.pipe.base import (QuantumGraph, TaskDef, PipelineTask, PipelineTaskConfig, PipelineTaskConnections, 

30 DatasetTypeName, IncompatibleGraphError) 

31import lsst.pipe.base.connectionTypes as cT 

32from lsst.daf.butler import Quantum, DatasetRef, DataCoordinate, DatasetType, Config 

33from lsst.pex.config import Field 

34from lsst.pipe.base.graph.quantumNode import NodeId, BuildId 

35import lsst.utils.tests 

36 

37 

38class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")): 

39 initOutput = cT.InitOutput(name="Dummy1InitOutput", 

40 storageClass="ExposureF", 

41 doc="n/a") 

42 input = cT.Input(name="Dummy1Input", 

43 storageClass="ExposureF", 

44 doc="n/a", 

45 dimensions=("A", "B")) 

46 output = cT.Output(name="Dummy1Output", 

47 storageClass="ExposureF", 

48 doc="n/a", 

49 dimensions=("A", "B")) 

50 

51 

52class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections): 

53 conf1 = Field(dtype=int, default=1, doc="dummy config") 

54 

55 

56class Dummy1PipelineTask(PipelineTask): 

57 ConfigClass = Dummy1Config 

58 

59 

60class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")): 

61 initInput = cT.InitInput(name="Dummy1InitOutput", 

62 storageClass="ExposureF", 

63 doc="n/a") 

64 initOutput = cT.InitOutput(name="Dummy2InitOutput", 

65 storageClass="ExposureF", 

66 doc="n/a") 

67 input = cT.Input(name="Dummy1Output", 

68 storageClass="ExposureF", 

69 doc="n/a", 

70 dimensions=("A", "B")) 

71 output = cT.Output(name="Dummy2Output", 

72 storageClass="ExposureF", 

73 doc="n/a", 

74 dimensions=("A", "B")) 

75 

76 

77class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections): 

78 conf1 = Field(dtype=int, default=1, doc="dummy config") 

79 

80 

81class Dummy2PipelineTask(PipelineTask): 

82 ConfigClass = Dummy2Config 

83 

84 

85class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")): 

86 initInput = cT.InitInput(name="Dummy2InitOutput", 

87 storageClass="ExposureF", 

88 doc="n/a") 

89 initOutput = cT.InitOutput(name="Dummy3InitOutput", 

90 storageClass="ExposureF", 

91 doc="n/a") 

92 input = cT.Input(name="Dummy2Output", 

93 storageClass="ExposureF", 

94 doc="n/a", 

95 dimensions=("A", "B")) 

96 output = cT.Output(name="Dummy3Output", 

97 storageClass="ExposureF", 

98 doc="n/a", 

99 dimensions=("A", "B")) 

100 

101 

102class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections): 

103 conf1 = Field(dtype=int, default=1, doc="dummy config") 

104 

105 

106class Dummy3PipelineTask(PipelineTask): 

107 ConfigClass = Dummy3Config 

108 

109 

110class QuantumGraphTestCase(unittest.TestCase): 

111 """Tests the various functions of a quantum graph 

112 """ 

113 def setUp(self): 

114 config = Config({ 

115 "version": 1, 

116 "skypix": { 

117 "common": "htm7", 

118 "htm": { 

119 "class": "lsst.sphgeom.HtmPixelization", 

120 "max_level": 24, 

121 } 

122 }, 

123 "elements": { 

124 "A": { 

125 "keys": [{ 

126 "name": "id", 

127 "type": "int", 

128 }], 

129 "storage": { 

130 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

131 }, 

132 }, 

133 "B": { 

134 "keys": [{ 

135 "name": "id", 

136 "type": "int", 

137 }], 

138 "storage": { 

139 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

140 }, 

141 } 

142 }, 

143 "packers": {} 

144 }) 

145 universe = DimensionUniverse(config=config) 

146 # need to make a mapping of TaskDef to set of quantum 

147 quantumMap = {} 

148 tasks = [] 

149 for task, label in ((Dummy1PipelineTask, "R"), (Dummy2PipelineTask, "S"), (Dummy3PipelineTask, "T")): 

150 config = task.ConfigClass() 

151 taskDef = TaskDef(f"__main__.{task.__qualname__}", config, task, label) 

152 tasks.append(taskDef) 

153 quantumSet = set() 

154 connections = taskDef.connections 

155 for a, b in ((1, 2), (3, 4)): 

156 if connections.initInputs: 

157 initInputDSType = DatasetType(connections.initInput.name, 

158 tuple(), 

159 storageClass=connections.initInput.storageClass, 

160 universe=universe) 

161 initRefs = [DatasetRef(initInputDSType, 

162 DataCoordinate.makeEmpty(universe))] 

163 else: 

164 initRefs = None 

165 inputDSType = DatasetType(connections.input.name, 

166 connections.input.dimensions, 

167 storageClass=connections.input.storageClass, 

168 universe=universe, 

169 ) 

170 inputRefs = [DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b}, 

171 universe=universe))] 

172 outputDSType = DatasetType(connections.output.name, 

173 connections.output.dimensions, 

174 storageClass=connections.output.storageClass, 

175 universe=universe, 

176 ) 

177 outputRefs = [DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b}, 

178 universe=universe))] 

179 quantumSet.add( 

180 Quantum(taskName=task.__qualname__, 

181 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe), 

182 taskClass=task, 

183 initInputs=initRefs, 

184 inputs={inputDSType: inputRefs}, 

185 outputs={outputDSType: outputRefs} 

186 ) 

187 ) 

188 quantumMap[taskDef] = quantumSet 

189 self.tasks = tasks 

190 self.quantumMap = quantumMap 

191 self.qGraph = QuantumGraph(quantumMap) 

192 self.universe = universe 

193 

194 def _cleanGraphs(self, graph1, graph2): 

195 # This is a hack for the unit test since the qualified name will be 

196 # different as it will be __main__ here, but qualified to the 

197 # unittest module name when restored 

198 # Updates in place 

199 for saved, loaded in zip(graph1._quanta.keys(), 

200 graph2._quanta.keys()): 

201 saved.taskName = saved.taskName.split('.')[-1] 

202 loaded.taskName = loaded.taskName.split('.')[-1] 

203 

204 def testTaskGraph(self): 

205 for taskDef in self.quantumMap.keys(): 

206 self.assertIn(taskDef, self.qGraph.taskGraph) 

207 

208 def testGraph(self): 

209 graphSet = {q.quantum for q in self.qGraph.graph} 

210 for quantum in chain.from_iterable(self.quantumMap.values()): 

211 self.assertIn(quantum, graphSet) 

212 

213 def testGetQuantumNodeByNodeId(self): 

214 inputQuanta = tuple(self.qGraph.inputQuanta) 

215 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId) 

216 self.assertEqual(node, inputQuanta[0]) 

217 wrongNode = NodeId(15, BuildId("alternative build Id")) 

218 with self.assertRaises(IncompatibleGraphError): 

219 self.qGraph.getQuantumNodeByNodeId(wrongNode) 

220 

221 def testPickle(self): 

222 stringify = pickle.dumps(self.qGraph) 

223 restore: QuantumGraph = pickle.loads(stringify) 

224 self._cleanGraphs(self.qGraph, restore) 

225 self.assertEqual(self.qGraph, restore) 

226 

227 def testInputQuanta(self): 

228 inputs = {q.quantum for q in self.qGraph.inputQuanta} 

229 self.assertEqual(self.quantumMap[self.tasks[0]], inputs) 

230 

231 def testOutputtQuanta(self): 

232 outputs = {q.quantum for q in self.qGraph.outputQuanta} 

233 self.assertEqual(self.quantumMap[self.tasks[-1]], outputs) 

234 

235 def testLength(self): 

236 self.assertEqual(len(self.qGraph), 6) 

237 

238 def testGetQuantaForTask(self): 

239 for task in self.tasks: 

240 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task]) 

241 

242 def testFindTasksWithInput(self): 

243 self.assertEqual(tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0], 

244 self.tasks[1]) 

245 

246 def testFindTasksWithOutput(self): 

247 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0]) 

248 

249 def testTaskWithDSType(self): 

250 self.assertEqual(set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), 

251 set(self.tasks[:2])) 

252 

253 def testFindTaskDefByName(self): 

254 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], 

255 self.tasks[0]) 

256 

257 def testFindTaskDefByLabel(self): 

258 self.assertEqual(self.qGraph.findTaskDefByLabel("R"), 

259 self.tasks[0]) 

260 

261 def testFindQuantaWIthDSType(self): 

262 self.assertEqual(self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), 

263 self.quantumMap[self.tasks[0]]) 

264 

265 def testAllDatasetTypes(self): 

266 allDatasetTypes = set(self.qGraph.allDatasetTypes) 

267 truth = set() 

268 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections): 

269 for connection in conClass.allConnections.values(): # type: ignore 

270 truth.add(connection.name) 

271 self.assertEqual(allDatasetTypes, truth) 

272 

273 def testSubset(self): 

274 allNodes = list(self.qGraph) 

275 subset = self.qGraph.subset(allNodes[0]) 

276 self.assertEqual(len(subset), 1) 

277 subsetList = list(subset) 

278 self.assertEqual(allNodes[0].quantum, subsetList[0].quantum) 

279 self.assertEqual(self.qGraph._buildId, subset._buildId) 

280 

281 def testIsConnected(self): 

282 # False because there are two quantum chains for two distinct sets of 

283 # dimensions 

284 self.assertFalse(self.qGraph.isConnected) 

285 # make a broken subset 

286 allNodes = list(self.qGraph) 

287 subset = self.qGraph.subset((allNodes[0], allNodes[1])) 

288 # True because we subset to only one chain of graphs 

289 self.assertTrue(subset.isConnected) 

290 

291 def testSubsetToConnected(self): 

292 connectedGraphs = self.qGraph.subsetToConnected() 

293 self.assertEqual(len(connectedGraphs), 2) 

294 self.assertTrue(connectedGraphs[0].isConnected) 

295 self.assertTrue(connectedGraphs[1].isConnected) 

296 

297 self.assertEqual(len(connectedGraphs[0]), 3) 

298 self.assertEqual(len(connectedGraphs[1]), 3) 

299 

300 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1]) 

301 

302 count = 0 

303 for node in self.qGraph: 

304 if connectedGraphs[0].checkQuantumInGraph(node.quantum): 

305 count += 1 

306 if connectedGraphs[1].checkQuantumInGraph(node.quantum): 

307 count += 1 

308 self.assertEqual(len(self.qGraph), count) 

309 

310 self.assertEqual(self.tasks, list(connectedGraphs[0].taskGraph)) 

311 self.assertEqual(self.tasks, list(connectedGraphs[1].taskGraph)) 

312 allNodes = list(self.qGraph) 

313 node = self.qGraph.determineInputsToQuantumNode(allNodes[1]) 

314 self.assertEqual(set([allNodes[0]]), node) 

315 node = self.qGraph.determineInputsToQuantumNode(allNodes[1]) 

316 self.assertEqual(set([allNodes[0]]), node) 

317 

318 def testDetermineOutputsOfQuantumNode(self): 

319 allNodes = list(self.qGraph) 

320 node = next(iter(self.qGraph.determineOutputsOfQuantumNode(allNodes[1]))) 

321 self.assertEqual(allNodes[2], node) 

322 

323 def testDetermineConnectionsOfQuantum(self): 

324 allNodes = list(self.qGraph) 

325 connections = self.qGraph.determineConnectionsOfQuantumNode(allNodes[1]) 

326 self.assertEqual(list(connections), list(self.qGraph.subset(allNodes[:3]))) 

327 

328 def testDetermineAnsestorsOfQuantumNode(self): 

329 allNodes = list(self.qGraph) 

330 ansestors = self.qGraph.determineAncestorsOfQuantumNode(allNodes[2]) 

331 self.assertEqual(list(ansestors), list(self.qGraph.subset(allNodes[:3]))) 

332 

333 def testFindCycle(self): 

334 self.assertFalse(self.qGraph.findCycle()) 

335 

336 def testSaveLoad(self): 

337 with tempfile.TemporaryFile() as tmpFile: 

338 self.qGraph.save(tmpFile) 

339 tmpFile.seek(0) 

340 restore = QuantumGraph.load(tmpFile, self.universe) 

341 self._cleanGraphs(self.qGraph, restore) 

342 self.assertEqual(self.qGraph, restore) 

343 

344 def testSaveLoadUri(self): 

345 uri = None 

346 try: 

347 with tempfile.NamedTemporaryFile(delete=False, suffix=".pickle") as tmpFile: 

348 uri = tmpFile.name 

349 self.qGraph.saveUri(uri) 

350 restore = QuantumGraph.loadUri(uri, self.universe) 

351 self._cleanGraphs(self.qGraph, restore) 

352 self.assertEqual(self.qGraph, restore) 

353 except Exception as e: 

354 raise e 

355 finally: 

356 if uri is not None: 

357 os.remove(uri) 

358 

359 with self.assertRaises(TypeError): 

360 self.qGraph.saveUri("test.notpickle") 

361 

362 def testContains(self): 

363 firstNode = next(iter(self.qGraph)) 

364 self.assertIn(firstNode, self.qGraph) 

365 

366 

367class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

368 pass 

369 

370 

371def setup_module(module): 

372 lsst.utils.tests.init() 

373 

374 

375if __name__ == "__main__": 375 ↛ 376line 375 didn't jump to line 376, because the condition on line 375 was never true

376 lsst.utils.tests.init() 

377 unittest.main()