Coverage for tests/test_quantumGraph.py: 33%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

246 statements  

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from itertools import chain 

23import os 

24import pickle 

25import tempfile 

26from typing import Iterable 

27import unittest 

28import random 

29from lsst.daf.butler import DimensionUniverse 

30 

31from lsst.pipe.base import (QuantumGraph, TaskDef, PipelineTask, PipelineTaskConfig, PipelineTaskConnections, 

32 DatasetTypeName, IncompatibleGraphError) 

33import lsst.pipe.base.connectionTypes as cT 

34from lsst.daf.butler import Quantum, DatasetRef, DataCoordinate, DatasetType, Config 

35from lsst.pex.config import Field 

36from lsst.pipe.base.graph.quantumNode import NodeId, BuildId, QuantumNode 

37import lsst.utils.tests 

38 

39try: 

40 import boto3 

41 from moto import mock_s3 

42except ImportError: 

43 boto3 = None 

44 

45 def mock_s3(cls): 

46 """A no-op decorator in case moto mock_s3 can not be imported. 

47 """ 

48 return cls 

49 

50METADATA = {'a': [1, 2, 3]} 

51 

52 

53class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")): 

54 initOutput = cT.InitOutput(name="Dummy1InitOutput", 

55 storageClass="ExposureF", 

56 doc="n/a") 

57 input = cT.Input(name="Dummy1Input", 

58 storageClass="ExposureF", 

59 doc="n/a", 

60 dimensions=("A", "B")) 

61 output = cT.Output(name="Dummy1Output", 

62 storageClass="ExposureF", 

63 doc="n/a", 

64 dimensions=("A", "B")) 

65 

66 

67class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections): 

68 conf1 = Field(dtype=int, default=1, doc="dummy config") 

69 

70 

71class Dummy1PipelineTask(PipelineTask): 

72 ConfigClass = Dummy1Config 

73 

74 

75class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")): 

76 initInput = cT.InitInput(name="Dummy1InitOutput", 

77 storageClass="ExposureF", 

78 doc="n/a") 

79 initOutput = cT.InitOutput(name="Dummy2InitOutput", 

80 storageClass="ExposureF", 

81 doc="n/a") 

82 input = cT.Input(name="Dummy1Output", 

83 storageClass="ExposureF", 

84 doc="n/a", 

85 dimensions=("A", "B")) 

86 output = cT.Output(name="Dummy2Output", 

87 storageClass="ExposureF", 

88 doc="n/a", 

89 dimensions=("A", "B")) 

90 

91 

92class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections): 

93 conf1 = Field(dtype=int, default=1, doc="dummy config") 

94 

95 

96class Dummy2PipelineTask(PipelineTask): 

97 ConfigClass = Dummy2Config 

98 

99 

100class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")): 

101 initInput = cT.InitInput(name="Dummy2InitOutput", 

102 storageClass="ExposureF", 

103 doc="n/a") 

104 initOutput = cT.InitOutput(name="Dummy3InitOutput", 

105 storageClass="ExposureF", 

106 doc="n/a") 

107 input = cT.Input(name="Dummy2Output", 

108 storageClass="ExposureF", 

109 doc="n/a", 

110 dimensions=("A", "B")) 

111 output = cT.Output(name="Dummy3Output", 

112 storageClass="ExposureF", 

113 doc="n/a", 

114 dimensions=("A", "B")) 

115 

116 

117class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections): 

118 conf1 = Field(dtype=int, default=1, doc="dummy config") 

119 

120 

121class Dummy3PipelineTask(PipelineTask): 

122 ConfigClass = Dummy3Config 

123 

124 

125class QuantumGraphTestCase(unittest.TestCase): 

126 """Tests the various functions of a quantum graph 

127 """ 

128 def setUp(self): 

129 config = Config({ 

130 "version": 1, 

131 "skypix": { 

132 "common": "htm7", 

133 "htm": { 

134 "class": "lsst.sphgeom.HtmPixelization", 

135 "max_level": 24, 

136 } 

137 }, 

138 "elements": { 

139 "A": { 

140 "keys": [{ 

141 "name": "id", 

142 "type": "int", 

143 }], 

144 "storage": { 

145 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

146 }, 

147 }, 

148 "B": { 

149 "keys": [{ 

150 "name": "id", 

151 "type": "int", 

152 }], 

153 "storage": { 

154 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

155 }, 

156 } 

157 }, 

158 "packers": {} 

159 }) 

160 universe = DimensionUniverse(config=config) 

161 # need to make a mapping of TaskDef to set of quantum 

162 quantumMap = {} 

163 tasks = [] 

164 for task, label in ((Dummy1PipelineTask, "R"), (Dummy2PipelineTask, "S"), (Dummy3PipelineTask, "T")): 

165 config = task.ConfigClass() 

166 taskDef = TaskDef(f"__main__.{task.__qualname__}", config, task, label) 

167 tasks.append(taskDef) 

168 quantumSet = set() 

169 connections = taskDef.connections 

170 for a, b in ((1, 2), (3, 4)): 

171 if connections.initInputs: 

172 initInputDSType = DatasetType(connections.initInput.name, 

173 tuple(), 

174 storageClass=connections.initInput.storageClass, 

175 universe=universe) 

176 initRefs = [DatasetRef(initInputDSType, 

177 DataCoordinate.makeEmpty(universe))] 

178 else: 

179 initRefs = None 

180 inputDSType = DatasetType(connections.input.name, 

181 connections.input.dimensions, 

182 storageClass=connections.input.storageClass, 

183 universe=universe, 

184 ) 

185 inputRefs = [DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b}, 

186 universe=universe))] 

187 outputDSType = DatasetType(connections.output.name, 

188 connections.output.dimensions, 

189 storageClass=connections.output.storageClass, 

190 universe=universe, 

191 ) 

192 outputRefs = [DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b}, 

193 universe=universe))] 

194 quantumSet.add( 

195 Quantum(taskName=task.__qualname__, 

196 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe), 

197 taskClass=task, 

198 initInputs=initRefs, 

199 inputs={inputDSType: inputRefs}, 

200 outputs={outputDSType: outputRefs} 

201 ) 

202 ) 

203 quantumMap[taskDef] = quantumSet 

204 self.tasks = tasks 

205 self.quantumMap = quantumMap 

206 self.qGraph = QuantumGraph(quantumMap, metadata=METADATA) 

207 self.universe = universe 

208 

209 def _cleanGraphs(self, graph1, graph2): 

210 # This is a hack for the unit test since the qualified name will be 

211 # different as it will be __main__ here, but qualified to the 

212 # unittest module name when restored 

213 # Updates in place 

214 for saved, loaded in zip(graph1._quanta.keys(), 

215 graph2._quanta.keys()): 

216 saved.taskName = saved.taskName.split('.')[-1] 

217 loaded.taskName = loaded.taskName.split('.')[-1] 

218 

219 def testTaskGraph(self): 

220 for taskDef in self.quantumMap.keys(): 

221 self.assertIn(taskDef, self.qGraph.taskGraph) 

222 

223 def testGraph(self): 

224 graphSet = {q.quantum for q in self.qGraph.graph} 

225 for quantum in chain.from_iterable(self.quantumMap.values()): 

226 self.assertIn(quantum, graphSet) 

227 

228 def testGetQuantumNodeByNodeId(self): 

229 inputQuanta = tuple(self.qGraph.inputQuanta) 

230 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId) 

231 self.assertEqual(node, inputQuanta[0]) 

232 wrongNode = NodeId(15, BuildId("alternative build Id")) 

233 with self.assertRaises(IncompatibleGraphError): 

234 self.qGraph.getQuantumNodeByNodeId(wrongNode) 

235 

236 def testPickle(self): 

237 stringify = pickle.dumps(self.qGraph) 

238 restore: QuantumGraph = pickle.loads(stringify) 

239 self._cleanGraphs(self.qGraph, restore) 

240 self.assertEqual(self.qGraph, restore) 

241 

242 def testInputQuanta(self): 

243 inputs = {q.quantum for q in self.qGraph.inputQuanta} 

244 self.assertEqual(self.quantumMap[self.tasks[0]], inputs) 

245 

246 def testOutputtQuanta(self): 

247 outputs = {q.quantum for q in self.qGraph.outputQuanta} 

248 self.assertEqual(self.quantumMap[self.tasks[-1]], outputs) 

249 

250 def testLength(self): 

251 self.assertEqual(len(self.qGraph), 6) 

252 

253 def testGetQuantaForTask(self): 

254 for task in self.tasks: 

255 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task]) 

256 

257 def testGetNodesForTask(self): 

258 for task in self.tasks: 

259 nodes: Iterable[QuantumNode] = self.qGraph.getNodesForTask(task) 

260 quanta_in_node = set(n.quantum for n in nodes) 

261 self.assertEqual(quanta_in_node, self.quantumMap[task]) 

262 

263 def testFindTasksWithInput(self): 

264 self.assertEqual(tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0], 

265 self.tasks[1]) 

266 

267 def testFindTasksWithOutput(self): 

268 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0]) 

269 

270 def testTaskWithDSType(self): 

271 self.assertEqual(set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), 

272 set(self.tasks[:2])) 

273 

274 def testFindTaskDefByName(self): 

275 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], 

276 self.tasks[0]) 

277 

278 def testFindTaskDefByLabel(self): 

279 self.assertEqual(self.qGraph.findTaskDefByLabel("R"), 

280 self.tasks[0]) 

281 

282 def testFindQuantaWIthDSType(self): 

283 self.assertEqual(self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), 

284 self.quantumMap[self.tasks[0]]) 

285 

286 def testAllDatasetTypes(self): 

287 allDatasetTypes = set(self.qGraph.allDatasetTypes) 

288 truth = set() 

289 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections): 

290 for connection in conClass.allConnections.values(): # type: ignore 

291 if not isinstance(connection, cT.InitOutput): 

292 truth.add(connection.name) 

293 self.assertEqual(allDatasetTypes, truth) 

294 

295 def testSubset(self): 

296 allNodes = list(self.qGraph) 

297 subset = self.qGraph.subset(allNodes[0]) 

298 self.assertEqual(len(subset), 1) 

299 subsetList = list(subset) 

300 self.assertEqual(allNodes[0].quantum, subsetList[0].quantum) 

301 self.assertEqual(self.qGraph._buildId, subset._buildId) 

302 

303 def testIsConnected(self): 

304 # False because there are two quantum chains for two distinct sets of 

305 # dimensions 

306 self.assertFalse(self.qGraph.isConnected) 

307 # make a broken subset 

308 allNodes = list(self.qGraph) 

309 subset = self.qGraph.subset((allNodes[0], allNodes[1])) 

310 # True because we subset to only one chain of graphs 

311 self.assertTrue(subset.isConnected) 

312 

313 def testSubsetToConnected(self): 

314 connectedGraphs = self.qGraph.subsetToConnected() 

315 self.assertEqual(len(connectedGraphs), 2) 

316 self.assertTrue(connectedGraphs[0].isConnected) 

317 self.assertTrue(connectedGraphs[1].isConnected) 

318 

319 self.assertEqual(len(connectedGraphs[0]), 3) 

320 self.assertEqual(len(connectedGraphs[1]), 3) 

321 

322 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1]) 

323 

324 count = 0 

325 for node in self.qGraph: 

326 if connectedGraphs[0].checkQuantumInGraph(node.quantum): 

327 count += 1 

328 if connectedGraphs[1].checkQuantumInGraph(node.quantum): 

329 count += 1 

330 self.assertEqual(len(self.qGraph), count) 

331 

332 self.assertEqual(set(self.tasks), set(connectedGraphs[0].taskGraph)) 

333 self.assertEqual(set(self.tasks), set(connectedGraphs[1].taskGraph)) 

334 allNodes = list(self.qGraph) 

335 node = self.qGraph.determineInputsToQuantumNode(allNodes[1]) 

336 self.assertEqual(set([allNodes[0]]), node) 

337 node = self.qGraph.determineInputsToQuantumNode(allNodes[1]) 

338 self.assertEqual(set([allNodes[0]]), node) 

339 

340 def testDetermineOutputsOfQuantumNode(self): 

341 allNodes = list(self.qGraph) 

342 node = next(iter(self.qGraph.determineOutputsOfQuantumNode(allNodes[1]))) 

343 self.assertEqual(allNodes[2], node) 

344 

345 def testDetermineConnectionsOfQuantum(self): 

346 allNodes = list(self.qGraph) 

347 connections = self.qGraph.determineConnectionsOfQuantumNode(allNodes[1]) 

348 self.assertEqual(list(connections), list(self.qGraph.subset(allNodes[:3]))) 

349 

350 def testDetermineAnsestorsOfQuantumNode(self): 

351 allNodes = list(self.qGraph) 

352 ansestors = self.qGraph.determineAncestorsOfQuantumNode(allNodes[2]) 

353 self.assertEqual(list(ansestors), list(self.qGraph.subset(allNodes[:3]))) 

354 

355 def testFindCycle(self): 

356 self.assertFalse(self.qGraph.findCycle()) 

357 

358 def testSaveLoad(self): 

359 with tempfile.TemporaryFile(suffix='.qgraph') as tmpFile: 

360 self.qGraph.save(tmpFile) 

361 tmpFile.seek(0) 

362 restore = QuantumGraph.load(tmpFile, self.universe) 

363 self._cleanGraphs(self.qGraph, restore) 

364 self.assertEqual(self.qGraph, restore) 

365 # Load in just one node 

366 tmpFile.seek(0) 

367 restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(0,)) 

368 self.assertEqual(len(restoreSub), 1) 

369 self.assertEqual(list(restoreSub)[0], 

370 restore.getQuantumNodeByNodeId(NodeId(0, restore._buildId))) 

371 

372 def testSaveLoadUri(self): 

373 uri = None 

374 try: 

375 with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile: 

376 uri = tmpFile.name 

377 self.qGraph.saveUri(uri) 

378 restore = QuantumGraph.loadUri(uri, self.universe) 

379 self.assertEqual(restore.metadata, METADATA) 

380 self._cleanGraphs(self.qGraph, restore) 

381 self.assertEqual(self.qGraph, restore) 

382 nodeNumber = random.randint(0, len(self.qGraph)-1) 

383 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber,), 

384 graphID=self.qGraph._buildId) 

385 self.assertEqual(len(restoreSub), 1) 

386 self.assertEqual(list(restoreSub)[0], 

387 restore.getQuantumNodeByNodeId(NodeId(nodeNumber, restore.graphID))) 

388 # verify that more than one node works 

389 nodeNumber2 = random.randint(0, len(self.qGraph)-1) 

390 # ensure it is a different node number 

391 while nodeNumber2 == nodeNumber: 

392 nodeNumber2 = random.randint(0, len(self.qGraph)-1) 

393 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2)) 

394 self.assertEqual(len(restoreSub), 2) 

395 self.assertEqual(set(restoreSub), 

396 set((restore.getQuantumNodeByNodeId(NodeId(nodeNumber, restore._buildId)), 

397 restore.getQuantumNodeByNodeId(NodeId(nodeNumber2, restore._buildId))))) 

398 # verify an error when requesting a non existant node number 

399 with self.assertRaises(ValueError): 

400 QuantumGraph.loadUri(uri, self.universe, nodes=(99,)) 

401 

402 # verify a graphID that does not match will be an error 

403 with self.assertRaises(ValueError): 

404 QuantumGraph.loadUri(uri, self.universe, graphID="NOTRIGHT") 

405 

406 except Exception as e: 

407 raise e 

408 finally: 

409 if uri is not None: 

410 os.remove(uri) 

411 

412 with self.assertRaises(TypeError): 

413 self.qGraph.saveUri("test.notgraph") 

414 

415 @unittest.skipIf(not boto3, "Warning: boto3 AWS SDK not found!") 

416 @mock_s3 

417 def testSaveLoadUriS3(self): 

418 # Test loading a quantum graph from an mock s3 store 

419 conn = boto3.resource('s3', region_name="us-east-1") 

420 conn.create_bucket(Bucket='testBucket') 

421 uri = "s3://testBucket/qgraph.qgraph" 

422 self.qGraph.saveUri(uri) 

423 restore = QuantumGraph.loadUri(uri, self.universe) 

424 self._cleanGraphs(self.qGraph, restore) 

425 self.assertEqual(self.qGraph, restore) 

426 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(0,)) 

427 self.assertEqual(len(restoreSub), 1) 

428 self.assertEqual(list(restoreSub)[0], 

429 restore.getQuantumNodeByNodeId(NodeId(0, restore._buildId))) 

430 

431 def testContains(self): 

432 firstNode = next(iter(self.qGraph)) 

433 self.assertIn(firstNode, self.qGraph) 

434 

435 

436class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

437 pass 

438 

439 

440def setup_module(module): 

441 lsst.utils.tests.init() 

442 

443 

444if __name__ == "__main__": 444 ↛ 445line 444 didn't jump to line 445, because the condition on line 444 was never true

445 lsst.utils.tests.init() 

446 unittest.main()