Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from itertools import chain 

23import os 

24import pickle 

25import tempfile 

26import unittest 

27import random 

28from lsst.daf.butler import DimensionUniverse 

29 

30from lsst.pipe.base import (QuantumGraph, TaskDef, PipelineTask, PipelineTaskConfig, PipelineTaskConnections, 

31 DatasetTypeName, IncompatibleGraphError) 

32import lsst.pipe.base.connectionTypes as cT 

33from lsst.daf.butler import Quantum, DatasetRef, DataCoordinate, DatasetType, Config 

34from lsst.pex.config import Field 

35from lsst.pipe.base.graph.quantumNode import NodeId, BuildId 

36import lsst.utils.tests 

37 

38try: 

39 import boto3 

40 from moto import mock_s3 

41except ImportError: 

42 boto3 = None 

43 

44 def mock_s3(cls): 

45 """A no-op decorator in case moto mock_s3 can not be imported. 

46 """ 

47 return cls 

48 

49 

50class Dummy1Connections(PipelineTaskConnections, dimensions=("A", "B")): 

51 initOutput = cT.InitOutput(name="Dummy1InitOutput", 

52 storageClass="ExposureF", 

53 doc="n/a") 

54 input = cT.Input(name="Dummy1Input", 

55 storageClass="ExposureF", 

56 doc="n/a", 

57 dimensions=("A", "B")) 

58 output = cT.Output(name="Dummy1Output", 

59 storageClass="ExposureF", 

60 doc="n/a", 

61 dimensions=("A", "B")) 

62 

63 

64class Dummy1Config(PipelineTaskConfig, pipelineConnections=Dummy1Connections): 

65 conf1 = Field(dtype=int, default=1, doc="dummy config") 

66 

67 

68class Dummy1PipelineTask(PipelineTask): 

69 ConfigClass = Dummy1Config 

70 

71 

72class Dummy2Connections(PipelineTaskConnections, dimensions=("A", "B")): 

73 initInput = cT.InitInput(name="Dummy1InitOutput", 

74 storageClass="ExposureF", 

75 doc="n/a") 

76 initOutput = cT.InitOutput(name="Dummy2InitOutput", 

77 storageClass="ExposureF", 

78 doc="n/a") 

79 input = cT.Input(name="Dummy1Output", 

80 storageClass="ExposureF", 

81 doc="n/a", 

82 dimensions=("A", "B")) 

83 output = cT.Output(name="Dummy2Output", 

84 storageClass="ExposureF", 

85 doc="n/a", 

86 dimensions=("A", "B")) 

87 

88 

89class Dummy2Config(PipelineTaskConfig, pipelineConnections=Dummy2Connections): 

90 conf1 = Field(dtype=int, default=1, doc="dummy config") 

91 

92 

93class Dummy2PipelineTask(PipelineTask): 

94 ConfigClass = Dummy2Config 

95 

96 

97class Dummy3Connections(PipelineTaskConnections, dimensions=("A", "B")): 

98 initInput = cT.InitInput(name="Dummy2InitOutput", 

99 storageClass="ExposureF", 

100 doc="n/a") 

101 initOutput = cT.InitOutput(name="Dummy3InitOutput", 

102 storageClass="ExposureF", 

103 doc="n/a") 

104 input = cT.Input(name="Dummy2Output", 

105 storageClass="ExposureF", 

106 doc="n/a", 

107 dimensions=("A", "B")) 

108 output = cT.Output(name="Dummy3Output", 

109 storageClass="ExposureF", 

110 doc="n/a", 

111 dimensions=("A", "B")) 

112 

113 

114class Dummy3Config(PipelineTaskConfig, pipelineConnections=Dummy3Connections): 

115 conf1 = Field(dtype=int, default=1, doc="dummy config") 

116 

117 

118class Dummy3PipelineTask(PipelineTask): 

119 ConfigClass = Dummy3Config 

120 

121 

122class QuantumGraphTestCase(unittest.TestCase): 

123 """Tests the various functions of a quantum graph 

124 """ 

125 def setUp(self): 

126 config = Config({ 

127 "version": 1, 

128 "skypix": { 

129 "common": "htm7", 

130 "htm": { 

131 "class": "lsst.sphgeom.HtmPixelization", 

132 "max_level": 24, 

133 } 

134 }, 

135 "elements": { 

136 "A": { 

137 "keys": [{ 

138 "name": "id", 

139 "type": "int", 

140 }], 

141 "storage": { 

142 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

143 }, 

144 }, 

145 "B": { 

146 "keys": [{ 

147 "name": "id", 

148 "type": "int", 

149 }], 

150 "storage": { 

151 "cls": "lsst.daf.butler.registry.dimensions.table.TableDimensionRecordStorage", 

152 }, 

153 } 

154 }, 

155 "packers": {} 

156 }) 

157 universe = DimensionUniverse(config=config) 

158 # need to make a mapping of TaskDef to set of quantum 

159 quantumMap = {} 

160 tasks = [] 

161 for task, label in ((Dummy1PipelineTask, "R"), (Dummy2PipelineTask, "S"), (Dummy3PipelineTask, "T")): 

162 config = task.ConfigClass() 

163 taskDef = TaskDef(f"__main__.{task.__qualname__}", config, task, label) 

164 tasks.append(taskDef) 

165 quantumSet = set() 

166 connections = taskDef.connections 

167 for a, b in ((1, 2), (3, 4)): 

168 if connections.initInputs: 

169 initInputDSType = DatasetType(connections.initInput.name, 

170 tuple(), 

171 storageClass=connections.initInput.storageClass, 

172 universe=universe) 

173 initRefs = [DatasetRef(initInputDSType, 

174 DataCoordinate.makeEmpty(universe))] 

175 else: 

176 initRefs = None 

177 inputDSType = DatasetType(connections.input.name, 

178 connections.input.dimensions, 

179 storageClass=connections.input.storageClass, 

180 universe=universe, 

181 ) 

182 inputRefs = [DatasetRef(inputDSType, DataCoordinate.standardize({"A": a, "B": b}, 

183 universe=universe))] 

184 outputDSType = DatasetType(connections.output.name, 

185 connections.output.dimensions, 

186 storageClass=connections.output.storageClass, 

187 universe=universe, 

188 ) 

189 outputRefs = [DatasetRef(outputDSType, DataCoordinate.standardize({"A": a, "B": b}, 

190 universe=universe))] 

191 quantumSet.add( 

192 Quantum(taskName=task.__qualname__, 

193 dataId=DataCoordinate.standardize({"A": a, "B": b}, universe=universe), 

194 taskClass=task, 

195 initInputs=initRefs, 

196 inputs={inputDSType: inputRefs}, 

197 outputs={outputDSType: outputRefs} 

198 ) 

199 ) 

200 quantumMap[taskDef] = quantumSet 

201 self.tasks = tasks 

202 self.quantumMap = quantumMap 

203 self.qGraph = QuantumGraph(quantumMap) 

204 self.universe = universe 

205 

206 def _cleanGraphs(self, graph1, graph2): 

207 # This is a hack for the unit test since the qualified name will be 

208 # different as it will be __main__ here, but qualified to the 

209 # unittest module name when restored 

210 # Updates in place 

211 for saved, loaded in zip(graph1._quanta.keys(), 

212 graph2._quanta.keys()): 

213 saved.taskName = saved.taskName.split('.')[-1] 

214 loaded.taskName = loaded.taskName.split('.')[-1] 

215 

216 def testTaskGraph(self): 

217 for taskDef in self.quantumMap.keys(): 

218 self.assertIn(taskDef, self.qGraph.taskGraph) 

219 

220 def testGraph(self): 

221 graphSet = {q.quantum for q in self.qGraph.graph} 

222 for quantum in chain.from_iterable(self.quantumMap.values()): 

223 self.assertIn(quantum, graphSet) 

224 

225 def testGetQuantumNodeByNodeId(self): 

226 inputQuanta = tuple(self.qGraph.inputQuanta) 

227 node = self.qGraph.getQuantumNodeByNodeId(inputQuanta[0].nodeId) 

228 self.assertEqual(node, inputQuanta[0]) 

229 wrongNode = NodeId(15, BuildId("alternative build Id")) 

230 with self.assertRaises(IncompatibleGraphError): 

231 self.qGraph.getQuantumNodeByNodeId(wrongNode) 

232 

233 def testPickle(self): 

234 stringify = pickle.dumps(self.qGraph) 

235 restore: QuantumGraph = pickle.loads(stringify) 

236 self._cleanGraphs(self.qGraph, restore) 

237 self.assertEqual(self.qGraph, restore) 

238 

239 def testInputQuanta(self): 

240 inputs = {q.quantum for q in self.qGraph.inputQuanta} 

241 self.assertEqual(self.quantumMap[self.tasks[0]], inputs) 

242 

243 def testOutputtQuanta(self): 

244 outputs = {q.quantum for q in self.qGraph.outputQuanta} 

245 self.assertEqual(self.quantumMap[self.tasks[-1]], outputs) 

246 

247 def testLength(self): 

248 self.assertEqual(len(self.qGraph), 6) 

249 

250 def testGetQuantaForTask(self): 

251 for task in self.tasks: 

252 self.assertEqual(self.qGraph.getQuantaForTask(task), self.quantumMap[task]) 

253 

254 def testFindTasksWithInput(self): 

255 self.assertEqual(tuple(self.qGraph.findTasksWithInput(DatasetTypeName("Dummy1Output")))[0], 

256 self.tasks[1]) 

257 

258 def testFindTasksWithOutput(self): 

259 self.assertEqual(self.qGraph.findTaskWithOutput(DatasetTypeName("Dummy1Output")), self.tasks[0]) 

260 

261 def testTaskWithDSType(self): 

262 self.assertEqual(set(self.qGraph.tasksWithDSType(DatasetTypeName("Dummy1Output"))), 

263 set(self.tasks[:2])) 

264 

265 def testFindTaskDefByName(self): 

266 self.assertEqual(self.qGraph.findTaskDefByName(Dummy1PipelineTask.__qualname__)[0], 

267 self.tasks[0]) 

268 

269 def testFindTaskDefByLabel(self): 

270 self.assertEqual(self.qGraph.findTaskDefByLabel("R"), 

271 self.tasks[0]) 

272 

273 def testFindQuantaWIthDSType(self): 

274 self.assertEqual(self.qGraph.findQuantaWithDSType(DatasetTypeName("Dummy1Input")), 

275 self.quantumMap[self.tasks[0]]) 

276 

277 def testAllDatasetTypes(self): 

278 allDatasetTypes = set(self.qGraph.allDatasetTypes) 

279 truth = set() 

280 for conClass in (Dummy1Connections, Dummy2Connections, Dummy3Connections): 

281 for connection in conClass.allConnections.values(): # type: ignore 

282 truth.add(connection.name) 

283 self.assertEqual(allDatasetTypes, truth) 

284 

285 def testSubset(self): 

286 allNodes = list(self.qGraph) 

287 subset = self.qGraph.subset(allNodes[0]) 

288 self.assertEqual(len(subset), 1) 

289 subsetList = list(subset) 

290 self.assertEqual(allNodes[0].quantum, subsetList[0].quantum) 

291 self.assertEqual(self.qGraph._buildId, subset._buildId) 

292 

293 def testIsConnected(self): 

294 # False because there are two quantum chains for two distinct sets of 

295 # dimensions 

296 self.assertFalse(self.qGraph.isConnected) 

297 # make a broken subset 

298 allNodes = list(self.qGraph) 

299 subset = self.qGraph.subset((allNodes[0], allNodes[1])) 

300 # True because we subset to only one chain of graphs 

301 self.assertTrue(subset.isConnected) 

302 

303 def testSubsetToConnected(self): 

304 connectedGraphs = self.qGraph.subsetToConnected() 

305 self.assertEqual(len(connectedGraphs), 2) 

306 self.assertTrue(connectedGraphs[0].isConnected) 

307 self.assertTrue(connectedGraphs[1].isConnected) 

308 

309 self.assertEqual(len(connectedGraphs[0]), 3) 

310 self.assertEqual(len(connectedGraphs[1]), 3) 

311 

312 self.assertNotEqual(connectedGraphs[0], connectedGraphs[1]) 

313 

314 count = 0 

315 for node in self.qGraph: 

316 if connectedGraphs[0].checkQuantumInGraph(node.quantum): 

317 count += 1 

318 if connectedGraphs[1].checkQuantumInGraph(node.quantum): 

319 count += 1 

320 self.assertEqual(len(self.qGraph), count) 

321 

322 self.assertEqual(self.tasks, list(connectedGraphs[0].taskGraph)) 

323 self.assertEqual(self.tasks, list(connectedGraphs[1].taskGraph)) 

324 allNodes = list(self.qGraph) 

325 node = self.qGraph.determineInputsToQuantumNode(allNodes[1]) 

326 self.assertEqual(set([allNodes[0]]), node) 

327 node = self.qGraph.determineInputsToQuantumNode(allNodes[1]) 

328 self.assertEqual(set([allNodes[0]]), node) 

329 

330 def testDetermineOutputsOfQuantumNode(self): 

331 allNodes = list(self.qGraph) 

332 node = next(iter(self.qGraph.determineOutputsOfQuantumNode(allNodes[1]))) 

333 self.assertEqual(allNodes[2], node) 

334 

335 def testDetermineConnectionsOfQuantum(self): 

336 allNodes = list(self.qGraph) 

337 connections = self.qGraph.determineConnectionsOfQuantumNode(allNodes[1]) 

338 self.assertEqual(list(connections), list(self.qGraph.subset(allNodes[:3]))) 

339 

340 def testDetermineAnsestorsOfQuantumNode(self): 

341 allNodes = list(self.qGraph) 

342 ansestors = self.qGraph.determineAncestorsOfQuantumNode(allNodes[2]) 

343 self.assertEqual(list(ansestors), list(self.qGraph.subset(allNodes[:3]))) 

344 

345 def testFindCycle(self): 

346 self.assertFalse(self.qGraph.findCycle()) 

347 

348 def testSaveLoad(self): 

349 with tempfile.TemporaryFile(suffix='.qgraph') as tmpFile: 

350 self.qGraph.save(tmpFile) 

351 tmpFile.seek(0) 

352 restore = QuantumGraph.load(tmpFile, self.universe) 

353 self._cleanGraphs(self.qGraph, restore) 

354 self.assertEqual(self.qGraph, restore) 

355 # Load in just one node 

356 tmpFile.seek(0) 

357 restoreSub = QuantumGraph.load(tmpFile, self.universe, nodes=(0,)) 

358 self.assertEqual(len(restoreSub), 1) 

359 self.assertEqual(list(restoreSub)[0], 

360 restore.getQuantumNodeByNodeId(NodeId(0, restore._buildId))) 

361 

362 def testSaveLoadUri(self): 

363 uri = None 

364 try: 

365 with tempfile.NamedTemporaryFile(delete=False, suffix=".qgraph") as tmpFile: 

366 uri = tmpFile.name 

367 self.qGraph.saveUri(uri) 

368 restore = QuantumGraph.loadUri(uri, self.universe) 

369 self._cleanGraphs(self.qGraph, restore) 

370 self.assertEqual(self.qGraph, restore) 

371 nodeNumber = random.randint(0, len(self.qGraph)-1) 

372 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber,), 

373 graphID=self.qGraph._buildId) 

374 self.assertEqual(len(restoreSub), 1) 

375 self.assertEqual(list(restoreSub)[0], 

376 restore.getQuantumNodeByNodeId(NodeId(nodeNumber, restore.graphID))) 

377 # verify that more than one node works 

378 nodeNumber2 = random.randint(0, len(self.qGraph)-1) 

379 # ensure it is a different node number 

380 while nodeNumber2 == nodeNumber: 

381 nodeNumber2 = random.randint(0, len(self.qGraph)-1) 

382 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(nodeNumber, nodeNumber2)) 

383 self.assertEqual(len(restoreSub), 2) 

384 self.assertEqual(set(restoreSub), 

385 set((restore.getQuantumNodeByNodeId(NodeId(nodeNumber, restore._buildId)), 

386 restore.getQuantumNodeByNodeId(NodeId(nodeNumber2, restore._buildId))))) 

387 # verify an error when requesting a non existant node number 

388 with self.assertRaises(ValueError): 

389 QuantumGraph.loadUri(uri, self.universe, nodes=(99,)) 

390 

391 # verify a graphID that does not match will be an error 

392 with self.assertRaises(ValueError): 

393 QuantumGraph.loadUri(uri, self.universe, graphID="NOTRIGHT") 

394 

395 except Exception as e: 

396 raise e 

397 finally: 

398 if uri is not None: 

399 os.remove(uri) 

400 

401 with self.assertRaises(TypeError): 

402 self.qGraph.saveUri("test.notgraph") 

403 

404 @unittest.skipIf(not boto3, "Warning: boto3 AWS SDK not found!") 

405 @mock_s3 

406 def testSaveLoadUriS3(self): 

407 # Test loading a quantum graph from an mock s3 store 

408 conn = boto3.resource('s3', region_name="us-east-1") 

409 conn.create_bucket(Bucket='testBucket') 

410 uri = f"s3://testBucket/qgraph.qgraph" 

411 self.qGraph.saveUri(uri) 

412 restore = QuantumGraph.loadUri(uri, self.universe) 

413 self._cleanGraphs(self.qGraph, restore) 

414 self.assertEqual(self.qGraph, restore) 

415 restoreSub = QuantumGraph.loadUri(uri, self.universe, nodes=(0,)) 

416 self.assertEqual(len(restoreSub), 1) 

417 self.assertEqual(list(restoreSub)[0], 

418 restore.getQuantumNodeByNodeId(NodeId(0, restore._buildId))) 

419 

420 def testContains(self): 

421 firstNode = next(iter(self.qGraph)) 

422 self.assertIn(firstNode, self.qGraph) 

423 

424 

425class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

426 pass 

427 

428 

429def setup_module(module): 

430 lsst.utils.tests.init() 

431 

432 

433if __name__ == "__main__": 433 ↛ 434line 433 didn't jump to line 434, because the condition on line 433 was never true

434 lsst.utils.tests.init() 

435 unittest.main()