Coverage for python / felis / diff.py: 25%
114 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-17 08:49 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-17 08:49 +0000
1"""Compare schemas and print the differences."""
3# This file is part of felis.
4#
5# Developed for the LSST Data Management System.
6# This product includes software developed by the LSST Project
7# (https://www.lsst.org).
8# See the COPYRIGHT file at the top-level directory of this distribution
9# for details of code ownership.
10#
11# This program is free software: you can redistribute it and/or modify
12# it under the terms of the GNU General Public License as published by
13# the Free Software Foundation, either version 3 of the License, or
14# (at your option) any later version.
15#
16# This program is distributed in the hope that it will be useful,
17# but WITHOUT ANY WARRANTY; without even the implied warranty of
18# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19# GNU General Public License for more details.
20#
21# You should have received a copy of the GNU General Public License
22# along with this program. If not, see <https://www.gnu.org/licenses/>.
24import logging
25import pprint
26import re
27from collections.abc import Callable
28from typing import Any
30from alembic.autogenerate import compare_metadata
31from alembic.migration import MigrationContext
32from deepdiff.diff import DeepDiff
33from sqlalchemy import Engine, MetaData
35from .datamodel import Schema
36from .metadata import MetaDataBuilder
38__all__ = ["DatabaseDiff", "SchemaDiff"]
40logger = logging.getLogger(__name__)
42# Change alembic log level to avoid unnecessary output
43logging.getLogger("alembic").setLevel(logging.WARNING)
46class SchemaDiff:
47 """
48 Compare two schemas using DeepDiff and print the differences.
50 Parameters
51 ----------
52 schema1
53 The first schema to compare.
54 schema2
55 The second schema to compare.
56 """
58 def __init__(self, schema1: Schema, schema2: Schema):
59 self.dict1 = schema1.model_dump(exclude_none=True)
60 self.dict2 = schema2.model_dump(exclude_none=True)
61 self.diff = DeepDiff(self.dict1, self.dict2, ignore_order=True)
63 def print(self) -> None:
64 """
65 Print the differences between the two schemas.
66 """
67 pprint.pprint(self.diff)
69 @property
70 def has_changes(self) -> bool:
71 """
72 Check if there are any differences between the two schemas.
74 Returns
75 -------
76 bool
77 True if there are differences, False otherwise.
78 """
79 return len(self.diff) > 0
82class FormattedSchemaDiff(SchemaDiff):
83 """
84 Compare two schemas using DeepDiff and print the differences using a
85 customized output format.
87 Parameters
88 ----------
89 schema1
90 The first schema to compare.
91 schema2
92 The second schema to compare.
93 """
95 def __init__(self, schema1: Schema, schema2: Schema):
96 super().__init__(schema1, schema2)
98 def print(self) -> None:
99 """
100 Print the differences between the two schemas using a custom format.
101 """
102 handlers: dict[str, Callable[[dict[str, Any]], None]] = {
103 "values_changed": self._handle_values_changed,
104 "iterable_item_added": self._handle_iterable_item_added,
105 "iterable_item_removed": self._handle_iterable_item_removed,
106 "dictionary_item_added": self._handle_dictionary_item_added,
107 "dictionary_item_removed": self._handle_dictionary_item_removed,
108 }
110 for change_type, handler in handlers.items():
111 if change_type in self.diff:
112 handler(self.diff[change_type])
114 def _print_header(self, id_dict: dict[str, Any], keys: list[int | str]) -> None:
115 # id = self._get_id(id_dict, keys)
116 # Don't display ID here for now; it is always just the schema ID.
117 print(f"{self._get_key_display(keys)}")
118 # print(f"{id} @ {self._get_key_display(keys)}")
120 def _handle_values_changed(self, changes: dict[str, Any]) -> None:
121 for key in changes:
122 keys = self._parse_deepdiff_path(key)
123 value1 = changes[key]["old_value"]
124 value2 = changes[key]["new_value"]
125 self._print_header(self.dict1, keys)
126 print(f"- {value1}")
127 print(f"+ {value2}")
129 def _handle_iterable_item_added(self, changes: dict[str, Any]) -> None:
130 for key in changes:
131 keys = self._parse_deepdiff_path(key)
132 value = changes[key]
133 self._print_header(self.dict2, keys)
134 print(f"+ {value}")
136 def _handle_iterable_item_removed(self, changes: dict[str, Any]) -> None:
137 for key in changes:
138 keys = self._parse_deepdiff_path(key)
139 value = changes[key]
140 self._print_header(self.dict1, keys)
141 print(f"- {value}")
143 def _handle_dictionary_item_added(self, changes: dict[str, Any]) -> None:
144 for key in changes:
145 keys = self._parse_deepdiff_path(key)
146 value = keys[-1]
147 keys.pop()
148 self._print_header(self.dict2, keys)
149 print(f"+ {value}")
151 def _handle_dictionary_item_removed(self, changes: dict[str, Any]) -> None:
152 for key in changes:
153 keys = self._parse_deepdiff_path(key)
154 value = keys[-1]
155 keys.pop()
156 self._print_header(self.dict1, keys)
157 print(f"- {value}")
159 @staticmethod
160 def _get_id(values: dict, keys: list[str | int]) -> str:
161 # Unused for now, pending updates to diff tool in DM-49446.
162 value: list | dict = values
163 last_id = None
165 for key in keys:
166 logger.debug(f"Processing key <{key}> with type {type(key)}")
167 logger.debug(f"Type of value: {type(value)}")
168 if isinstance(value, dict) and "id" in value:
169 last_id = value["id"]
170 elif isinstance(value, list) and isinstance(key, int):
171 if 0 <= key < len(value):
172 value = value[key]
173 else:
174 raise ValueError(f"Index '{key}' is out of range for list of length {len(value)}")
175 value = value[key]
177 if isinstance(value, dict) and "id" in value:
178 last_id = value["id"]
180 if last_id is not None:
181 return last_id
182 else:
183 raise ValueError("No 'id' found in the specified path")
185 @staticmethod
186 def _get_key_display(keys: list[str | int]) -> str:
187 return ".".join(str(k) for k in keys)
189 @staticmethod
190 def _parse_deepdiff_path(path: str) -> list[str | int]:
191 if path.startswith("root"):
192 path = path[4:]
194 pattern = re.compile(r"\['([^']+)'\]|\[(\d+)\]")
195 matches = pattern.findall(path)
197 keys = []
198 for match in matches:
199 if match[0]: # String key
200 keys.append(match[0])
201 elif match[1]: # Integer index
202 keys.append(int(match[1]))
204 return keys
207class DatabaseDiff(SchemaDiff):
208 """
209 Compare a schema with a database and print the differences.
211 Parameters
212 ----------
213 schema
214 The schema to compare.
215 engine
216 The database engine to compare with.
217 """
219 def __init__(self, schema: Schema, engine: Engine):
220 db_metadata = MetaData()
221 with engine.connect() as connection:
222 db_metadata.reflect(bind=connection)
223 mc = MigrationContext.configure(
224 connection, opts={"compare_type": True, "target_metadata": db_metadata}
225 )
226 schema_metadata = MetaDataBuilder(schema, apply_schema_to_metadata=False).build()
227 self.diff = compare_metadata(mc, schema_metadata)
229 def print(self) -> None:
230 """
231 Print the differences between the schema and the database.
232 """
233 if self.has_changes:
234 pprint.pprint(self.diff)