Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from itertools import chain 

23import pickle 

24import tempfile 

25import unittest 

26from lsst.daf.butler import DimensionUniverse 

27 

28from lsst.pipe.base import (QuantumGraph, TaskDef, PipelineTask, PipelineTaskConfig, PipelineTaskConnections, 

29 DatasetTypeName, IncompatibleGraphError) 

30import lsst.pipe.base.connectionTypes as cT 

31from lsst.daf.butler import Quantum, DatasetRef, DataCoordinate, DatasetType, Config 

32from lsst.pex.config import Field 

33from lsst.pipe.base.graph.quantumNode import NodeId, BuildId 

34import lsst.utils.tests 

35 

36 

37class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")): 

38 initOutput = cT.InitOutput(name="Dummy1InitOutput", 

39 storageClass="ExposureF", 

40 doc="n/a") 

41 input = cT.Input(name="Dummy1Input", 

42 storageClass="ExposureF", 

43 doc="n/a", 

44 dimensions=("A", "B")) 

45 output = cT.Output(name="Dummy1Output", 

46 storageClass="ExposureF", 

47 doc="n/a", 

48 dimensions=("A", "B")) 

49 

50 

51class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections): 

52 conf1 = Field(dtype=int, default=1, doc="dummy config") 

53 

54 

55class Dummy1PipelineTask(PipelineTask): 

56 ConfigClass = Dummy1Config 

57 

58 

59class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")): 

60 initInput = cT.InitInput(name="Dummy1InitOutput", 

61 storageClass="ExposureF", 

62 doc="n/a") 

63 initOutput = cT.InitOutput(name="Dummy2InitOutput", 

64 storageClass="ExposureF", 

65 doc="n/a") 

66 input = cT.Input(name="Dummy1Output", 

67 storageClass="ExposureF", 

68 doc="n/a", 

69 dimensions=("A", "B")) 

70 output = cT.Output(name="Dummy2Output", 

71 storageClass="ExposureF", 

72 doc="n/a", 

73 dimensions=("A", "B")) 

74 

75 

76class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections): 

77 conf1 = Field(dtype=int, default=1, doc="dummy config") 

78 

79 

80class Dummy2PipelineTask(PipelineTask): 

81 ConfigClass = Dummy2Config 

82 

83 

84class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")): 

85 initInput = cT.InitInput(name="Dummy2InitOutput", 

86 storageClass="ExposureF", 

87 doc="n/a") 

88 initOutput = cT.InitOutput(name="Dummy3InitOutput", 

89 storageClass="ExposureF", 

90 doc="n/a") 

91 input = cT.Input(name="Dummy2Output", 

92 storageClass="ExposureF", 

93 doc="n/a", 

94 dimensions=("A", "B")) 

95 output = cT.Output(name="Dummy3Output", 

96 storageClass="ExposureF", 

97 doc="n/a", 

98 dimensions=("A", "B")) 

99 

100 

101class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections): 

102 conf1 = Field(dtype=int, default=1, doc="dummy config") 

103 

104 

105class Dummy3PipelineTask(PipelineTask): 

106 ConfigClass = Dummy3Config 

107 

108 

109class QuantumGraphTestCase(unittest.TestCase): 

110 """Tests the various functions of a quantum graph 

111 """ 

112 def setUp(self): 

113 config = Config({ 

114 "version": 1, 

115 "skypix": { 

116 "common": "htm7", 

117 "htm": { 

118 "class": "lsst.sphgeom.HtmPixelization", 

119 "max_level": 24, 

120 } 

121 }, 

122 "elements": { 

123 "A": { 

124 "keys": [{ 

125 "name": "id", 

126 "type": "int", 

127 }], 

128 "storage": { 

129 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

130 }, 

131 }, 

132 "B": { 

133 "keys": [{ 

134 "name": "id", 

135 "type": "int", 

136 }], 

137 "storage": { 

138 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

139 }, 

140 } 

141 }, 

142 "packers": {} 

143 }) 

144 universe = DimensionUniverse(config=config) 

145 # need to make a mapping of TaskDef to set of quantum 

146 quantumMap = {} 

147 tasks = [] 

148 for task, label in ((Dummy1PipelineTask, "R"), (Dummy2PipelineTask, "S"), (Dummy3PipelineTask, "T")): 

149 config = task.ConfigClass() 

150 taskDef = TaskDef(f"__main__.{task.__qualname__}", config, task, label) 

151 tasks.append(taskDef) 

152 quantumSet = set() 

153 connections = taskDef.connections 

154 for a, b in ((1, 2), (3, 4)): 

155 if connections.initInputs: 

156 initInputDSType = DatasetType(connections.initInput.name, 

157 tuple(), 

158 storageClass=connections.initInput.storageClass, 

159 universe=universe) 

160 initRefs = [DatasetRef(initInputDSType, 

161 DataCoordinate.makeEmpty(universe))] 

162 else: 

163 initRefs = None 

164 inputDSType = DatasetType(connections.input.name, 

165 connections.input.dimensions, 

166 storageClass=connections.input.storageClass, 

167 universe=universe, 

168 ) 

169 inputRefs = [DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b}, 

170 universe=universe))] 

171 outputDSType = DatasetType(connections.output.name, 

172 connections.output.dimensions, 

173 storageClass=connections.output.storageClass, 

174 universe=universe, 

175 ) 

176 outputRefs = [DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b}, 

177 universe=universe))] 

178 quantumSet.add( 

179 Quantum(taskName=task.__qualname__, 

180 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe), 

181 taskClass=task, 

182 initInputs=initRefs, 

183 inputs={inputDSType: inputRefs}, 

184 outputs={outputDSType: outputRefs} 

185 ) 

186 ) 

187 quantumMap[taskDef] = quantumSet 

188 self.tasks = tasks 

189 self.quantumMap = quantumMap 

190 self.qGraph = QuantumGraph(quantumMap) 

191 self.universe = universe 

192 

193 def testTaskGraph(self): 

194 for taskDef in self.quantumMap.keys(): 

195 self.assertIn(taskDef, self.qGraph.taskGraph) 

196 

197 def testGraph(self): 

198 graphSet = {q.quantum for q in self.qGraph.graph} 

199 for quantum in chain.from_iterable(self.quantumMap.values()): 

200 self.assertIn(quantum, graphSet) 

201 

202 def testGetQuantumNodeByNodeId(self): 

203 inputQuanta = tuple(self.qGraph.inputQuanta) 

204 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId) 

205 self.assertEqual(node, inputQuanta[0]) 

206 wrongNode = NodeId(15, BuildId("alternative build Id")) 

207 with self.assertRaises(IncompatibleGraphError): 

208 self.qGraph.getQuantumNodeByNodeId(wrongNode) 

209 

210 def testPickle(self): 

211 stringify = pickle.dumps(self.qGraph) 

212 restore: QuantumGraph = pickle.loads(stringify) 

213 # This is a hack for the unit test since the qualified name will be 

214 # different as it will be __main__ here, but qualified to the 

215 # unittest module name when restored 

216 for saved, loaded in zip(self.qGraph._quanta.keys(), 

217 restore._quanta.keys()): 

218 saved.taskName = saved.taskName.split('.')[-1] 

219 loaded.taskName = loaded.taskName.split('.')[-1] 

220 self.assertEqual(self.qGraph, restore) 

221 

222 def testInputQuanta(self): 

223 inputs = {q.quantum for q in self.qGraph.inputQuanta} 

224 self.assertEqual(self.quantumMap[self.tasks[0]], inputs) 

225 

226 def testOutputtQuanta(self): 

227 outputs = {q.quantum for q in self.qGraph.outputQuanta} 

228 self.assertEqual(self.quantumMap[self.tasks[-1]], outputs) 

229 

230 def testLength(self): 

231 self.assertEqual(len(self.qGraph), 6) 

232 

233 def testGetQuantaForTask(self): 

234 for task in self.tasks: 

235 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task]) 

236 

237 def testFindTasksWithInput(self): 

238 self.assertEqual(tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0], 

239 self.tasks[1]) 

240 

241 def testFindTasksWithOutput(self): 

242 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0]) 

243 

244 def testTaskWithDSType(self): 

245 self.assertEqual(set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), 

246 set(self.tasks[:2])) 

247 

248 def testFindTaskDefByName(self): 

249 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], 

250 self.tasks[0]) 

251 

252 def testFindTaskDefByLabel(self): 

253 self.assertEqual(self.qGraph.findTaskDefByLabel("R"), 

254 self.tasks[0]) 

255 

256 def testFindQuantaWIthDSType(self): 

257 self.assertEqual(self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), 

258 self.quantumMap[self.tasks[0]]) 

259 

260 def testAllDatasetTypes(self): 

261 allDatasetTypes = set(self.qGraph.allDatasetTypes) 

262 truth = set() 

263 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections): 

264 for connection in conClass.allConnections.values(): # type: ignore 

265 truth.add(connection.name) 

266 self.assertEqual(allDatasetTypes, truth) 

267 

268 def testSubset(self): 

269 allNodes = list(self.qGraph) 

270 subset = self.qGraph.subset(allNodes[0]) 

271 self.assertEqual(len(subset), 1) 

272 subsetList = list(subset) 

273 self.assertEqual(allNodes[0].quantum, subsetList[0].quantum) 

274 self.assertEqual(self.qGraph._buildId, subset._buildId) 

275 

276 def testIsConnected(self): 

277 # False because there are two quantum chains for two distinct sets of 

278 # dimensions 

279 self.assertFalse(self.qGraph.isConnected) 

280 # make a broken subset 

281 allNodes = list(self.qGraph) 

282 subset = self.qGraph.subset((allNodes[0], allNodes[1])) 

283 # True because we subset to only one chain of graphs 

284 self.assertTrue(subset.isConnected) 

285 

286 def testSubsetToConnected(self): 

287 connectedGraphs = self.qGraph.subsetToConnected() 

288 self.assertEqual(len(connectedGraphs), 2) 

289 self.assertTrue(connectedGraphs[0].isConnected) 

290 self.assertTrue(connectedGraphs[1].isConnected) 

291 

292 self.assertEqual(len(connectedGraphs[0]), 3) 

293 self.assertEqual(len(connectedGraphs[1]), 3) 

294 

295 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1]) 

296 

297 count = 0 

298 for node in self.qGraph: 

299 if connectedGraphs[0].checkQuantumInGraph(node.quantum): 

300 count += 1 

301 if connectedGraphs[1].checkQuantumInGraph(node.quantum): 

302 count += 1 

303 self.assertEqual(len(self.qGraph), count) 

304 

305 self.assertEqual(self.tasks, list(connectedGraphs[0].taskGraph)) 

306 self.assertEqual(self.tasks, list(connectedGraphs[1].taskGraph)) 

307 allNodes = list(self.qGraph) 

308 node = self.qGraph.determineInputsToQuantumNode(allNodes[1]) 

309 self.assertEqual(set([allNodes[0]]), node) 

310 node = self.qGraph.determineInputsToQuantumNode(allNodes[1]) 

311 self.assertEqual(set([allNodes[0]]), node) 

312 

313 def testDetermineOutputsOfQuantumNode(self): 

314 allNodes = list(self.qGraph) 

315 node = next(iter(self.qGraph.determineOutputsOfQuantumNode(allNodes[1]))) 

316 self.assertEqual(allNodes[2], node) 

317 

318 def testDetermineConnectionsOfQuantum(self): 

319 allNodes = list(self.qGraph) 

320 connections = self.qGraph.determineConnectionsOfQuantumNode(allNodes[1]) 

321 self.assertEqual(list(connections), list(self.qGraph.subset(allNodes[:3]))) 

322 

323 def testDetermineAnsestorsOfQuantumNode(self): 

324 allNodes = list(self.qGraph) 

325 ansestors = self.qGraph.determineAncestorsOfQuantumNode(allNodes[2]) 

326 self.assertEqual(list(ansestors), list(self.qGraph.subset(allNodes[:3]))) 

327 

328 def testFindCycle(self): 

329 self.assertFalse(self.qGraph.findCycle()) 

330 

331 def testSaveLoad(self): 

332 with tempfile.TemporaryFile() as tmpFile: 

333 self.qGraph.save(tmpFile) 

334 tmpFile.seek(0) 

335 restore = QuantumGraph.load(tmpFile, self.universe) 

336 # This is a hack for the unit test since the qualified name will be 

337 # different as it will be __main__ here, but qualified to the 

338 # unittest module name when restored 

339 for saved, loaded in zip(self.qGraph._quanta.keys(), 

340 restore._quanta.keys()): 

341 saved.taskName = saved.taskName.split('.')[-1] 

342 loaded.taskName = loaded.taskName.split('.')[-1] 

343 self.assertEqual(self.qGraph, restore) 

344 

345 def testContains(self): 

346 firstNode = next(iter(self.qGraph)) 

347 self.assertIn(firstNode, self.qGraph) 

348 

349 

350class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

351 pass 

352 

353 

354def setup_module(module): 

355 lsst.utils.tests.init() 

356 

357 

358if __name__ == "__main__": 358 ↛ 359line 358 didn't jump to line 359, because the condition on line 358 was never true

359 lsst.utils.tests.init() 

360 unittest.main()