Coverage for tests/test_apdb.py: 23%

99 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-15 10:33 +0000

1# This file is part of analysis_ap. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import os 

23import unittest 

24 

25import lsst.utils.tests 

26import pandas as pd 

27from lsst.analysis.ap.apdb import ApdbSqliteQuery 

28 

29 

30class TestApdbSqlite(lsst.utils.tests.TestCase): 

31 def setUp(self): 

32 datadir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data/") 

33 apdb_file = os.path.join(datadir, "apdb.sqlite3") 

34 self.apdb = ApdbSqliteQuery(apdb_file, instrument="LSSTCam-imSim") 

35 

36 def test_load_sources(self): 

37 result = self.apdb.load_sources(limit=None) 

38 self.assertEqual(len(result), 290) 

39 # spot check a few fields 

40 self.assertEqual(result['diaSourceId'][0], 506428274000265217) 

41 self.assertEqual(result['diaObjectId'][14], 506428274000265241) 

42 self.assertEqual(result['detector'][0], 168) 

43 self.assertEqual(result['visit'][0], 943296) 

44 

45 # check using a query limit 

46 result = self.apdb.load_sources(limit=2) 

47 self.assertEqual(len(result), 2) 

48 

49 def test_load_sources_exclude_flags(self): 

50 # Test that we load the expected number of diaSources. 

51 # (There are 19 diaSources of the 290 that should be excluded.) 

52 result = self.apdb.load_sources(exclude_flagged=True) 

53 self.assertEqual(len(result), 271) 

54 

55 def test_load_sources_for_object(self): 

56 # Test that we load one specific diaObject and 1 of its 2 diaSources 

57 result = self.apdb.load_sources_for_object(506428274000265388) 

58 self.assertEqual(len(result), 2) 

59 self.assertEqual(result['diaSourceId'][0], 506428274000265388) 

60 

61 def test_load_forced_sources_for_object(self): 

62 # Test that we can load the same diaObject 

63 # This diaObject was found to have 2 constituent diaForcedSources 

64 result = self.apdb.load_forced_sources_for_object(506428274000265388) 

65 self.assertEqual(len(result), 2) 

66 self.assertEqual(result['diaForcedSourceId'][0], 506428274000265354) 

67 

68 def test_load_sources_for_object_exclude_flags(self): 

69 # diaObject chosen from inspection to have 2 flagged diaSources 

70 result = self.apdb.load_sources_for_object(506428274000265285) 

71 self.assertEqual(len(result), 2) 

72 self.assertEqual(result['diaSourceId'][0], 506428274000265285) 

73 self.assertEqual(result['diaSourceId'][1], 527736141479149663) 

74 # This same diaObject has `diaSource_flags_exclude` flags 

75 # on all 2 of its diaSources 

76 result = self.apdb.load_sources_for_object(506428274000265285, 

77 exclude_flagged=True) 

78 self.assertEqual(len(result), 0) 

79 

80 def test_load_objects(self): 

81 result = self.apdb.load_objects(limit=None) 

82 self.assertEqual(len(result), 259) 

83 # spot check a few fields 

84 self.assertNotIn("diaSourceId", result) 

85 self.assertEqual(result['diaObjectId'][0], 506428274000265217) 

86 self.assertIn("validityStart", result.columns) 

87 

88 result = self.apdb.load_objects(limit=2) 

89 self.assertEqual(len(result), 2) 

90 

91 def test_load_forced_sources(self): 

92 result = self.apdb.load_forced_sources(limit=None) 

93 self.assertEqual(len(result), 376) 

94 # spot check a few fields 

95 self.assertEqual(result['diaObjectId'][0], 506428274000265217) 

96 self.assertEqual(result['diaForcedSourceId'][0], 506428274000265217) 

97 self.assertEqual(result['detector'][0], 168) 

98 self.assertEqual(result['visit'][0], 943296) 

99 

100 result = self.apdb.load_forced_sources(limit=2) 

101 self.assertEqual(len(result), 2) 

102 

103 def test_load_source(self): 

104 result = self.apdb.load_source(506428274000265217) 

105 # spot check a few fields 

106 self.assertEqual(result['diaSourceId'], 506428274000265217) 

107 self.assertEqual(result['diaObjectId'], 506428274000265217) 

108 self.assertEqual(result['band'], 'r') 

109 

110 with self.assertRaisesRegex(RuntimeError, "diaSourceId=54321 not found"): 

111 self.apdb.load_source(54321) 

112 

113 def test_load_object(self): 

114 result = self.apdb.load_object(506428274000265228) 

115 # spot check a few fields 

116 self.assertEqual(result['diaObjectId'], 506428274000265228) 

117 self.assertFloatsAlmostEqual(result['ra'], 55.7887299103902, rtol=1e-15) 

118 

119 with self.assertRaisesRegex(RuntimeError, "diaObjectId=54321 not found"): 

120 self.apdb.load_object(54321) 

121 

122 def test_load_forced_source(self): 

123 result = self.apdb.load_forced_source(506428274000265224) 

124 # spot check a few fields 

125 self.assertEqual(result['diaForcedSourceId'], 506428274000265224) 

126 self.assertEqual(result['diaObjectId'], 506428274000265228) 

127 

128 with self.assertRaisesRegex(RuntimeError, "diaForcedSourceId=54321 not found"): 

129 self.apdb.load_forced_source(54321) 

130 

131 def test_make_flag_exclusion_clause(self): 

132 # Test clause generation with default flag list. 

133 table = self.apdb._tables["DiaSource"] 

134 query = table.select() 

135 query = self.apdb._make_flag_exclusion_query(query, table, self.apdb.diaSource_flags_exclude) 

136 # Check that the SQL query literal string does the flag exclusion. 

137 queryString = ('NOT ("DiaSource"."pixelFlags_bad" = 1 ' 

138 'OR "DiaSource"."pixelFlags_suspect" = 1 ' 

139 'OR "DiaSource"."pixelFlags_saturatedCenter" = 1 ' 

140 'OR "DiaSource"."pixelFlags_interpolated" = 1 ' 

141 'OR "DiaSource"."pixelFlags_interpolatedCenter" = 1 ' 

142 'OR "DiaSource"."pixelFlags_edge" = 1)') 

143 self.assertEqual(str(query.whereclause.compile(compile_kwargs={"literal_binds": True})), 

144 queryString) 

145 

146 def test_set_excluded_diaSource_flags(self): 

147 with self.assertRaisesRegex(ValueError, "flag not a real flag not included"): 

148 self.apdb.set_excluded_diaSource_flags(['not a real flag']) 

149 

150 self.apdb.set_excluded_diaSource_flags(['pixelFlags_streak']) 

151 table = self.apdb._tables["DiaSource"] 

152 query = table.select() 

153 query = self.apdb._make_flag_exclusion_query(query, table, self.apdb.diaSource_flags_exclude) 

154 # Check that the SQL query does a non-default flag exclusion. 

155 queryString = '"DiaSource"."pixelFlags_streak" != 1' 

156 self.assertEqual(str(query.whereclause.compile(compile_kwargs={"literal_binds": True})), 

157 queryString) 

158 

159 def test_fill_from_instrument(self): 

160 # an empty series should be unchanged 

161 empty = pd.Series() 

162 self.apdb._fill_from_instrument(empty) 

163 self.assertTrue(empty.equals(pd.Series())) 

164 

165 

166class TestMemory(lsst.utils.tests.MemoryTestCase): 

167 pass 

168 

169 

170def setup_module(module): 

171 lsst.utils.tests.init() 

172 

173 

174if __name__ == "__main__": 174 ↛ 175line 174 didn't jump to line 175, because the condition on line 174 was never true

175 lsst.utils.tests.init() 

176 unittest.main()