Coverage for tests/test_quantumGraph.py: 21%

300 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-02-05 10:08 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22import os 

23import pickle 

24import random 

25import tempfile 

26import unittest 

27import uuid 

28from itertools import chain 

29from typing import Iterable 

30 

31import lsst.pipe.base.connectionTypes as cT 

32import lsst.utils.tests 

33from lsst.daf.butler import Config, DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum 

34from lsst.pex.config import Field 

35from lsst.pipe.base import ( 

36 DatasetTypeName, 

37 PipelineTask, 

38 PipelineTaskConfig, 

39 PipelineTaskConnections, 

40 QuantumGraph, 

41 TaskDef, 

42) 

43from lsst.pipe.base.graph.quantumNode import QuantumNode 

44from lsst.utils.introspection import get_full_type_name 

45 

46try: 

47 import boto3 

48 from moto import mock_s3 

49except ImportError: 

50 boto3 = None 

51 

52 def mock_s3(cls): 

53 """A no-op decorator in case moto mock_s3 can not be imported.""" 

54 return cls 

55 

56 

57METADATA = {"a": [1, 2, 3]} 

58 

59 

60class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")): 

61 initOutput = cT.InitOutput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a") 

62 input = cT.Input(name="Dummy1Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

63 output = cT.Output(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

64 

65 

66class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections): 

67 conf1 = Field(dtype=int, default=1, doc="dummy config") 

68 

69 

70class Dummy1PipelineTask(PipelineTask): 

71 ConfigClass = Dummy1Config 

72 

73 

74class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")): 

75 initInput = cT.InitInput(name="Dummy1InitOutput", storageClass="ExposureF", doc="n/a") 

76 initOutput = cT.InitOutput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a") 

77 input = cT.Input(name="Dummy1Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

78 output = cT.Output(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

79 

80 

81class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections): 

82 conf1 = Field(dtype=int, default=1, doc="dummy config") 

83 

84 

85class Dummy2PipelineTask(PipelineTask): 

86 ConfigClass = Dummy2Config 

87 

88 

89class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")): 

90 initInput = cT.InitInput(name="Dummy2InitOutput", storageClass="ExposureF", doc="n/a") 

91 initOutput = cT.InitOutput(name="Dummy3InitOutput", storageClass="ExposureF", doc="n/a") 

92 input = cT.Input(name="Dummy2Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

93 output = cT.Output(name="Dummy3Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

94 

95 

96class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections): 

97 conf1 = Field(dtype=int, default=1, doc="dummy config") 

98 

99 

100class Dummy3PipelineTask(PipelineTask): 

101 ConfigClass = Dummy3Config 

102 

103 

104# Test if a Task that does not interact with the other Tasks works fine in 

105# the graph. 

106class Dummy4Connections(PipelineTaskConnections, dimensions=("A", "B")): 

107 input = cT.Input(name="Dummy4Input", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

108 output = cT.Output(name="Dummy4Output", storageClass="ExposureF", doc="n/a", dimensions=("A", "B")) 

109 

110 

111class Dummy4Config(PipelineTaskConfig, pipelineConnections=Dummy4Connections): 

112 conf1 = Field(dtype=int, default=1, doc="dummy config") 

113 

114 

115class Dummy4PipelineTask(PipelineTask): 

116 ConfigClass = Dummy4Config 

117 

118 

119class QuantumGraphTestCase(unittest.TestCase): 

120 """Tests the various functions of a quantum graph""" 

121 

122 def setUp(self): 

123 self.config = Config( 

124 { 

125 "version": 1, 

126 "namespace": "pipe_base_test", 

127 "skypix": { 

128 "common": "htm7", 

129 "htm": { 

130 "class": "lsst.sphgeom.HtmPixelization", 

131 "max_level": 24, 

132 }, 

133 }, 

134 "elements": { 

135 "A": { 

136 "keys": [ 

137 { 

138 "name": "id", 

139 "type": "int", 

140 } 

141 ], 

142 "storage": { 

143 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

144 }, 

145 }, 

146 "B": { 

147 "keys": [ 

148 { 

149 "name": "id", 

150 "type": "int", 

151 } 

152 ], 

153 "storage": { 

154 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

155 }, 

156 }, 

157 }, 

158 "packers": {}, 

159 } 

160 ) 

161 universe = DimensionUniverse(config=self.config) 

162 

163 def _makeDatasetType(connection): 

164 return DatasetType( 

165 connection.name, 

166 getattr(connection, "dimensions", ()), 

167 storageClass=connection.storageClass, 

168 universe=universe, 

169 ) 

170 

171 # need to make a mapping of TaskDef to set of quantum 

172 quantumMap = {} 

173 tasks = [] 

174 initInputs = {} 

175 initOutputs = {} 

176 for task, label in ( 

177 (Dummy1PipelineTask, "R"), 

178 (Dummy2PipelineTask, "S"), 

179 (Dummy3PipelineTask, "T"), 

180 (Dummy4PipelineTask, "U"), 

181 ): 

182 config = task.ConfigClass() 

183 taskDef = TaskDef(get_full_type_name(task), config, task, label) 

184 tasks.append(taskDef) 

185 quantumSet = set() 

186 connections = taskDef.connections 

187 if connections.initInputs: 

188 initInputDSType = _makeDatasetType(connections.initInput) 

189 initRefs = [DatasetRef(initInputDSType, DataCoordinate.makeEmpty(universe))] 

190 initInputs[taskDef] = initRefs 

191 else: 

192 initRefs = None 

193 if connections.initOutputs: 

194 initOutputDSType = _makeDatasetType(connections.initOutput) 

195 initRefs = [DatasetRef(initOutputDSType, DataCoordinate.makeEmpty(universe))] 

196 initOutputs[taskDef] = initRefs 

197 inputDSType = _makeDatasetType(connections.input) 

198 outputDSType = _makeDatasetType(connections.output) 

199 for a, b in ((1, 2), (3, 4)): 

200 inputRefs = [ 

201 DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe)) 

202 ] 

203 outputRefs = [ 

204 DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b}, universe=universe)) 

205 ] 

206 quantumSet.add( 

207 Quantum( 

208 taskName=task.__qualname__, 

209 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe), 

210 taskClass=task, 

211 initInputs=initRefs, 

212 inputs={inputDSType: inputRefs}, 

213 outputs={outputDSType: outputRefs}, 

214 ) 

215 ) 

216 quantumMap[taskDef] = quantumSet 

217 self.tasks = tasks 

218 self.quantumMap = quantumMap 

219 self.packagesDSType = DatasetType("packages", universe.empty, storageClass="Packages") 

220 globalInitOutputs = [DatasetRef(self.packagesDSType, DataCoordinate.makeEmpty(universe))] 

221 self.qGraph = QuantumGraph( 

222 quantumMap, 

223 metadata=METADATA, 

224 universe=universe, 

225 initInputs=initInputs, 

226 initOutputs=initOutputs, 

227 globalInitOutputs=globalInitOutputs, 

228 ) 

229 self.universe = universe 

230 

231 def testTaskGraph(self): 

232 for taskDef in self.quantumMap.keys(): 

233 self.assertIn(taskDef, self.qGraph.taskGraph) 

234 

235 def testGraph(self): 

236 graphSet = {q.quantum for q in self.qGraph.graph} 

237 for quantum in chain.from_iterable(self.quantumMap.values()): 

238 self.assertIn(quantum, graphSet) 

239 

240 def testGetQuantumNodeByNodeId(self): 

241 inputQuanta = tuple(self.qGraph.inputQuanta) 

242 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId) 

243 self.assertEqual(node, inputQuanta[0]) 

244 wrongNode = uuid.uuid4() 

245 with self.assertRaises(KeyError): 

246 self.qGraph.getQuantumNodeByNodeId(wrongNode) 

247 

248 def testPickle(self): 

249 stringify = pickle.dumps(self.qGraph) 

250 restore: QuantumGraph = pickle.loads(stringify) 

251 self.assertEqual(self.qGraph, restore) 

252 

253 def testInputQuanta(self): 

254 inputs = {q.quantum for q in self.qGraph.inputQuanta} 

255 self.assertEqual(self.quantumMap[self.tasks[0]] | self.quantumMap[self.tasks[3]], inputs) 

256 

257 def testOutputQuanta(self): 

258 outputs = {q.quantum for q in self.qGraph.outputQuanta} 

259 self.assertEqual(self.quantumMap[self.tasks[2]] | self.quantumMap[self.tasks[3]], outputs) 

260 

261 def testLength(self): 

262 self.assertEqual(len(self.qGraph), 2 * len(self.tasks)) 

263 

264 def testGetQuantaForTask(self): 

265 for task in self.tasks: 

266 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task]) 

267 

268 def testGetNumberOfQuantaForTask(self): 

269 for task in self.tasks: 

270 self.assertEqual(self.qGraph.getNumberOfQuantaForTask(task), len(self.quantumMap[task])) 

271 

272 def testGetNodesForTask(self): 

273 for task in self.tasks: 

274 nodes: Iterable[QuantumNode] = self.qGraph.getNodesForTask(task) 

275 quanta_in_node = set(n.quantum for n in nodes) 

276 self.assertEqual(quanta_in_node, self.quantumMap[task]) 

277 

278 def testFindTasksWithInput(self): 

279 self.assertEqual( 

280 tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0], self.tasks[1] 

281 ) 

282 

283 def testFindTasksWithOutput(self): 

284 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0]) 

285 

286 def testTaskWithDSType(self): 

287 self.assertEqual( 

288 set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), set(self.tasks[:2]) 

289 ) 

290 

291 def testFindTaskDefByName(self): 

292 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], self.tasks[0]) 

293 

294 def testFindTaskDefByLabel(self): 

295 self.assertEqual(self.qGraph.findTaskDefByLabel("R"), self.tasks[0]) 

296 

297 def testFindQuantaWIthDSType(self): 

298 self.assertEqual( 

299 self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), self.quantumMap[self.tasks[0]] 

300 ) 

301 

302 def testAllDatasetTypes(self): 

303 allDatasetTypes = set(self.qGraph.allDatasetTypes) 

304 truth = set() 

305 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections, Dummy4Connections): 

306 for connection in conClass.allConnections.values(): # type: ignore 

307 if not isinstance(connection, cT.InitOutput): 

308 truth.add(connection.name) 

309 self.assertEqual(allDatasetTypes, truth) 

310 

311 def testSubset(self): 

312 allNodes = list(self.qGraph) 

313 subset = self.qGraph.subset(allNodes[0]) 

314 self.assertEqual(len(subset), 1) 

315 subsetList = list(subset) 

316 self.assertEqual(allNodes[0].quantum, subsetList[0].quantum) 

317 self.assertEqual(self.qGraph._buildId, subset._buildId) 

318 self.assertEqual(len(subset.globalInitOutputRefs()), 1) 

319 

320 def testSubsetToConnected(self): 

321 # False because there are two quantum chains for two distinct sets of 

322 # dimensions 

323 self.assertFalse(self.qGraph.isConnected) 

324 

325 connectedGraphs = self.qGraph.subsetToConnected() 

326 self.assertEqual(len(connectedGraphs), 4) 

327 self.assertTrue(connectedGraphs[0].isConnected) 

328 self.assertTrue(connectedGraphs[1].isConnected) 

329 self.assertTrue(connectedGraphs[2].isConnected) 

330 self.assertTrue(connectedGraphs[3].isConnected) 

331 

332 # Split out task[3] because it is expected to be on its own 

333 for cg in connectedGraphs: 

334 if self.tasks[3] in cg.taskGraph: 

335 self.assertEqual(len(cg), 1) 

336 else: 

337 self.assertEqual(len(cg), 3) 

338 

339 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1]) 

340 

341 count = 0 

342 for node in self.qGraph: 

343 if connectedGraphs[0].checkQuantumInGraph(node.quantum): 

344 count += 1 

345 if connectedGraphs[1].checkQuantumInGraph(node.quantum): 

346 count += 1 

347 if connectedGraphs[2].checkQuantumInGraph(node.quantum): 

348 count += 1 

349 if connectedGraphs[3].checkQuantumInGraph(node.quantum): 

350 count += 1 

351 self.assertEqual(len(self.qGraph), count) 

352 

353 taskSets = {len(tg := s.taskGraph): set(tg) for s in connectedGraphs} 

354 for setLen, tskSet in taskSets.items(): 

355 if setLen == 3: 

356 self.assertEqual(set(self.tasks[:-1]), tskSet) 

357 elif setLen == 1: 

358 self.assertEqual({self.tasks[-1]}, tskSet) 

359 for cg in connectedGraphs: 

360 if len(cg.taskGraph) == 1: 

361 continue 

362 allNodes = list(cg) 

363 node = cg.determineInputsToQuantumNode(allNodes[1]) 

364 self.assertEqual(set([allNodes[0]]), node) 

365 node = cg.determineInputsToQuantumNode(allNodes[1]) 

366 self.assertEqual(set([allNodes[0]]), node) 

367 

368 def testDetermineOutputsOfQuantumNode(self): 

369 testNodes = self.qGraph.getNodesForTask(self.tasks[0]) 

370 matchNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

371 connections = set() 

372 for node in testNodes: 

373 connections |= set(self.qGraph.determineOutputsOfQuantumNode(node)) 

374 self.assertEqual(matchNodes, connections) 

375 

376 def testDetermineConnectionsOfQuantum(self): 

377 testNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

378 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) | self.qGraph.getNodesForTask(self.tasks[2]) 

379 # outputs contain nodes tested for because it is a complete graph 

380 matchNodes |= set(testNodes) 

381 connections = set() 

382 for node in testNodes: 

383 connections |= set(self.qGraph.determineConnectionsOfQuantumNode(node)) 

384 self.assertEqual(matchNodes, connections) 

385 

386 def testDetermineAnsestorsOfQuantumNode(self): 

387 testNodes = self.qGraph.getNodesForTask(self.tasks[1]) 

388 matchNodes = self.qGraph.getNodesForTask(self.tasks[0]) 

389 matchNodes |= set(testNodes) 

390 connections = set() 

391 for node in testNodes: 

392 connections |= set(self.qGraph.determineAncestorsOfQuantumNode(node)) 

393 self.assertEqual(matchNodes, connections) 

394 

395 def testFindCycle(self): 

396 self.assertFalse(self.qGraph.findCycle()) 

397 

398 def testSaveLoad(self): 

399 with tempfile.TemporaryFile(suffix=".qgraph") as tmpFile: 

400 self.qGraph.save(tmpFile) 

401 tmpFile.seek(0) 

402 restore = QuantumGraph.load(tmpFile, self.universe) 

403 self.assertEqual(self.qGraph, restore) 

404 # Load in just one node 

405 tmpFile.seek(0) 

406 nodeId = [n.nodeId for n in self.qGraph][0] 

407 restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(nodeId,)) 

408 self.assertEqual(len(restoreSub), 1) 

409 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeId)) 

410 self.assertEqual(len(restoreSub.globalInitOutputRefs()), 1) 

411 # Check that InitInput and InitOutput refs are restored correctly. 

412 for taskDef in restore.iterTaskGraph(): 

413 if taskDef.label in ("S", "T"): 

414 refs = restore.initInputRefs(taskDef) 

415 self.assertIsNotNone(refs) 

416 self.assertGreater(len(refs), 0) 

417 if taskDef.label in ("R", "S", "T"): 

418 refs = restore.initOutputRefs(taskDef) 

419 self.assertIsNotNone(refs) 

420 self.assertGreater(len(refs), 0) 

421 

422 # Different universes. 

423 tmpFile.seek(0) 

424 different_config = self.config.copy() 

425 different_config["version"] = 1_000_000 

426 different_universe = DimensionUniverse(config=different_config) 

427 with self.assertLogs("lsst.daf.butler", "INFO"): 

428 QuantumGraph.load(tmpFile, different_universe) 

429 

430 different_config["namespace"] = "incompatible" 

431 different_universe = DimensionUniverse(config=different_config) 

432 print("Trying with uni ", different_universe) 

433 tmpFile.seek(0) 

434 with self.assertRaises(RuntimeError) as cm: 

435 QuantumGraph.load(tmpFile, different_universe) 

436 self.assertIn("not compatible with", str(cm.exception)) 

437 

438 def testSaveLoadUri(self): 

439 uri = None 

440 try: 

441 with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile: 

442 uri = tmpFile.name 

443 self.qGraph.saveUri(uri) 

444 restore = QuantumGraph.loadUri(uri) 

445 self.assertEqual(restore.metadata, METADATA) 

446 self.assertEqual(self.qGraph, restore) 

447 nodeNumberId = random.randint(0, len(self.qGraph) - 1) 

448 nodeNumber = [n.nodeId for n in self.qGraph][nodeNumberId] 

449 restoreSub = QuantumGraph.loadUri( 

450 uri, self.universe, nodes=(nodeNumber,), graphID=self.qGraph._buildId 

451 ) 

452 self.assertEqual(len(restoreSub), 1) 

453 self.assertEqual(list(restoreSub)[0], restore.getQuantumNodeByNodeId(nodeNumber)) 

454 # verify that more than one node works 

455 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1) 

456 # ensure it is a different node number 

457 while nodeNumberId2 == nodeNumberId: 

458 nodeNumberId2 = random.randint(0, len(self.qGraph) - 1) 

459 nodeNumber2 = [n.nodeId for n in self.qGraph][nodeNumberId2] 

460 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2)) 

461 self.assertEqual(len(restoreSub), 2) 

462 self.assertEqual( 

463 set(restoreSub), 

464 set( 

465 ( 

466 restore.getQuantumNodeByNodeId(nodeNumber), 

467 restore.getQuantumNodeByNodeId(nodeNumber2), 

468 ) 

469 ), 

470 ) 

471 # verify an error when requesting a non existant node number 

472 with self.assertRaises(ValueError): 

473 QuantumGraph.loadUri(uri, self.universe, nodes=(99,)) 

474 

475 # verify a graphID that does not match will be an error 

476 with self.assertRaises(ValueError): 

477 QuantumGraph.loadUri(uri, self.universe, graphID="NOTRIGHT") 

478 

479 except Exception as e: 

480 raise e 

481 finally: 

482 if uri is not None: 

483 os.remove(uri) 

484 

485 with self.assertRaises(TypeError): 

486 self.qGraph.saveUri("test.notgraph") 

487 

488 def testContains(self): 

489 firstNode = next(iter(self.qGraph)) 

490 self.assertIn(firstNode, self.qGraph) 

491 

492 def testDimensionUniverseInSave(self): 

493 _, header = self.qGraph._buildSaveObject(returnHeader=True) 

494 # type ignore because buildSaveObject does not have method overload 

495 self.assertEqual(header["universe"], self.universe.dimensionConfig.toDict()) # type: ignore 

496 

497 

498class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

499 pass 

500 

501 

502def setup_module(module): 

503 lsst.utils.tests.init() 

504 

505 

506if __name__ == "__main__": 506 ↛ 507line 506 didn't jump to line 507, because the condition on line 506 was never true

507 lsst.utils.tests.init() 

508 unittest.main()