Coverage for tests / test_loadDiaCatalogs.py: 31%
81 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-22 08:54 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-22 08:54 +0000
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import os
23import astropy.units
24import numpy as np
25import tempfile
26import unittest
27import yaml
29from lsst.ap.association import LoadDiaCatalogsTask
30from lsst.ap.association.utils import getMidpointFromTimespan, readSchemaFromApdb
31from lsst.dax.apdb import Apdb, ApdbSql, ApdbTables
32from lsst.utils import getPackageDir
33import lsst.utils.tests
34from utils_tests import makeExposure, makeDiaObjects, makeDiaSources, makeDiaForcedSources, makeRegionTime, \
35 getRegion
38def _data_file_name(basename, module_name):
39 """Return path name of a data file.
41 Parameters
42 ----------
43 basename : `str`
44 Name of the file to add to the path string.
45 module_name : `str`
46 Name of lsst stack package environment variable.
48 Returns
49 -------
50 data_file_path : `str`
51 Full path of the file to load from the "data" directory in a given
52 repository.
53 """
54 return os.path.join(getPackageDir(module_name), "data", basename)
57class TestLoadDiaCatalogs(unittest.TestCase):
59 def setUp(self):
60 # Create an instance of random generator with fixed seed.
61 rng = np.random.default_rng(1234)
63 self.db_file_fd, self.db_file = tempfile.mkstemp(
64 dir=os.path.dirname(__file__))
65 self.addCleanup(os.remove, self.db_file)
66 self.addCleanup(os.close, self.db_file_fd)
68 self.apdbConfig = ApdbSql.init_database(db_url="sqlite:///" + self.db_file)
69 self.config_file = tempfile.NamedTemporaryFile()
70 self.addCleanup(self.config_file.close)
71 self.apdbConfig.save(self.config_file.name)
72 self.apdb = Apdb.from_config(self.apdbConfig)
73 self.schema = readSchemaFromApdb(self.apdb)
75 self.exposure = makeExposure(False, False)
76 self.regionTime = makeRegionTime(exposure=self.exposure)
77 self.dateTime = getMidpointFromTimespan(self.regionTime.timespan)
79 self.diaObjects = makeDiaObjects(20, self.exposure, rng)
80 self.diaSources = makeDiaSources(
81 100, self.diaObjects["diaObjectId"].to_numpy(), self.exposure, rng)
82 self.diaForcedSources = makeDiaForcedSources(
83 200, self.diaObjects["diaObjectId"].to_numpy(), self.exposure, rng)
85 # Store the test diaSources as though they were observed a month before
86 # the current exposure.
87 dateTime = self.regionTime.timespan.begin.tai - 30 * astropy.units.day
88 self.apdb.store(dateTime,
89 self.diaObjects,
90 self.diaSources,
91 self.diaForcedSources)
93 # These columns are not in the DPDD, yet do appear in DiaSource.yaml.
94 # We don't need to check them against the default APDB schema.
95 self.ignoreColumns = ["band", "bboxSize", "isDipole", "flags"]
97 def _makeConfig(self, **kwargs):
98 config = LoadDiaCatalogsTask.ConfigClass()
99 config.apdb_config_url = self.config_file.name
100 config.update(**kwargs)
101 return config
103 def testRun(self):
104 """Test the full run method for the loader.
105 """
106 diaConfig = self._makeConfig()
107 diaLoader = LoadDiaCatalogsTask(config=diaConfig)
108 result = diaLoader.run(self.regionTime)
110 self.assertEqual(len(result.diaObjects), len(self.diaObjects))
111 self.assertEqual(len(result.diaSources), len(self.diaSources))
112 self.assertEqual(len(result.diaForcedSources),
113 len(self.diaForcedSources))
115 def testLoadDiaObjects(self):
116 """Test that the correct number of diaObjects are loaded.
117 """
118 diaConfig = self._makeConfig()
119 diaLoader = LoadDiaCatalogsTask(config=diaConfig)
120 region = getRegion(self.exposure)
121 diaObjects = diaLoader.loadDiaObjects(region,
122 self.schema)
123 self.assertEqual(len(diaObjects), len(self.diaObjects))
125 def testLoadDiaForcedSources(self):
126 """Test that the correct number of diaForcedSources are loaded.
127 """
128 diaConfig = self._makeConfig()
129 diaLoader = LoadDiaCatalogsTask(config=diaConfig)
130 region = getRegion(self.exposure)
131 diaForcedSources = diaLoader.loadDiaForcedSources(
132 self.diaObjects,
133 region,
134 self.dateTime,
135 self.schema)
136 self.assertEqual(len(diaForcedSources), len(self.diaForcedSources))
138 def testLoadDiaSources(self):
139 """Test that the correct number of diaSources are loaded.
141 Also check that they can be properly loaded both by location and
142 ``diaObjectId``.
143 """
144 diaConfig = self._makeConfig()
145 diaLoader = LoadDiaCatalogsTask(config=diaConfig)
147 region = getRegion(self.exposure)
148 diaSources = diaLoader.loadDiaSources(self.diaObjects,
149 region,
150 self.dateTime,
151 self.schema)
152 self.assertEqual(len(diaSources), len(self.diaSources))
154 def test_apdbSchema(self):
155 """Test that the default DiaSource schema from dax_apdb agrees with the
156 column names defined here in ap_association/data/DiaSource.yaml.
157 """
158 tableDef = self.apdb.tableDef(ApdbTables.DiaSource)
159 apdbSchemaColumns = [column.name for column in tableDef.columns]
161 functorFile = _data_file_name("DiaSource.yaml", "ap_association")
162 with open(functorFile) as yaml_stream:
163 diaSourceFunctor = yaml.safe_load_all(yaml_stream)
164 for functor in diaSourceFunctor:
165 diaSourceColumns = [column for column in list(functor['funcs'].keys())
166 if column not in self.ignoreColumns]
167 self.assertLess(set(diaSourceColumns), set(apdbSchemaColumns))
170class MemoryTester(lsst.utils.tests.MemoryTestCase):
171 pass
174def setup_module(module):
175 lsst.utils.tests.init()
178if __name__ == "__main__": 178 ↛ 179line 178 didn't jump to line 179 because the condition on line 178 was never true
179 lsst.utils.tests.init()
180 unittest.main()