Coverage for tests/test_metadata.py: 9%
107 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 02:49 -0700
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 02:49 -0700
1# This file is part of felis.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import os
23import unittest
25import yaml
26from sqlalchemy import (
27 CheckConstraint,
28 Constraint,
29 ForeignKeyConstraint,
30 Index,
31 MetaData,
32 PrimaryKeyConstraint,
33 UniqueConstraint,
34 create_engine,
35)
37from felis import datamodel as dm
38from felis.datamodel import Schema
39from felis.metadata import DatabaseContext, MetaDataBuilder, get_datatype_with_variants
41TESTDIR = os.path.abspath(os.path.dirname(__file__))
42TEST_YAML = os.path.join(TESTDIR, "data", "sales.yaml")
45class MetaDataTestCase(unittest.TestCase):
46 """Test creation of SQLAlchemy `MetaData` from a `Schema`."""
48 def setUp(self) -> None:
49 """Create an in-memory SQLite database and load the test data."""
50 self.engine = create_engine("sqlite://")
51 with open(TEST_YAML) as data:
52 self.yaml_data = yaml.safe_load(data)
54 def connection(self):
55 """Return a connection to the database."""
56 return self.engine.connect()
58 def test_create_all(self):
59 """Create all tables in the schema using the metadata object and a
60 sqlite connection.
62 Check that the reflected `MetaData` from the database matches that
63 which was created by the `MetaDataBuilder`.
64 """
66 def _sorted_indexes(indexes: set[Index]) -> list[Index]:
67 """Return a sorted list of indexes."""
68 return sorted(indexes, key=lambda i: i.name)
70 def _sorted_constraints(constraints: set[Constraint]) -> list[Constraint]:
71 """Return a sorted list of constraints with the
72 `PrimaryKeyConstraint` objects filtered out.
73 """
74 return sorted(
75 [c for c in constraints if not isinstance(c, PrimaryKeyConstraint)], key=lambda c: c.name
76 )
78 with self.connection() as connection:
79 schema = Schema.model_validate(self.yaml_data)
80 schema.name = "main"
81 builder = MetaDataBuilder(schema)
82 md = builder.build()
84 ctx = DatabaseContext(md, connection)
86 ctx.create_all()
88 md_db = MetaData()
89 md_db.reflect(connection, schema=schema.name)
91 self.assertEqual(md_db.tables.keys(), md.tables.keys())
93 for md_table_name in md.tables.keys():
94 md_table = md.tables[md_table_name]
95 md_db_table = md_db.tables[md_table_name]
96 self.assertEqual(md_table.columns.keys(), md_db_table.columns.keys())
97 for md_column_name in md_table.columns.keys():
98 md_column = md_table.columns[md_column_name]
99 md_db_column = md_db_table.columns[md_column_name]
100 self.assertEqual(type(md_column.type), type(md_db_column.type))
101 self.assertEqual(md_column.nullable, md_db_column.nullable)
102 self.assertEqual(md_column.primary_key, md_db_column.primary_key)
103 self.assertTrue(
104 (md_table.constraints and md_db_table.constraints)
105 or (not md_table.constraints and not md_table.constraints),
106 "Constraints not created correctly",
107 )
108 if md_table.constraints:
109 self.assertEqual(len(md_table.constraints), len(md_db_table.constraints))
110 md_constraints = _sorted_constraints(md_table.constraints)
111 md_db_constraints = _sorted_constraints(md_db_table.constraints)
112 for md_constraint, md_db_constraint in zip(md_constraints, md_db_constraints):
113 self.assertEqual(md_constraint.name, md_db_constraint.name)
114 self.assertEqual(md_constraint.deferrable, md_db_constraint.deferrable)
115 self.assertEqual(md_constraint.initially, md_db_constraint.initially)
116 if isinstance(md_constraint, ForeignKeyConstraint):
117 md_fk: ForeignKeyConstraint = md_constraint
118 md_db_fk: ForeignKeyConstraint = md_db_constraint
119 self.assertEqual(md_fk.referred_table.name, md_db_fk.referred_table.name)
120 self.assertEqual(md_fk.column_keys, md_db_fk.column_keys)
121 elif isinstance(md_constraint, UniqueConstraint):
122 md_uniq: UniqueConstraint = md_constraint
123 md_db_uniq: UniqueConstraint = md_db_constraint
124 self.assertEqual(md_uniq.columns.keys(), md_db_uniq.columns.keys())
125 elif isinstance(md_constraint, CheckConstraint):
126 md_check: CheckConstraint = md_constraint
127 md_db_check: CheckConstraint = md_db_constraint
128 self.assertEqual(str(md_check.sqltext), str(md_db_check.sqltext))
129 self.assertTrue(
130 (md_table.indexes and md_db_table.indexes)
131 or (not md_table.indexes and not md_table.indexes),
132 "Indexes not created correctly",
133 )
134 if md_table.indexes:
135 md_indexes = _sorted_indexes(md_table.indexes)
136 md_db_indexes = _sorted_indexes(md_db_table.indexes)
137 self.assertEqual(len(md_indexes), len(md_db_indexes))
138 for md_index, md_db_index in zip(md_table.indexes, md_db_table.indexes):
139 self.assertEqual(md_index.name, md_db_index.name)
140 self.assertEqual(md_index.columns.keys(), md_db_index.columns.keys())
142 def test_builder(self):
143 """Test that the `MetaData` object created by the `MetaDataBuilder`
144 matches the Felis `Schema` used to build it.
145 """
146 sch = Schema.model_validate(self.yaml_data)
147 bld = MetaDataBuilder(sch, apply_schema_to_tables=False, apply_schema_to_metadata=False)
148 md = bld.build()
150 self.assertEqual(len(sch.tables), len(md.tables))
151 self.assertEqual([table.name for table in sch.tables], list(md.tables.keys()))
152 for table in sch.tables:
153 md_table = md.tables[table.name]
154 self.assertEqual(table.name, md_table.name)
155 self.assertEqual(len(table.columns), len(md_table.columns))
156 for column in table.columns:
157 md_table_column = md_table.columns[column.name]
158 datatype = get_datatype_with_variants(column)
159 self.assertEqual(type(datatype), type(md_table_column.type))
160 if column.nullable is not None:
161 self.assertEqual(column.nullable, md_table_column.nullable)
162 for constraint in table.constraints:
163 md_constraint = [mdc for mdc in md_table.constraints if mdc.name == constraint.name][0]
164 if isinstance(constraint, dm.ForeignKeyConstraint):
165 self.assertTrue(isinstance(md_constraint, ForeignKeyConstraint))
166 self.assertTrue(
167 sorted([sch[column_id].name for column_id in constraint.columns]),
168 sorted(md_constraint.columns.keys()),
169 )
170 elif isinstance(constraint, dm.UniqueConstraint):
171 self.assertEqual(
172 sorted([sch[column_id].name for column_id in constraint.columns]),
173 sorted(md_constraint.columns.keys()),
174 )
175 elif isinstance(constraint, dm.CheckConstraint):
176 self.assertEqual(constraint.expression, str(md_constraint.sqltext))
177 for index in table.indexes:
178 md_index = [mdi for mdi in md_table.indexes if mdi.name == index.name][0]
179 self.assertEqual(
180 sorted([sch[column_id].name for column_id in index.columns]),
181 sorted(md_index.columns.keys()),
182 )
183 if table.primary_key:
184 if isinstance(table.primary_key, str):
185 primary_keys = [sch[table.primary_key].name]
186 else:
187 primary_keys = [sch[pk].name for pk in table.primary_key]
188 for primary_key in primary_keys:
189 self.assertTrue(md_table.columns[primary_key].primary_key)
192if __name__ == "__main__": 192 ↛ 193line 192 didn't jump to line 193, because the condition on line 192 was never true
193 unittest.main()