Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from itertools import chain 

23import os 

24import pickle 

25import tempfile 

26from typing import Iterable 

27import unittest 

28import random 

29from lsst.daf.butler import DimensionUniverse 

30 

31from lsst.pipe.base import (QuantumGraph, TaskDef, PipelineTask, PipelineTaskConfig, PipelineTaskConnections, 

32 DatasetTypeName, IncompatibleGraphError) 

33import lsst.pipe.base.connectionTypes as cT 

34from lsst.daf.butler import Quantum, DatasetRef, DataCoordinate, DatasetType, Config 

35from lsst.pex.config import Field 

36from lsst.pipe.base.graph.quantumNode import NodeId, BuildId, QuantumNode 

37import lsst.utils.tests 

38 

39try: 

40 import boto3 

41 from moto import mock_s3 

42except ImportError: 

43 boto3 = None 

44 

45 def mock_s3(cls): 

46 """A no-op decorator in case moto mock_s3 can not be imported. 

47 """ 

48 return cls 

49 

50 

51class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")): 

52 initOutput = cT.InitOutput(name="Dummy1InitOutput", 

53 storageClass="ExposureF", 

54 doc="n/a") 

55 input = cT.Input(name="Dummy1Input", 

56 storageClass="ExposureF", 

57 doc="n/a", 

58 dimensions=("A", "B")) 

59 output = cT.Output(name="Dummy1Output", 

60 storageClass="ExposureF", 

61 doc="n/a", 

62 dimensions=("A", "B")) 

63 

64 

65class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections): 

66 conf1 = Field(dtype=int, default=1, doc="dummy config") 

67 

68 

69class Dummy1PipelineTask(PipelineTask): 

70 ConfigClass = Dummy1Config 

71 

72 

73class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")): 

74 initInput = cT.InitInput(name="Dummy1InitOutput", 

75 storageClass="ExposureF", 

76 doc="n/a") 

77 initOutput = cT.InitOutput(name="Dummy2InitOutput", 

78 storageClass="ExposureF", 

79 doc="n/a") 

80 input = cT.Input(name="Dummy1Output", 

81 storageClass="ExposureF", 

82 doc="n/a", 

83 dimensions=("A", "B")) 

84 output = cT.Output(name="Dummy2Output", 

85 storageClass="ExposureF", 

86 doc="n/a", 

87 dimensions=("A", "B")) 

88 

89 

90class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections): 

91 conf1 = Field(dtype=int, default=1, doc="dummy config") 

92 

93 

94class Dummy2PipelineTask(PipelineTask): 

95 ConfigClass = Dummy2Config 

96 

97 

98class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")): 

99 initInput = cT.InitInput(name="Dummy2InitOutput", 

100 storageClass="ExposureF", 

101 doc="n/a") 

102 initOutput = cT.InitOutput(name="Dummy3InitOutput", 

103 storageClass="ExposureF", 

104 doc="n/a") 

105 input = cT.Input(name="Dummy2Output", 

106 storageClass="ExposureF", 

107 doc="n/a", 

108 dimensions=("A", "B")) 

109 output = cT.Output(name="Dummy3Output", 

110 storageClass="ExposureF", 

111 doc="n/a", 

112 dimensions=("A", "B")) 

113 

114 

115class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections): 

116 conf1 = Field(dtype=int, default=1, doc="dummy config") 

117 

118 

119class Dummy3PipelineTask(PipelineTask): 

120 ConfigClass = Dummy3Config 

121 

122 

123class QuantumGraphTestCase(unittest.TestCase): 

124 """Tests the various functions of a quantum graph 

125 """ 

126 def setUp(self): 

127 config = Config({ 

128 "version": 1, 

129 "skypix": { 

130 "common": "htm7", 

131 "htm": { 

132 "class": "lsst.sphgeom.HtmPixelization", 

133 "max_level": 24, 

134 } 

135 }, 

136 "elements": { 

137 "A": { 

138 "keys": [{ 

139 "name": "id", 

140 "type": "int", 

141 }], 

142 "storage": { 

143 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

144 }, 

145 }, 

146 "B": { 

147 "keys": [{ 

148 "name": "id", 

149 "type": "int", 

150 }], 

151 "storage": { 

152 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

153 }, 

154 } 

155 }, 

156 "packers": {} 

157 }) 

158 universe = DimensionUniverse(config=config) 

159 # need to make a mapping of TaskDef to set of quantum 

160 quantumMap = {} 

161 tasks = [] 

162 for task, label in ((Dummy1PipelineTask, "R"), (Dummy2PipelineTask, "S"), (Dummy3PipelineTask, "T")): 

163 config = task.ConfigClass() 

164 taskDef = TaskDef(f"__main__.{task.__qualname__}", config, task, label) 

165 tasks.append(taskDef) 

166 quantumSet = set() 

167 connections = taskDef.connections 

168 for a, b in ((1, 2), (3, 4)): 

169 if connections.initInputs: 

170 initInputDSType = DatasetType(connections.initInput.name, 

171 tuple(), 

172 storageClass=connections.initInput.storageClass, 

173 universe=universe) 

174 initRefs = [DatasetRef(initInputDSType, 

175 DataCoordinate.makeEmpty(universe))] 

176 else: 

177 initRefs = None 

178 inputDSType = DatasetType(connections.input.name, 

179 connections.input.dimensions, 

180 storageClass=connections.input.storageClass, 

181 universe=universe, 

182 ) 

183 inputRefs = [DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b}, 

184 universe=universe))] 

185 outputDSType = DatasetType(connections.output.name, 

186 connections.output.dimensions, 

187 storageClass=connections.output.storageClass, 

188 universe=universe, 

189 ) 

190 outputRefs = [DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b}, 

191 universe=universe))] 

192 quantumSet.add( 

193 Quantum(taskName=task.__qualname__, 

194 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe), 

195 taskClass=task, 

196 initInputs=initRefs, 

197 inputs={inputDSType: inputRefs}, 

198 outputs={outputDSType: outputRefs} 

199 ) 

200 ) 

201 quantumMap[taskDef] = quantumSet 

202 self.tasks = tasks 

203 self.quantumMap = quantumMap 

204 self.qGraph = QuantumGraph(quantumMap) 

205 self.universe = universe 

206 

207 def _cleanGraphs(self, graph1, graph2): 

208 # This is a hack for the unit test since the qualified name will be 

209 # different as it will be __main__ here, but qualified to the 

210 # unittest module name when restored 

211 # Updates in place 

212 for saved, loaded in zip(graph1._quanta.keys(), 

213 graph2._quanta.keys()): 

214 saved.taskName = saved.taskName.split('.')[-1] 

215 loaded.taskName = loaded.taskName.split('.')[-1] 

216 

217 def testTaskGraph(self): 

218 for taskDef in self.quantumMap.keys(): 

219 self.assertIn(taskDef, self.qGraph.taskGraph) 

220 

221 def testGraph(self): 

222 graphSet = {q.quantum for q in self.qGraph.graph} 

223 for quantum in chain.from_iterable(self.quantumMap.values()): 

224 self.assertIn(quantum, graphSet) 

225 

226 def testGetQuantumNodeByNodeId(self): 

227 inputQuanta = tuple(self.qGraph.inputQuanta) 

228 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId) 

229 self.assertEqual(node, inputQuanta[0]) 

230 wrongNode = NodeId(15, BuildId("alternative build Id")) 

231 with self.assertRaises(IncompatibleGraphError): 

232 self.qGraph.getQuantumNodeByNodeId(wrongNode) 

233 

234 def testPickle(self): 

235 stringify = pickle.dumps(self.qGraph) 

236 restore: QuantumGraph = pickle.loads(stringify) 

237 self._cleanGraphs(self.qGraph, restore) 

238 self.assertEqual(self.qGraph, restore) 

239 

240 def testInputQuanta(self): 

241 inputs = {q.quantum for q in self.qGraph.inputQuanta} 

242 self.assertEqual(self.quantumMap[self.tasks[0]], inputs) 

243 

244 def testOutputtQuanta(self): 

245 outputs = {q.quantum for q in self.qGraph.outputQuanta} 

246 self.assertEqual(self.quantumMap[self.tasks[-1]], outputs) 

247 

248 def testLength(self): 

249 self.assertEqual(len(self.qGraph), 6) 

250 

251 def testGetQuantaForTask(self): 

252 for task in self.tasks: 

253 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task]) 

254 

255 def testGetNodesForTask(self): 

256 for task in self.tasks: 

257 nodes: Iterable[QuantumNode] = self.qGraph.getNodesForTask(task) 

258 quanta_in_node = set(n.quantum for n in nodes) 

259 self.assertEqual(quanta_in_node, self.quantumMap[task]) 

260 

261 def testFindTasksWithInput(self): 

262 self.assertEqual(tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0], 

263 self.tasks[1]) 

264 

265 def testFindTasksWithOutput(self): 

266 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0]) 

267 

268 def testTaskWithDSType(self): 

269 self.assertEqual(set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), 

270 set(self.tasks[:2])) 

271 

272 def testFindTaskDefByName(self): 

273 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], 

274 self.tasks[0]) 

275 

276 def testFindTaskDefByLabel(self): 

277 self.assertEqual(self.qGraph.findTaskDefByLabel("R"), 

278 self.tasks[0]) 

279 

280 def testFindQuantaWIthDSType(self): 

281 self.assertEqual(self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), 

282 self.quantumMap[self.tasks[0]]) 

283 

284 def testAllDatasetTypes(self): 

285 allDatasetTypes = set(self.qGraph.allDatasetTypes) 

286 truth = set() 

287 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections): 

288 for connection in conClass.allConnections.values(): # type: ignore 

289 truth.add(connection.name) 

290 self.assertEqual(allDatasetTypes, truth) 

291 

292 def testSubset(self): 

293 allNodes = list(self.qGraph) 

294 subset = self.qGraph.subset(allNodes[0]) 

295 self.assertEqual(len(subset), 1) 

296 subsetList = list(subset) 

297 self.assertEqual(allNodes[0].quantum, subsetList[0].quantum) 

298 self.assertEqual(self.qGraph._buildId, subset._buildId) 

299 

300 def testIsConnected(self): 

301 # False because there are two quantum chains for two distinct sets of 

302 # dimensions 

303 self.assertFalse(self.qGraph.isConnected) 

304 # make a broken subset 

305 allNodes = list(self.qGraph) 

306 subset = self.qGraph.subset((allNodes[0], allNodes[1])) 

307 # True because we subset to only one chain of graphs 

308 self.assertTrue(subset.isConnected) 

309 

310 def testSubsetToConnected(self): 

311 connectedGraphs = self.qGraph.subsetToConnected() 

312 self.assertEqual(len(connectedGraphs), 2) 

313 self.assertTrue(connectedGraphs[0].isConnected) 

314 self.assertTrue(connectedGraphs[1].isConnected) 

315 

316 self.assertEqual(len(connectedGraphs[0]), 3) 

317 self.assertEqual(len(connectedGraphs[1]), 3) 

318 

319 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1]) 

320 

321 count = 0 

322 for node in self.qGraph: 

323 if connectedGraphs[0].checkQuantumInGraph(node.quantum): 

324 count += 1 

325 if connectedGraphs[1].checkQuantumInGraph(node.quantum): 

326 count += 1 

327 self.assertEqual(len(self.qGraph), count) 

328 

329 self.assertEqual(self.tasks, list(connectedGraphs[0].taskGraph)) 

330 self.assertEqual(self.tasks, list(connectedGraphs[1].taskGraph)) 

331 allNodes = list(self.qGraph) 

332 node = self.qGraph.determineInputsToQuantumNode(allNodes[1]) 

333 self.assertEqual(set([allNodes[0]]), node) 

334 node = self.qGraph.determineInputsToQuantumNode(allNodes[1]) 

335 self.assertEqual(set([allNodes[0]]), node) 

336 

337 def testDetermineOutputsOfQuantumNode(self): 

338 allNodes = list(self.qGraph) 

339 node = next(iter(self.qGraph.determineOutputsOfQuantumNode(allNodes[1]))) 

340 self.assertEqual(allNodes[2], node) 

341 

342 def testDetermineConnectionsOfQuantum(self): 

343 allNodes = list(self.qGraph) 

344 connections = self.qGraph.determineConnectionsOfQuantumNode(allNodes[1]) 

345 self.assertEqual(list(connections), list(self.qGraph.subset(allNodes[:3]))) 

346 

347 def testDetermineAnsestorsOfQuantumNode(self): 

348 allNodes = list(self.qGraph) 

349 ansestors = self.qGraph.determineAncestorsOfQuantumNode(allNodes[2]) 

350 self.assertEqual(list(ansestors), list(self.qGraph.subset(allNodes[:3]))) 

351 

352 def testFindCycle(self): 

353 self.assertFalse(self.qGraph.findCycle()) 

354 

355 def testSaveLoad(self): 

356 with tempfile.TemporaryFile(suffix='.qgraph') as tmpFile: 

357 self.qGraph.save(tmpFile) 

358 tmpFile.seek(0) 

359 restore = QuantumGraph.load(tmpFile, self.universe) 

360 self._cleanGraphs(self.qGraph, restore) 

361 self.assertEqual(self.qGraph, restore) 

362 # Load in just one node 

363 tmpFile.seek(0) 

364 restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(0,)) 

365 self.assertEqual(len(restoreSub), 1) 

366 self.assertEqual(list(restoreSub)[0], 

367 restore.getQuantumNodeByNodeId(NodeId(0, restore._buildId))) 

368 

369 def testSaveLoadUri(self): 

370 uri = None 

371 try: 

372 with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile: 

373 uri = tmpFile.name 

374 self.qGraph.saveUri(uri) 

375 restore = QuantumGraph.loadUri(uri, self.universe) 

376 self._cleanGraphs(self.qGraph, restore) 

377 self.assertEqual(self.qGraph, restore) 

378 nodeNumber = random.randint(0, len(self.qGraph)-1) 

379 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber,), 

380 graphID=self.qGraph._buildId) 

381 self.assertEqual(len(restoreSub), 1) 

382 self.assertEqual(list(restoreSub)[0], 

383 restore.getQuantumNodeByNodeId(NodeId(nodeNumber, restore.graphID))) 

384 # verify that more than one node works 

385 nodeNumber2 = random.randint(0, len(self.qGraph)-1) 

386 # ensure it is a different node number 

387 while nodeNumber2 == nodeNumber: 

388 nodeNumber2 = random.randint(0, len(self.qGraph)-1) 

389 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2)) 

390 self.assertEqual(len(restoreSub), 2) 

391 self.assertEqual(set(restoreSub), 

392 set((restore.getQuantumNodeByNodeId(NodeId(nodeNumber, restore._buildId)), 

393 restore.getQuantumNodeByNodeId(NodeId(nodeNumber2, restore._buildId))))) 

394 # verify an error when requesting a non existant node number 

395 with self.assertRaises(ValueError): 

396 QuantumGraph.loadUri(uri, self.universe, nodes=(99,)) 

397 

398 # verify a graphID that does not match will be an error 

399 with self.assertRaises(ValueError): 

400 QuantumGraph.loadUri(uri, self.universe, graphID="NOTRIGHT") 

401 

402 except Exception as e: 

403 raise e 

404 finally: 

405 if uri is not None: 

406 os.remove(uri) 

407 

408 with self.assertRaises(TypeError): 

409 self.qGraph.saveUri("test.notgraph") 

410 

411 @unittest.skipIf(not boto3, "Warning: boto3 AWS SDK not found!") 

412 @mock_s3 

413 def testSaveLoadUriS3(self): 

414 # Test loading a quantum graph from an mock s3 store 

415 conn = boto3.resource('s3', region_name="us-east-1") 

416 conn.create_bucket(Bucket='testBucket') 

417 uri = "s3://testBucket/qgraph.qgraph" 

418 self.qGraph.saveUri(uri) 

419 restore = QuantumGraph.loadUri(uri, self.universe) 

420 self._cleanGraphs(self.qGraph, restore) 

421 self.assertEqual(self.qGraph, restore) 

422 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(0,)) 

423 self.assertEqual(len(restoreSub), 1) 

424 self.assertEqual(list(restoreSub)[0], 

425 restore.getQuantumNodeByNodeId(NodeId(0, restore._buildId))) 

426 

427 def testContains(self): 

428 firstNode = next(iter(self.qGraph)) 

429 self.assertIn(firstNode, self.qGraph) 

430 

431 

432class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

433 pass 

434 

435 

436def setup_module(module): 

437 lsst.utils.tests.init() 

438 

439 

440if __name__ == "__main__": 440 ↛ 441line 440 didn't jump to line 441, because the condition on line 440 was never true

441 lsst.utils.tests.init() 

442 unittest.main()