Coverage for tests/test_sql.py: 18%
75 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-22 01:55 -0700
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-22 01:55 -0700
1# This file is part of felis.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import os
23import unittest
24from collections.abc import Mapping, MutableMapping
25from typing import Any, Optional
27import sqlalchemy
28import yaml
30from felis import DEFAULT_FRAME
31from felis.sql import SQLVisitor
33TESTDIR = os.path.abspath(os.path.dirname(__file__))
34TEST_YAML = os.path.join(TESTDIR, "data", "test.yml")
37def _get_unique_constraint(table: sqlalchemy.schema.Table) -> Optional[sqlalchemy.schema.UniqueConstraint]:
38 """Return a unique constraint for a table, raise if table has more than
39 one unique constraint.
40 """
41 uniques = [
42 constraint
43 for constraint in table.constraints
44 if isinstance(constraint, sqlalchemy.schema.UniqueConstraint)
45 ]
46 if len(uniques) > 1:
47 raise TypeError(f"More than one constraint defined for table {table}")
48 elif not uniques:
49 return None
50 else:
51 return uniques[0]
54def _get_indices(table: sqlalchemy.schema.Table) -> Mapping[str, sqlalchemy.schema.Index]:
55 """Return mapping of table indices indexed by index name."""
56 return {index.name: index for index in table.indexes}
59class VisitorTestCase(unittest.TestCase):
60 """Tests for both CheckingVisitor and SQLVisitor classes."""
62 schema_obj: MutableMapping[str, Any] = {}
64 def setUp(self) -> None:
65 """Load data from test file."""
66 with open(TEST_YAML) as test_yaml:
67 self.schema_obj = yaml.load(test_yaml, Loader=yaml.SafeLoader)
68 self.schema_obj.update(DEFAULT_FRAME)
70 def test_make_metadata(self) -> None:
71 """Generate sqlalchemy metadata using SQLVisitor class"""
72 visitor = SQLVisitor()
73 schema = visitor.visit_schema(self.schema_obj)
74 self.assertIsNotNone(schema)
75 self.assertEqual(schema.name, "sdqa")
76 self.assertIsNotNone(schema.tables)
77 self.assertIsNotNone(schema.graph_index)
78 self.assertIsNotNone(schema.metadata)
80 table_names = [
81 "sdqa_ImageStatus",
82 "sdqa_Metric",
83 "sdqa_Rating_ForAmpVisit",
84 "sdqa_Rating_CcdVisit",
85 "sdqa_Threshold",
86 ]
88 # Look at metadata tables.
89 self.assertIsNone(schema.metadata.schema)
90 tables = schema.metadata.tables
91 self.assertCountEqual(tables.keys(), [f"sdqa.{table}" for table in table_names])
93 # Check schema.tables attribute.
94 self.assertCountEqual([table.name for table in schema.tables], table_names)
96 # Checks tables in graph index.
97 for table_name in table_names:
98 self.assertIs(schema.graph_index[f"#{table_name}"], tables[f"sdqa.{table_name}"])
100 # Details of sdqa_ImageStatus table.
101 table = tables["sdqa.sdqa_ImageStatus"]
102 self.assertCountEqual(table.columns.keys(), ["sdqa_imageStatusId", "statusName", "definition"])
103 self.assertTrue(table.columns["sdqa_imageStatusId"].primary_key)
104 self.assertFalse(table.indexes)
105 for column in table.columns.values():
106 self.assertIsInstance(column.type, sqlalchemy.types.Variant)
108 # Details of sdqa_Metric table.
109 table = tables["sdqa.sdqa_Metric"]
110 self.assertCountEqual(
111 table.columns.keys(), ["sdqa_metricId", "metricName", "physicalUnits", "dataType", "definition"]
112 )
113 self.assertTrue(table.columns["sdqa_metricId"].primary_key)
114 self.assertFalse(table.indexes)
115 for column in table.columns.values():
116 self.assertIsInstance(column.type, sqlalchemy.types.Variant)
117 # It defines a unique constraint.
118 unique = _get_unique_constraint(table)
119 assert unique is not None, "Constraint must be defined"
120 self.assertEqual(unique.name, "UQ_sdqaMetric_metricName")
121 self.assertCountEqual(unique.columns, [table.columns["metricName"]])
123 # Details of sdqa_Rating_ForAmpVisit table.
124 table = tables["sdqa.sdqa_Rating_ForAmpVisit"]
125 self.assertCountEqual(
126 table.columns.keys(),
127 [
128 "sdqa_ratingId",
129 "sdqa_metricId",
130 "sdqa_thresholdId",
131 "ampVisitId",
132 "metricValue",
133 "metricSigma",
134 ],
135 )
136 self.assertTrue(table.columns["sdqa_ratingId"].primary_key)
137 for column in table.columns.values():
138 self.assertIsInstance(column.type, sqlalchemy.types.Variant)
139 unique = _get_unique_constraint(table)
140 self.assertIsNotNone(unique)
141 self.assertEqual(unique.name, "UQ_sdqaRatingForAmpVisit_metricId_ampVisitId")
142 self.assertCountEqual(unique.columns, [table.columns["sdqa_metricId"], table.columns["ampVisitId"]])
143 # It has a bunch of indices.
144 indices = _get_indices(table)
145 self.assertCountEqual(
146 indices.keys(),
147 [
148 "IDX_sdqaRatingForAmpVisit_metricId",
149 "IDX_sdqaRatingForAmpVisit_thresholdId",
150 "IDX_sdqaRatingForAmpVisit_ampVisitId",
151 ],
152 )
153 self.assertCountEqual(
154 indices["IDX_sdqaRatingForAmpVisit_metricId"].columns,
155 [schema.graph_index["#sdqa_Rating_ForAmpVisit.sdqa_metricId"]],
156 )
157 # And a foreign key referencing sdqa_Metric table.
158 self.assertEqual(len(table.foreign_key_constraints), 1)
159 fk = list(table.foreign_key_constraints)[0]
160 self.assertEqual(fk.name, "FK_sdqa_Rating_ForAmpVisit_sdqa_Metric")
161 self.assertCountEqual(fk.columns, [table.columns["sdqa_metricId"]])
162 self.assertIs(fk.referred_table, tables["sdqa.sdqa_Metric"])
165if __name__ == "__main__": 165 ↛ 166line 165 didn't jump to line 166, because the condition on line 165 was never true
166 unittest.main()