Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from itertools import chain 

23import pickle 

24import tempfile 

25import unittest 

26from lsst.daf.butler import DimensionUniverse 

27 

28from lsst.pipe.base import (QuantumGraph, TaskDef, PipelineTask, PipelineTaskConfig, PipelineTaskConnections, 

29 DatasetTypeName, IncompatibleGraphError) 

30import lsst.pipe.base.connectionTypes as cT 

31from lsst.daf.butler import Quantum, DatasetRef, DataCoordinate, DatasetType, Config 

32from lsst.pex.config import Field 

33from lsst.pipe.base.graph.quantumNode import NodeId, BuildId 

34import lsst.utils.tests 

35 

36 

37class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")): 

38 initOutput = cT.InitOutput(name="Dummy1InitOutput", 

39 storageClass="ExposureF", 

40 doc="n/a") 

41 input = cT.Input(name="Dummy1Input", 

42 storageClass="ExposureF", 

43 doc="n/a", 

44 dimensions=("A", "B")) 

45 output = cT.Output(name="Dummy1Output", 

46 storageClass="ExposureF", 

47 doc="n/a", 

48 dimensions=("A", "B")) 

49 

50 

51class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections): 

52 conf1 = Field(dtype=int, default=1, doc="dummy config") 

53 

54 

55class Dummy1PipelineTask(PipelineTask): 

56 ConfigClass = Dummy1Config 

57 

58 

59class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")): 

60 initInput = cT.InitInput(name="Dummy1InitOutput", 

61 storageClass="ExposureF", 

62 doc="n/a") 

63 initOutput = cT.InitOutput(name="Dummy2InitOutput", 

64 storageClass="ExposureF", 

65 doc="n/a") 

66 input = cT.Input(name="Dummy1Output", 

67 storageClass="ExposureF", 

68 doc="n/a", 

69 dimensions=("A", "B")) 

70 output = cT.Output(name="Dummy2Output", 

71 storageClass="ExposureF", 

72 doc="n/a", 

73 dimensions=("A", "B")) 

74 

75 

76class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections): 

77 conf1 = Field(dtype=int, default=1, doc="dummy config") 

78 

79 

80class Dummy2PipelineTask(PipelineTask): 

81 ConfigClass = Dummy2Config 

82 

83 

84class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")): 

85 initInput = cT.InitInput(name="Dummy2InitOutput", 

86 storageClass="ExposureF", 

87 doc="n/a") 

88 initOutput = cT.InitOutput(name="Dummy3InitOutput", 

89 storageClass="ExposureF", 

90 doc="n/a") 

91 input = cT.Input(name="Dummy2Output", 

92 storageClass="ExposureF", 

93 doc="n/a", 

94 dimensions=("A", "B")) 

95 output = cT.Output(name="Dummy3Output", 

96 storageClass="ExposureF", 

97 doc="n/a", 

98 dimensions=("A", "B")) 

99 

100 

101class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections): 

102 conf1 = Field(dtype=int, default=1, doc="dummy config") 

103 

104 

105class Dummy3PipelineTask(PipelineTask): 

106 ConfigClass = Dummy3Config 

107 

108 

109class QuantumGraphTestCase(unittest.TestCase): 

110 """Tests the various functions of a quantum graph 

111 """ 

112 def setUp(self): 

113 config = Config({ 

114 "dimensions": { 

115 "version": 1, 

116 "skypix": {}, 

117 "elements": { 

118 "A": { 

119 "keys": [{ 

120 "name": "id", 

121 "type": "int", 

122 }], 

123 }, 

124 "B": { 

125 "keys": [{ 

126 "name": "id", 

127 "type": "int", 

128 }], 

129 } 

130 } 

131 } 

132 }) 

133 universe = DimensionUniverse(config=config) 

134 # need to make a mapping of TaskDef to set of quantum 

135 quantumMap = {} 

136 tasks = [] 

137 for task, label in ((Dummy1PipelineTask, "R"), (Dummy2PipelineTask, "S"), (Dummy3PipelineTask, "T")): 

138 config = task.ConfigClass() 

139 taskDef = TaskDef(f"__main__.{task.__qualname__}", config, task, label) 

140 tasks.append(taskDef) 

141 quantumSet = set() 

142 connections = taskDef.connections 

143 for a, b in ((1, 2), (3, 4)): 

144 if connections.initInputs: 

145 initInputDSType = DatasetType(connections.initInput.name, 

146 tuple(), 

147 storageClass=connections.initInput.storageClass, 

148 universe=universe) 

149 initRefs = [DatasetRef(initInputDSType, 

150 DataCoordinate.makeEmpty(universe))] 

151 else: 

152 initRefs = None 

153 inputDSType = DatasetType(connections.input.name, 

154 connections.input.dimensions, 

155 storageClass=connections.input.storageClass, 

156 universe=universe, 

157 ) 

158 inputRefs = [DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b}, 

159 universe=universe))] 

160 outputDSType = DatasetType(connections.output.name, 

161 connections.output.dimensions, 

162 storageClass=connections.output.storageClass, 

163 universe=universe, 

164 ) 

165 outputRefs = [DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b}, 

166 universe=universe))] 

167 quantumSet.add( 

168 Quantum(taskName=task.__qualname__, 

169 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe), 

170 taskClass=task, 

171 initInputs=initRefs, 

172 inputs={inputDSType: inputRefs}, 

173 outputs={outputDSType: outputRefs} 

174 ) 

175 ) 

176 quantumMap[taskDef] = quantumSet 

177 self.tasks = tasks 

178 self.quantumMap = quantumMap 

179 self.qGraph = QuantumGraph(quantumMap) 

180 self.universe = universe 

181 

182 def testTaskGraph(self): 

183 for taskDef in self.quantumMap.keys(): 

184 self.assertIn(taskDef, self.qGraph.taskGraph) 

185 

186 def testGraph(self): 

187 graphSet = {q.quantum for q in self.qGraph.graph} 

188 for quantum in chain.from_iterable(self.quantumMap.values()): 

189 self.assertIn(quantum, graphSet) 

190 

191 def testGetQuantumNodeByNodeId(self): 

192 inputQuanta = tuple(self.qGraph.inputQuanta) 

193 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId) 

194 self.assertEqual(node, inputQuanta[0]) 

195 wrongNode = NodeId(15, BuildId("alternative build Id")) 

196 with self.assertRaises(IncompatibleGraphError): 

197 self.qGraph.getQuantumNodeByNodeId(wrongNode) 

198 

199 def testPickle(self): 

200 stringify = pickle.dumps(self.qGraph) 

201 restore: QuantumGraph = pickle.loads(stringify) 

202 # This is a hack for the unit test since the qualified name will be 

203 # different as it will be __main__ here, but qualified to the 

204 # unittest module name when restored 

205 for saved, loaded in zip(self.qGraph._quanta.keys(), 

206 restore._quanta.keys()): 

207 saved.taskName = saved.taskName.split('.')[-1] 

208 loaded.taskName = loaded.taskName.split('.')[-1] 

209 self.assertEqual(self.qGraph, restore) 

210 

211 def testInputQuanta(self): 

212 inputs = {q.quantum for q in self.qGraph.inputQuanta} 

213 self.assertEqual(self.quantumMap[self.tasks[0]], inputs) 

214 

215 def testOutputtQuanta(self): 

216 outputs = {q.quantum for q in self.qGraph.outputQuanta} 

217 self.assertEqual(self.quantumMap[self.tasks[-1]], outputs) 

218 

219 def testLength(self): 

220 self.assertEqual(len(self.qGraph), 6) 

221 

222 def testGetQuantaForTask(self): 

223 for task in self.tasks: 

224 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task]) 

225 

226 def testFindTasksWithInput(self): 

227 self.assertEqual(tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0], 

228 self.tasks[1]) 

229 

230 def testFindTasksWithOutput(self): 

231 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0]) 

232 

233 def testTaskWithDSType(self): 

234 self.assertEqual(set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), 

235 set(self.tasks[:2])) 

236 

237 def testFindTaskDefByName(self): 

238 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], 

239 self.tasks[0]) 

240 

241 def testFindTaskDefByLabel(self): 

242 self.assertEqual(self.qGraph.findTaskDefByLabel("R"), 

243 self.tasks[0]) 

244 

245 def testFindQuantaWIthDSType(self): 

246 self.assertEqual(self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), 

247 self.quantumMap[self.tasks[0]]) 

248 

249 def testAllDatasetTypes(self): 

250 allDatasetTypes = set(self.qGraph.allDatasetTypes) 

251 truth = set() 

252 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections): 

253 for connection in conClass.allConnections.values(): # type: ignore 

254 truth.add(connection.name) 

255 self.assertEqual(allDatasetTypes, truth) 

256 

257 def testSubset(self): 

258 allNodes = list(self.qGraph) 

259 subset = self.qGraph.subset(allNodes[0]) 

260 self.assertEqual(len(subset), 1) 

261 subsetList = list(subset) 

262 self.assertEqual(allNodes[0].quantum, subsetList[0].quantum) 

263 self.assertEqual(self.qGraph._buildId, subset._buildId) 

264 

265 def testIsConnected(self): 

266 # False because there are two quantum chains for two distinct sets of 

267 # dimensions 

268 self.assertFalse(self.qGraph.isConnected) 

269 # make a broken subset 

270 allNodes = list(self.qGraph) 

271 subset = self.qGraph.subset((allNodes[0], allNodes[1])) 

272 # True because we subset to only one chain of graphs 

273 self.assertTrue(subset.isConnected) 

274 

275 def testSubsetToConnected(self): 

276 connectedGraphs = self.qGraph.subsetToConnected() 

277 self.assertEqual(len(connectedGraphs), 2) 

278 self.assertTrue(connectedGraphs[0].isConnected) 

279 self.assertTrue(connectedGraphs[1].isConnected) 

280 

281 self.assertEqual(len(connectedGraphs[0]), 3) 

282 self.assertEqual(len(connectedGraphs[1]), 3) 

283 

284 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1]) 

285 

286 count = 0 

287 for node in self.qGraph: 

288 if connectedGraphs[0].checkQuantumInGraph(node.quantum): 

289 count += 1 

290 if connectedGraphs[1].checkQuantumInGraph(node.quantum): 

291 count += 1 

292 self.assertEqual(len(self.qGraph), count) 

293 

294 self.assertEqual(self.tasks, list(connectedGraphs[0].taskGraph)) 

295 self.assertEqual(self.tasks, list(connectedGraphs[1].taskGraph)) 

296 allNodes = list(self.qGraph) 

297 node = self.qGraph.determineInputsToQuantumNode(allNodes[1]) 

298 self.assertEqual(set([allNodes[0]]), node) 

299 node = self.qGraph.determineInputsToQuantumNode(allNodes[1]) 

300 self.assertEqual(set([allNodes[0]]), node) 

301 

302 def testDetermineOutputsOfQuantumNode(self): 

303 allNodes = list(self.qGraph) 

304 node = next(iter(self.qGraph.determineOutputsOfQuantumNode(allNodes[1]))) 

305 self.assertEqual(allNodes[2], node) 

306 

307 def testDetermineConnectionsOfQuantum(self): 

308 allNodes = list(self.qGraph) 

309 connections = self.qGraph.determineConnectionsOfQuantumNode(allNodes[1]) 

310 self.assertEqual(list(connections), list(self.qGraph.subset(allNodes[:3]))) 

311 

312 def testDetermineAnsestorsOfQuantumNode(self): 

313 allNodes = list(self.qGraph) 

314 ansestors = self.qGraph.determineAncestorsOfQuantumNode(allNodes[2]) 

315 self.assertEqual(list(ansestors), list(self.qGraph.subset(allNodes[:3]))) 

316 

317 def testFindCycle(self): 

318 self.assertFalse(self.qGraph.findCycle()) 

319 

320 def testSaveLoad(self): 

321 with tempfile.TemporaryFile() as tmpFile: 

322 self.qGraph.save(tmpFile) 

323 tmpFile.seek(0) 

324 restore = QuantumGraph.load(tmpFile, self.universe) 

325 # This is a hack for the unit test since the qualified name will be 

326 # different as it will be __main__ here, but qualified to the 

327 # unittest module name when restored 

328 for saved, loaded in zip(self.qGraph._quanta.keys(), 

329 restore._quanta.keys()): 

330 saved.taskName = saved.taskName.split('.')[-1] 

331 loaded.taskName = loaded.taskName.split('.')[-1] 

332 self.assertEqual(self.qGraph, restore) 

333 

334 def testContains(self): 

335 firstNode = next(iter(self.qGraph)) 

336 self.assertIn(firstNode, self.qGraph) 

337 

338 

339class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

340 pass 

341 

342 

343def setup_module(module): 

344 lsst.utils.tests.init() 

345 

346 

347if __name__ == "__main__": 347 ↛ 348line 347 didn't jump to line 348, because the condition on line 347 was never true

348 lsst.utils.tests.init() 

349 unittest.main()