Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from itertools import chain 

23import pickle 

24import tempfile 

25import unittest 

26from lsst.daf.butler import DimensionUniverse 

27 

28from lsst.pipe.base import (QuantumGraph, TaskDef, PipelineTask, PipelineTaskConfig, PipelineTaskConnections, 

29 DatasetTypeName, IncompatibleGraphError) 

30import lsst.pipe.base.connectionTypes as cT 

31from lsst.daf.butler import Quantum, DatasetRef, DataCoordinate, DatasetType, Config 

32from lsst.pex.config import Field 

33from lsst.pipe.base.graph.quantumNode import NodeId, BuildId 

34import lsst.utils.tests 

35 

36 

37class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")): 

38 initOutput = cT.InitOutput(name="Dummy1InitOutput", 

39 storageClass="ExposureF", 

40 doc="n/a") 

41 input = cT.Input(name="Dummy1Input", 

42 storageClass="ExposureF", 

43 doc="n/a", 

44 dimensions=("A", "B")) 

45 output = cT.Output(name="Dummy1Output", 

46 storageClass="ExposureF", 

47 doc="n/a", 

48 dimensions=("A", "B")) 

49 

50 

51class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections): 

52 conf1 = Field(dtype=int, default=1, doc="dummy config") 

53 

54 

55class Dummy1PipelineTask(PipelineTask): 

56 ConfigClass = Dummy1Config 

57 

58 

59class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")): 

60 initInput = cT.InitInput(name="Dummy1InitOutput", 

61 storageClass="ExposureF", 

62 doc="n/a") 

63 initOutput = cT.InitOutput(name="Dummy2InitOutput", 

64 storageClass="ExposureF", 

65 doc="n/a") 

66 input = cT.Input(name="Dummy1Output", 

67 storageClass="ExposureF", 

68 doc="n/a", 

69 dimensions=("A", "B")) 

70 output = cT.Output(name="Dummy2Output", 

71 storageClass="ExposureF", 

72 doc="n/a", 

73 dimensions=("A", "B")) 

74 

75 

76class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections): 

77 conf1 = Field(dtype=int, default=1, doc="dummy config") 

78 

79 

80class Dummy2PipelineTask(PipelineTask): 

81 ConfigClass = Dummy2Config 

82 

83 

84class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")): 

85 initInput = cT.InitInput(name="Dummy2InitOutput", 

86 storageClass="ExposureF", 

87 doc="n/a") 

88 initOutput = cT.InitOutput(name="Dummy3InitOutput", 

89 storageClass="ExposureF", 

90 doc="n/a") 

91 input = cT.Input(name="Dummy2Output", 

92 storageClass="ExposureF", 

93 doc="n/a", 

94 dimensions=("A", "B")) 

95 output = cT.Output(name="Dummy3Output", 

96 storageClass="ExposureF", 

97 doc="n/a", 

98 dimensions=("A", "B")) 

99 

100 

101class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections): 

102 conf1 = Field(dtype=int, default=1, doc="dummy config") 

103 

104 

105class Dummy3PipelineTask(PipelineTask): 

106 ConfigClass = Dummy3Config 

107 

108 

109class QuantumGraphTestCase(unittest.TestCase): 

110 """Tests the various functions of a quantum graph 

111 """ 

112 def setUp(self): 

113 config = Config({ 

114 "version": 1, 

115 "skypix": { 

116 "common": "htm7", 

117 "htm": { 

118 "class": "lsst.sphgeom.HtmPixelization", 

119 "max_level": 24, 

120 } 

121 }, 

122 "elements": { 

123 "A": { 

124 "keys": [{ 

125 "name": "id", 

126 "type": "int", 

127 }], 

128 }, 

129 "B": { 

130 "keys": [{ 

131 "name": "id", 

132 "type": "int", 

133 }], 

134 } 

135 }, 

136 "packers": {} 

137 }) 

138 universe = DimensionUniverse(config=config) 

139 # need to make a mapping of TaskDef to set of quantum 

140 quantumMap = {} 

141 tasks = [] 

142 for task, label in ((Dummy1PipelineTask, "R"), (Dummy2PipelineTask, "S"), (Dummy3PipelineTask, "T")): 

143 config = task.ConfigClass() 

144 taskDef = TaskDef(f"__main__.{task.__qualname__}", config, task, label) 

145 tasks.append(taskDef) 

146 quantumSet = set() 

147 connections = taskDef.connections 

148 for a, b in ((1, 2), (3, 4)): 

149 if connections.initInputs: 

150 initInputDSType = DatasetType(connections.initInput.name, 

151 tuple(), 

152 storageClass=connections.initInput.storageClass, 

153 universe=universe) 

154 initRefs = [DatasetRef(initInputDSType, 

155 DataCoordinate.makeEmpty(universe))] 

156 else: 

157 initRefs = None 

158 inputDSType = DatasetType(connections.input.name, 

159 connections.input.dimensions, 

160 storageClass=connections.input.storageClass, 

161 universe=universe, 

162 ) 

163 inputRefs = [DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b}, 

164 universe=universe))] 

165 outputDSType = DatasetType(connections.output.name, 

166 connections.output.dimensions, 

167 storageClass=connections.output.storageClass, 

168 universe=universe, 

169 ) 

170 outputRefs = [DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b}, 

171 universe=universe))] 

172 quantumSet.add( 

173 Quantum(taskName=task.__qualname__, 

174 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe), 

175 taskClass=task, 

176 initInputs=initRefs, 

177 inputs={inputDSType: inputRefs}, 

178 outputs={outputDSType: outputRefs} 

179 ) 

180 ) 

181 quantumMap[taskDef] = quantumSet 

182 self.tasks = tasks 

183 self.quantumMap = quantumMap 

184 self.qGraph = QuantumGraph(quantumMap) 

185 self.universe = universe 

186 

187 def testTaskGraph(self): 

188 for taskDef in self.quantumMap.keys(): 

189 self.assertIn(taskDef, self.qGraph.taskGraph) 

190 

191 def testGraph(self): 

192 graphSet = {q.quantum for q in self.qGraph.graph} 

193 for quantum in chain.from_iterable(self.quantumMap.values()): 

194 self.assertIn(quantum, graphSet) 

195 

196 def testGetQuantumNodeByNodeId(self): 

197 inputQuanta = tuple(self.qGraph.inputQuanta) 

198 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId) 

199 self.assertEqual(node, inputQuanta[0]) 

200 wrongNode = NodeId(15, BuildId("alternative build Id")) 

201 with self.assertRaises(IncompatibleGraphError): 

202 self.qGraph.getQuantumNodeByNodeId(wrongNode) 

203 

204 def testPickle(self): 

205 stringify = pickle.dumps(self.qGraph) 

206 restore: QuantumGraph = pickle.loads(stringify) 

207 # This is a hack for the unit test since the qualified name will be 

208 # different as it will be __main__ here, but qualified to the 

209 # unittest module name when restored 

210 for saved, loaded in zip(self.qGraph._quanta.keys(), 

211 restore._quanta.keys()): 

212 saved.taskName = saved.taskName.split('.')[-1] 

213 loaded.taskName = loaded.taskName.split('.')[-1] 

214 self.assertEqual(self.qGraph, restore) 

215 

216 def testInputQuanta(self): 

217 inputs = {q.quantum for q in self.qGraph.inputQuanta} 

218 self.assertEqual(self.quantumMap[self.tasks[0]], inputs) 

219 

220 def testOutputtQuanta(self): 

221 outputs = {q.quantum for q in self.qGraph.outputQuanta} 

222 self.assertEqual(self.quantumMap[self.tasks[-1]], outputs) 

223 

224 def testLength(self): 

225 self.assertEqual(len(self.qGraph), 6) 

226 

227 def testGetQuantaForTask(self): 

228 for task in self.tasks: 

229 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task]) 

230 

231 def testFindTasksWithInput(self): 

232 self.assertEqual(tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0], 

233 self.tasks[1]) 

234 

235 def testFindTasksWithOutput(self): 

236 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0]) 

237 

238 def testTaskWithDSType(self): 

239 self.assertEqual(set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), 

240 set(self.tasks[:2])) 

241 

242 def testFindTaskDefByName(self): 

243 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], 

244 self.tasks[0]) 

245 

246 def testFindTaskDefByLabel(self): 

247 self.assertEqual(self.qGraph.findTaskDefByLabel("R"), 

248 self.tasks[0]) 

249 

250 def testFindQuantaWIthDSType(self): 

251 self.assertEqual(self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), 

252 self.quantumMap[self.tasks[0]]) 

253 

254 def testAllDatasetTypes(self): 

255 allDatasetTypes = set(self.qGraph.allDatasetTypes) 

256 truth = set() 

257 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections): 

258 for connection in conClass.allConnections.values(): # type: ignore 

259 truth.add(connection.name) 

260 self.assertEqual(allDatasetTypes, truth) 

261 

262 def testSubset(self): 

263 allNodes = list(self.qGraph) 

264 subset = self.qGraph.subset(allNodes[0]) 

265 self.assertEqual(len(subset), 1) 

266 subsetList = list(subset) 

267 self.assertEqual(allNodes[0].quantum, subsetList[0].quantum) 

268 self.assertEqual(self.qGraph._buildId, subset._buildId) 

269 

270 def testIsConnected(self): 

271 # False because there are two quantum chains for two distinct sets of 

272 # dimensions 

273 self.assertFalse(self.qGraph.isConnected) 

274 # make a broken subset 

275 allNodes = list(self.qGraph) 

276 subset = self.qGraph.subset((allNodes[0], allNodes[1])) 

277 # True because we subset to only one chain of graphs 

278 self.assertTrue(subset.isConnected) 

279 

280 def testSubsetToConnected(self): 

281 connectedGraphs = self.qGraph.subsetToConnected() 

282 self.assertEqual(len(connectedGraphs), 2) 

283 self.assertTrue(connectedGraphs[0].isConnected) 

284 self.assertTrue(connectedGraphs[1].isConnected) 

285 

286 self.assertEqual(len(connectedGraphs[0]), 3) 

287 self.assertEqual(len(connectedGraphs[1]), 3) 

288 

289 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1]) 

290 

291 count = 0 

292 for node in self.qGraph: 

293 if connectedGraphs[0].checkQuantumInGraph(node.quantum): 

294 count += 1 

295 if connectedGraphs[1].checkQuantumInGraph(node.quantum): 

296 count += 1 

297 self.assertEqual(len(self.qGraph), count) 

298 

299 self.assertEqual(self.tasks, list(connectedGraphs[0].taskGraph)) 

300 self.assertEqual(self.tasks, list(connectedGraphs[1].taskGraph)) 

301 allNodes = list(self.qGraph) 

302 node = self.qGraph.determineInputsToQuantumNode(allNodes[1]) 

303 self.assertEqual(set([allNodes[0]]), node) 

304 node = self.qGraph.determineInputsToQuantumNode(allNodes[1]) 

305 self.assertEqual(set([allNodes[0]]), node) 

306 

307 def testDetermineOutputsOfQuantumNode(self): 

308 allNodes = list(self.qGraph) 

309 node = next(iter(self.qGraph.determineOutputsOfQuantumNode(allNodes[1]))) 

310 self.assertEqual(allNodes[2], node) 

311 

312 def testDetermineConnectionsOfQuantum(self): 

313 allNodes = list(self.qGraph) 

314 connections = self.qGraph.determineConnectionsOfQuantumNode(allNodes[1]) 

315 self.assertEqual(list(connections), list(self.qGraph.subset(allNodes[:3]))) 

316 

317 def testDetermineAnsestorsOfQuantumNode(self): 

318 allNodes = list(self.qGraph) 

319 ansestors = self.qGraph.determineAncestorsOfQuantumNode(allNodes[2]) 

320 self.assertEqual(list(ansestors), list(self.qGraph.subset(allNodes[:3]))) 

321 

322 def testFindCycle(self): 

323 self.assertFalse(self.qGraph.findCycle()) 

324 

325 def testSaveLoad(self): 

326 with tempfile.TemporaryFile() as tmpFile: 

327 self.qGraph.save(tmpFile) 

328 tmpFile.seek(0) 

329 restore = QuantumGraph.load(tmpFile, self.universe) 

330 # This is a hack for the unit test since the qualified name will be 

331 # different as it will be __main__ here, but qualified to the 

332 # unittest module name when restored 

333 for saved, loaded in zip(self.qGraph._quanta.keys(), 

334 restore._quanta.keys()): 

335 saved.taskName = saved.taskName.split('.')[-1] 

336 loaded.taskName = loaded.taskName.split('.')[-1] 

337 self.assertEqual(self.qGraph, restore) 

338 

339 def testContains(self): 

340 firstNode = next(iter(self.qGraph)) 

341 self.assertIn(firstNode, self.qGraph) 

342 

343 

344class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

345 pass 

346 

347 

348def setup_module(module): 

349 lsst.utils.tests.init() 

350 

351 

352if __name__ == "__main__": 352 ↛ 353line 352 didn't jump to line 353, because the condition on line 352 was never true

353 lsst.utils.tests.init() 

354 unittest.main()