Coverage for tests/jointcalTestBase.py: 13%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of jointcal.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import copy
23import os
24import shutil
26import lsst.afw.image.utils
27from lsst.ctrl.mpexec import SimplePipelineExecutor
28import lsst.daf.butler
29import lsst.obs.base
30import lsst.geom
31from lsst.verify.bin.jobReporter import JobReporter
33from lsst.jointcal import jointcal, utils
36class JointcalTestBase:
37 """
38 Base class for jointcal tests, to genericize some test running and setup.
40 Derive from this first, then from TestCase.
41 """
43 def set_output_dir(self):
44 self.output_dir = os.path.join('.test', self.__class__.__name__, self.id().split('.')[-1])
46 def setUp_base(self, center, radius,
47 match_radius=0.1*lsst.geom.arcseconds,
48 input_dir="",
49 all_visits=None,
50 other_args=None,
51 do_plot=False,
52 log_level=None,
53 where=""):
54 """
55 Call from your child classes's setUp() to get the necessary variables built.
57 Parameters
58 ----------
59 center : `lsst.geom.SpherePoint`
60 Center of the reference catalog.
61 radius : `lsst.geom.Angle`
62 Radius from center to load reference catalog objects inside.
63 match_radius : `lsst.geom.Angle`
64 matching radius when calculating RMS of result.
65 input_dir : `str`
66 Directory of input butler repository.
67 all_visits : `list` [`int`]
68 List of the available visits to generate the parseAndRun arguments.
69 other_args : `list` [`str`]
70 Optional other arguments for the butler dataId.
71 do_plot : `bool`
72 Set to True for a comparison plot and some diagnostic numbers.
73 log_level : `str`
74 Set to the default log level you want jointcal to produce while the
75 tests are running. See the developer docs about logging for valid
76 levels: https://developer.lsst.io/coding/logging.html
77 where : `str`
78 Data ID query for pipetask specifying the data to run on.
79 """
80 self.path = os.path.dirname(__file__)
82 self.center = center
83 self.radius = radius
84 self.jointcalStatistics = utils.JointcalStatistics(match_radius, verbose=True)
85 self.input_dir = input_dir
86 self.all_visits = all_visits
87 if other_args is None:
88 other_args = []
89 self.other_args = other_args
90 self.do_plot = do_plot
91 self.log_level = log_level
92 # Signal/Noise (flux/fluxErr) for sources to be included in the RMS cross-match.
93 # 100 is a balance between good centroids and enough sources.
94 self.flux_limit = 100
96 # Individual tests may want to tweak the config that is passed to parseAndRun().
97 self.config = None
98 self.configfiles = []
100 # Append `msg` arguments to assert failures.
101 self.longMessage = True
103 self.where = where
105 # Ensure that the filter list is reset for each test so that we avoid
106 # confusion or contamination from other instruments.
107 lsst.obs.base.FilterDefinitionCollection.reset()
109 self.set_output_dir()
111 def tearDown(self):
112 shutil.rmtree(self.output_dir, ignore_errors=True)
114 if getattr(self, 'reference', None) is not None:
115 del self.reference
116 if getattr(self, 'oldWcsList', None) is not None:
117 del self.oldWcsList
118 if getattr(self, 'jointcalTask', None) is not None:
119 del self.jointcalTask
120 if getattr(self, 'jointcalStatistics', None) is not None:
121 del self.jointcalStatistics
122 if getattr(self, 'config', None) is not None:
123 del self.config
125 def _testJointcalTask(self, nCatalogs, dist_rms_relative, dist_rms_absolute, pa1,
126 metrics=None):
127 """
128 Test parseAndRun for jointcal on nCatalogs.
130 Checks relative and absolute astrometric error (arcsec) and photometric
131 repeatability (PA1 from the SRD).
133 Parameters
134 ----------
135 nCatalogs : `int`
136 Number of catalogs to run jointcal on. Used to construct the "id"
137 field for parseAndRun.
138 dist_rms_relative : `astropy.Quantity`
139 Minimum relative astrometric rms post-jointcal to pass the test.
140 dist_rms_absolute : `astropy.Quantity`
141 Minimum absolute astrometric rms post-jointcal to pass the test.
142 pa1 : `float`
143 Minimum PA1 (from Table 14 of the Science Requirements Document:
144 https://ls.st/LPM-17) post-jointcal to pass the test.
145 metrics : `dict`, optional
146 Dictionary of 'metricName': value to test jointcal's result.metrics
147 against.
149 Returns
150 -------
151 dataRefs : `list` [`lsst.daf.persistence.ButlerDataRef`]
152 The dataRefs that were processed.
153 """
155 resultFull = self._runJointcalTask(nCatalogs, metrics=metrics)
156 result = resultFull.resultList[0].result # shorten this very long thing
158 def compute_statistics(refObjLoader):
159 refCat = refObjLoader.loadSkyCircle(self.center,
160 self.radius,
161 result.defaultFilter.bandLabel,
162 epoch=result.epoch).refCat
163 rms_result = self.jointcalStatistics.compute_rms(result.dataRefs, refCat)
164 # Make plots before testing, if requested, so we still get plots if tests fail.
165 if self.do_plot:
166 self._plotJointcalTask(result.dataRefs, result.oldWcsList)
167 return rms_result
169 # we now have different astrometry/photometry refcats, so have to
170 # do these calculations separately
171 if self.jointcalStatistics.do_astrometry:
172 refObjLoader = result.astrometryRefObjLoader
173 # preserve do_photometry for the next `if`
174 temp = copy.copy(self.jointcalStatistics.do_photometry)
175 self.jointcalStatistics.do_photometry = False
176 rms_result = compute_statistics(refObjLoader)
177 self.jointcalStatistics.do_photometry = temp # restore do_photometry
179 if dist_rms_relative is not None and dist_rms_absolute is not None:
180 self.assertLess(rms_result.dist_relative, dist_rms_relative)
181 self.assertLess(rms_result.dist_absolute, dist_rms_absolute)
183 if self.jointcalStatistics.do_photometry:
184 refObjLoader = result.photometryRefObjLoader
185 self.jointcalStatistics.do_astrometry = False
186 rms_result = compute_statistics(refObjLoader)
188 if pa1 is not None:
189 self.assertLess(rms_result.pa1, pa1)
191 return result.dataRefs
193 def _runJointcalTask(self, nCatalogs, metrics=None):
194 """
195 Run jointcalTask on nCatalogs, with the most basic tests.
196 Tests for non-empty result list, and that the basic metrics are correct.
198 Parameters
199 ----------
200 nCatalogs : `int`
201 Number of catalogs to test on.
202 metrics : `dict`, optional
203 Dictionary of 'metricName': value to test jointcal's result.metrics
204 against.
206 Returns
207 -------
208 result : `pipe.base.Struct`
209 The structure returned by jointcalTask.run()
210 """
211 visits = '^'.join(str(v) for v in self.all_visits[:nCatalogs])
212 if self.log_level is not None:
213 self.other_args.extend(['--loglevel', 'jointcal=%s'%self.log_level])
215 # Place default configfile first so that specific subclass configfiles are applied after
216 test_config = os.path.join(self.path, 'config/config.py')
217 configfiles = [test_config] + self.configfiles
219 args = [self.input_dir, '--output', self.output_dir,
220 '--clobber-versions', '--clobber-config',
221 '--doraise', '--configfile', *configfiles,
222 '--id', 'visit=%s'%visits]
223 args.extend(self.other_args)
224 result = jointcal.JointcalTask.parseAndRun(args=args, doReturnResults=True, config=self.config)
225 self.assertNotEqual(result.resultList, [], 'resultList should not be empty')
226 self.assertEqual(result.resultList[0].exitStatus, 0)
227 job = result.resultList[0].result.job
228 self._test_metrics(job.measurements, metrics)
230 return result
232 def _plotJointcalTask(self, data_refs, oldWcsList):
233 """
234 Plot the results of a jointcal run.
236 Parameters
237 ----------
238 data_refs : `list` [`lsst.daf.persistence.ButlerDataRef`]
239 The dataRefs that were processed.
240 oldWcsList : `list` [`lsst.afw.image.SkyWcs`]
241 The original WCS from each dataRef.
242 """
243 plot_dir = os.path.join('.test', self.__class__.__name__, 'plots')
244 if not os.path.isdir(plot_dir):
245 os.mkdir(plot_dir)
246 self.jointcalStatistics.make_plots(data_refs, oldWcsList, name=self.id(), outdir=plot_dir)
247 print("Plots saved to: {}".format(plot_dir))
249 def _test_metrics(self, result, expect):
250 """Test a dictionary of "metrics" against those returned by jointcal.py
252 Parameters
253 ----------
254 result : `dict`
255 Result metric dictionary from jointcal.py
256 expect : `dict`
257 Expected metric dictionary; set a value to None to not test it.
258 """
259 for key in result:
260 if expect[key.metric] is not None:
261 value = result[key].quantity.value
262 if isinstance(value, float):
263 self.assertFloatsAlmostEqual(value, expect[key.metric], msg=key.metric, rtol=1e-5)
264 else:
265 self.assertEqual(value, expect[key.metric], msg=key.metric)
267 def _importRepository(self, instrument, exportPath, exportFile):
268 """Import a gen3 test repository into self.testDir
270 Parameters
271 ----------
272 instrument : `str`
273 Full string name for the instrument.
274 exportPath : `str`
275 Path to location of repository to export.
276 This path must contain an `exports.yaml` file containing the
277 description of the exported gen3 repo that will be imported.
278 exportFile : `str`
279 Filename of export data.
280 """
281 self.repo = os.path.join(self.output_dir, 'testrepo')
283 # Make the repo and retrieve a writeable Butler
284 _ = lsst.daf.butler.Butler.makeRepo(self.repo)
285 butler = lsst.daf.butler.Butler(self.repo, writeable=True)
286 # Register the instrument
287 instrInstance = lsst.obs.base.utils.getInstrument(instrument)
288 instrInstance.register(butler.registry)
289 # Import the exportFile
290 butler.import_(directory=exportPath, filename=exportFile,
291 transfer='symlink',
292 skip_dimensions={'instrument', 'detector', 'physical_filter'})
294 def _runPipeline(self, repo,
295 inputCollections, outputCollection,
296 configFiles=None, configOptions=None,
297 registerDatasetTypes=False, whereSuffix=None,
298 nJobs=1):
299 """Run a pipeline via the SimplePipelineExecutor.
301 Parameters
302 ----------
303 repo : `str`
304 Gen3 Butler repository to read from/write to.
305 inputCollections : `list` [`str`]
306 String to use for "-i" input collections (comma delimited).
307 For example, "refcats,HSC/runs/tests,HSC/calib"
308 outputCollection : `str`
309 String to use for "-o" output collection. For example,
310 "HSC/testdata/jointcal"
311 configFiles : `list` [`str`], optional
312 List of jointcal config files to use.
313 configOptions : `dict` [`str`], optional
314 Individual jointcal config options (field: value) to override.
315 registerDatasetTypes : bool, optional
316 Set "--register-dataset-types" when running the pipeline.
317 whereSuffix : `str`, optional
318 Additional parameters to the ``where`` pipetask statement.
319 nJobs : `int`, optional
320 Number of quanta expected to be run.
322 Returns
323 -------
324 job : `lsst.verify.Job`
325 Job containing the metric measurements from this test run.
326 """
327 config = lsst.jointcal.JointcalConfig()
328 for file in configFiles:
329 config.load(file)
330 for key, value in configOptions.items():
331 setattr(config, key, value)
332 if whereSuffix is not None:
333 where = ' '.join((self.where, whereSuffix))
334 lsst.daf.butler.cli.cliLog.CliLog.initLog(False)
335 butler = SimplePipelineExecutor.prep_butler(repo,
336 inputs=inputCollections,
337 output=outputCollection,
338 output_run=outputCollection.replace("all", "jointcal"))
339 executor = SimplePipelineExecutor.from_task_class(lsst.jointcal.JointcalTask,
340 config=config,
341 where=where,
342 butler=butler)
343 executor.run(register_dataset_types=registerDatasetTypes)
344 # JobReporter bundles all metrics in the collection into one job.
345 jobs = JobReporter(repo, outputCollection, "jointcal", "", "jointcal").run()
346 # should only ever get one job output in tests
347 self.assertEqual(len(jobs), nJobs)
348 return list(jobs.values())[0]
350 def _runGen3Jointcal(self, instrumentClass, instrumentName,
351 configFiles=None, configOptions=None, whereSuffix=None,
352 metrics=None, nJobs=1):
353 """Create a Butler repo and run jointcal on it.
355 Parameters
356 ----------
357 instrumentClass : `str`
358 The full module name of the instrument to be registered in the
359 new repo. For example, "lsst.obs.subaru.HyperSuprimeCam"
360 instrumentName : `str`
361 The name of the instrument as it appears in the repo collections.
362 For example, "HSC".
363 configFiles : `list` [`str`], optional
364 List of jointcal config files to use.
365 configOptions : `dict` [`str`], optional
366 Individual jointcal config options (field: value) to override.
367 whereSuffix : `str`, optional
368 Additional parameters to the ``where`` pipetask statement.
369 metrics : `dict`, optional
370 Dictionary of 'metricName': value to test jointcal's result.metrics
371 against.
372 nJobs : `int`, optional
373 Number of quanta expected to be run.
374 """
375 self._importRepository(instrumentClass,
376 self.input_dir,
377 os.path.join(self.input_dir, "exports.yaml"))
378 # TODO post-RFC-741: the names of these collections will have to change
379 # once testdata_jointcal is updated to reflect the collection
380 # conventions in RFC-741 (no ticket for that change yet).
381 inputCollections = ["refcats/gen2",
382 f"{instrumentName}/testdata",
383 f"{instrumentName}/calib/unbounded"]
385 configs = [os.path.join(self.path, "config/config-gen3.py")]
386 configs.extend(self.configfiles or [])
387 configs.extend(configFiles or [])
388 job = self._runPipeline(self.repo,
389 inputCollections,
390 f"{instrumentName}/testdata/all",
391 configFiles=configs,
392 configOptions=configOptions,
393 registerDatasetTypes=True,
394 whereSuffix=whereSuffix,
395 nJobs=nJobs)
397 if metrics:
398 self._test_metrics(job.measurements, metrics)