Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2SNObject_tests: 

3A Class containing tests to check crictical functionality for SNObject.py 

4 

5The following functionality is tested: 

6 

7 - SED (flambda) for unextincted SEDs in SNCosmo and SNObject 

8 - SED (flambda) for MW extincted SEDs in SNCosmo and SNObject (independent 

9 implementations of extinction using OD94 model.) 

10 - Band Flux for extincted SED in r Band 

11 - Band Mag for extincted SED in r Band 

12 

13SNIaCatalog_tests: 

14A Class containing tests to check crictical functionality for SNIaCatalog 

15""" 

16from __future__ import print_function 

17from builtins import map 

18from builtins import str 

19from builtins import range 

20import os 

21import sqlite3 

22import numpy as np 

23import unittest 

24import tempfile 

25import shutil 

26 

27# External packages used 

28# import pandas as pd 

29from pandas.util.testing import assert_frame_equal 

30import sncosmo 

31import astropy 

32 

33 

34# Lsst Sims Dependencies 

35import lsst.utils.tests 

36from lsst.sims.utils.CodeUtilities import sims_clean_up 

37from lsst.utils import getPackageDir 

38from lsst.sims.photUtils.PhotometricParameters import PhotometricParameters 

39from lsst.sims.photUtils import BandpassDict 

40from lsst.sims.utils import ObservationMetaData 

41from lsst.sims.utils import spatiallySample_obsmetadata as sample_obsmetadata 

42from lsst.sims.catUtils.utils import ObservationMetaDataGenerator 

43from lsst.sims.catalogs.db import CatalogDBObject, fileDBObject 

44 

45# Routines Being Tested 

46from lsst.sims.catUtils.supernovae import SNObject 

47from lsst.sims.catUtils.mixins import SNIaCatalog 

48from lsst.sims.catUtils.utils import SNIaLightCurveGenerator 

49 

50# 2016 July 28 

51# For some reason, the Jenkins slaves used for continuous integration 

52# cannot properly load the astropy config directories used by sncosmo. 

53# To prevent this from crashing every build, we will test whether 

54# the directories can be accessed and, if they cannot, use unittest.skipIf() 

55# to skip all of the unit tests in this file. 

56from astropy.config import get_config_dir 

57 

58_skip_sn_tests = False 

59try: 

60 get_config_dir() 

61except: 

62 _skip_sn_tests = True 

63 

64ROOT = os.path.abspath(os.path.dirname(__file__)) 

65 

66 

67def setup_module(module): 

68 lsst.utils.tests.init() 

69 

70 

71@unittest.skipIf(_skip_sn_tests, "cannot properly load astropy config dir") 

72class SNObject_tests(unittest.TestCase): 

73 

74 @classmethod 

75 def tearDownClass(cls): 

76 sims_clean_up() 

77 

78 def setUp(self): 

79 """ 

80 Setup tests 

81 SN_blank: A SNObject with no MW extinction 

82 """ 

83 

84 mydir = get_config_dir() 

85 print('===============================') 

86 print('===============================') 

87 print (mydir) 

88 print('===============================') 

89 print('===============================') 

90 # A range of wavelengths in Ang 

91 self.wave = np.arange(3000., 12000., 50.) 

92 # Equivalent wavelenths in nm 

93 self.wavenm = self.wave / 10. 

94 # Time to be used as Peak 

95 self.mjdobs = 571190 

96 

97 # Check that we can set up a SED 

98 # with no extinction 

99 self.SN_blank = SNObject() 

100 self.SN_blank.setCoords(ra=30., dec=-60.) 

101 self.SN_blank.set(z=0.96, t0=571181, x1=2.66, c=0.353, x0=1.796e-6) 

102 self.SN_blank.set_MWebv(0.) 

103 

104 self.SN_extincted = SNObject(ra=30., dec=-60.) 

105 self.SN_extincted.set(z=0.96, t0=571181, x1=2.66, c=0.353, 

106 x0=1.796112e-06) 

107 

108 self.SNCosmoModel = self.SN_extincted.equivalentSNCosmoModel() 

109 self.rectify_photParams = PhotometricParameters() 

110 self.lsstBandPass = BandpassDict.loadTotalBandpassesFromFiles() 

111 self.SNCosmoBP = sncosmo.Bandpass(wave=self.lsstBandPass['r'].wavelen, 

112 trans=self.lsstBandPass['r'].sb, 

113 wave_unit=astropy.units.Unit('nm'), 

114 name='lsst_r') 

115 

116 def tearDown(self): 

117 del self.SNCosmoBP 

118 del self.SN_blank 

119 del self.SN_extincted 

120 

121 def test_SNstatenotEmpty(self): 

122 """ 

123 Check that the state of SNObject, stored in self.SNstate has valid 

124 entries for all keys and does not contain keys with None type Values. 

125 """ 

126 myDict = self.SN_extincted.SNstate 

127 for key in myDict: 

128 assert myDict[key] is not None 

129 

130 def test_attributeDefaults(self): 

131 """ 

132 Check the defaults and the setter properties for rectifySED and 

133 modelOutSideRange 

134 """ 

135 snobj = SNObject(ra=30., dec=-60., source='salt2') 

136 self.assertEqual(snobj.rectifySED, True) 

137 self.assertEqual(snobj.modelOutSideTemporalRange, 'zero') 

138 

139 snobj.rectifySED = False 

140 self.assertFalse(snobj.rectifySED, False) 

141 self.assertEqual(snobj.modelOutSideTemporalRange, 'zero') 

142 

143 def test_raisingerror_forunimplementedmodelOutSideRange(self): 

144 """ 

145 check that correct error is raised if the user tries to assign an 

146 un-implemented model value to 

147 `sims.catUtils.supernovae.SNObject.modelOutSideTemporalRange` 

148 """ 

149 snobj = SNObject(ra=30., dec=-60., source='salt2') 

150 assert snobj.modelOutSideTemporalRange == 'zero' 

151 with self.assertRaises(ValueError) as context: 

152 snobj.modelOutSideTemporalRange = 'False' 

153 self.assertEqual('Model not implemented, defaulting to zero method\n', 

154 context.exception.args[0]) 

155 

156 def test_rectifiedSED(self): 

157 """ 

158 Check for an extreme case that the SN seds are being rectified. This is 

159 done by setting up an extreme case where there will be negative seds, and 

160 checking that this is indeed the case, and checking that they are not 

161 negative if rectified. 

162 """ 

163 

164 snobj = SNObject(ra=30., dec=-60., source='salt2') 

165 snobj.set(z=0.96, t0=self.mjdobs, x1=-3., x0=1.8e-6) 

166 snobj.rectifySED = False 

167 times = np.arange(self.mjdobs - 50., self.mjdobs + 150., 1.) 

168 badTimes = [] 

169 for time in times: 

170 sed = snobj.SNObjectSED(time=time, 

171 bandpass=self.lsstBandPass['r']) 

172 if any(sed.flambda < 0.): 

173 badTimes.append(time) 

174 # Check that there are negative SEDs 

175 assert(len(badTimes) > 0) 

176 snobj.rectifySED = True 

177 for time in badTimes: 

178 sed = snobj.SNObjectSED(time=time, 

179 bandpass=self.lsstBandPass['r']) 

180 self.assertGreaterEqual(sed.calcADU(bandpass=self.lsstBandPass['r'], 

181 photParams=self.rectify_photParams), 0.) 

182 self.assertFalse(any(sed.flambda < 0.)) 

183 

184 def test_ComparebandFluxes2photUtils(self): 

185 """ 

186 The SNObject.catsimBandFlux computation uses the sims.photUtils.sed 

187 band flux computation under the hood. This test makes sure that these 

188 definitions are in sync 

189 """ 

190 

191 snobject_r = self.SN_extincted.catsimBandFlux( 

192 bandpassobject=self.lsstBandPass['r'], 

193 time=self.mjdobs) 

194 

195 # `sims.photUtils.Sed` 

196 sed = self.SN_extincted.SNObjectSED(time=self.mjdobs, 

197 bandpass=self.lsstBandPass['r']) 

198 sedflux = sed.calcFlux(bandpass=self.lsstBandPass['r']) 

199 np.testing.assert_allclose(snobject_r, sedflux / 3631.0) 

200 

201 def test_CompareBandFluxes2SNCosmo(self): 

202 """ 

203 Compare the r band flux at a particular time computed in SNObject and 

204 SNCosmo for MW-extincted SEDs. While the underlying sed is obtained 

205 from SNCosmo the integration with the bandpass is an independent 

206 calculation in SNCosmo and catsim 

207 """ 

208 

209 times = self.mjdobs 

210 catsim_r = self.SN_extincted.catsimBandFlux( 

211 bandpassobject=self.lsstBandPass['r'], 

212 time=times) 

213 sncosmo_r = self.SNCosmoModel.bandflux(band=self.SNCosmoBP, 

214 time=times, zpsys='ab', 

215 zp=0.) 

216 np.testing.assert_allclose(sncosmo_r, catsim_r) 

217 

218 def test_CompareBandMags2SNCosmo(self): 

219 """ 

220 Compare the r band flux at a particular time computed in SNObject and 

221 SNCosmo for MW-extincted SEDs. Should work whenever the flux comparison 

222 above works. 

223 """ 

224 times = self.mjdobs 

225 catsim_r = self.SN_extincted.catsimBandMag( 

226 bandpassobject=self.lsstBandPass['r'], 

227 time=times) 

228 sncosmo_r = self.SNCosmoModel.bandmag(band=self.SNCosmoBP, 

229 time=times, magsys='ab') 

230 np.testing.assert_allclose(sncosmo_r, catsim_r) 

231 

232 def test_CompareExtinctedSED2SNCosmo(self): 

233 """ 

234 Compare the extincted SEDS in SNCosmo and SNObject. Slightly more 

235 non-trivial than comparing unextincted SEDS, as the extinction in 

236 SNObject uses different code from SNCosmo. However, this is still 

237 using the same values of MWEBV, rather than reading it off a map. 

238 """ 

239 SNObjectSED = self.SN_extincted.SNObjectSED(time=self.mjdobs, 

240 wavelen=self.wavenm) 

241 

242 SNCosmoSED = self.SNCosmoModel.flux(time=self.mjdobs, wave=self.wave) \ 

243 * 10. 

244 

245 np.testing.assert_allclose(SNObjectSED.flambda, SNCosmoSED, 

246 rtol=1.0e-7) 

247 

248 def test_CompareUnextinctedSED2SNCosmo(self): 

249 """ 

250 Compares the unextincted flux Densities in SNCosmo and SNObject. This 

251 is mereley a sanity check as SNObject uses SNCosmo under the hood. 

252 """ 

253 

254 SNCosmoFluxDensity = self.SN_blank.flux(wave=self.wave, 

255 time=self.mjdobs) * 10. 

256 

257 unextincted_sed = self.SN_blank.SNObjectSED(time=self.mjdobs, 

258 wavelen=self.wavenm) 

259 

260 SNObjectFluxDensity = unextincted_sed.flambda 

261 np.testing.assert_allclose(SNCosmoFluxDensity, SNObjectFluxDensity, 

262 rtol=1.0e-7) 

263 

264 def test_redshift(self): 

265 """ 

266 test that the redshift method works as expected by checking that 

267 if we redshift a SN from its original redshift orig_z to new_z where 

268 new_z is smaller (larger) than orig_z: 

269 - 1. x0 increases (decreases) 

270 - 2. source peak absolute magnitude in BesselB band stays the same 

271 """ 

272 from astropy.cosmology import FlatLambdaCDM 

273 cosmo = FlatLambdaCDM(H0=70., Om0=0.3) 

274 

275 orig_z = self.SN_extincted.get('z') 

276 orig_x0 = self.SN_extincted.get('x0') 

277 peakabsMag = self.SN_extincted.source_peakabsmag('BessellB', 'AB', cosmo=cosmo) 

278 

279 lowz = orig_z * 0.5 

280 highz = orig_z * 2.0 

281 

282 # Test Case for lower redshift 

283 self.SN_extincted.redshift(z=lowz, cosmo=cosmo) 

284 low_x0 = self.SN_extincted.get('x0') 

285 lowPeakAbsMag = self.SN_extincted.source_peakabsmag('BessellB', 'AB', cosmo=cosmo) 

286 

287 # Test 1. 

288 self.assertGreater(low_x0, orig_x0) 

289 # Test 2. 

290 self.assertAlmostEqual(peakabsMag, lowPeakAbsMag, places=14) 

291 

292 # Test Case for higher redshift 

293 self.SN_extincted.redshift(z=highz, cosmo=cosmo) 

294 high_x0 = self.SN_extincted.get('x0') 

295 HiPeakAbsMag = self.SN_extincted.source_peakabsmag('BessellB', 'AB', cosmo=cosmo) 

296 

297 # Test 1. 

298 self.assertLess(high_x0, orig_x0) 

299 # Test 2. 

300 self.assertAlmostEqual(peakabsMag, HiPeakAbsMag, places=14) 

301 

302 def test_bandFluxErrorWorks(self): 

303 """ 

304 test that bandflux errors work even if the flux is negative 

305 """ 

306 times = self.mjdobs 

307 

308 e = self.SN_extincted.catsimBandFluxError(times, 

309 self.lsstBandPass['r'], 

310 m5=24.5, fluxinMaggies=-1.0) 

311 assert isinstance(e, np.float) 

312 print(e) 

313 assert not(np.isinf(e) or np.isnan(e)) 

314 

315 

316 

317 

318 

319 

320@unittest.skipIf(_skip_sn_tests, "cannot properly load astropy config dir") 

321class SNIaCatalog_tests(unittest.TestCase): 

322 

323 @classmethod 

324 def setUpClass(cls): 

325 

326 # Set directory where scratch work will be done 

327 cls.scratchDir = tempfile.mkdtemp(dir=ROOT, prefix='scratchSpace-') 

328 

329 # ObsMetaData instance with spatial window within which we will 

330 # put galaxies in a fake galaxy catalog 

331 cls.obsMetaDataforCat = ObservationMetaData(boundType='circle', 

332 boundLength=np.degrees(0.25), 

333 pointingRA=np.degrees(0.13), 

334 pointingDec=np.degrees(-1.2), 

335 bandpassName=['r'], mjd=49350.) 

336 

337 # Randomly generate self.size Galaxy positions within the spatial window 

338 # of obsMetaDataforCat 

339 cls.dbname = os.path.join(cls.scratchDir, 'galcat.db') 

340 cls.size = 1000 

341 cls.GalaxyPositionSamps = sample_obsmetadata(obsmetadata=cls.obsMetaDataforCat, 

342 size=cls.size) 

343 

344 # Create a galaxy Table overlapping with the obsMetaData Spatial Bounds 

345 # using positions from the samples above and a database name given by 

346 # self.dbname 

347 vals = cls._createFakeGalaxyDB() 

348 cls.valName = os.path.join(cls.scratchDir, 'valsFromTest.dat') 

349 with open(cls.valName, 'w') as f: 

350 for i, v in enumerate(vals[0]): 

351 f.write(str(np.radians(vals[0][i])) + ' ' + str(np.radians(vals[1][i])) + '\n') 

352 

353 # fig, ax = plt.subplots() 

354 # ax.plot(vals[0][:1000], vals[1][: 1000], '.') 

355 # ax.plot([0.13], [-1.2], 'rs', markersize=8) 

356 # fig.savefig(os.path.join(cls.scratchDir, 'match_galDBPosns.pdf')) 

357 

358 # Read it into a CatalogDBObject galDB 

359 class MyGalaxyCatalog(CatalogDBObject): 

360 ''' 

361 Create a like CatalogDBObject connecting to a local sqlite database 

362 ''' 

363 

364 objid = 'mytestgals' 

365 tableid = 'gals' 

366 idColKey = 'id' 

367 objectTypeId = 0 

368 appendint = 10000 

369 database = cls.dbname 

370 # dbAddress = './testData/galcat.db' 

371 raColName = 'raJ2000' 

372 decColName = 'decJ2000' 

373 driver = 'sqlite' 

374 

375 # columns required to convert the ra, dec values in degrees 

376 # to radians again 

377 columns = [('id', 'id', int), 

378 ('raJ2000', 'raJ2000 * PI()/ 180. '), 

379 ('decJ2000', 'decJ2000 * PI()/ 180.'), 

380 ('redshift', 'redshift')] 

381 

382 cls.galDB = MyGalaxyCatalog(database=cls.dbname) 

383 

384 # Generate a set of Observation MetaData Outputs that overlap 

385 # the galaxies in space 

386 opsimPath = os.path.join(getPackageDir('sims_data'), 'OpSimData') 

387 opsimDB = os.path.join(opsimPath, 'opsimblitz1_1133_sqlite.db') 

388 

389 generator = ObservationMetaDataGenerator(database=opsimDB) 

390 cls.obsMetaDataResults = generator.getObservationMetaData(limit=100, 

391 fieldRA=(5.0, 8.0), 

392 fieldDec=(-85., -60.), 

393 expMJD=(49300., 49400.), 

394 boundLength=0.15, 

395 boundType='circle') 

396 

397 sncatalog = SNIaCatalog(db_obj=cls.galDB, 

398 obs_metadata=cls.obsMetaDataResults[6], 

399 column_outputs=['t0', 'flux_u', 'flux_g', 

400 'flux_r', 'flux_i', 'flux_z', 

401 'flux_y', 'mag_u', 'mag_g', 

402 'mag_r', 'mag_i', 'mag_z', 

403 'mag_y', 'adu_u', 'adu_g', 

404 'adu_r', 'adu_i', 'adu_z', 

405 'adu_y', 'mwebv']) 

406 sncatalog.suppressDimSN = True 

407 sncatalog.midSurveyTime = sncatalog.mjdobs - 20. 

408 sncatalog.snFrequency = 1.0 

409 cls.fullCatalog = os.path.join(cls.scratchDir, 'testSNCatalogTest.dat') 

410 sncatalog.write_catalog(cls.fullCatalog) 

411 

412 # Create a SNCatalog based on GalDB, and having times of explosions 

413 # overlapping the times in obsMetaData 

414 cls.fnameList = cls._writeManySNCatalogs(cls.obsMetaDataResults) 

415 

416 @classmethod 

417 def tearDownClass(cls): 

418 sims_clean_up() 

419 del cls.galDB 

420 cls.cleanDB(cls.dbname) 

421 if os.path.exists(cls.valName): 

422 os.unlink(cls.valName) 

423 

424 for fname in cls.fnameList: 

425 if os.path.exists(fname): 

426 os.unlink(fname) 

427 

428 if os.path.exists(cls.fullCatalog): 

429 os.unlink(cls.fullCatalog) 

430 if os.path.exists(cls.scratchDir): 

431 shutil.rmtree(cls.scratchDir) 

432 

433 def test_writingfullCatalog(self): 

434 """ 

435 Check that a full catalog of SN has more than one line 

436 """ 

437 

438 with open(self.fullCatalog, 'r') as f: 

439 numLines = sum(1 for _ in f) 

440 

441 self.assertGreater(numLines, 1) 

442 

443 @staticmethod 

444 def buildLCfromInstanceCatFilenames(fnamelist): 

445 # External packages used 

446 import pandas as pd 

447 dfs = [] 

448 list(map(lambda x: dfs.append(pd.read_csv(x, index_col=None, sep=', ')), 

449 fnamelist)) 

450 all_lcsDumped = pd.concat(dfs) 

451 all_lcsDumped.rename(columns={'#snid': 'snid'}, inplace=True) 

452 all_lcsDumped['snid'] = all_lcsDumped['snid'].astype(int) 

453 lcs = all_lcsDumped.groupby('snid') 

454 

455 return lcs 

456 

457 def test_drawReproducibility(self): 

458 """ 

459 Check that when the same SN (ie. with same snid) is observed with 

460 different pointings leading to different instance catalogs, the 

461 values of properties remain the same. 

462 """ 

463 lcs = self.buildLCfromInstanceCatFilenames(self.fnameList) 

464 

465 props = ['snid', 'snra', 'sndec', 'z', 'x0', 'x1', 'c', 

466 'cosmologicalDistanceModulus', 'mwebv'] 

467 s = "Testing Equality across {0:2d} pointings for reported properties" 

468 s += " of SN {1:8d} of the property " 

469 for key in lcs.groups: 

470 df = lcs.get_group(key) 

471 for prop in props: 

472 print(s.format(len(df), df.snid.iloc[0]) + prop) 

473 np.testing.assert_equal(len(df[prop].unique()), 1) 

474 

475 def test_redrawingCatalog(self): 

476 """ 

477 test that drawing the same catalog 

478 """ 

479 from random import shuffle 

480 import copy 

481 

482 obsMetaDataResults = copy.deepcopy(self.obsMetaDataResults) 

483 shuffle(obsMetaDataResults) 

484 fnameList = self._writeManySNCatalogs(obsMetaDataResults, 

485 suffix='.v2.dat') 

486 

487 newlcs = self.buildLCfromInstanceCatFilenames(fnameList) 

488 oldlcs = self.buildLCfromInstanceCatFilenames(self.fnameList) 

489 

490 for key in oldlcs.groups: 

491 df_old = oldlcs.get_group(key) 

492 df_old.sort_values(['time', 'band'], inplace=True) 

493 df_new = newlcs.get_group(key) 

494 df_new.sort_values(['time', 'band'], inplace=True) 

495 s = "Testing equality for SNID {0:8d} with {1:2d} datapoints" 

496 print(s.format(df_new.snid.iloc[0], len(df_old))) 

497 assert_frame_equal(df_new, df_old) 

498 

499 for fname in fnameList: 

500 if os.path.exists(fname): 

501 os.unlink(fname) 

502 

503 def test_obsMetaDataGeneration(self): 

504 

505 numObs = len(self.obsMetaDataResults) 

506 self.assertEqual(numObs, 15) 

507 

508 @staticmethod 

509 def coords(x): 

510 return np.radians(x.summary['unrefractedRA']),\ 

511 np.radians(x.summary['unrefractedDec']) 

512 

513 @classmethod 

514 def _writeManySNCatalogs(cls, obsMetaDataResults, suffix=''): 

515 

516 fnameList = [] 

517 for obsindex, obsMetaData in enumerate(obsMetaDataResults): 

518 

519 cols = ['t0', 'mwebv', 'time', 'band', 'flux', 'flux_err', 

520 'mag', 'mag_err', 'cosmologicalDistanceModulus'] 

521 newCatalog = SNIaCatalog(db_obj=cls.galDB, obs_metadata=obsMetaData, 

522 column_outputs=cols) 

523 newCatalog.midSurveyTime = 49350 

524 newCatalog.averageRate = 1. 

525 newCatalog.suppressDimSN = False 

526 s = "{0:d}".format(obsindex) 

527 fname = os.path.join(cls.scratchDir, "SNCatalog_" + s + suffix) 

528 newCatalog.write_catalog(fname) 

529 fnameList.append(fname) 

530 return fnameList 

531 

532 @classmethod 

533 def _createFakeGalaxyDB(cls): 

534 ''' 

535 Create a local sqlite galaxy database having filename dbname with 

536 

537 variables id, raJ2000, decJ2000 and redshift, having number of 

538 rows =size, and having overlap with ObsMetaData. 

539 

540 Parameters 

541 ---------- 

542 

543 ''' 

544 dbname = cls.dbname 

545 samps = cls.GalaxyPositionSamps 

546 size = cls.size 

547 cls.cleanDB(dbname) 

548 conn = sqlite3.connect(dbname) 

549 curs = conn.cursor() 

550 curs.execute('CREATE TABLE if not exists gals ' 

551 '(id INT, raJ2000 FLOAT, decJ2000 FLOAT, redshift FLOAT)') 

552 

553 seed = 1 

554 np.random.seed(seed) 

555 

556 for count in range(size): 

557 id = 1000000 + count 

558 

559 # Main Database should have values in degrees 

560 ra = samps[0][count] 

561 dec = samps[1][count] 

562 redshift = np.random.uniform() 

563 row = tuple([id, ra, dec, redshift]) 

564 exec_str = cls.insertfromdata(tablename='gals', records=row, 

565 multiple=False) 

566 curs.execute(exec_str, row) 

567 

568 conn.commit() 

569 conn.close() 

570 return samps 

571 

572 @staticmethod 

573 def cleanDB(dbname, verbose=True): 

574 """ 

575 Deletes the database dbname from the disk. 

576 Parameters 

577 ---------- 

578 dbname: string, mandatory 

579 name (abs path) of the database to be deleted 

580 verbose: Bool, optional, defaults to True 

581 """ 

582 

583 if os.path.exists(dbname): 

584 if verbose: 

585 print("deleting database ", dbname) 

586 os.unlink(dbname) 

587 else: 

588 if verbose: 

589 print('database ', dbname, ' does not exist') 

590 

591 @staticmethod 

592 def insertfromdata(tablename, records, multiple=True): 

593 """ 

594 construct string to insert multiple records into sqlite3 database 

595 args: 

596 tablename: str, mandatory 

597 Name of table in the database. 

598 records: set of records 

599 multiple: 

600 returns: 

601 """ 

602 if multiple: 

603 lst = records[0] 

604 else: 

605 lst = records 

606 s = 'INSERT INTO ' + str(tablename) + ' VALUES ' 

607 s += "( " + ", ".join(["?"] * len(lst)) + ")" 

608 return s 

609 

610 

611class SNIaLightCurveControlCatalog(SNIaCatalog): 

612 catalog_type = __file__ + 'sn_ia_lc_cat' 

613 column_outputs = ['uniqueId', 'flux', 'flux_err', 'redshift'] 

614 _midSurveyTime = 49000.0 

615 _snFrequency = 0.001 

616 

617 

618@unittest.skipIf(_skip_sn_tests, "cannot properly load astropy config dir") 

619class SNIaLightCurveTest(unittest.TestCase): 

620 

621 @classmethod 

622 def setUpClass(cls): 

623 

624 rng = np.random.RandomState(99) 

625 n_sne = 100 

626 ra_list = rng.random_sample(n_sne) * 7.0 + 78.0 

627 dec_list = rng.random_sample(n_sne) * 4.0 - 69.0 

628 zz_list = rng.random_sample(n_sne) * 1.0 + 0.05 

629 

630 cls.scratchDir = tempfile.mkdtemp(dir=ROOT, prefix='scratchSpace-') 

631 cls.input_cat_name = os.path.join(cls.scratchDir, "sne_input_cat.txt") 

632 

633 with open(cls.input_cat_name, "w") as output_file: 

634 for ix in range(n_sne): 

635 output_file.write("%d;%.12f;%.12f;%.12f;%.12f;%.12f\n" 

636 % (ix + 1, ra_list[ix], dec_list[ix], 

637 np.radians(ra_list[ix]), np.radians(dec_list[ix]), 

638 zz_list[ix])) 

639 

640 dtype = np.dtype([('id', np.int), 

641 ('raDeg', np.float), ('decDeg', np.float), 

642 ('raJ2000', np.float), ('decJ2000', np.float), 

643 ('redshift', np.float)]) 

644 

645 cls.db = fileDBObject(cls.input_cat_name, delimiter=';', 

646 runtable='test', dtype=dtype, 

647 idColKey='id') 

648 

649 cls.db.raColName = 'raDeg' 

650 cls.db.decColName = 'decDeg' 

651 cls.db.objectTypeId = 873 

652 

653 cls.opsimDb = os.path.join(getPackageDir("sims_data"), "OpSimData") 

654 cls.opsimDb = os.path.join(cls.opsimDb, "opsimblitz1_1133_sqlite.db") 

655 

656 @classmethod 

657 def tearDownClass(cls): 

658 sims_clean_up() 

659 if os.path.exists(cls.input_cat_name): 

660 os.unlink(cls.input_cat_name) 

661 if os.path.exists(cls.scratchDir): 

662 shutil.rmtree(cls.scratchDir) 

663 

664 def test_sne_light_curves(self): 

665 """ 

666 Generate some super nova light curves. Verify that they come up with the same 

667 magnitudes and uncertainties as supernova catalogs. 

668 """ 

669 

670 gen = SNIaLightCurveGenerator(self.db, self.opsimDb) 

671 

672 raRange = (78.0, 85.0) 

673 decRange = (-69.0, -65.0) 

674 bandpass = 'r' 

675 

676 pointings = gen.get_pointings(raRange, decRange, bandpass=bandpass) 

677 gen.sn_universe._midSurveyTime = 49000.0 

678 gen.sn_universe._snFrequency = 0.001 

679 self.assertGreater(len(pointings), 1) 

680 lc_dict, truth = gen.light_curves_from_pointings(pointings) 

681 self.assertGreater(len(lc_dict), 0) 

682 

683 for group in pointings: 

684 self.assertGreater(len(group), 1) 

685 for obs in group: 

686 cat = SNIaLightCurveControlCatalog(self.db, obs_metadata=obs) 

687 for sn in cat.iter_catalog(): 

688 if sn[1] > 0.0: 

689 lc = lc_dict[sn[0]][bandpass] 

690 dex = np.argmin(np.abs(lc['mjd'] - obs.mjd.TAI)) 

691 self.assertLess(np.abs(lc['mjd'][dex] - obs.mjd.TAI), 1.0e-7) 

692 self.assertLess(np.abs(lc['flux'][dex] - sn[1]), 1.0e-7) 

693 self.assertLess(np.abs(lc['error'][dex] - sn[2]), 1.0e-7) 

694 

695 def test_sne_light_curves_z_cut(self): 

696 """ 

697 Generate some super nova light curves. Add a cutoff in redshift. 

698 Verify that they come up with the same magnitudes and uncertainties 

699 as supernova catalogs and that objects with z>z_cutoff are not returned. 

700 """ 

701 z_cut = 0.9 

702 

703 gen = SNIaLightCurveGenerator(self.db, self.opsimDb) 

704 gen.z_cutoff = z_cut 

705 

706 raRange = (78.0, 85.0) 

707 decRange = (-69.0, -65.0) 

708 bandpass = 'r' 

709 

710 pointings = gen.get_pointings(raRange, decRange, bandpass=bandpass) 

711 gen.sn_universe._midSurveyTime = 49000.0 

712 gen.sn_universe._snFrequency = 0.001 

713 self.assertGreater(len(pointings), 1) 

714 lc_dict, truth = gen.light_curves_from_pointings(pointings) 

715 self.assertGreater(len(lc_dict), 0) 

716 

717 over_z = 0 

718 

719 for group in pointings: 

720 self.assertGreater(len(group), 1) 

721 for obs in group: 

722 cat = SNIaLightCurveControlCatalog(self.db, obs_metadata=obs) 

723 for sn in cat.iter_catalog(): 

724 if sn[1] > 0.0: 

725 if sn[3] > z_cut: 

726 self.assertNotIn(sn[0], lc_dict) 

727 over_z += 1 

728 else: 

729 lc = lc_dict[sn[0]][bandpass] 

730 dex = np.argmin(np.abs(lc['mjd'] - obs.mjd.TAI)) 

731 self.assertLess(np.abs(lc['mjd'][dex] - obs.mjd.TAI), 1.0e-7) 

732 self.assertLess(np.abs(lc['flux'][dex] - sn[1]), 1.0e-7) 

733 self.assertLess(np.abs(lc['error'][dex] - sn[2]), 

734 1.0e-7, msg='%e vs %e' % (lc['error'][dex], sn[2])) 

735 

736 self.assertGreater(over_z, 0) 

737 

738 def test_sne_multiband_light_curves(self): 

739 """ 

740 Generate some super nova light curves. Verify that they come up with the same 

741 magnitudes and uncertainties as supernova catalogs. Use multiband light curves. 

742 """ 

743 

744 gen = SNIaLightCurveGenerator(self.db, self.opsimDb) 

745 

746 raRange = (78.0, 85.0) 

747 decRange = (-69.0, -65.0) 

748 

749 pointings = gen.get_pointings(raRange, decRange, bandpass=('r', 'z')) 

750 gen.sn_universe._midSurveyTime = 49000.0 

751 gen.sn_universe._snFrequency = 0.001 

752 self.assertGreater(len(pointings), 1) 

753 lc_dict, truth = gen.light_curves_from_pointings(pointings) 

754 self.assertGreater(len(lc_dict), 0) 

755 

756 obs_gen = ObservationMetaDataGenerator(database=self.opsimDb, driver='sqlite') 

757 control_obs_r = obs_gen.getObservationMetaData(fieldRA=raRange, fieldDec=decRange, 

758 telescopeFilter='r', boundLength=1.75) 

759 

760 control_obs_z = obs_gen.getObservationMetaData(fieldRA=raRange, fieldDec=decRange, 

761 telescopeFilter='z', boundLength=1.75) 

762 

763 self.assertGreater(len(control_obs_r), 0) 

764 self.assertGreater(len(control_obs_z), 0) 

765 

766 ct_r = 0 

767 for obs in control_obs_r: 

768 cat = SNIaLightCurveControlCatalog(self.db, obs_metadata=obs) 

769 for sn in cat.iter_catalog(): 

770 if sn[1] > 0.0: 

771 ct_r += 1 

772 lc = lc_dict[sn[0]]['r'] 

773 dex = np.argmin(np.abs(lc['mjd'] - obs.mjd.TAI)) 

774 self.assertLess(np.abs(lc['mjd'][dex] - obs.mjd.TAI), 1.0e-7) 

775 self.assertLess(np.abs(lc['flux'][dex] - sn[1]), 1.0e-7) 

776 self.assertLess(np.abs(lc['error'][dex] - sn[2]), 1.0e-7) 

777 

778 self.assertGreater(ct_r, 0) 

779 

780 ct_z = 0 

781 for obs in control_obs_z: 

782 cat = SNIaLightCurveControlCatalog(self.db, obs_metadata=obs) 

783 for sn in cat.iter_catalog(): 

784 if sn[1] > 0.0: 

785 ct_z += 1 

786 lc = lc_dict[sn[0]]['z'] 

787 dex = np.argmin(np.abs(lc['mjd'] - obs.mjd.TAI)) 

788 self.assertLess(np.abs(lc['mjd'][dex] - obs.mjd.TAI), 1.0e-7) 

789 self.assertLess(np.abs(lc['flux'][dex] - sn[1]), 1.0e-7) 

790 self.assertLess(np.abs(lc['error'][dex] - sn[2]), 1.0e-7) 

791 

792 self.assertGreater(ct_z, 0) 

793 

794 def test_limit_sne_light_curves(self): 

795 """ 

796 Test that we can limit the number of light curves returned per field of view 

797 """ 

798 lc_limit = 2 

799 gen = SNIaLightCurveGenerator(self.db, self.opsimDb) 

800 gen.sn_universe._midSurveyTime = 49000.0 

801 gen.sn_universe._snFrequency = 0.001 

802 

803 raRange = (78.0, 85.0) 

804 decRange = (-69.0, -65.0) 

805 

806 pointings = gen.get_pointings(raRange, decRange, bandpass=('r', 'z')) 

807 

808 control_lc, truth = gen.light_curves_from_pointings(pointings) 

809 test_lc, truth = gen.light_curves_from_pointings(pointings, lc_per_field=lc_limit) 

810 self.assertGreater(len(control_lc), len(test_lc)) 

811 self.assertLessEqual(len(test_lc), lc_limit*len(pointings)) 

812 

813 

814class MemoryTestClass(lsst.utils.tests.MemoryTestCase): 

815 pass 

816 

817if __name__ == "__main__": 817 ↛ 818line 817 didn't jump to line 818, because the condition on line 817 was never true

818 lsst.utils.tests.init() 

819 unittest.main()