Coverage for tests/test_sql.py: 19%
77 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-13 09:59 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-13 09:59 +0000
1# This file is part of felis.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import os
23import unittest
24from collections.abc import Mapping, MutableMapping
25from typing import Any, Optional, cast
27import sqlalchemy
28import yaml
30from felis import DEFAULT_FRAME
31from felis.db import sqltypes
32from felis.sql import SQLVisitor
34TESTDIR = os.path.abspath(os.path.dirname(__file__))
35TEST_YAML = os.path.join(TESTDIR, "data", "test.yml")
38def _get_unique_constraint(table: sqlalchemy.schema.Table) -> Optional[sqlalchemy.schema.UniqueConstraint]:
39 """Return a unique constraint for a table, raise if table has more than
40 one unique constraint.
41 """
42 uniques = [
43 constraint
44 for constraint in table.constraints
45 if isinstance(constraint, sqlalchemy.schema.UniqueConstraint)
46 ]
47 if len(uniques) > 1:
48 raise TypeError(f"More than one constraint defined for table {table}")
49 elif not uniques:
50 return None
51 else:
52 return uniques[0]
55def _get_indices(table: sqlalchemy.schema.Table) -> Mapping[str, sqlalchemy.schema.Index]:
56 """Return mapping of table indices indexed by index name."""
57 return {cast(str, index.name): index for index in table.indexes}
60class VisitorTestCase(unittest.TestCase):
61 """Tests for both CheckingVisitor and SQLVisitor classes."""
63 schema_obj: MutableMapping[str, Any] = {}
65 def setUp(self) -> None:
66 """Load data from test file."""
67 with open(TEST_YAML) as test_yaml:
68 self.schema_obj = yaml.load(test_yaml, Loader=yaml.SafeLoader)
69 self.schema_obj.update(DEFAULT_FRAME)
71 def test_make_metadata(self) -> None:
72 """Generate sqlalchemy metadata using SQLVisitor class"""
73 visitor = SQLVisitor()
74 schema = visitor.visit_schema(self.schema_obj)
75 self.assertIsNotNone(schema)
76 self.assertEqual(schema.name, "sdqa")
77 self.assertIsNotNone(schema.tables)
78 self.assertIsNotNone(schema.graph_index)
79 self.assertIsNotNone(schema.metadata)
81 table_names = [
82 "sdqa_ImageStatus",
83 "sdqa_Metric",
84 "sdqa_Rating_ForAmpVisit",
85 "sdqa_Rating_CcdVisit",
86 "sdqa_Threshold",
87 ]
89 # Look at metadata tables.
90 self.assertIsNone(schema.metadata.schema)
91 tables = schema.metadata.tables
92 self.assertCountEqual(tables.keys(), [f"sdqa.{table}" for table in table_names])
94 # Check schema.tables attribute.
95 self.assertCountEqual([table.name for table in schema.tables], table_names)
97 # Checks tables in graph index.
98 for table_name in table_names:
99 self.assertIs(schema.graph_index[f"#{table_name}"], tables[f"sdqa.{table_name}"])
101 # Details of sdqa_ImageStatus table.
102 table = tables["sdqa.sdqa_ImageStatus"]
103 self.assertCountEqual(table.columns.keys(), ["sdqa_imageStatusId", "statusName", "definition"])
104 self.assertTrue(table.columns["sdqa_imageStatusId"].primary_key)
105 self.assertFalse(table.indexes)
106 for column, ctype in zip(
107 table.columns.values(),
108 (sqlalchemy.types.SMALLINT, sqlalchemy.types.VARCHAR, sqlalchemy.types.VARCHAR),
109 ):
110 self.assertIsInstance(column.type, (ctype, sqlalchemy.types.Variant))
112 # Details of sdqa_Metric table.
113 table = tables["sdqa.sdqa_Metric"]
114 self.assertCountEqual(
115 table.columns.keys(), ["sdqa_metricId", "metricName", "physicalUnits", "dataType", "definition"]
116 )
117 self.assertTrue(table.columns["sdqa_metricId"].primary_key)
118 self.assertFalse(table.indexes)
119 for column, ctype in zip(
120 table.columns.values(),
121 (
122 sqlalchemy.types.SMALLINT,
123 sqlalchemy.types.VARCHAR,
124 sqlalchemy.types.VARCHAR,
125 sqlalchemy.types.CHAR,
126 sqlalchemy.types.VARCHAR,
127 ),
128 ):
129 self.assertIsInstance(column.type, (ctype, sqlalchemy.types.Variant))
130 # It defines a unique constraint.
131 unique = _get_unique_constraint(table)
132 assert unique is not None, "Constraint must be defined"
133 self.assertEqual(unique.name, "UQ_sdqaMetric_metricName")
134 self.assertCountEqual(unique.columns, [table.columns["metricName"]])
136 # Details of sdqa_Rating_ForAmpVisit table.
137 table = tables["sdqa.sdqa_Rating_ForAmpVisit"]
138 self.assertCountEqual(
139 table.columns.keys(),
140 [
141 "sdqa_ratingId",
142 "sdqa_metricId",
143 "sdqa_thresholdId",
144 "ampVisitId",
145 "metricValue",
146 "metricSigma",
147 ],
148 )
149 self.assertTrue(table.columns["sdqa_ratingId"].primary_key)
150 for column, ctype in zip(
151 table.columns.values(),
152 (
153 sqlalchemy.types.BIGINT,
154 sqlalchemy.types.SMALLINT,
155 sqlalchemy.types.SMALLINT,
156 sqlalchemy.types.BIGINT,
157 sqltypes.DOUBLE,
158 sqltypes.DOUBLE,
159 ),
160 ):
161 self.assertIsInstance(column.type, (ctype, sqlalchemy.types.Variant))
162 unique = _get_unique_constraint(table)
163 self.assertIsNotNone(unique)
164 assert unique is not None, "Constraint must be defined"
165 self.assertEqual(unique.name, "UQ_sdqaRatingForAmpVisit_metricId_ampVisitId")
166 self.assertCountEqual(unique.columns, [table.columns["sdqa_metricId"], table.columns["ampVisitId"]])
167 # It has a bunch of indices.
168 indices = _get_indices(table)
169 self.assertCountEqual(
170 indices.keys(),
171 [
172 "IDX_sdqaRatingForAmpVisit_metricId",
173 "IDX_sdqaRatingForAmpVisit_thresholdId",
174 "IDX_sdqaRatingForAmpVisit_ampVisitId",
175 ],
176 )
177 self.assertCountEqual(
178 indices["IDX_sdqaRatingForAmpVisit_metricId"].columns,
179 [schema.graph_index["#sdqa_Rating_ForAmpVisit.sdqa_metricId"]],
180 )
181 # And a foreign key referencing sdqa_Metric table.
182 self.assertEqual(len(table.foreign_key_constraints), 1)
183 fk = list(table.foreign_key_constraints)[0]
184 self.assertEqual(fk.name, "FK_sdqa_Rating_ForAmpVisit_sdqa_Metric")
185 self.assertCountEqual(fk.columns, [table.columns["sdqa_metricId"]])
186 self.assertIs(fk.referred_table, tables["sdqa.sdqa_Metric"])
189if __name__ == "__main__": 189 ↛ 190line 189 didn't jump to line 190, because the condition on line 189 was never true
190 unittest.main()