Coverage for python / felis / db / _variants.py: 45%

44 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-26 08:49 +0000

1"""Handle variant overrides for a Felis column.""" 

2 

3# This file is part of felis. 

4# 

5# Developed for the LSST Data Management System. 

6# This product includes software developed by the LSST Project 

7# (https://www.lsst.org). 

8# See the COPYRIGHT file at the top-level directory of this distribution 

9# for details of code ownership. 

10# 

11# This program is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU General Public License as published by 

13# the Free Software Foundation, either version 3 of the License, or 

14# (at your option) any later version. 

15# 

16# This program is distributed in the hope that it will be useful, 

17# but WITHOUT ANY WARRANTY; without even the implied warranty of 

18# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

19# GNU General Public License for more details. 

20# 

21# You should have received a copy of the GNU General Public License 

22# along with this program. If not, see <https://www.gnu.org/licenses/>. 

23 

24from __future__ import annotations 

25 

26import re 

27from collections.abc import Mapping 

28from types import MappingProxyType 

29from typing import Any 

30 

31from sqlalchemy import types 

32from sqlalchemy.types import TypeEngine 

33 

34from ..datamodel import Column 

35from ._dialects import get_dialect_module, get_supported_dialects 

36 

37__all__ = ["make_variant_dict"] 

38 

39 

40def _create_column_variant_overrides() -> dict[str, str]: 

41 """Map column variant overrides to their dialect name. 

42 

43 Returns 

44 ------- 

45 column_variant_overrides : `dict` [ `str`, `str` ] 

46 A mapping of column variant overrides to their dialect name. 

47 

48 Notes 

49 ----- 

50 This function is intended for internal use only. 

51 """ 

52 column_variant_overrides = {} 

53 for dialect_name in get_supported_dialects().keys(): 

54 column_variant_overrides[f"{dialect_name}_datatype"] = dialect_name 

55 return column_variant_overrides 

56 

57 

58_COLUMN_VARIANT_OVERRIDES = MappingProxyType(_create_column_variant_overrides()) 

59"""Map of column variant overrides to their dialect name.""" 

60 

61 

62def _get_column_variant_overrides() -> Mapping[str, str]: 

63 """Get a dictionary of column variant overrides. 

64 

65 Returns 

66 ------- 

67 column_variant_overrides : `dict` [ `str`, `str` ] 

68 A mapping of column variant overrides to their dialect name. 

69 """ 

70 return _COLUMN_VARIANT_OVERRIDES 

71 

72 

73def _get_column_variant_override(field_name: str) -> str: 

74 """Get the dialect name from an override field name on the column like 

75 ``mysql_datatype``. 

76 

77 Returns 

78 ------- 

79 dialect_name : `str` 

80 The name of the dialect. 

81 

82 Raises 

83 ------ 

84 ValueError 

85 Raised if the field name is not found in the column variant overrides. 

86 """ 

87 if field_name not in _COLUMN_VARIANT_OVERRIDES: 

88 raise ValueError(f"Field name {field_name} not found in column variant overrides") 

89 return _COLUMN_VARIANT_OVERRIDES[field_name] 

90 

91 

92_length_regex = re.compile(r"\((\d+)\)") 

93"""A regular expression that is looking for numbers within parentheses.""" 

94 

95 

96def _process_variant_override(dialect_name: str, variant_override_str: str) -> types.TypeEngine: 

97 """Get the variant type for the given dialect. 

98 

99 Parameters 

100 ---------- 

101 dialect_name 

102 The name of the dialect to create. 

103 variant_override_str 

104 The string representation of the variant override. 

105 

106 Returns 

107 ------- 

108 variant_type : `~sqlalchemy.types.TypeEngine` 

109 The variant type for the given dialect. 

110 

111 Raises 

112 ------ 

113 ValueError 

114 Raised if the type is not found in the dialect. 

115 

116 Notes 

117 ----- 

118 This function converts a string representation of a variant override 

119 into a `sqlalchemy.types.TypeEngine` object. 

120 """ 

121 dialect = get_dialect_module(dialect_name) 

122 variant_type_name = variant_override_str.split("(")[0] 

123 

124 # Process Variant Type 

125 if variant_type_name not in dir(dialect): 

126 raise ValueError(f"Type {variant_type_name} not found in dialect {dialect_name}") 

127 variant_type = getattr(dialect, variant_type_name) 

128 length_params = [] 

129 if match := _length_regex.search(variant_override_str): 

130 length_params.extend([int(i) for i in match.group(1).split(",")]) 

131 return variant_type(*length_params) 

132 

133 

134def make_variant_dict(column_obj: Column) -> dict[str, TypeEngine[Any]]: 

135 """Handle variant overrides for a `felis.datamodel.Column`. 

136 

137 This function will return a dictionary of `str` to 

138 `sqlalchemy.types.TypeEngine` containing variant datatype information 

139 (e.g., for mysql, postgresql, etc). 

140 

141 Parameters 

142 ---------- 

143 column_obj 

144 The column object from which to build the variant dictionary. 

145 

146 Returns 

147 ------- 

148 `dict` [ `str`, `~sqlalchemy.types.TypeEngine` ] 

149 The dictionary of `str` to `sqlalchemy.types.TypeEngine` containing 

150 variant datatype information (e.g., for mysql, postgresql, etc). 

151 """ 

152 variant_dict = {} 

153 variant_overrides = _get_column_variant_overrides() 

154 for field_name, value in iter(column_obj): 

155 if field_name in variant_overrides and value is not None: 

156 dialect = _get_column_variant_override(field_name) 

157 variant: TypeEngine = _process_variant_override(dialect, value) 

158 variant_dict[dialect] = variant 

159 return variant_dict