Coverage for tests/test_metadata.py: 9%

107 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-20 02:40 -0700

1# This file is part of felis. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import os 

23import unittest 

24 

25import yaml 

26from sqlalchemy import ( 

27 CheckConstraint, 

28 Constraint, 

29 ForeignKeyConstraint, 

30 Index, 

31 MetaData, 

32 PrimaryKeyConstraint, 

33 UniqueConstraint, 

34 create_engine, 

35) 

36 

37from felis import datamodel as dm 

38from felis.datamodel import Schema 

39from felis.metadata import DatabaseContext, MetaDataBuilder, get_datatype_with_variants 

40 

41TESTDIR = os.path.abspath(os.path.dirname(__file__)) 

42TEST_YAML = os.path.join(TESTDIR, "data", "sales.yaml") 

43 

44 

45class MetaDataTestCase(unittest.TestCase): 

46 """Test creation of SQLAlchemy `MetaData` from a `Schema`.""" 

47 

48 def setUp(self) -> None: 

49 """Create an in-memory SQLite database and load the test data.""" 

50 self.engine = create_engine("sqlite://") 

51 with open(TEST_YAML) as data: 

52 self.yaml_data = yaml.safe_load(data) 

53 

54 def connection(self): 

55 """Return a connection to the database.""" 

56 return self.engine.connect() 

57 

58 def test_create_all(self): 

59 """Create all tables in the schema using the metadata object and a 

60 sqlite connection. 

61 

62 Check that the reflected `MetaData` from the database matches that 

63 which was created by the `MetaDataBuilder`. 

64 """ 

65 

66 def _sorted_indexes(indexes: set[Index]) -> list[Index]: 

67 """Return a sorted list of indexes.""" 

68 return sorted(indexes, key=lambda i: i.name) 

69 

70 def _sorted_constraints(constraints: set[Constraint]) -> list[Constraint]: 

71 """Return a sorted list of constraints with the 

72 `PrimaryKeyConstraint` objects filtered out. 

73 """ 

74 return sorted( 

75 [c for c in constraints if not isinstance(c, PrimaryKeyConstraint)], key=lambda c: c.name 

76 ) 

77 

78 with self.connection() as connection: 

79 schema = Schema.model_validate(self.yaml_data) 

80 schema.name = "main" 

81 builder = MetaDataBuilder(schema) 

82 md = builder.build() 

83 

84 ctx = DatabaseContext(md, connection) 

85 

86 ctx.create_all() 

87 

88 md_db = MetaData() 

89 md_db.reflect(connection, schema=schema.name) 

90 

91 self.assertEqual(md_db.tables.keys(), md.tables.keys()) 

92 

93 for md_table_name in md.tables.keys(): 

94 md_table = md.tables[md_table_name] 

95 md_db_table = md_db.tables[md_table_name] 

96 self.assertEqual(md_table.columns.keys(), md_db_table.columns.keys()) 

97 for md_column_name in md_table.columns.keys(): 

98 md_column = md_table.columns[md_column_name] 

99 md_db_column = md_db_table.columns[md_column_name] 

100 self.assertEqual(type(md_column.type), type(md_db_column.type)) 

101 self.assertEqual(md_column.nullable, md_db_column.nullable) 

102 self.assertEqual(md_column.primary_key, md_db_column.primary_key) 

103 self.assertTrue( 

104 (md_table.constraints and md_db_table.constraints) 

105 or (not md_table.constraints and not md_table.constraints), 

106 "Constraints not created correctly", 

107 ) 

108 if md_table.constraints: 

109 self.assertEqual(len(md_table.constraints), len(md_db_table.constraints)) 

110 md_constraints = _sorted_constraints(md_table.constraints) 

111 md_db_constraints = _sorted_constraints(md_db_table.constraints) 

112 for md_constraint, md_db_constraint in zip(md_constraints, md_db_constraints): 

113 self.assertEqual(md_constraint.name, md_db_constraint.name) 

114 self.assertEqual(md_constraint.deferrable, md_db_constraint.deferrable) 

115 self.assertEqual(md_constraint.initially, md_db_constraint.initially) 

116 if isinstance(md_constraint, ForeignKeyConstraint): 

117 md_fk: ForeignKeyConstraint = md_constraint 

118 md_db_fk: ForeignKeyConstraint = md_db_constraint 

119 self.assertEqual(md_fk.referred_table.name, md_db_fk.referred_table.name) 

120 self.assertEqual(md_fk.column_keys, md_db_fk.column_keys) 

121 elif isinstance(md_constraint, UniqueConstraint): 

122 md_uniq: UniqueConstraint = md_constraint 

123 md_db_uniq: UniqueConstraint = md_db_constraint 

124 self.assertEqual(md_uniq.columns.keys(), md_db_uniq.columns.keys()) 

125 elif isinstance(md_constraint, CheckConstraint): 

126 md_check: CheckConstraint = md_constraint 

127 md_db_check: CheckConstraint = md_db_constraint 

128 self.assertEqual(str(md_check.sqltext), str(md_db_check.sqltext)) 

129 self.assertTrue( 

130 (md_table.indexes and md_db_table.indexes) 

131 or (not md_table.indexes and not md_table.indexes), 

132 "Indexes not created correctly", 

133 ) 

134 if md_table.indexes: 

135 md_indexes = _sorted_indexes(md_table.indexes) 

136 md_db_indexes = _sorted_indexes(md_db_table.indexes) 

137 self.assertEqual(len(md_indexes), len(md_db_indexes)) 

138 for md_index, md_db_index in zip(md_table.indexes, md_db_table.indexes): 

139 self.assertEqual(md_index.name, md_db_index.name) 

140 self.assertEqual(md_index.columns.keys(), md_db_index.columns.keys()) 

141 

142 def test_builder(self): 

143 """Test that the `MetaData` object created by the `MetaDataBuilder` 

144 matches the Felis `Schema` used to build it. 

145 """ 

146 sch = Schema.model_validate(self.yaml_data) 

147 bld = MetaDataBuilder(sch, apply_schema_to_tables=False, apply_schema_to_metadata=False) 

148 md = bld.build() 

149 

150 self.assertEqual(len(sch.tables), len(md.tables)) 

151 self.assertEqual([table.name for table in sch.tables], list(md.tables.keys())) 

152 for table in sch.tables: 

153 md_table = md.tables[table.name] 

154 self.assertEqual(table.name, md_table.name) 

155 self.assertEqual(len(table.columns), len(md_table.columns)) 

156 for column in table.columns: 

157 md_table_column = md_table.columns[column.name] 

158 datatype = get_datatype_with_variants(column) 

159 self.assertEqual(type(datatype), type(md_table_column.type)) 

160 if column.nullable is not None: 

161 self.assertEqual(column.nullable, md_table_column.nullable) 

162 for constraint in table.constraints: 

163 md_constraint = [mdc for mdc in md_table.constraints if mdc.name == constraint.name][0] 

164 if isinstance(constraint, dm.ForeignKeyConstraint): 

165 self.assertTrue(isinstance(md_constraint, ForeignKeyConstraint)) 

166 self.assertTrue( 

167 sorted([sch[column_id].name for column_id in constraint.columns]), 

168 sorted(md_constraint.columns.keys()), 

169 ) 

170 elif isinstance(constraint, dm.UniqueConstraint): 

171 self.assertEqual( 

172 sorted([sch[column_id].name for column_id in constraint.columns]), 

173 sorted(md_constraint.columns.keys()), 

174 ) 

175 elif isinstance(constraint, dm.CheckConstraint): 

176 self.assertEqual(constraint.expression, str(md_constraint.sqltext)) 

177 for index in table.indexes: 

178 md_index = [mdi for mdi in md_table.indexes if mdi.name == index.name][0] 

179 self.assertEqual( 

180 sorted([sch[column_id].name for column_id in index.columns]), 

181 sorted(md_index.columns.keys()), 

182 ) 

183 if table.primary_key: 

184 if isinstance(table.primary_key, str): 

185 primary_keys = [sch[table.primary_key].name] 

186 else: 

187 primary_keys = [sch[pk].name for pk in table.primary_key] 

188 for primary_key in primary_keys: 

189 self.assertTrue(md_table.columns[primary_key].primary_key) 

190 

191 

192if __name__ == "__main__": 192 ↛ 193line 192 didn't jump to line 193, because the condition on line 192 was never true

193 unittest.main()