Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2SNObject_tests: 

3A Class containing tests to check crictical functionality for SNObject.py 

4 

5The following functionality is tested: 

6 

7 - SED (flambda) for unextincted SEDs in SNCosmo and SNObject 

8 - SED (flambda) for MW extincted SEDs in SNCosmo and SNObject (independent 

9 implementations of extinction using OD94 model.) 

10 - Band Flux for extincted SED in r Band 

11 - Band Mag for extincted SED in r Band 

12 

13SNIaCatalog_tests: 

14A Class containing tests to check crictical functionality for SNIaCatalog 

15""" 

16from __future__ import print_function 

17from builtins import map 

18from builtins import str 

19from builtins import range 

20import os 

21import sqlite3 

22import numpy as np 

23import unittest 

24import tempfile 

25import shutil 

26 

27# External packages used 

28# import pandas as pd 

29from pandas.testing import assert_frame_equal 

30import sncosmo 

31import astropy 

32 

33 

34# Lsst Sims Dependencies 

35import lsst.utils.tests 

36from lsst.sims.utils.CodeUtilities import sims_clean_up 

37from lsst.utils import getPackageDir 

38from lsst.sims.photUtils.PhotometricParameters import PhotometricParameters 

39from lsst.sims.photUtils import BandpassDict 

40from lsst.sims.utils import ObservationMetaData 

41from lsst.sims.utils import spatiallySample_obsmetadata as sample_obsmetadata 

42from lsst.sims.catUtils.utils import ObservationMetaDataGenerator 

43from lsst.sims.catalogs.db import CatalogDBObject, fileDBObject 

44 

45# Routines Being Tested 

46from lsst.sims.catUtils.supernovae import SNObject 

47from lsst.sims.catUtils.mixins import SNIaCatalog 

48from lsst.sims.catUtils.utils import SNIaLightCurveGenerator 

49 

50# 2016 July 28 

51# For some reason, the Jenkins slaves used for continuous integration 

52# cannot properly load the astropy config directories used by sncosmo. 

53# To prevent this from crashing every build, we will test whether 

54# the directories can be accessed and, if they cannot, use unittest.skipIf() 

55# to skip all of the unit tests in this file. 

56from astropy.config import get_config_dir 

57 

58_skip_sn_tests = False 

59try: 

60 get_config_dir() 

61except: 

62 _skip_sn_tests = True 

63 

64ROOT = os.path.abspath(os.path.dirname(__file__)) 

65 

66 

67def setup_module(module): 

68 lsst.utils.tests.init() 

69 

70 

71@unittest.skipIf(_skip_sn_tests, "cannot properly load astropy config dir") 

72class SNObject_tests(unittest.TestCase): 

73 

74 @classmethod 

75 def tearDownClass(cls): 

76 sims_clean_up() 

77 

78 def setUp(self): 

79 """ 

80 Setup tests 

81 SN_blank: A SNObject with no MW extinction 

82 """ 

83 

84 mydir = get_config_dir() 

85 print('===============================') 

86 print('===============================') 

87 print (mydir) 

88 print('===============================') 

89 print('===============================') 

90 # A range of wavelengths in Ang 

91 self.wave = np.arange(3000., 12000., 50.) 

92 # Equivalent wavelenths in nm 

93 self.wavenm = self.wave / 10. 

94 # Time to be used as Peak 

95 self.mjdobs = 571190 

96 

97 # Check that we can set up a SED 

98 # with no extinction 

99 self.SN_blank = SNObject() 

100 self.SN_blank.setCoords(ra=30., dec=-60.) 

101 self.SN_blank.set(z=0.96, t0=571181, x1=2.66, c=0.353, x0=1.796e-6) 

102 self.SN_blank.set_MWebv(0.) 

103 

104 self.SN_extincted = SNObject(ra=30., dec=-60.) 

105 self.SN_extincted.set(z=0.96, t0=571181, x1=2.66, c=0.353, 

106 x0=1.796112e-06) 

107 

108 self.SNCosmoModel = self.SN_extincted.equivalentSNCosmoModel() 

109 self.rectify_photParams = PhotometricParameters() 

110 self.lsstBandPass = BandpassDict.loadTotalBandpassesFromFiles() 

111 self.SNCosmoBP = sncosmo.Bandpass(wave=self.lsstBandPass['r'].wavelen, 

112 trans=self.lsstBandPass['r'].sb, 

113 wave_unit=astropy.units.Unit('nm'), 

114 name='lsst_r') 

115 

116 def tearDown(self): 

117 del self.SNCosmoBP 

118 del self.SN_blank 

119 del self.SN_extincted 

120 

121 def test_SNstatenotEmpty(self): 

122 """ 

123 Check that the state of SNObject, stored in self.SNstate has valid 

124 entries for all keys and does not contain keys with None type Values. 

125 """ 

126 myDict = self.SN_extincted.SNstate 

127 for key in myDict: 

128 assert myDict[key] is not None 

129 

130 def test_attributeDefaults(self): 

131 """ 

132 Check the defaults and the setter properties for rectifySED and 

133 modelOutSideRange 

134 """ 

135 snobj = SNObject(ra=30., dec=-60., source='salt2') 

136 self.assertEqual(snobj.rectifySED, True) 

137 self.assertEqual(snobj.modelOutSideTemporalRange, 'zero') 

138 

139 snobj.rectifySED = False 

140 self.assertFalse(snobj.rectifySED, False) 

141 self.assertEqual(snobj.modelOutSideTemporalRange, 'zero') 

142 

143 def test_raisingerror_forunimplementedmodelOutSideRange(self): 

144 """ 

145 check that correct error is raised if the user tries to assign an 

146 un-implemented model value to 

147 `sims.catUtils.supernovae.SNObject.modelOutSideTemporalRange` 

148 """ 

149 snobj = SNObject(ra=30., dec=-60., source='salt2') 

150 assert snobj.modelOutSideTemporalRange == 'zero' 

151 with self.assertRaises(ValueError) as context: 

152 snobj.modelOutSideTemporalRange = 'False' 

153 self.assertEqual('Model not implemented, defaulting to zero method\n', 

154 context.exception.args[0]) 

155 

156 def test_rectifiedSED(self): 

157 """ 

158 Check for an extreme case that the SN seds are being rectified. This is 

159 done by setting up an extreme case where there will be negative seds, and 

160 checking that this is indeed the case, and checking that they are not 

161 negative if rectified. 

162 """ 

163 

164 snobj = SNObject(ra=30., dec=-60., source='salt2') 

165 snobj.set(z=0.96, t0=self.mjdobs, x1=-3., x0=1.8e-6) 

166 snobj.rectifySED = False 

167 times = np.arange(self.mjdobs - 50., self.mjdobs + 150., 1.) 

168 badTimes = [] 

169 for time in times: 

170 sed = snobj.SNObjectSED(time=time, 

171 bandpass=self.lsstBandPass['r']) 

172 if any(sed.flambda < 0.): 

173 badTimes.append(time) 

174 # Check that there are negative SEDs 

175 assert(len(badTimes) > 0) 

176 snobj.rectifySED = True 

177 for time in badTimes: 

178 sed = snobj.SNObjectSED(time=time, 

179 bandpass=self.lsstBandPass['r']) 

180 self.assertGreaterEqual(sed.calcADU(bandpass=self.lsstBandPass['r'], 

181 photParams=self.rectify_photParams), 0.) 

182 self.assertFalse(any(sed.flambda < 0.)) 

183 

184 def test_ComparebandFluxes2photUtils(self): 

185 """ 

186 The SNObject.catsimBandFlux computation uses the sims.photUtils.sed 

187 band flux computation under the hood. This test makes sure that these 

188 definitions are in sync 

189 """ 

190 

191 snobject_r = self.SN_extincted.catsimBandFlux( 

192 bandpassobject=self.lsstBandPass['r'], 

193 time=self.mjdobs) 

194 

195 # `sims.photUtils.Sed` 

196 sed = self.SN_extincted.SNObjectSED(time=self.mjdobs, 

197 bandpass=self.lsstBandPass['r']) 

198 sedflux = sed.calcFlux(bandpass=self.lsstBandPass['r']) 

199 np.testing.assert_allclose(snobject_r, sedflux / 3631.0) 

200 

201 def test_CompareBandFluxes2SNCosmo(self): 

202 """ 

203 Compare the r band flux at a particular time computed in SNObject and 

204 SNCosmo for MW-extincted SEDs. While the underlying sed is obtained 

205 from SNCosmo the integration with the bandpass is an independent 

206 calculation in SNCosmo and catsim 

207 """ 

208 

209 times = self.mjdobs 

210 catsim_r = self.SN_extincted.catsimBandFlux( 

211 bandpassobject=self.lsstBandPass['r'], 

212 time=times) 

213 sncosmo_r = self.SNCosmoModel.bandflux(band=self.SNCosmoBP, 

214 time=times, zpsys='ab', 

215 zp=0.) 

216 np.testing.assert_allclose(sncosmo_r, catsim_r, rtol=1.0e-4) 

217 

218 def test_CompareBandMags2SNCosmo(self): 

219 """ 

220 Compare the r band flux at a particular time computed in SNObject and 

221 SNCosmo for MW-extincted SEDs. Should work whenever the flux comparison 

222 above works. 

223 """ 

224 times = self.mjdobs 

225 catsim_r = self.SN_extincted.catsimBandMag( 

226 bandpassobject=self.lsstBandPass['r'], 

227 time=times) 

228 sncosmo_r = self.SNCosmoModel.bandmag(band=self.SNCosmoBP, 

229 time=times, magsys='ab') 

230 np.testing.assert_allclose(sncosmo_r, catsim_r, rtol=1.0e-5) 

231 

232 def test_CompareExtinctedSED2SNCosmo(self): 

233 """ 

234 Compare the extincted SEDS in SNCosmo and SNObject. Slightly more 

235 non-trivial than comparing unextincted SEDS, as the extinction in 

236 SNObject uses different code from SNCosmo. However, this is still 

237 using the same values of MWEBV, rather than reading it off a map. 

238 """ 

239 SNObjectSED = self.SN_extincted.SNObjectSED(time=self.mjdobs, 

240 wavelen=self.wavenm) 

241 

242 SNCosmoSED = self.SNCosmoModel.flux(time=self.mjdobs, wave=self.wave[10:]) \ 

243 * 10. 

244 np.testing.assert_allclose(SNObjectSED.flambda[10:], SNCosmoSED, 

245 rtol=1.0e-7) 

246 

247 def test_CompareUnextinctedSED2SNCosmo(self): 

248 """ 

249 Compares the unextincted flux Densities in SNCosmo and SNObject. This 

250 is mereley a sanity check as SNObject uses SNCosmo under the hood. 

251 """ 

252 

253 SNCosmoFluxDensity = self.SN_blank.flux(wave=self.wave[10:], 

254 time=self.mjdobs) * 10. 

255 

256 unextincted_sed = self.SN_blank.SNObjectSED(time=self.mjdobs, 

257 wavelen=self.wavenm) 

258 

259 SNObjectFluxDensity = unextincted_sed.flambda[10:] 

260 np.testing.assert_allclose(SNCosmoFluxDensity, SNObjectFluxDensity, 

261 rtol=1.0e-4) 

262 

263 def test_redshift(self): 

264 """ 

265 test that the redshift method works as expected by checking that 

266 if we redshift a SN from its original redshift orig_z to new_z where 

267 new_z is smaller (larger) than orig_z: 

268 - 1. x0 increases (decreases) 

269 - 2. source peak absolute magnitude in BesselB band stays the same 

270 """ 

271 from astropy.cosmology import FlatLambdaCDM 

272 cosmo = FlatLambdaCDM(H0=70., Om0=0.3) 

273 

274 orig_z = self.SN_extincted.get('z') 

275 orig_x0 = self.SN_extincted.get('x0') 

276 peakabsMag = self.SN_extincted.source_peakabsmag('BessellB', 'AB', cosmo=cosmo) 

277 

278 lowz = orig_z * 0.5 

279 highz = orig_z * 2.0 

280 

281 # Test Case for lower redshift 

282 self.SN_extincted.redshift(z=lowz, cosmo=cosmo) 

283 low_x0 = self.SN_extincted.get('x0') 

284 lowPeakAbsMag = self.SN_extincted.source_peakabsmag('BessellB', 'AB', cosmo=cosmo) 

285 

286 # Test 1. 

287 self.assertGreater(low_x0, orig_x0) 

288 # Test 2. 

289 self.assertAlmostEqual(peakabsMag, lowPeakAbsMag, places=14) 

290 

291 # Test Case for higher redshift 

292 self.SN_extincted.redshift(z=highz, cosmo=cosmo) 

293 high_x0 = self.SN_extincted.get('x0') 

294 HiPeakAbsMag = self.SN_extincted.source_peakabsmag('BessellB', 'AB', cosmo=cosmo) 

295 

296 # Test 1. 

297 self.assertLess(high_x0, orig_x0) 

298 # Test 2. 

299 self.assertAlmostEqual(peakabsMag, HiPeakAbsMag, places=14) 

300 

301 def test_bandFluxErrorWorks(self): 

302 """ 

303 test that bandflux errors work even if the flux is negative 

304 """ 

305 times = self.mjdobs 

306 

307 e = self.SN_extincted.catsimBandFluxError(times, 

308 self.lsstBandPass['r'], 

309 m5=24.5, fluxinMaggies=-1.0) 

310 assert isinstance(e, np.float) 

311 print(e) 

312 assert not(np.isinf(e) or np.isnan(e)) 

313 

314 

315 

316 

317 

318 

319@unittest.skipIf(_skip_sn_tests, "cannot properly load astropy config dir") 

320class SNIaCatalog_tests(unittest.TestCase): 

321 

322 @classmethod 

323 def setUpClass(cls): 

324 

325 # Set directory where scratch work will be done 

326 cls.scratchDir = tempfile.mkdtemp(dir=ROOT, prefix='scratchSpace-') 

327 

328 # ObsMetaData instance with spatial window within which we will 

329 # put galaxies in a fake galaxy catalog 

330 cls.obsMetaDataforCat = ObservationMetaData(boundType='circle', 

331 boundLength=np.degrees(0.25), 

332 pointingRA=np.degrees(0.13), 

333 pointingDec=np.degrees(-1.2), 

334 bandpassName=['r'], mjd=49350.) 

335 

336 # Randomly generate self.size Galaxy positions within the spatial window 

337 # of obsMetaDataforCat 

338 cls.dbname = os.path.join(cls.scratchDir, 'galcat.db') 

339 cls.size = 1000 

340 cls.GalaxyPositionSamps = sample_obsmetadata(obsmetadata=cls.obsMetaDataforCat, 

341 size=cls.size) 

342 

343 # Create a galaxy Table overlapping with the obsMetaData Spatial Bounds 

344 # using positions from the samples above and a database name given by 

345 # self.dbname 

346 vals = cls._createFakeGalaxyDB() 

347 cls.valName = os.path.join(cls.scratchDir, 'valsFromTest.dat') 

348 with open(cls.valName, 'w') as f: 

349 for i, v in enumerate(vals[0]): 

350 f.write(str(np.radians(vals[0][i])) + ' ' + str(np.radians(vals[1][i])) + '\n') 

351 

352 # fig, ax = plt.subplots() 

353 # ax.plot(vals[0][:1000], vals[1][: 1000], '.') 

354 # ax.plot([0.13], [-1.2], 'rs', markersize=8) 

355 # fig.savefig(os.path.join(cls.scratchDir, 'match_galDBPosns.pdf')) 

356 

357 # Read it into a CatalogDBObject galDB 

358 class MyGalaxyCatalog(CatalogDBObject): 

359 ''' 

360 Create a like CatalogDBObject connecting to a local sqlite database 

361 ''' 

362 

363 objid = 'mytestgals' 

364 tableid = 'gals' 

365 idColKey = 'id' 

366 objectTypeId = 0 

367 appendint = 10000 

368 database = cls.dbname 

369 # dbAddress = './testData/galcat.db' 

370 raColName = 'raJ2000' 

371 decColName = 'decJ2000' 

372 driver = 'sqlite' 

373 

374 # columns required to convert the ra, dec values in degrees 

375 # to radians again 

376 columns = [('id', 'id', int), 

377 ('raJ2000', 'raJ2000 * PI()/ 180. '), 

378 ('decJ2000', 'decJ2000 * PI()/ 180.'), 

379 ('redshift', 'redshift')] 

380 

381 cls.galDB = MyGalaxyCatalog(database=cls.dbname) 

382 

383 # Generate a set of Observation MetaData Outputs that overlap 

384 # the galaxies in space 

385 opsimPath = os.path.join(getPackageDir('sims_data'), 'OpSimData') 

386 opsimDB = os.path.join(opsimPath, 'opsimblitz1_1133_sqlite.db') 

387 

388 generator = ObservationMetaDataGenerator(database=opsimDB) 

389 cls.obsMetaDataResults = generator.getObservationMetaData(limit=100, 

390 fieldRA=(5.0, 8.0), 

391 fieldDec=(-85., -60.), 

392 expMJD=(49300., 49400.), 

393 boundLength=0.15, 

394 boundType='circle') 

395 

396 sncatalog = SNIaCatalog(db_obj=cls.galDB, 

397 obs_metadata=cls.obsMetaDataResults[6], 

398 column_outputs=['t0', 'flux_u', 'flux_g', 

399 'flux_r', 'flux_i', 'flux_z', 

400 'flux_y', 'mag_u', 'mag_g', 

401 'mag_r', 'mag_i', 'mag_z', 

402 'mag_y', 'adu_u', 'adu_g', 

403 'adu_r', 'adu_i', 'adu_z', 

404 'adu_y', 'mwebv']) 

405 sncatalog.suppressDimSN = True 

406 sncatalog.midSurveyTime = sncatalog.mjdobs - 20. 

407 sncatalog.snFrequency = 1.0 

408 cls.fullCatalog = os.path.join(cls.scratchDir, 'testSNCatalogTest.dat') 

409 sncatalog.write_catalog(cls.fullCatalog) 

410 

411 # Create a SNCatalog based on GalDB, and having times of explosions 

412 # overlapping the times in obsMetaData 

413 cls.fnameList = cls._writeManySNCatalogs(cls.obsMetaDataResults) 

414 

415 @classmethod 

416 def tearDownClass(cls): 

417 sims_clean_up() 

418 del cls.galDB 

419 cls.cleanDB(cls.dbname) 

420 if os.path.exists(cls.valName): 

421 os.unlink(cls.valName) 

422 

423 for fname in cls.fnameList: 

424 if os.path.exists(fname): 

425 os.unlink(fname) 

426 

427 if os.path.exists(cls.fullCatalog): 

428 os.unlink(cls.fullCatalog) 

429 if os.path.exists(cls.scratchDir): 

430 shutil.rmtree(cls.scratchDir, ignore_errors=True) 

431 

432 def test_writingfullCatalog(self): 

433 """ 

434 Check that a full catalog of SN has more than one line 

435 """ 

436 

437 with open(self.fullCatalog, 'r') as f: 

438 numLines = sum(1 for _ in f) 

439 

440 self.assertGreater(numLines, 1) 

441 

442 @staticmethod 

443 def buildLCfromInstanceCatFilenames(fnamelist): 

444 # External packages used 

445 import pandas as pd 

446 dfs = [] 

447 list(map(lambda x: dfs.append(pd.read_csv(x, index_col=None, sep=', ')), 

448 fnamelist)) 

449 all_lcsDumped = pd.concat(dfs) 

450 all_lcsDumped.rename(columns={'#snid': 'snid'}, inplace=True) 

451 all_lcsDumped['snid'] = all_lcsDumped['snid'].astype(int) 

452 lcs = all_lcsDumped.groupby('snid') 

453 

454 return lcs 

455 

456 def test_drawReproducibility(self): 

457 """ 

458 Check that when the same SN (ie. with same snid) is observed with 

459 different pointings leading to different instance catalogs, the 

460 values of properties remain the same. 

461 """ 

462 lcs = self.buildLCfromInstanceCatFilenames(self.fnameList) 

463 

464 props = ['snid', 'snra', 'sndec', 'z', 'x0', 'x1', 'c', 

465 'cosmologicalDistanceModulus', 'mwebv'] 

466 s = "Testing Equality across {0:2d} pointings for reported properties" 

467 s += " of SN {1:8d} of the property " 

468 for key in lcs.groups: 

469 df = lcs.get_group(key) 

470 for prop in props: 

471 print(s.format(len(df), df.snid.iloc[0]) + prop) 

472 np.testing.assert_equal(len(df[prop].unique()), 1) 

473 

474 def test_redrawingCatalog(self): 

475 """ 

476 test that drawing the same catalog 

477 """ 

478 from random import shuffle 

479 import copy 

480 

481 obsMetaDataResults = copy.deepcopy(self.obsMetaDataResults) 

482 shuffle(obsMetaDataResults) 

483 fnameList = self._writeManySNCatalogs(obsMetaDataResults, 

484 suffix='.v2.dat') 

485 

486 newlcs = self.buildLCfromInstanceCatFilenames(fnameList) 

487 oldlcs = self.buildLCfromInstanceCatFilenames(self.fnameList) 

488 

489 for key in oldlcs.groups: 

490 df_old = oldlcs.get_group(key) 

491 df_old.sort_values(['time', 'band'], inplace=True) 

492 df_new = newlcs.get_group(key) 

493 df_new.sort_values(['time', 'band'], inplace=True) 

494 s = "Testing equality for SNID {0:8d} with {1:2d} datapoints" 

495 print(s.format(df_new.snid.iloc[0], len(df_old))) 

496 assert_frame_equal(df_new, df_old) 

497 

498 for fname in fnameList: 

499 if os.path.exists(fname): 

500 os.unlink(fname) 

501 

502 def test_obsMetaDataGeneration(self): 

503 

504 numObs = len(self.obsMetaDataResults) 

505 self.assertEqual(numObs, 15) 

506 

507 @staticmethod 

508 def coords(x): 

509 return np.radians(x.summary['unrefractedRA']),\ 

510 np.radians(x.summary['unrefractedDec']) 

511 

512 @classmethod 

513 def _writeManySNCatalogs(cls, obsMetaDataResults, suffix=''): 

514 

515 fnameList = [] 

516 for obsindex, obsMetaData in enumerate(obsMetaDataResults): 

517 

518 cols = ['t0', 'mwebv', 'time', 'band', 'flux', 'flux_err', 

519 'mag', 'mag_err', 'cosmologicalDistanceModulus'] 

520 newCatalog = SNIaCatalog(db_obj=cls.galDB, obs_metadata=obsMetaData, 

521 column_outputs=cols) 

522 newCatalog.midSurveyTime = 49350 

523 newCatalog.averageRate = 1. 

524 newCatalog.suppressDimSN = False 

525 s = "{0:d}".format(obsindex) 

526 fname = os.path.join(cls.scratchDir, "SNCatalog_" + s + suffix) 

527 newCatalog.write_catalog(fname) 

528 fnameList.append(fname) 

529 return fnameList 

530 

531 @classmethod 

532 def _createFakeGalaxyDB(cls): 

533 ''' 

534 Create a local sqlite galaxy database having filename dbname with 

535 

536 variables id, raJ2000, decJ2000 and redshift, having number of 

537 rows =size, and having overlap with ObsMetaData. 

538 

539 Parameters 

540 ---------- 

541 

542 ''' 

543 dbname = cls.dbname 

544 samps = cls.GalaxyPositionSamps 

545 size = cls.size 

546 cls.cleanDB(dbname) 

547 conn = sqlite3.connect(dbname) 

548 curs = conn.cursor() 

549 curs.execute('CREATE TABLE if not exists gals ' 

550 '(id INT, raJ2000 FLOAT, decJ2000 FLOAT, redshift FLOAT)') 

551 

552 seed = 1 

553 np.random.seed(seed) 

554 

555 for count in range(size): 

556 id = 1000000 + count 

557 

558 # Main Database should have values in degrees 

559 ra = samps[0][count] 

560 dec = samps[1][count] 

561 redshift = np.random.uniform() 

562 row = tuple([id, ra, dec, redshift]) 

563 exec_str = cls.insertfromdata(tablename='gals', records=row, 

564 multiple=False) 

565 curs.execute(exec_str, row) 

566 

567 conn.commit() 

568 conn.close() 

569 return samps 

570 

571 @staticmethod 

572 def cleanDB(dbname, verbose=True): 

573 """ 

574 Deletes the database dbname from the disk. 

575 Parameters 

576 ---------- 

577 dbname: string, mandatory 

578 name (abs path) of the database to be deleted 

579 verbose: Bool, optional, defaults to True 

580 """ 

581 

582 if os.path.exists(dbname): 

583 if verbose: 

584 print("deleting database ", dbname) 

585 os.unlink(dbname) 

586 else: 

587 if verbose: 

588 print('database ', dbname, ' does not exist') 

589 

590 @staticmethod 

591 def insertfromdata(tablename, records, multiple=True): 

592 """ 

593 construct string to insert multiple records into sqlite3 database 

594 args: 

595 tablename: str, mandatory 

596 Name of table in the database. 

597 records: set of records 

598 multiple: 

599 returns: 

600 """ 

601 if multiple: 

602 lst = records[0] 

603 else: 

604 lst = records 

605 s = 'INSERT INTO ' + str(tablename) + ' VALUES ' 

606 s += "( " + ", ".join(["?"] * len(lst)) + ")" 

607 return s 

608 

609 

610class SNIaLightCurveControlCatalog(SNIaCatalog): 

611 catalog_type = __file__ + 'sn_ia_lc_cat' 

612 column_outputs = ['uniqueId', 'flux', 'flux_err', 'redshift'] 

613 _midSurveyTime = 49000.0 

614 _snFrequency = 0.001 

615 

616 

617@unittest.skipIf(_skip_sn_tests, "cannot properly load astropy config dir") 

618class SNIaLightCurveTest(unittest.TestCase): 

619 

620 @classmethod 

621 def setUpClass(cls): 

622 

623 rng = np.random.RandomState(99) 

624 n_sne = 100 

625 ra_list = rng.random_sample(n_sne) * 7.0 + 78.0 

626 dec_list = rng.random_sample(n_sne) * 4.0 - 69.0 

627 zz_list = rng.random_sample(n_sne) * 1.0 + 0.05 

628 

629 cls.scratchDir = tempfile.mkdtemp(dir=ROOT, prefix='scratchSpace-') 

630 cls.input_cat_name = os.path.join(cls.scratchDir, "sne_input_cat.txt") 

631 

632 with open(cls.input_cat_name, "w") as output_file: 

633 for ix in range(n_sne): 

634 output_file.write("%d;%.12f;%.12f;%.12f;%.12f;%.12f\n" 

635 % (ix + 1, ra_list[ix], dec_list[ix], 

636 np.radians(ra_list[ix]), np.radians(dec_list[ix]), 

637 zz_list[ix])) 

638 

639 dtype = np.dtype([('id', np.int), 

640 ('raDeg', np.float), ('decDeg', np.float), 

641 ('raJ2000', np.float), ('decJ2000', np.float), 

642 ('redshift', np.float)]) 

643 

644 cls.db = fileDBObject(cls.input_cat_name, delimiter=';', 

645 runtable='test', dtype=dtype, 

646 idColKey='id') 

647 

648 cls.db.raColName = 'raDeg' 

649 cls.db.decColName = 'decDeg' 

650 cls.db.objectTypeId = 873 

651 

652 cls.opsimDb = os.path.join(getPackageDir("sims_data"), "OpSimData") 

653 cls.opsimDb = os.path.join(cls.opsimDb, "opsimblitz1_1133_sqlite.db") 

654 

655 @classmethod 

656 def tearDownClass(cls): 

657 sims_clean_up() 

658 if os.path.exists(cls.input_cat_name): 

659 os.unlink(cls.input_cat_name) 

660 if os.path.exists(cls.scratchDir): 

661 shutil.rmtree(cls.scratchDir, ignore_errors=True) 

662 

663 def test_sne_light_curves(self): 

664 """ 

665 Generate some super nova light curves. Verify that they come up with the same 

666 magnitudes and uncertainties as supernova catalogs. 

667 """ 

668 

669 gen = SNIaLightCurveGenerator(self.db, self.opsimDb) 

670 

671 raRange = (78.0, 85.0) 

672 decRange = (-69.0, -65.0) 

673 bandpass = 'r' 

674 

675 pointings = gen.get_pointings(raRange, decRange, bandpass=bandpass) 

676 gen.sn_universe._midSurveyTime = 49000.0 

677 gen.sn_universe._snFrequency = 0.001 

678 self.assertGreater(len(pointings), 1) 

679 lc_dict, truth = gen.light_curves_from_pointings(pointings) 

680 self.assertGreater(len(lc_dict), 0) 

681 

682 for group in pointings: 

683 self.assertGreater(len(group), 1) 

684 for obs in group: 

685 cat = SNIaLightCurveControlCatalog(self.db, obs_metadata=obs) 

686 for sn in cat.iter_catalog(): 

687 if sn[1] > 0.0: 

688 lc = lc_dict[sn[0]][bandpass] 

689 dex = np.argmin(np.abs(lc['mjd'] - obs.mjd.TAI)) 

690 self.assertLess(np.abs(lc['mjd'][dex] - obs.mjd.TAI), 1.0e-7) 

691 self.assertLess(np.abs(lc['flux'][dex] - sn[1]), 1.0e-7) 

692 self.assertLess(np.abs(lc['error'][dex] - sn[2]), 1.0e-7) 

693 

694 def test_sne_light_curves_z_cut(self): 

695 """ 

696 Generate some super nova light curves. Add a cutoff in redshift. 

697 Verify that they come up with the same magnitudes and uncertainties 

698 as supernova catalogs and that objects with z>z_cutoff are not returned. 

699 """ 

700 z_cut = 0.9 

701 

702 gen = SNIaLightCurveGenerator(self.db, self.opsimDb) 

703 gen.z_cutoff = z_cut 

704 

705 raRange = (78.0, 85.0) 

706 decRange = (-69.0, -65.0) 

707 bandpass = 'r' 

708 

709 pointings = gen.get_pointings(raRange, decRange, bandpass=bandpass) 

710 gen.sn_universe._midSurveyTime = 49000.0 

711 gen.sn_universe._snFrequency = 0.001 

712 self.assertGreater(len(pointings), 1) 

713 lc_dict, truth = gen.light_curves_from_pointings(pointings) 

714 self.assertGreater(len(lc_dict), 0) 

715 

716 over_z = 0 

717 

718 for group in pointings: 

719 self.assertGreater(len(group), 1) 

720 for obs in group: 

721 cat = SNIaLightCurveControlCatalog(self.db, obs_metadata=obs) 

722 for sn in cat.iter_catalog(): 

723 if sn[1] > 0.0: 

724 if sn[3] > z_cut: 

725 self.assertNotIn(sn[0], lc_dict) 

726 over_z += 1 

727 else: 

728 lc = lc_dict[sn[0]][bandpass] 

729 dex = np.argmin(np.abs(lc['mjd'] - obs.mjd.TAI)) 

730 self.assertLess(np.abs(lc['mjd'][dex] - obs.mjd.TAI), 1.0e-7) 

731 self.assertLess(np.abs(lc['flux'][dex] - sn[1]), 1.0e-7) 

732 self.assertLess(np.abs(lc['error'][dex] - sn[2]), 

733 1.0e-7, msg='%e vs %e' % (lc['error'][dex], sn[2])) 

734 

735 self.assertGreater(over_z, 0) 

736 

737 def test_sne_multiband_light_curves(self): 

738 """ 

739 Generate some super nova light curves. Verify that they come up with the same 

740 magnitudes and uncertainties as supernova catalogs. Use multiband light curves. 

741 """ 

742 

743 gen = SNIaLightCurveGenerator(self.db, self.opsimDb) 

744 

745 raRange = (78.0, 85.0) 

746 decRange = (-69.0, -65.0) 

747 

748 pointings = gen.get_pointings(raRange, decRange, bandpass=('r', 'z')) 

749 gen.sn_universe._midSurveyTime = 49000.0 

750 gen.sn_universe._snFrequency = 0.001 

751 self.assertGreater(len(pointings), 1) 

752 lc_dict, truth = gen.light_curves_from_pointings(pointings) 

753 self.assertGreater(len(lc_dict), 0) 

754 

755 obs_gen = ObservationMetaDataGenerator(database=self.opsimDb, driver='sqlite') 

756 control_obs_r = obs_gen.getObservationMetaData(fieldRA=raRange, fieldDec=decRange, 

757 telescopeFilter='r', boundLength=1.75) 

758 

759 control_obs_z = obs_gen.getObservationMetaData(fieldRA=raRange, fieldDec=decRange, 

760 telescopeFilter='z', boundLength=1.75) 

761 

762 self.assertGreater(len(control_obs_r), 0) 

763 self.assertGreater(len(control_obs_z), 0) 

764 

765 ct_r = 0 

766 for obs in control_obs_r: 

767 cat = SNIaLightCurveControlCatalog(self.db, obs_metadata=obs) 

768 for sn in cat.iter_catalog(): 

769 if sn[1] > 0.0: 

770 ct_r += 1 

771 lc = lc_dict[sn[0]]['r'] 

772 dex = np.argmin(np.abs(lc['mjd'] - obs.mjd.TAI)) 

773 self.assertLess(np.abs(lc['mjd'][dex] - obs.mjd.TAI), 1.0e-7) 

774 self.assertLess(np.abs(lc['flux'][dex] - sn[1]), 1.0e-7) 

775 self.assertLess(np.abs(lc['error'][dex] - sn[2]), 1.0e-7) 

776 

777 self.assertGreater(ct_r, 0) 

778 

779 ct_z = 0 

780 for obs in control_obs_z: 

781 cat = SNIaLightCurveControlCatalog(self.db, obs_metadata=obs) 

782 for sn in cat.iter_catalog(): 

783 if sn[1] > 0.0: 

784 ct_z += 1 

785 lc = lc_dict[sn[0]]['z'] 

786 dex = np.argmin(np.abs(lc['mjd'] - obs.mjd.TAI)) 

787 self.assertLess(np.abs(lc['mjd'][dex] - obs.mjd.TAI), 1.0e-7) 

788 self.assertLess(np.abs(lc['flux'][dex] - sn[1]), 1.0e-7) 

789 self.assertLess(np.abs(lc['error'][dex] - sn[2]), 1.0e-7) 

790 

791 self.assertGreater(ct_z, 0) 

792 

793 def test_limit_sne_light_curves(self): 

794 """ 

795 Test that we can limit the number of light curves returned per field of view 

796 """ 

797 lc_limit = 2 

798 gen = SNIaLightCurveGenerator(self.db, self.opsimDb) 

799 gen.sn_universe._midSurveyTime = 49000.0 

800 gen.sn_universe._snFrequency = 0.001 

801 

802 raRange = (78.0, 85.0) 

803 decRange = (-69.0, -65.0) 

804 

805 pointings = gen.get_pointings(raRange, decRange, bandpass=('r', 'z')) 

806 

807 control_lc, truth = gen.light_curves_from_pointings(pointings) 

808 test_lc, truth = gen.light_curves_from_pointings(pointings, lc_per_field=lc_limit) 

809 self.assertGreater(len(control_lc), len(test_lc)) 

810 self.assertLessEqual(len(test_lc), lc_limit*len(pointings)) 

811 

812 

813class MemoryTestClass(lsst.utils.tests.MemoryTestCase): 

814 pass 

815 

816if __name__ == "__main__": 816 ↛ 817line 816 didn't jump to line 817, because the condition on line 816 was never true

817 lsst.utils.tests.init() 

818 unittest.main()