Coverage for python / felis / db / _variants.py: 45%
44 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 08:49 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 08:49 +0000
1"""Handle variant overrides for a Felis column."""
3# This file is part of felis.
4#
5# Developed for the LSST Data Management System.
6# This product includes software developed by the LSST Project
7# (https://www.lsst.org).
8# See the COPYRIGHT file at the top-level directory of this distribution
9# for details of code ownership.
10#
11# This program is free software: you can redistribute it and/or modify
12# it under the terms of the GNU General Public License as published by
13# the Free Software Foundation, either version 3 of the License, or
14# (at your option) any later version.
15#
16# This program is distributed in the hope that it will be useful,
17# but WITHOUT ANY WARRANTY; without even the implied warranty of
18# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19# GNU General Public License for more details.
20#
21# You should have received a copy of the GNU General Public License
22# along with this program. If not, see <https://www.gnu.org/licenses/>.
24from __future__ import annotations
26import re
27from collections.abc import Mapping
28from types import MappingProxyType
29from typing import Any
31from sqlalchemy import types
32from sqlalchemy.types import TypeEngine
34from ..datamodel import Column
35from ._dialects import get_dialect_module, get_supported_dialects
37__all__ = ["make_variant_dict"]
40def _create_column_variant_overrides() -> dict[str, str]:
41 """Map column variant overrides to their dialect name.
43 Returns
44 -------
45 column_variant_overrides : `dict` [ `str`, `str` ]
46 A mapping of column variant overrides to their dialect name.
48 Notes
49 -----
50 This function is intended for internal use only.
51 """
52 column_variant_overrides = {}
53 for dialect_name in get_supported_dialects().keys():
54 column_variant_overrides[f"{dialect_name}_datatype"] = dialect_name
55 return column_variant_overrides
58_COLUMN_VARIANT_OVERRIDES = MappingProxyType(_create_column_variant_overrides())
59"""Map of column variant overrides to their dialect name."""
62def _get_column_variant_overrides() -> Mapping[str, str]:
63 """Get a dictionary of column variant overrides.
65 Returns
66 -------
67 column_variant_overrides : `dict` [ `str`, `str` ]
68 A mapping of column variant overrides to their dialect name.
69 """
70 return _COLUMN_VARIANT_OVERRIDES
73def _get_column_variant_override(field_name: str) -> str:
74 """Get the dialect name from an override field name on the column like
75 ``mysql_datatype``.
77 Returns
78 -------
79 dialect_name : `str`
80 The name of the dialect.
82 Raises
83 ------
84 ValueError
85 Raised if the field name is not found in the column variant overrides.
86 """
87 if field_name not in _COLUMN_VARIANT_OVERRIDES:
88 raise ValueError(f"Field name {field_name} not found in column variant overrides")
89 return _COLUMN_VARIANT_OVERRIDES[field_name]
92_length_regex = re.compile(r"\((\d+)\)")
93"""A regular expression that is looking for numbers within parentheses."""
96def _process_variant_override(dialect_name: str, variant_override_str: str) -> types.TypeEngine:
97 """Get the variant type for the given dialect.
99 Parameters
100 ----------
101 dialect_name
102 The name of the dialect to create.
103 variant_override_str
104 The string representation of the variant override.
106 Returns
107 -------
108 variant_type : `~sqlalchemy.types.TypeEngine`
109 The variant type for the given dialect.
111 Raises
112 ------
113 ValueError
114 Raised if the type is not found in the dialect.
116 Notes
117 -----
118 This function converts a string representation of a variant override
119 into a `sqlalchemy.types.TypeEngine` object.
120 """
121 dialect = get_dialect_module(dialect_name)
122 variant_type_name = variant_override_str.split("(")[0]
124 # Process Variant Type
125 if variant_type_name not in dir(dialect):
126 raise ValueError(f"Type {variant_type_name} not found in dialect {dialect_name}")
127 variant_type = getattr(dialect, variant_type_name)
128 length_params = []
129 if match := _length_regex.search(variant_override_str):
130 length_params.extend([int(i) for i in match.group(1).split(",")])
131 return variant_type(*length_params)
134def make_variant_dict(column_obj: Column) -> dict[str, TypeEngine[Any]]:
135 """Handle variant overrides for a `felis.datamodel.Column`.
137 This function will return a dictionary of `str` to
138 `sqlalchemy.types.TypeEngine` containing variant datatype information
139 (e.g., for mysql, postgresql, etc).
141 Parameters
142 ----------
143 column_obj
144 The column object from which to build the variant dictionary.
146 Returns
147 -------
148 `dict` [ `str`, `~sqlalchemy.types.TypeEngine` ]
149 The dictionary of `str` to `sqlalchemy.types.TypeEngine` containing
150 variant datatype information (e.g., for mysql, postgresql, etc).
151 """
152 variant_dict = {}
153 variant_overrides = _get_column_variant_overrides()
154 for field_name, value in iter(column_obj):
155 if field_name in variant_overrides and value is not None:
156 dialect = _get_column_variant_override(field_name)
157 variant: TypeEngine = _process_variant_override(dialect, value)
158 variant_dict[dialect] = variant
159 return variant_dict