Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from itertools import chain 

23import os 

24import pickle 

25import tempfile 

26from typing import Iterable 

27import unittest 

28import random 

29from lsst.daf.butler import DimensionUniverse 

30 

31from lsst.pipe.base import (QuantumGraph, TaskDef, PipelineTask, PipelineTaskConfig, PipelineTaskConnections, 

32 DatasetTypeName, IncompatibleGraphError) 

33import lsst.pipe.base.connectionTypes as cT 

34from lsst.daf.butler import Quantum, DatasetRef, DataCoordinate, DatasetType, Config 

35from lsst.pex.config import Field 

36from lsst.pipe.base.graph.quantumNode import NodeId, BuildId, QuantumNode 

37import lsst.utils.tests 

38 

39try: 

40 import boto3 

41 from moto import mock_s3 

42except ImportError: 

43 boto3 = None 

44 

45 def mock_s3(cls): 

46 """A no-op decorator in case moto mock_s3 can not be imported. 

47 """ 

48 return cls 

49 

50METADATA = {'a': [1, 2, 3]} 

51 

52 

53class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")): 

54 initOutput = cT.InitOutput(name="Dummy1InitOutput", 

55 storageClass="ExposureF", 

56 doc="n/a") 

57 input = cT.Input(name="Dummy1Input", 

58 storageClass="ExposureF", 

59 doc="n/a", 

60 dimensions=("A", "B")) 

61 output = cT.Output(name="Dummy1Output", 

62 storageClass="ExposureF", 

63 doc="n/a", 

64 dimensions=("A", "B")) 

65 

66 

67class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections): 

68 conf1 = Field(dtype=int, default=1, doc="dummy config") 

69 

70 

71class Dummy1PipelineTask(PipelineTask): 

72 ConfigClass = Dummy1Config 

73 

74 

75class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")): 

76 initInput = cT.InitInput(name="Dummy1InitOutput", 

77 storageClass="ExposureF", 

78 doc="n/a") 

79 initOutput = cT.InitOutput(name="Dummy2InitOutput", 

80 storageClass="ExposureF", 

81 doc="n/a") 

82 input = cT.Input(name="Dummy1Output", 

83 storageClass="ExposureF", 

84 doc="n/a", 

85 dimensions=("A", "B")) 

86 output = cT.Output(name="Dummy2Output", 

87 storageClass="ExposureF", 

88 doc="n/a", 

89 dimensions=("A", "B")) 

90 

91 

92class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections): 

93 conf1 = Field(dtype=int, default=1, doc="dummy config") 

94 

95 

96class Dummy2PipelineTask(PipelineTask): 

97 ConfigClass = Dummy2Config 

98 

99 

100class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")): 

101 initInput = cT.InitInput(name="Dummy2InitOutput", 

102 storageClass="ExposureF", 

103 doc="n/a") 

104 initOutput = cT.InitOutput(name="Dummy3InitOutput", 

105 storageClass="ExposureF", 

106 doc="n/a") 

107 input = cT.Input(name="Dummy2Output", 

108 storageClass="ExposureF", 

109 doc="n/a", 

110 dimensions=("A", "B")) 

111 output = cT.Output(name="Dummy3Output", 

112 storageClass="ExposureF", 

113 doc="n/a", 

114 dimensions=("A", "B")) 

115 

116 

117class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections): 

118 conf1 = Field(dtype=int, default=1, doc="dummy config") 

119 

120 

121class Dummy3PipelineTask(PipelineTask): 

122 ConfigClass = Dummy3Config 

123 

124 

125class QuantumGraphTestCase(unittest.TestCase): 

126 """Tests the various functions of a quantum graph 

127 """ 

128 def setUp(self): 

129 config = Config({ 

130 "version": 1, 

131 "skypix": { 

132 "common": "htm7", 

133 "htm": { 

134 "class": "lsst.sphgeom.HtmPixelization", 

135 "max_level": 24, 

136 } 

137 }, 

138 "elements": { 

139 "A": { 

140 "keys": [{ 

141 "name": "id", 

142 "type": "int", 

143 }], 

144 "storage": { 

145 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

146 }, 

147 }, 

148 "B": { 

149 "keys": [{ 

150 "name": "id", 

151 "type": "int", 

152 }], 

153 "storage": { 

154 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

155 }, 

156 } 

157 }, 

158 "packers": {} 

159 }) 

160 universe = DimensionUniverse(config=config) 

161 # need to make a mapping of TaskDef to set of quantum 

162 quantumMap = {} 

163 tasks = [] 

164 for task, label in ((Dummy1PipelineTask, "R"), (Dummy2PipelineTask, "S"), (Dummy3PipelineTask, "T")): 

165 config = task.ConfigClass() 

166 taskDef = TaskDef(f"__main__.{task.__qualname__}", config, task, label) 

167 tasks.append(taskDef) 

168 quantumSet = set() 

169 connections = taskDef.connections 

170 for a, b in ((1, 2), (3, 4)): 

171 if connections.initInputs: 

172 initInputDSType = DatasetType(connections.initInput.name, 

173 tuple(), 

174 storageClass=connections.initInput.storageClass, 

175 universe=universe) 

176 initRefs = [DatasetRef(initInputDSType, 

177 DataCoordinate.makeEmpty(universe))] 

178 else: 

179 initRefs = None 

180 inputDSType = DatasetType(connections.input.name, 

181 connections.input.dimensions, 

182 storageClass=connections.input.storageClass, 

183 universe=universe, 

184 ) 

185 inputRefs = [DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b}, 

186 universe=universe))] 

187 outputDSType = DatasetType(connections.output.name, 

188 connections.output.dimensions, 

189 storageClass=connections.output.storageClass, 

190 universe=universe, 

191 ) 

192 outputRefs = [DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b}, 

193 universe=universe))] 

194 quantumSet.add( 

195 Quantum(taskName=task.__qualname__, 

196 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe), 

197 taskClass=task, 

198 initInputs=initRefs, 

199 inputs={inputDSType: inputRefs}, 

200 outputs={outputDSType: outputRefs} 

201 ) 

202 ) 

203 quantumMap[taskDef] = quantumSet 

204 self.tasks = tasks 

205 self.quantumMap = quantumMap 

206 self.qGraph = QuantumGraph(quantumMap, metadata=METADATA) 

207 self.universe = universe 

208 

209 def _cleanGraphs(self, graph1, graph2): 

210 # This is a hack for the unit test since the qualified name will be 

211 # different as it will be __main__ here, but qualified to the 

212 # unittest module name when restored 

213 # Updates in place 

214 for saved, loaded in zip(graph1._quanta.keys(), 

215 graph2._quanta.keys()): 

216 saved.taskName = saved.taskName.split('.')[-1] 

217 loaded.taskName = loaded.taskName.split('.')[-1] 

218 

219 def testTaskGraph(self): 

220 for taskDef in self.quantumMap.keys(): 

221 self.assertIn(taskDef, self.qGraph.taskGraph) 

222 

223 def testGraph(self): 

224 graphSet = {q.quantum for q in self.qGraph.graph} 

225 for quantum in chain.from_iterable(self.quantumMap.values()): 

226 self.assertIn(quantum, graphSet) 

227 

228 def testGetQuantumNodeByNodeId(self): 

229 inputQuanta = tuple(self.qGraph.inputQuanta) 

230 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId) 

231 self.assertEqual(node, inputQuanta[0]) 

232 wrongNode = NodeId(15, BuildId("alternative build Id")) 

233 with self.assertRaises(IncompatibleGraphError): 

234 self.qGraph.getQuantumNodeByNodeId(wrongNode) 

235 

236 def testPickle(self): 

237 stringify = pickle.dumps(self.qGraph) 

238 restore: QuantumGraph = pickle.loads(stringify) 

239 self._cleanGraphs(self.qGraph, restore) 

240 self.assertEqual(self.qGraph, restore) 

241 

242 def testInputQuanta(self): 

243 inputs = {q.quantum for q in self.qGraph.inputQuanta} 

244 self.assertEqual(self.quantumMap[self.tasks[0]], inputs) 

245 

246 def testOutputtQuanta(self): 

247 outputs = {q.quantum for q in self.qGraph.outputQuanta} 

248 self.assertEqual(self.quantumMap[self.tasks[-1]], outputs) 

249 

250 def testLength(self): 

251 self.assertEqual(len(self.qGraph), 6) 

252 

253 def testGetQuantaForTask(self): 

254 for task in self.tasks: 

255 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task]) 

256 

257 def testGetNodesForTask(self): 

258 for task in self.tasks: 

259 nodes: Iterable[QuantumNode] = self.qGraph.getNodesForTask(task) 

260 quanta_in_node = set(n.quantum for n in nodes) 

261 self.assertEqual(quanta_in_node, self.quantumMap[task]) 

262 

263 def testFindTasksWithInput(self): 

264 self.assertEqual(tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0], 

265 self.tasks[1]) 

266 

267 def testFindTasksWithOutput(self): 

268 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0]) 

269 

270 def testTaskWithDSType(self): 

271 self.assertEqual(set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), 

272 set(self.tasks[:2])) 

273 

274 def testFindTaskDefByName(self): 

275 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], 

276 self.tasks[0]) 

277 

278 def testFindTaskDefByLabel(self): 

279 self.assertEqual(self.qGraph.findTaskDefByLabel("R"), 

280 self.tasks[0]) 

281 

282 def testFindQuantaWIthDSType(self): 

283 self.assertEqual(self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), 

284 self.quantumMap[self.tasks[0]]) 

285 

286 def testAllDatasetTypes(self): 

287 allDatasetTypes = set(self.qGraph.allDatasetTypes) 

288 truth = set() 

289 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections): 

290 for connection in conClass.allConnections.values(): # type: ignore 

291 truth.add(connection.name) 

292 self.assertEqual(allDatasetTypes, truth) 

293 

294 def testSubset(self): 

295 allNodes = list(self.qGraph) 

296 subset = self.qGraph.subset(allNodes[0]) 

297 self.assertEqual(len(subset), 1) 

298 subsetList = list(subset) 

299 self.assertEqual(allNodes[0].quantum, subsetList[0].quantum) 

300 self.assertEqual(self.qGraph._buildId, subset._buildId) 

301 

302 def testIsConnected(self): 

303 # False because there are two quantum chains for two distinct sets of 

304 # dimensions 

305 self.assertFalse(self.qGraph.isConnected) 

306 # make a broken subset 

307 allNodes = list(self.qGraph) 

308 subset = self.qGraph.subset((allNodes[0], allNodes[1])) 

309 # True because we subset to only one chain of graphs 

310 self.assertTrue(subset.isConnected) 

311 

312 def testSubsetToConnected(self): 

313 connectedGraphs = self.qGraph.subsetToConnected() 

314 self.assertEqual(len(connectedGraphs), 2) 

315 self.assertTrue(connectedGraphs[0].isConnected) 

316 self.assertTrue(connectedGraphs[1].isConnected) 

317 

318 self.assertEqual(len(connectedGraphs[0]), 3) 

319 self.assertEqual(len(connectedGraphs[1]), 3) 

320 

321 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1]) 

322 

323 count = 0 

324 for node in self.qGraph: 

325 if connectedGraphs[0].checkQuantumInGraph(node.quantum): 

326 count += 1 

327 if connectedGraphs[1].checkQuantumInGraph(node.quantum): 

328 count += 1 

329 self.assertEqual(len(self.qGraph), count) 

330 

331 self.assertEqual(set(self.tasks), set(connectedGraphs[0].taskGraph)) 

332 self.assertEqual(set(self.tasks), set(connectedGraphs[1].taskGraph)) 

333 allNodes = list(self.qGraph) 

334 node = self.qGraph.determineInputsToQuantumNode(allNodes[1]) 

335 self.assertEqual(set([allNodes[0]]), node) 

336 node = self.qGraph.determineInputsToQuantumNode(allNodes[1]) 

337 self.assertEqual(set([allNodes[0]]), node) 

338 

339 def testDetermineOutputsOfQuantumNode(self): 

340 allNodes = list(self.qGraph) 

341 node = next(iter(self.qGraph.determineOutputsOfQuantumNode(allNodes[1]))) 

342 self.assertEqual(allNodes[2], node) 

343 

344 def testDetermineConnectionsOfQuantum(self): 

345 allNodes = list(self.qGraph) 

346 connections = self.qGraph.determineConnectionsOfQuantumNode(allNodes[1]) 

347 self.assertEqual(list(connections), list(self.qGraph.subset(allNodes[:3]))) 

348 

349 def testDetermineAnsestorsOfQuantumNode(self): 

350 allNodes = list(self.qGraph) 

351 ansestors = self.qGraph.determineAncestorsOfQuantumNode(allNodes[2]) 

352 self.assertEqual(list(ansestors), list(self.qGraph.subset(allNodes[:3]))) 

353 

354 def testFindCycle(self): 

355 self.assertFalse(self.qGraph.findCycle()) 

356 

357 def testSaveLoad(self): 

358 with tempfile.TemporaryFile(suffix='.qgraph') as tmpFile: 

359 self.qGraph.save(tmpFile) 

360 tmpFile.seek(0) 

361 restore = QuantumGraph.load(tmpFile, self.universe) 

362 self._cleanGraphs(self.qGraph, restore) 

363 self.assertEqual(self.qGraph, restore) 

364 # Load in just one node 

365 tmpFile.seek(0) 

366 restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(0,)) 

367 self.assertEqual(len(restoreSub), 1) 

368 self.assertEqual(list(restoreSub)[0], 

369 restore.getQuantumNodeByNodeId(NodeId(0, restore._buildId))) 

370 

371 def testSaveLoadUri(self): 

372 uri = None 

373 try: 

374 with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile: 

375 uri = tmpFile.name 

376 self.qGraph.saveUri(uri) 

377 restore = QuantumGraph.loadUri(uri, self.universe) 

378 self.assertEqual(restore.metadata, METADATA) 

379 self._cleanGraphs(self.qGraph, restore) 

380 self.assertEqual(self.qGraph, restore) 

381 nodeNumber = random.randint(0, len(self.qGraph)-1) 

382 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber,), 

383 graphID=self.qGraph._buildId) 

384 self.assertEqual(len(restoreSub), 1) 

385 self.assertEqual(list(restoreSub)[0], 

386 restore.getQuantumNodeByNodeId(NodeId(nodeNumber, restore.graphID))) 

387 # verify that more than one node works 

388 nodeNumber2 = random.randint(0, len(self.qGraph)-1) 

389 # ensure it is a different node number 

390 while nodeNumber2 == nodeNumber: 

391 nodeNumber2 = random.randint(0, len(self.qGraph)-1) 

392 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2)) 

393 self.assertEqual(len(restoreSub), 2) 

394 self.assertEqual(set(restoreSub), 

395 set((restore.getQuantumNodeByNodeId(NodeId(nodeNumber, restore._buildId)), 

396 restore.getQuantumNodeByNodeId(NodeId(nodeNumber2, restore._buildId))))) 

397 # verify an error when requesting a non existant node number 

398 with self.assertRaises(ValueError): 

399 QuantumGraph.loadUri(uri, self.universe, nodes=(99,)) 

400 

401 # verify a graphID that does not match will be an error 

402 with self.assertRaises(ValueError): 

403 QuantumGraph.loadUri(uri, self.universe, graphID="NOTRIGHT") 

404 

405 except Exception as e: 

406 raise e 

407 finally: 

408 if uri is not None: 

409 os.remove(uri) 

410 

411 with self.assertRaises(TypeError): 

412 self.qGraph.saveUri("test.notgraph") 

413 

414 @unittest.skipIf(not boto3, "Warning: boto3 AWS SDK not found!") 

415 @mock_s3 

416 def testSaveLoadUriS3(self): 

417 # Test loading a quantum graph from an mock s3 store 

418 conn = boto3.resource('s3', region_name="us-east-1") 

419 conn.create_bucket(Bucket='testBucket') 

420 uri = "s3://testBucket/qgraph.qgraph" 

421 self.qGraph.saveUri(uri) 

422 restore = QuantumGraph.loadUri(uri, self.universe) 

423 self._cleanGraphs(self.qGraph, restore) 

424 self.assertEqual(self.qGraph, restore) 

425 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(0,)) 

426 self.assertEqual(len(restoreSub), 1) 

427 self.assertEqual(list(restoreSub)[0], 

428 restore.getQuantumNodeByNodeId(NodeId(0, restore._buildId))) 

429 

430 def testContains(self): 

431 firstNode = next(iter(self.qGraph)) 

432 self.assertIn(firstNode, self.qGraph) 

433 

434 

435class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

436 pass 

437 

438 

439def setup_module(module): 

440 lsst.utils.tests.init() 

441 

442 

443if __name__ == "__main__": 443 ↛ 444line 443 didn't jump to line 444, because the condition on line 443 was never true

444 lsst.utils.tests.init() 

445 unittest.main()