Coverage for tests/test_apdb.py: 23%
99 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-18 12:10 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-18 12:10 +0000
1# This file is part of analysis_ap.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import os
23import unittest
25import lsst.utils.tests
26import pandas as pd
27from lsst.analysis.ap.apdb import ApdbSqliteQuery
30class TestApdbSqlite(lsst.utils.tests.TestCase):
31 def setUp(self):
32 datadir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data/")
33 apdb_file = os.path.join(datadir, "apdb.sqlite3")
34 self.apdb = ApdbSqliteQuery(apdb_file, instrument="LSSTCam-imSim")
36 def test_load_sources(self):
37 result = self.apdb.load_sources(limit=None)
38 self.assertEqual(len(result), 290)
39 # spot check a few fields
40 self.assertEqual(result['diaSourceId'][0], 506428274000265217)
41 self.assertEqual(result['diaObjectId'][14], 506428274000265241)
42 self.assertEqual(result['detector'][0], 168)
43 self.assertEqual(result['visit'][0], 943296)
45 # check using a query limit
46 result = self.apdb.load_sources(limit=2)
47 self.assertEqual(len(result), 2)
49 def test_load_sources_exclude_flags(self):
50 # Test that we load the expected number of diaSources.
51 # (There are 19 diaSources of the 290 that should be excluded.)
52 result = self.apdb.load_sources(exclude_flagged=True)
53 self.assertEqual(len(result), 271)
55 def test_load_sources_for_object(self):
56 # Test that we load one specific diaObject and 1 of its 2 diaSources
57 result = self.apdb.load_sources_for_object(506428274000265388)
58 self.assertEqual(len(result), 2)
59 self.assertEqual(result['diaSourceId'][0], 506428274000265388)
61 def test_load_forced_sources_for_object(self):
62 # Test that we can load the same diaObject
63 # This diaObject was found to have 2 constituent diaForcedSources
64 result = self.apdb.load_forced_sources_for_object(506428274000265388)
65 self.assertEqual(len(result), 2)
66 self.assertEqual(result['diaForcedSourceId'][0], 506428274000265354)
68 def test_load_sources_for_object_exclude_flags(self):
69 # diaObject chosen from inspection to have 2 flagged diaSources
70 result = self.apdb.load_sources_for_object(506428274000265285)
71 self.assertEqual(len(result), 2)
72 self.assertEqual(result['diaSourceId'][0], 506428274000265285)
73 self.assertEqual(result['diaSourceId'][1], 527736141479149663)
74 # This same diaObject has `diaSource_flags_exclude` flags
75 # on all 2 of its diaSources
76 result = self.apdb.load_sources_for_object(506428274000265285,
77 exclude_flagged=True)
78 self.assertEqual(len(result), 0)
80 def test_load_objects(self):
81 result = self.apdb.load_objects(limit=None)
82 self.assertEqual(len(result), 259)
83 # spot check a few fields
84 self.assertNotIn("diaSourceId", result)
85 self.assertEqual(result['diaObjectId'][0], 506428274000265217)
86 self.assertIn("validityStart", result.columns)
88 result = self.apdb.load_objects(limit=2)
89 self.assertEqual(len(result), 2)
91 def test_load_forced_sources(self):
92 result = self.apdb.load_forced_sources(limit=None)
93 self.assertEqual(len(result), 376)
94 # spot check a few fields
95 self.assertEqual(result['diaObjectId'][0], 506428274000265217)
96 self.assertEqual(result['diaForcedSourceId'][0], 506428274000265217)
97 self.assertEqual(result['detector'][0], 168)
98 self.assertEqual(result['visit'][0], 943296)
100 result = self.apdb.load_forced_sources(limit=2)
101 self.assertEqual(len(result), 2)
103 def test_load_source(self):
104 result = self.apdb.load_source(506428274000265217)
105 # spot check a few fields
106 self.assertEqual(result['diaSourceId'], 506428274000265217)
107 self.assertEqual(result['diaObjectId'], 506428274000265217)
108 self.assertEqual(result['band'], 'r')
110 with self.assertRaisesRegex(RuntimeError, "diaSourceId=54321 not found"):
111 self.apdb.load_source(54321)
113 def test_load_object(self):
114 result = self.apdb.load_object(506428274000265228)
115 # spot check a few fields
116 self.assertEqual(result['diaObjectId'], 506428274000265228)
117 self.assertFloatsAlmostEqual(result['ra'], 55.7887299103902, rtol=1e-15)
119 with self.assertRaisesRegex(RuntimeError, "diaObjectId=54321 not found"):
120 self.apdb.load_object(54321)
122 def test_load_forced_source(self):
123 result = self.apdb.load_forced_source(506428274000265224)
124 # spot check a few fields
125 self.assertEqual(result['diaForcedSourceId'], 506428274000265224)
126 self.assertEqual(result['diaObjectId'], 506428274000265228)
128 with self.assertRaisesRegex(RuntimeError, "diaForcedSourceId=54321 not found"):
129 self.apdb.load_forced_source(54321)
131 def test_make_flag_exclusion_clause(self):
132 # Test clause generation with default flag list.
133 table = self.apdb._tables["DiaSource"]
134 query = table.select()
135 query = self.apdb._make_flag_exclusion_query(query, table, self.apdb.diaSource_flags_exclude)
136 # Check that the SQL query literal string does the flag exclusion.
137 queryString = ('"DiaSource"."pixelFlags_bad" = false '
138 'AND "DiaSource"."pixelFlags_suspect" = false '
139 'AND "DiaSource"."pixelFlags_saturatedCenter" = false '
140 'AND "DiaSource"."pixelFlags_interpolated" = false '
141 'AND "DiaSource"."pixelFlags_interpolatedCenter" = false '
142 'AND "DiaSource"."pixelFlags_edge" = false')
143 self.assertEqual(str(query.whereclause.compile(compile_kwargs={"literal_binds": True})),
144 queryString)
146 def test_set_excluded_diaSource_flags(self):
147 with self.assertRaisesRegex(ValueError, "flag not a real flag not included"):
148 self.apdb.set_excluded_diaSource_flags(['not a real flag'])
150 self.apdb.set_excluded_diaSource_flags(['pixelFlags_streak'])
151 table = self.apdb._tables["DiaSource"]
152 query = table.select()
153 query = self.apdb._make_flag_exclusion_query(query, table, self.apdb.diaSource_flags_exclude)
154 # Check that the SQL query does a non-default flag exclusion.
155 queryString = '"DiaSource"."pixelFlags_streak" = false'
156 self.assertEqual(str(query.whereclause.compile(compile_kwargs={"literal_binds": True})),
157 queryString)
159 def test_fill_from_instrument(self):
160 # an empty series should be unchanged
161 empty = pd.Series()
162 self.apdb._fill_from_instrument(empty)
163 self.assertTrue(empty.equals(pd.Series()))
166class TestMemory(lsst.utils.tests.MemoryTestCase):
167 pass
170def setup_module(module):
171 lsst.utils.tests.init()
174if __name__ == "__main__": 174 ↛ 175line 174 didn't jump to line 175, because the condition on line 174 was never true
175 lsst.utils.tests.init()
176 unittest.main()