Coverage for python / felis / datamodel.py: 31%

723 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-07 08:14 +0000

1"""Define Pydantic data models for Felis.""" 

2 

3# This file is part of felis. 

4# 

5# Developed for the LSST Data Management System. 

6# This product includes software developed by the LSST Project 

7# (https://www.lsst.org). 

8# See the COPYRIGHT file at the top-level directory of this distribution 

9# for details of code ownership. 

10# 

11# This program is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU General Public License as published by 

13# the Free Software Foundation, either version 3 of the License, or 

14# (at your option) any later version. 

15# 

16# This program is distributed in the hope that it will be useful, 

17# but WITHOUT ANY WARRANTY; without even the implied warranty of 

18# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

19# GNU General Public License for more details. 

20# 

21# You should have received a copy of the GNU General Public License 

22# along with this program. If not, see <https://www.gnu.org/licenses/>. 

23 

24from __future__ import annotations 

25 

26import json 

27import logging 

28import sys 

29from collections.abc import Sequence 

30from enum import StrEnum, auto 

31from operator import itemgetter 

32from typing import IO, Annotated, Any, Generic, Literal, TypeAlias, TypeVar 

33 

34import yaml 

35from astropy import units as units # type: ignore 

36from astropy.io.votable import ucd # type: ignore 

37from lsst.resources import ResourcePath, ResourcePathExpression 

38from pydantic import ( 

39 BaseModel, 

40 ConfigDict, 

41 Field, 

42 PrivateAttr, 

43 ValidationError, 

44 ValidationInfo, 

45 field_serializer, 

46 field_validator, 

47 model_validator, 

48) 

49from pydantic_core import InitErrorDetails 

50 

51from .db._dialects import get_supported_dialects, string_to_typeengine 

52from .db._sqltypes import get_type_func 

53from .types import Boolean, Byte, Char, Double, FelisType, Float, Int, Long, Short, String, Text, Unicode 

54 

55logger = logging.getLogger(__name__) 

56 

57__all__ = ( 

58 "BaseObject", 

59 "CheckConstraint", 

60 "Column", 

61 "ColumnOverrides", 

62 "ColumnResourceRef", 

63 "Constraint", 

64 "DataType", 

65 "ForeignKeyConstraint", 

66 "Index", 

67 "Resource", 

68 "Schema", 

69 "SchemaVersion", 

70 "Table", 

71 "UniqueConstraint", 

72) 

73 

74CONFIG = ConfigDict( 

75 populate_by_name=True, # Populate attributes by name. 

76 extra="forbid", # Do not allow extra fields. 

77 str_strip_whitespace=True, # Strip whitespace from string fields. 

78 use_enum_values=False, # Do not use enum values during serialization. 

79) 

80"""Pydantic model configuration as described in: 

81https://docs.pydantic.dev/2.0/api/config/#pydantic.config.ConfigDict 

82""" 

83 

84DESCR_MIN_LENGTH = 3 

85"""Minimum length for a description field.""" 

86 

87DescriptionStr: TypeAlias = Annotated[str, Field(min_length=DESCR_MIN_LENGTH)] 

88"""Type for a description, which must be three or more characters long.""" 

89 

90 

91class BaseObject(BaseModel): 

92 """Base model. 

93 

94 All classes representing objects in the Felis data model should inherit 

95 from this class. 

96 """ 

97 

98 model_config = CONFIG 

99 """Pydantic model configuration.""" 

100 

101 name: str 

102 """Name of the database object.""" 

103 

104 id: str = Field(alias="@id") 

105 """Unique identifier of the database object.""" 

106 

107 description: DescriptionStr | None = None 

108 """Description of the database object.""" 

109 

110 votable_utype: str | None = Field(None, alias="votable:utype") 

111 """VOTable utype (usage-specific or unique type) of the object.""" 

112 

113 @model_validator(mode="after") 

114 def check_description(self, info: ValidationInfo) -> BaseObject: 

115 """Check that the description is present if required. 

116 

117 Parameters 

118 ---------- 

119 info 

120 Validation context used to determine if the check is enabled. 

121 

122 Returns 

123 ------- 

124 `BaseObject` 

125 The object being validated. 

126 """ 

127 context = info.context 

128 if not context or not context.get("check_description", False): 

129 return self 

130 if self.description is None or self.description == "": 

131 raise ValueError("Description is required and must be non-empty") 

132 if len(self.description) < DESCR_MIN_LENGTH: 

133 raise ValueError(f"Description must be at least {DESCR_MIN_LENGTH} characters long") 

134 return self 

135 

136 

137class DataType(StrEnum): 

138 """``Enum`` representing the data types supported by Felis.""" 

139 

140 boolean = auto() 

141 byte = auto() 

142 short = auto() 

143 int = auto() 

144 long = auto() 

145 float = auto() 

146 double = auto() 

147 char = auto() 

148 string = auto() 

149 unicode = auto() 

150 text = auto() 

151 binary = auto() 

152 timestamp = auto() 

153 

154 

155def validate_ivoa_ucd(ivoa_ucd: str) -> str: 

156 """Validate IVOA UCD values. 

157 

158 Parameters 

159 ---------- 

160 ivoa_ucd 

161 IVOA UCD value to check. 

162 

163 Returns 

164 ------- 

165 `str` 

166 The IVOA UCD value if it is valid. 

167 

168 Raises 

169 ------ 

170 ValueError 

171 If the IVOA UCD value is invalid. 

172 """ 

173 if ivoa_ucd is not None: 

174 try: 

175 ucd.parse_ucd(ivoa_ucd, check_controlled_vocabulary=True, has_colon=";" in ivoa_ucd) 

176 except ValueError as e: 

177 raise ValueError(f"Invalid IVOA UCD: {e}") 

178 return ivoa_ucd 

179 

180 

181class Column(BaseObject): 

182 """Column model.""" 

183 

184 datatype: DataType 

185 """Datatype of the column.""" 

186 

187 length: int | None = Field(None, gt=0) 

188 """Length of the column.""" 

189 

190 precision: int | None = Field(None, ge=0) 

191 """The numerical precision of the column. 

192 

193 For timestamps, this is the number of fractional digits retained in the 

194 seconds field. 

195 """ 

196 

197 nullable: bool = True 

198 """Whether the column can be ``NULL``.""" 

199 

200 value: str | int | float | bool | None = None 

201 """Default value of the column.""" 

202 

203 autoincrement: bool | None = None 

204 """Whether the column is autoincremented.""" 

205 

206 ivoa_ucd: str | None = Field(None, alias="ivoa:ucd") 

207 """IVOA UCD of the column.""" 

208 

209 fits_tunit: str | None = Field(None, alias="fits:tunit") 

210 """FITS TUNIT of the column.""" 

211 

212 ivoa_unit: str | None = Field(None, alias="ivoa:unit") 

213 """IVOA unit of the column.""" 

214 

215 tap_column_index: int | None = Field(None, alias="tap:column_index") 

216 """TAP_SCHEMA column index of the column.""" 

217 

218 tap_principal: int | None = Field(0, alias="tap:principal", ge=0, le=1) 

219 """Whether this is a TAP_SCHEMA principal column.""" 

220 

221 votable_arraysize: int | str | None = Field(None, alias="votable:arraysize") 

222 """VOTable arraysize of the column.""" 

223 

224 tap_std: int | None = Field(0, alias="tap:std", ge=0, le=1) 

225 """TAP_SCHEMA indication that this column is defined by an IVOA standard. 

226 """ 

227 

228 votable_xtype: str | None = Field(None, alias="votable:xtype") 

229 """VOTable xtype (extended type) of the column.""" 

230 

231 votable_datatype: str | None = Field(None, alias="votable:datatype") 

232 """VOTable datatype of the column.""" 

233 

234 mysql_datatype: str | None = Field(None, alias="mysql:datatype") 

235 """MySQL datatype override on the column.""" 

236 

237 postgresql_datatype: str | None = Field(None, alias="postgresql:datatype") 

238 """PostgreSQL datatype override on the column.""" 

239 

240 _is_resource_ref: bool = PrivateAttr(False) 

241 """Whether this column is a resource reference column.""" 

242 

243 @model_validator(mode="after") 

244 def check_value(self) -> Column: 

245 """Check that the default value is valid. 

246 

247 Returns 

248 ------- 

249 `Column` 

250 The column being validated. 

251 """ 

252 if (value := self.value) is not None: 

253 if value is not None and self.autoincrement is True: 

254 raise ValueError("Column cannot have both a default value and be autoincremented") 

255 felis_type = FelisType.felis_type(self.datatype) 

256 if felis_type.is_numeric: 

257 if felis_type in (Byte, Short, Int, Long) and not isinstance(value, int): 

258 raise ValueError("Default value must be an int for integer type columns") 

259 elif felis_type in (Float, Double) and not isinstance(value, float): 

260 raise ValueError("Default value must be a decimal number for float and double columns") 

261 elif felis_type in (String, Char, Unicode, Text): 

262 if not isinstance(value, str): 

263 raise ValueError("Default value must be a string for string columns") 

264 if not len(value): 

265 raise ValueError("Default value must be a non-empty string for string columns") 

266 elif felis_type is Boolean and not isinstance(value, bool): 

267 raise ValueError("Default value must be a boolean for boolean columns") 

268 return self 

269 

270 @field_validator("ivoa_ucd") 

271 @classmethod 

272 def check_ivoa_ucd(cls, ivoa_ucd: str) -> str: 

273 """Check that IVOA UCD values are valid. 

274 

275 Parameters 

276 ---------- 

277 ivoa_ucd 

278 IVOA UCD value to check. 

279 

280 Returns 

281 ------- 

282 `str` 

283 The IVOA UCD value if it is valid. 

284 """ 

285 return validate_ivoa_ucd(ivoa_ucd) 

286 

287 @model_validator(mode="after") 

288 def check_units(self) -> Column: 

289 """Check that the ``fits:tunit`` or ``ivoa:unit`` field has valid 

290 units according to astropy. Only one may be provided. 

291 

292 Returns 

293 ------- 

294 `Column` 

295 The column being validated. 

296 

297 Raises 

298 ------ 

299 ValueError 

300 Raised if both FITS and IVOA units are provided, or if the unit is 

301 invalid. 

302 """ 

303 fits_unit = self.fits_tunit 

304 ivoa_unit = self.ivoa_unit 

305 

306 if fits_unit and ivoa_unit: 

307 raise ValueError("Column cannot have both FITS and IVOA units") 

308 unit = fits_unit or ivoa_unit 

309 

310 if unit is not None: 

311 try: 

312 units.Unit(unit) 

313 except ValueError as e: 

314 raise ValueError(f"Invalid unit: {e}") 

315 

316 return self 

317 

318 @model_validator(mode="before") 

319 @classmethod 

320 def check_length(cls, values: dict[str, Any]) -> dict[str, Any]: 

321 """Check that a valid length is provided for sized types. 

322 

323 Parameters 

324 ---------- 

325 values 

326 Values of the column. 

327 

328 Returns 

329 ------- 

330 `dict` [ `str`, `Any` ] 

331 The values of the column. 

332 

333 Raises 

334 ------ 

335 ValueError 

336 Raised if a length is not provided for a sized type. 

337 """ 

338 datatype = values.get("datatype") 

339 if datatype is None: 

340 # Skip this validation if datatype is not provided 

341 return values 

342 length = values.get("length") 

343 felis_type = FelisType.felis_type(datatype) 

344 if felis_type.is_sized and length is None: 

345 raise ValueError( 

346 f"Length must be provided for type '{datatype}'" 

347 + (f" in column '{values['@id']}'" if "@id" in values else "") 

348 ) 

349 elif not felis_type.is_sized and length is not None: 

350 logger.warning( 

351 f"The datatype '{datatype}' does not support a specified length" 

352 + (f" in column '{values['@id']}'" if "@id" in values else "") 

353 ) 

354 return values 

355 

356 @model_validator(mode="after") 

357 def check_redundant_datatypes(self, info: ValidationInfo) -> Column: 

358 """Check for redundant datatypes on columns. 

359 

360 Parameters 

361 ---------- 

362 info 

363 Validation context used to determine if the check is enabled. 

364 

365 Returns 

366 ------- 

367 `Column` 

368 The column being validated. 

369 

370 Raises 

371 ------ 

372 ValueError 

373 Raised if a datatype override is redundant. 

374 """ 

375 context = info.context 

376 if not context or not context.get("check_redundant_datatypes", False): 

377 return self 

378 if all( 

379 getattr(self, f"{dialect}:datatype", None) is not None 

380 for dialect in get_supported_dialects().keys() 

381 ): 

382 return self 

383 

384 datatype = self.datatype 

385 length: int | None = self.length or None 

386 

387 datatype_func = get_type_func(datatype) 

388 felis_type = FelisType.felis_type(datatype) 

389 if felis_type.is_sized: 

390 datatype_obj = datatype_func(length) 

391 else: 

392 datatype_obj = datatype_func() 

393 

394 for dialect_name, dialect in get_supported_dialects().items(): 

395 db_annotation = f"{dialect_name}_datatype" 

396 if datatype_string := self.model_dump().get(db_annotation): 

397 db_datatype_obj = string_to_typeengine(datatype_string, dialect, length) 

398 if datatype_obj.compile(dialect) == db_datatype_obj.compile(dialect): 

399 raise ValueError( 

400 "'{}: {}' is a redundant override of 'datatype: {}' in column '{}'{}".format( 

401 db_annotation, 

402 datatype_string, 

403 self.datatype, 

404 self.id, 

405 "" if length is None else f" with length {length}", 

406 ) 

407 ) 

408 else: 

409 logger.debug( 

410 f"Type override of 'datatype: {self.datatype}' " 

411 f"with '{db_annotation}: {datatype_string}' in column '{self.id}' " 

412 f"compiled to '{datatype_obj.compile(dialect)}' and " 

413 f"'{db_datatype_obj.compile(dialect)}'" 

414 ) 

415 return self 

416 

417 @model_validator(mode="after") 

418 def check_precision(self) -> Column: 

419 """Check that precision is only valid for timestamp columns. 

420 

421 Returns 

422 ------- 

423 `Column` 

424 The column being validated. 

425 """ 

426 if self.precision is not None and self.datatype != "timestamp": 

427 raise ValueError("Precision is only valid for timestamp columns") 

428 return self 

429 

430 @model_validator(mode="before") 

431 @classmethod 

432 def check_votable_arraysize(cls, values: dict[str, Any], info: ValidationInfo) -> dict[str, Any]: 

433 """Set the default value for the ``votable_arraysize`` field, which 

434 corresponds to ``arraysize`` in the IVOA VOTable standard. 

435 

436 Parameters 

437 ---------- 

438 values 

439 Values of the column. 

440 info 

441 Validation context used to determine if the check is enabled. 

442 

443 Returns 

444 ------- 

445 `dict` [ `str`, `Any` ] 

446 The values of the column. 

447 

448 Notes 

449 ----- 

450 Following the IVOA VOTable standard, an ``arraysize`` of 1 should not 

451 be used. 

452 """ 

453 if values.get("name", None) is None or values.get("datatype", None) is None: 

454 # Skip bad column data that will not validate 

455 return values 

456 context = info.context if info.context else {} 

457 arraysize = values.get("votable:arraysize", None) 

458 if arraysize is None: 

459 length = values.get("length", None) 

460 datatype = values.get("datatype") 

461 if length is not None and length > 1: 

462 # Following the IVOA standard, arraysize of 1 is disallowed 

463 if datatype == "char": 

464 arraysize = str(length) 

465 elif datatype in ("string", "unicode", "binary"): 

466 if context.get("force_unbounded_arraysize", False): 

467 arraysize = "*" 

468 logger.debug( 

469 f"Forced VOTable's 'arraysize' to '*' on column '{values['name']}' with datatype " 

470 + f"'{values['datatype']}' and length '{length}'" 

471 ) 

472 else: 

473 arraysize = f"{length}*" 

474 elif datatype in ("timestamp", "text"): 

475 arraysize = "*" 

476 if arraysize is not None: 

477 values["votable:arraysize"] = arraysize 

478 logger.debug( 

479 f"Set default 'votable:arraysize' to '{arraysize}' on column '{values['name']}'" 

480 + f" with datatype '{values['datatype']}' and length '{values.get('length', None)}'" 

481 ) 

482 else: 

483 logger.debug(f"Using existing 'votable:arraysize' of '{arraysize}' on column '{values['name']}'") 

484 if isinstance(values["votable:arraysize"], int): 

485 logger.warning( 

486 f"Usage of an integer value for 'votable:arraysize' in column '{values['name']}' is " 

487 + "deprecated" 

488 ) 

489 values["votable:arraysize"] = str(arraysize) 

490 return values 

491 

492 @field_serializer("datatype") 

493 def serialize_datatype(self, value: DataType) -> str: 

494 """Convert `DataType` to string when serializing to JSON/YAML. 

495 

496 Parameters 

497 ---------- 

498 value 

499 The `DataType` value to serialize. 

500 

501 Returns 

502 ------- 

503 `str` 

504 The serialized `DataType` value. 

505 """ 

506 return str(value) 

507 

508 @field_validator("datatype", mode="before") 

509 @classmethod 

510 def deserialize_datatype(cls, value: str) -> DataType: 

511 """Convert string back into `DataType` when loading from JSON/YAML. 

512 

513 Parameters 

514 ---------- 

515 value 

516 The string value to deserialize. 

517 

518 Returns 

519 ------- 

520 `DataType` 

521 The deserialized `DataType` value. 

522 """ 

523 return DataType(value) 

524 

525 @model_validator(mode="after") 

526 def check_votable_xtype(self) -> Column: 

527 """Set the default value for the ``votable_xtype`` field, which 

528 corresponds to an Extended Datatype or ``xtype`` in the IVOA VOTable 

529 standard. 

530 

531 Returns 

532 ------- 

533 `Column` 

534 The column being validated. 

535 

536 Notes 

537 ----- 

538 This is currently only set automatically for the Felis ``timestamp`` 

539 datatype. 

540 """ 

541 if self.datatype == DataType.timestamp and self.votable_xtype is None: 

542 self.votable_xtype = "timestamp" 

543 return self 

544 

545 def _update_from_overrides(self, overrides: ColumnOverrides) -> None: 

546 """Update the column attributes from the given overrides. 

547 

548 Parameters 

549 ---------- 

550 overrides 

551 The column overrides to apply or `None` to skip applying overrides. 

552 

553 Notes 

554 ----- 

555 Using ``model_fields_set`` allows updating only the fields that are 

556 explicitly set in the `overrides` object. This prevents overwriting 

557 existing column attributes which were not explicitly provided. 

558 """ 

559 if overrides.model_fields_set: 

560 logger.debug("Applying overrides to column '%s': %s", self.id, overrides.model_fields_set) 

561 for field in overrides.model_fields_set: 

562 setattr(self, field, getattr(overrides, field)) 

563 

564 

565class Constraint(BaseObject): 

566 """Table constraint model.""" 

567 

568 deferrable: bool = False 

569 """Whether this constraint will be declared as deferrable.""" 

570 

571 initially: Literal["IMMEDIATE", "DEFERRED"] | None = None 

572 """Value for ``INITIALLY`` clause; only used if `deferrable` is 

573 `True`.""" 

574 

575 @model_validator(mode="after") 

576 def check_deferrable(self) -> Constraint: 

577 """Check that the ``INITIALLY`` clause is only used if `deferrable` is 

578 `True`. 

579 

580 Returns 

581 ------- 

582 `Constraint` 

583 The constraint being validated. 

584 """ 

585 if self.initially is not None and not self.deferrable: 

586 raise ValueError("INITIALLY clause can only be used if deferrable is True") 

587 return self 

588 

589 

590class CheckConstraint(Constraint): 

591 """Table check constraint model.""" 

592 

593 type: Literal["Check"] = Field("Check", alias="@type") 

594 """Type of the constraint.""" 

595 

596 expression: str 

597 """Expression for the check constraint.""" 

598 

599 @field_serializer("type") 

600 def serialize_type(self, value: str) -> str: 

601 """Ensure '@type' is included in serialized output. 

602 

603 Parameters 

604 ---------- 

605 value 

606 The value to serialize. 

607 

608 Returns 

609 ------- 

610 `str` 

611 The serialized value. 

612 """ 

613 return value 

614 

615 

616class UniqueConstraint(Constraint): 

617 """Table unique constraint model.""" 

618 

619 type: Literal["Unique"] = Field("Unique", alias="@type") 

620 """Type of the constraint.""" 

621 

622 columns: list[str] 

623 """Columns in the unique constraint.""" 

624 

625 @field_serializer("type") 

626 def serialize_type(self, value: str) -> str: 

627 """Ensure '@type' is included in serialized output. 

628 

629 Parameters 

630 ---------- 

631 value 

632 The value to serialize. 

633 

634 Returns 

635 ------- 

636 `str` 

637 The serialized value. 

638 """ 

639 return value 

640 

641 

642class ForeignKeyConstraint(Constraint): 

643 """Table foreign key constraint model. 

644 

645 This constraint is used to define a foreign key relationship between two 

646 tables in the schema. There must be at least one column in the 

647 `columns` list, and at least one column in the `referenced_columns` list 

648 or a validation error will be raised. 

649 

650 Notes 

651 ----- 

652 These relationships will be reflected in the TAP_SCHEMA ``keys`` and 

653 ``key_columns`` data. 

654 """ 

655 

656 type: Literal["ForeignKey"] = Field("ForeignKey", alias="@type") 

657 """Type of the constraint.""" 

658 

659 columns: list[str] = Field(min_length=1) 

660 """The columns comprising the foreign key.""" 

661 

662 referenced_columns: list[str] = Field(alias="referencedColumns", min_length=1) 

663 """The columns referenced by the foreign key.""" 

664 

665 on_delete: Literal["CASCADE", "SET NULL", "SET DEFAULT", "RESTRICT", "NO ACTION"] | None = None 

666 """Action to take when the referenced row is deleted.""" 

667 

668 on_update: Literal["CASCADE", "SET NULL", "SET DEFAULT", "RESTRICT", "NO ACTION"] | None = None 

669 """Action to take when the referenced row is updated.""" 

670 

671 @field_serializer("type") 

672 def serialize_type(self, value: str) -> str: 

673 """Ensure '@type' is included in serialized output. 

674 

675 Parameters 

676 ---------- 

677 value 

678 The value to serialize. 

679 

680 Returns 

681 ------- 

682 `str` 

683 The serialized value. 

684 """ 

685 return value 

686 

687 @model_validator(mode="after") 

688 def check_column_lengths(self) -> ForeignKeyConstraint: 

689 """Check that the `columns` and `referenced_columns` lists have the 

690 same length. 

691 

692 Returns 

693 ------- 

694 `ForeignKeyConstraint` 

695 The foreign key constraint being validated. 

696 

697 Raises 

698 ------ 

699 ValueError 

700 Raised if the `columns` and `referenced_columns` lists do not have 

701 the same length. 

702 """ 

703 if len(self.columns) != len(self.referenced_columns): 

704 raise ValueError( 

705 "Columns and referencedColumns must have the same length for a ForeignKey constraint" 

706 ) 

707 return self 

708 

709 

710_ConstraintType = Annotated[ 

711 CheckConstraint | ForeignKeyConstraint | UniqueConstraint, Field(discriminator="type") 

712] 

713"""Type alias for a constraint type.""" 

714 

715 

716class Index(BaseObject): 

717 """Table index model. 

718 

719 An index can be defined on either columns or expressions, but not both. 

720 """ 

721 

722 columns: list[str] | None = None 

723 """Columns in the index.""" 

724 

725 expressions: list[str] | None = None 

726 """Expressions in the index.""" 

727 

728 @model_validator(mode="before") 

729 @classmethod 

730 def check_columns_or_expressions(cls, values: dict[str, Any]) -> dict[str, Any]: 

731 """Check that columns or expressions are specified, but not both. 

732 

733 Parameters 

734 ---------- 

735 values 

736 Values of the index. 

737 

738 Returns 

739 ------- 

740 `dict` [ `str`, `Any` ] 

741 The values of the index. 

742 

743 Raises 

744 ------ 

745 ValueError 

746 Raised if both columns and expressions are specified, or if neither 

747 are specified. 

748 """ 

749 if "columns" in values and "expressions" in values: 

750 raise ValueError("Defining columns and expressions is not valid") 

751 elif "columns" not in values and "expressions" not in values: 

752 raise ValueError("Must define columns or expressions") 

753 return values 

754 

755 

756ColumnRef: TypeAlias = str 

757"""Type alias for a column reference.""" 

758 

759 

760class ColumnGroup(BaseObject): 

761 """Column group model.""" 

762 

763 columns: list[ColumnRef | Column] = Field(..., min_length=1) 

764 """Columns in the group.""" 

765 

766 ivoa_ucd: str | None = Field(None, alias="ivoa:ucd") 

767 """IVOA UCD of the column.""" 

768 

769 table: Table | None = Field(None, exclude=True) 

770 """Reference to the parent table.""" 

771 

772 @field_validator("ivoa_ucd") 

773 @classmethod 

774 def check_ivoa_ucd(cls, ivoa_ucd: str) -> str: 

775 """Check that IVOA UCD values are valid. 

776 

777 Parameters 

778 ---------- 

779 ivoa_ucd 

780 IVOA UCD value to check. 

781 

782 Returns 

783 ------- 

784 `str` 

785 The IVOA UCD value if it is valid. 

786 """ 

787 return validate_ivoa_ucd(ivoa_ucd) 

788 

789 @model_validator(mode="after") 

790 def check_unique_columns(self) -> ColumnGroup: 

791 """Check that the columns list contains unique items. 

792 

793 Returns 

794 ------- 

795 `ColumnGroup` 

796 The column group being validated. 

797 """ 

798 column_ids = [col if isinstance(col, str) else col.id for col in self.columns] 

799 if len(column_ids) != len(set(column_ids)): 

800 raise ValueError("Columns in the group must be unique") 

801 return self 

802 

803 def _dereference_columns(self) -> None: 

804 """Dereference ColumnRef to Column objects.""" 

805 if self.table is None: 

806 raise ValueError("ColumnGroup must have a reference to its parent table") 

807 

808 dereferenced_columns: list[ColumnRef | Column] = [] 

809 for col in self.columns: 

810 if isinstance(col, str): 

811 # Dereference ColumnRef to Column object 

812 try: 

813 col_obj = self.table._find_column_by_id(col) 

814 except KeyError as e: 

815 raise ValueError(f"Column '{col}' not found in table '{self.table.name}'") from e 

816 dereferenced_columns.append(col_obj) 

817 else: 

818 dereferenced_columns.append(col) 

819 

820 self.columns = dereferenced_columns 

821 

822 @field_serializer("columns") 

823 def serialize_columns(self, columns: list[ColumnRef | Column]) -> list[str]: 

824 """Serialize columns as their IDs. 

825 

826 Parameters 

827 ---------- 

828 columns 

829 The columns to serialize. 

830 

831 Returns 

832 ------- 

833 `list` [ `str` ] 

834 The serialized column IDs. 

835 """ 

836 return [col if isinstance(col, str) else col.id for col in columns] 

837 

838 

839class ColumnOverrides(BaseModel): 

840 """Allowed overrides for a referenced column. 

841 

842 Notes 

843 ----- 

844 All of these fields are optional. Values of None may be explicitly set to 

845 override the corresponding attribute in the referenced column but only 

846 for certain fields (see validation in `_check_non_nullable_overrides`). 

847 """ 

848 

849 model_config = CONFIG.copy() 

850 

851 datatype: DataType | None = None 

852 """New datatype for the column.""" 

853 

854 length: int | None = None 

855 """New length for the column.""" 

856 

857 description: str | None = None 

858 """New description for the column.""" 

859 

860 nullable: bool | None = None 

861 """New nullable flag for the column.""" 

862 

863 tap_principal: int | None = Field(default=None, alias="tap:principal") 

864 """Override for the TAP_SCHEMA 'principal' flag.""" 

865 

866 tap_column_index: int | None = Field(default=None, alias="tap:column_index") 

867 """Override for the TAP_SCHEMA column index.""" 

868 

869 @model_validator(mode="before") 

870 @classmethod 

871 def _check_non_nullable_overrides(cls, data: Any) -> Any: 

872 """Check that certain fields are not overridden to null.""" 

873 if not isinstance(data, dict): 

874 return data 

875 non_nullable_fields = ("datatype", "length", "nullable", "tap_principal") 

876 for name in non_nullable_fields: 

877 if name in data and data[name] is None: 

878 raise ValueError(f"The '{name}' field cannot be overridden to null") 

879 return data 

880 

881 @field_serializer("datatype") 

882 def serialize_datatype(self, value: DataType | None) -> str | None: 

883 """Convert `DataType` to string when serializing to JSON/YAML. 

884 

885 Parameters 

886 ---------- 

887 value 

888 The `DataType` value to serialize, or None. 

889 

890 Returns 

891 ------- 

892 `str` | None 

893 The serialized `DataType` value, or None if the input was None. 

894 """ 

895 if value is None: 

896 return None 

897 return str(value) 

898 

899 @field_validator("datatype", mode="before") 

900 @classmethod 

901 def deserialize_datatype(cls, value: str | None) -> DataType | None: 

902 """Convert string back into `DataType` when loading from JSON/YAML. 

903 

904 Parameters 

905 ---------- 

906 value 

907 The string value to deserialize, or None. 

908 

909 Returns 

910 ------- 

911 `DataType` | None 

912 The deserialized `DataType` value, or None if the input was None. 

913 """ 

914 if value is None: 

915 return None 

916 return DataType(value) 

917 

918 

919class ColumnResourceRef(BaseModel): 

920 """A column which is dervived from an external resource.""" 

921 

922 ref_name: str | None = None 

923 """Name of the referenced column in the resource 

924 (if different from the key).""" 

925 

926 overrides: ColumnOverrides | None = None 

927 """Optional overrides of the referenced column's attributes.""" 

928 

929 

930# Type aliases for the nested mapping structure of referenced columns 

931ResourceColumnMap: TypeAlias = dict[str, ColumnResourceRef | None] 

932ResourceTableMap: TypeAlias = dict[str, ResourceColumnMap] 

933ResourceMap: TypeAlias = dict[str, ResourceTableMap] 

934 

935 

936class Table(BaseObject): 

937 """Table model.""" 

938 

939 primary_key: str | list[str] | None = Field(None, alias="primaryKey") 

940 """Primary key of the table.""" 

941 

942 tap_table_index: int | None = Field(None, alias="tap:table_index") 

943 """IVOA TAP_SCHEMA table index of the table.""" 

944 

945 mysql_engine: str | None = Field("MyISAM", alias="mysql:engine") 

946 """MySQL engine to use for the table.""" 

947 

948 mysql_charset: str | None = Field(None, alias="mysql:charset") 

949 """MySQL charset to use for the table.""" 

950 

951 columns: list[Column] = Field(default_factory=list) 

952 """Columns in the table.""" 

953 

954 column_refs: ResourceMap = Field(default_factory=dict, alias="columnRefs") 

955 """Referenced columns from external resources.""" 

956 

957 column_groups: list[ColumnGroup] = Field(default_factory=list, alias="columnGroups") 

958 """Column groups in the table.""" 

959 

960 constraints: list[_ConstraintType] = Field(default_factory=list) 

961 """Constraints on the table.""" 

962 

963 indexes: list[Index] = Field(default_factory=list) 

964 """Indexes on the table.""" 

965 

966 @field_validator("columns", mode="after") 

967 @classmethod 

968 def check_unique_column_names(cls, columns: list[Column]) -> list[Column]: 

969 """Check that column names are unique. 

970 

971 Parameters 

972 ---------- 

973 columns 

974 The columns to check. 

975 

976 Returns 

977 ------- 

978 `list` [ `Column` ] 

979 The columns if they are unique. 

980 

981 Raises 

982 ------ 

983 ValueError 

984 Raised if column names are not unique. 

985 """ 

986 if len(columns) != len(set(column.name for column in columns)): 

987 raise ValueError("Column names must be unique") 

988 return columns 

989 

990 @model_validator(mode="after") 

991 def check_tap_table_index(self, info: ValidationInfo) -> Table: 

992 """Check that the table has a TAP table index. 

993 

994 Parameters 

995 ---------- 

996 info 

997 Validation context used to determine if the check is enabled. 

998 

999 Returns 

1000 ------- 

1001 `Table` 

1002 The table being validated. 

1003 

1004 Raises 

1005 ------ 

1006 ValueError 

1007 Raised If the table is missing a TAP table index. 

1008 """ 

1009 context = info.context 

1010 if not context or not context.get("check_tap_table_indexes", False): 

1011 return self 

1012 if self.tap_table_index is None: 

1013 raise ValueError("Table is missing a TAP table index") 

1014 return self 

1015 

1016 @model_validator(mode="after") 

1017 def check_tap_principal(self, info: ValidationInfo) -> Table: 

1018 """Check that at least one column is flagged as 'principal' for TAP 

1019 purposes. 

1020 

1021 Parameters 

1022 ---------- 

1023 info 

1024 Validation context used to determine if the check is enabled. 

1025 

1026 Returns 

1027 ------- 

1028 `Table` 

1029 The table being validated. 

1030 

1031 Raises 

1032 ------ 

1033 ValueError 

1034 Raised if the table is missing a column flagged as 'principal'. 

1035 """ 

1036 context = info.context 

1037 if not context or not context.get("check_tap_principal", False): 

1038 return self 

1039 for col in self.columns: 

1040 if col.tap_principal == 1: 

1041 return self 

1042 raise ValueError(f"Table '{self.name}' is missing at least one column designated as 'tap:principal'") 

1043 

1044 def _find_column_by_id(self, id: str) -> Column: 

1045 """Find a column by ID. 

1046 

1047 Parameters 

1048 ---------- 

1049 id 

1050 The ID of the column to find. 

1051 

1052 Returns 

1053 ------- 

1054 `Column` 

1055 The column with the given ID. 

1056 

1057 Raises 

1058 ------ 

1059 ValueError 

1060 Raised if the column is not found. 

1061 """ 

1062 for column in self.columns: 

1063 if column.id == id: 

1064 return column 

1065 raise KeyError(f"Column '{id}' not found in table '{self.name}'") 

1066 

1067 def _find_column_by_name(self, name: str) -> Column: 

1068 for column in self.columns: 

1069 if column.name == name: 

1070 return column 

1071 raise KeyError(f"Column '{name}' not found in table '{self.name}'") 

1072 

1073 @model_validator(mode="after") 

1074 def dereference_column_groups(self: Table) -> Table: 

1075 """Dereference columns in column groups. 

1076 

1077 Returns 

1078 ------- 

1079 `Table` 

1080 The table with dereferenced column groups. 

1081 """ 

1082 for group in self.column_groups: 

1083 group.table = self 

1084 group._dereference_columns() 

1085 return self 

1086 

1087 @field_serializer("columns") 

1088 def _serialize_columns(self, columns: list[Column]) -> list[dict[str, Any]]: 

1089 """Serialize only non-resource columns.""" 

1090 return [ 

1091 col.model_dump( 

1092 by_alias=True, 

1093 exclude_none=True, 

1094 exclude_defaults=True, 

1095 ) 

1096 for col in columns 

1097 if not col._is_resource_ref 

1098 ] 

1099 

1100 

1101class SchemaVersion(BaseModel): 

1102 """Schema version model.""" 

1103 

1104 current: str 

1105 """The current version of the schema.""" 

1106 

1107 compatible: list[str] = Field(default_factory=list) 

1108 """The compatible versions of the schema.""" 

1109 

1110 read_compatible: list[str] = Field(default_factory=list) 

1111 """The read compatible versions of the schema.""" 

1112 

1113 

1114class SchemaIdVisitor: 

1115 """Visit a schema and build the map of IDs to objects. 

1116 

1117 Notes 

1118 ----- 

1119 Duplicates are added to a set when they are encountered, which can be 

1120 accessed via the ``duplicates`` attribute. The presence of duplicates will 

1121 not throw an error. Only the first object with a given ID will be added to 

1122 the map, but this should not matter, since a ``ValidationError`` will be 

1123 thrown by the ``model_validator`` method if any duplicates are found in the 

1124 schema. 

1125 """ 

1126 

1127 def __init__(self) -> None: 

1128 """Create a new SchemaVisitor.""" 

1129 self.schema: Schema | None = None 

1130 self.duplicates: set[str] = set() 

1131 

1132 def add(self, obj: BaseObject) -> None: 

1133 """Add an object to the ID map. 

1134 

1135 Parameters 

1136 ---------- 

1137 obj 

1138 The object to add to the ID map. 

1139 """ 

1140 if hasattr(obj, "id"): 

1141 obj_id = getattr(obj, "id") 

1142 if self.schema is not None: 

1143 if obj_id in self.schema._id_map: 

1144 self.duplicates.add(obj_id) 

1145 else: 

1146 self.schema._id_map[obj_id] = obj 

1147 

1148 def visit_schema(self, schema: Schema) -> None: 

1149 """Visit the objects in a schema and build the ID map. 

1150 

1151 Parameters 

1152 ---------- 

1153 schema 

1154 The schema object to visit. 

1155 

1156 Notes 

1157 ----- 

1158 This will set an internal variable pointing to the schema object. 

1159 """ 

1160 self.schema = schema 

1161 self.duplicates.clear() 

1162 self.add(self.schema) 

1163 for table in self.schema.tables: 

1164 self.visit_table(table) 

1165 

1166 def visit_table(self, table: Table) -> None: 

1167 """Visit a table object. 

1168 

1169 Parameters 

1170 ---------- 

1171 table 

1172 The table object to visit. 

1173 """ 

1174 self.add(table) 

1175 for column in table.columns: 

1176 self.visit_column(column) 

1177 for constraint in table.constraints: 

1178 self.visit_constraint(constraint) 

1179 

1180 def visit_column(self, column: Column) -> None: 

1181 """Visit a column object. 

1182 

1183 Parameters 

1184 ---------- 

1185 column 

1186 The column object to visit. 

1187 """ 

1188 self.add(column) 

1189 

1190 def visit_constraint(self, constraint: Constraint) -> None: 

1191 """Visit a constraint object. 

1192 

1193 Parameters 

1194 ---------- 

1195 constraint 

1196 The constraint object to visit. 

1197 """ 

1198 self.add(constraint) 

1199 

1200 

1201T = TypeVar("T", bound=BaseObject) 

1202 

1203 

1204def _strip_ids(data: Any) -> Any: 

1205 """Recursively strip '@id' fields from a dictionary or list. 

1206 

1207 Parameters 

1208 ---------- 

1209 data 

1210 The data to strip IDs from, which can be a dictionary, list, or any 

1211 other type. Other types will be returned unchanged. 

1212 """ 

1213 if isinstance(data, dict): 

1214 data.pop("@id", None) 

1215 for k, v in data.items(): 

1216 data[k] = _strip_ids(v) 

1217 return data 

1218 elif isinstance(data, list): 

1219 return [_strip_ids(item) for item in data] 

1220 else: 

1221 return data 

1222 

1223 

1224def _append_error( 

1225 errors: list[InitErrorDetails], 

1226 loc: tuple, 

1227 input_value: Any, 

1228 error_message: str, 

1229 error_type: str = "value_error", 

1230) -> None: 

1231 """Append an error to the errors list. 

1232 

1233 Parameters 

1234 ---------- 

1235 errors : list[InitErrorDetails] 

1236 The list of errors to append to. 

1237 loc : tuple 

1238 The location of the error in the schema. 

1239 input_value : Any 

1240 The input value that caused the error. 

1241 error_message : str 

1242 The error message to include in the context. 

1243 """ 

1244 errors.append( 

1245 { 

1246 "type": error_type, 

1247 "loc": loc, 

1248 "input": input_value, 

1249 "ctx": {"error": error_message}, 

1250 } 

1251 ) 

1252 

1253 

1254class Resource(BaseModel): 

1255 """A resource definition referencing an external schema.""" 

1256 

1257 uri: str = Field(..., description="Resource URI or path") 

1258 """URI of the schema resource which may be a local path, ``resource://``, 

1259 or remote URL.""" 

1260 

1261 

1262class Schema(BaseObject, Generic[T]): 

1263 """Database schema model. 

1264 

1265 This represents a database schema, which contains one or more tables. 

1266 """ 

1267 

1268 version: SchemaVersion | str | None = None 

1269 """The version of the schema.""" 

1270 

1271 tables: Sequence[Table] 

1272 """The tables in the schema.""" 

1273 

1274 resources: dict[str, Resource] = Field(default_factory=dict) 

1275 """External resources referenced by this schema.""" 

1276 

1277 _id_map: dict[str, Any] = PrivateAttr(default_factory=dict) 

1278 """Map of IDs to objects.""" 

1279 

1280 _resource_map: dict[str, Schema] = PrivateAttr(default_factory=dict) 

1281 """Map of resource names to loaded schemas.""" 

1282 

1283 @model_validator(mode="after") 

1284 def _load_resources(self: Schema, info: ValidationInfo) -> Schema: 

1285 """Load external resources referenced by this schema into an internal 

1286 mapping of resource names to their `Schema` objects. 

1287 

1288 Returns 

1289 ------- 

1290 `Schema` 

1291 The schema being validated. 

1292 

1293 Raises 

1294 ------ 

1295 ValueError 

1296 Raised if a resource cannot be loaded. 

1297 """ 

1298 if info.context: 

1299 context = info.context.copy() 

1300 # Ignore this flag for loading the resources themselves 

1301 context.pop("dereference_resources", None) 

1302 else: 

1303 context = {} 

1304 

1305 for resource_name, resource in self.resources.items(): 

1306 uri = resource.uri 

1307 try: 

1308 loaded_schema = Schema.from_uri(uri, context=context) 

1309 self._resource_map[resource_name] = loaded_schema 

1310 logger.debug(f"Loaded resource '{resource_name}' from URI '{uri}'") 

1311 except Exception as e: 

1312 raise ValueError(f"Failed to load resource '{resource_name}' from URI '{uri}': {e}") from e 

1313 return self 

1314 

1315 def _find_table_by_name(self, name: str) -> Table: 

1316 """Find a table by name. 

1317 

1318 Parameters 

1319 ---------- 

1320 name 

1321 The name of the table to find. 

1322 

1323 Returns 

1324 ------- 

1325 `Table` 

1326 The table with the given name. 

1327 

1328 Raises 

1329 ------ 

1330 KeyError 

1331 Raised if the table is not found. 

1332 """ 

1333 for table in self.tables: 

1334 if table.name == name: 

1335 return table 

1336 raise KeyError(f"Table '{name}' not found in schema '{self.name}'") 

1337 

1338 @model_validator(mode="after") 

1339 def _dereference_resource_columns(self: Schema, info: ValidationInfo) -> Schema: 

1340 """Dereference columns from external resources and add them to the 

1341 tables in this schema. 

1342 """ 

1343 context = info.context 

1344 column_ref_index_increment: int | None = None 

1345 dereference_resources = False 

1346 if context is not None: 

1347 dereference_resources = context.get("dereference_resources", False) 

1348 column_ref_index_increment = context.get("column_ref_index_increment", None) 

1349 

1350 for table in self.tables: 

1351 if column_refs := table.column_refs: 

1352 for resource_name, tables in column_refs.items(): 

1353 resource_schema = self._resource_map.get(resource_name) 

1354 if resource_schema is None: 

1355 raise ValueError(f"Schema resource '{resource_name}' was not found in resources") 

1356 self._process_column_refs( 

1357 table, 

1358 tables, 

1359 resource_schema, 

1360 dereference_resources, 

1361 column_ref_index_increment, 

1362 ) 

1363 if dereference_resources and len(table.column_refs) > 0: 

1364 # Clear column refs in table if fully dereferencing 

1365 logger.debug( 

1366 f"Clearing columnRefs in table '{table.name}' after dereferencing resource columns" 

1367 ) 

1368 table.column_refs = {} 

1369 return self 

1370 

1371 @classmethod 

1372 def _process_column_refs( 

1373 cls, 

1374 table: Table, 

1375 ref_tables: ResourceTableMap, 

1376 resource_schema: Schema, 

1377 dereference_resources: bool = False, 

1378 column_ref_index_increment: int | None = None, 

1379 ) -> None: 

1380 """Process column references from an external resource and add them 

1381 to the given table as columns. 

1382 """ 

1383 current_column_index = column_ref_index_increment if column_ref_index_increment is not None else -1 

1384 

1385 for table_name, columns in ref_tables.items(): 

1386 try: 

1387 resource_table = resource_schema._find_table_by_name(table_name) 

1388 except KeyError as e: 

1389 raise ValueError( 

1390 f"Table '{table_name}' not found in resource '{resource_schema.name}'" 

1391 ) from e 

1392 for local_column_name, column_ref in columns.items(): 

1393 if column_ref is not None and column_ref.ref_name is not None: 

1394 # Use specified ref_name 

1395 ref_column_name = column_ref.ref_name 

1396 else: 

1397 # Use the local column name if no ref_name 

1398 # specified 

1399 ref_column_name = local_column_name 

1400 

1401 # Check if referenced column exists in resource 

1402 try: 

1403 base_column = resource_table._find_column_by_name(ref_column_name) 

1404 except KeyError: 

1405 # The ref_name is specified but column is not 

1406 # found 

1407 if column_ref is not None and column_ref.ref_name is not None: 

1408 raise ValueError( 

1409 f"Column '{ref_column_name}' not found in table '{table_name}' " 

1410 f"from resource '{resource_schema.name}'" 

1411 ) 

1412 # The ref_name is not specified and the local 

1413 # column name is not found 

1414 raise ValueError( 

1415 f"Column '{local_column_name}' not found in table '{table_name}' " 

1416 f"from resource '{resource_schema.name}' and no ref_name provided" 

1417 ) 

1418 

1419 # Create a copy of the base column 

1420 column_copy = base_column.model_copy() 

1421 

1422 # Set the local name (key from the mapping) 

1423 column_copy.name = local_column_name 

1424 

1425 if not dereference_resources: 

1426 # Flag the column as a resource reference so it will not be 

1427 # written out during serialization 

1428 column_copy._is_resource_ref = True 

1429 

1430 # Apply overrides to the referenced column definition 

1431 overrides = column_ref.overrides if column_ref is not None else None 

1432 if overrides is not None: 

1433 column_copy._update_from_overrides(overrides) 

1434 

1435 # Manually set the ID of the copied column as ID generation has 

1436 # already occurred by now 

1437 column_copy.id = f"{table.id}.{local_column_name}" 

1438 

1439 # Apply automatic assignment of 'tap:column_index', if enabled 

1440 if column_ref_index_increment is not None: 

1441 if (not overrides) or (overrides.tap_column_index is None): 

1442 column_copy.tap_column_index = current_column_index 

1443 current_column_index += column_ref_index_increment 

1444 logger.debug( 

1445 f"Automatically assigned 'tap:column_index' {column_copy.tap_column_index} to " 

1446 f"column '{local_column_name}' in table '{table_name}' from resource " 

1447 f"'{resource_schema.name}'" 

1448 ) 

1449 else: 

1450 # Skip automatic assignment of 'tap:column_index' if it 

1451 # is already overridden 

1452 logger.debug( 

1453 f"Skipping automatic assignment of 'tap:column_index' for column " 

1454 f"'{local_column_name}' in table '{table_name}' from resource " 

1455 f"'{resource_schema.name}' as it is already overridden to " 

1456 f"{column_copy.tap_column_index}" 

1457 ) 

1458 table.columns.append(column_copy) 

1459 logger.debug( 

1460 f"Dereferenced column '{local_column_name}' from table '{table_name}' " 

1461 f"in resource '{resource_schema.name}' into table '{table.name}'" 

1462 ) 

1463 

1464 @model_validator(mode="before") 

1465 @classmethod 

1466 def generate_ids(cls, values: dict[str, Any], info: ValidationInfo) -> dict[str, Any]: 

1467 """Generate IDs for objects that do not have them. 

1468 

1469 Parameters 

1470 ---------- 

1471 values 

1472 The values of the schema. 

1473 info 

1474 Validation context used to determine if ID generation is enabled. 

1475 

1476 Returns 

1477 ------- 

1478 `dict` [ `str`, `Any` ] 

1479 The values of the schema with generated IDs. 

1480 """ 

1481 context = info.context 

1482 if not context or not context.get("id_generation", False): 

1483 logger.debug("Skipping ID generation") 

1484 return values 

1485 schema_name = values["name"] 

1486 if "@id" not in values: 

1487 values["@id"] = f"#{schema_name}" 

1488 logger.debug(f"Generated ID '{values['@id']}' for schema '{schema_name}'") 

1489 if "tables" in values: 

1490 for table in values["tables"]: 

1491 if "@id" not in table: 

1492 table["@id"] = f"#{table['name']}" 

1493 logger.debug(f"Generated ID '{table['@id']}' for table '{table['name']}'") 

1494 if "columns" in table: 

1495 for column in table["columns"]: 

1496 if "@id" not in column: 

1497 column["@id"] = f"#{table['name']}.{column['name']}" 

1498 logger.debug(f"Generated ID '{column['@id']}' for column '{column['name']}'") 

1499 if "columnGroups" in table: 

1500 for column_group in table["columnGroups"]: 

1501 if "@id" not in column_group: 

1502 column_group["@id"] = f"#{table['name']}.{column_group['name']}" 

1503 logger.debug( 

1504 f"Generated ID '{column_group['@id']}' for column group " 

1505 f"'{column_group['name']}'" 

1506 ) 

1507 if "constraints" in table: 

1508 for constraint in table["constraints"]: 

1509 if "@id" not in constraint: 

1510 constraint["@id"] = f"#{constraint['name']}" 

1511 logger.debug( 

1512 f"Generated ID '{constraint['@id']}' for constraint '{constraint['name']}'" 

1513 ) 

1514 if "indexes" in table: 

1515 for index in table["indexes"]: 

1516 if "@id" not in index: 

1517 index["@id"] = f"#{index['name']}" 

1518 logger.debug(f"Generated ID '{index['@id']}' for index '{index['name']}'") 

1519 return values 

1520 

1521 @field_validator("tables", mode="after") 

1522 @classmethod 

1523 def check_unique_table_names(cls, tables: list[Table]) -> list[Table]: 

1524 """Check that table names are unique. 

1525 

1526 Parameters 

1527 ---------- 

1528 tables 

1529 The tables to check. 

1530 

1531 Returns 

1532 ------- 

1533 `list` [ `Table` ] 

1534 The tables if they are unique. 

1535 

1536 Raises 

1537 ------ 

1538 ValueError 

1539 Raised if table names are not unique. 

1540 """ 

1541 if len(tables) != len(set(table.name for table in tables)): 

1542 raise ValueError("Table names must be unique") 

1543 return tables 

1544 

1545 @model_validator(mode="after") 

1546 def check_tap_table_indexes(self, info: ValidationInfo) -> Schema: 

1547 """Check that the TAP table indexes are unique. 

1548 

1549 Parameters 

1550 ---------- 

1551 info 

1552 The validation context used to determine if the check is enabled. 

1553 

1554 Returns 

1555 ------- 

1556 `Schema` 

1557 The schema being validated. 

1558 """ 

1559 context = info.context 

1560 if not context or not context.get("check_tap_table_indexes", False): 

1561 return self 

1562 table_indicies = set() 

1563 for table in self.tables: 

1564 table_index = table.tap_table_index 

1565 if table_index is not None: 

1566 if table_index in table_indicies: 

1567 raise ValueError(f"Duplicate 'tap:table_index' value {table_index} found in schema") 

1568 table_indicies.add(table_index) 

1569 return self 

1570 

1571 @model_validator(mode="after") 

1572 def check_unique_constraint_names(self: Schema) -> Schema: 

1573 """Check for duplicate constraint names in the schema. 

1574 

1575 Returns 

1576 ------- 

1577 `Schema` 

1578 The schema being validated. 

1579 

1580 Raises 

1581 ------ 

1582 ValueError 

1583 Raised if duplicate constraint names are found in the schema. 

1584 """ 

1585 constraint_names = set() 

1586 duplicate_names = [] 

1587 

1588 for table in self.tables: 

1589 for constraint in table.constraints: 

1590 constraint_name = constraint.name 

1591 if constraint_name in constraint_names: 

1592 duplicate_names.append(constraint_name) 

1593 else: 

1594 constraint_names.add(constraint_name) 

1595 

1596 if duplicate_names: 

1597 raise ValueError(f"Duplicate constraint names found in schema: {duplicate_names}") 

1598 

1599 return self 

1600 

1601 @model_validator(mode="after") 

1602 def check_unique_index_names(self: Schema) -> Schema: 

1603 """Check for duplicate index names in the schema. 

1604 

1605 Returns 

1606 ------- 

1607 `Schema` 

1608 The schema being validated. 

1609 

1610 Raises 

1611 ------ 

1612 ValueError 

1613 Raised if duplicate index names are found in the schema. 

1614 """ 

1615 index_names = set() 

1616 duplicate_names = [] 

1617 

1618 for table in self.tables: 

1619 for index in table.indexes: 

1620 index_name = index.name 

1621 if index_name in index_names: 

1622 duplicate_names.append(index_name) 

1623 else: 

1624 index_names.add(index_name) 

1625 

1626 if duplicate_names: 

1627 raise ValueError(f"Duplicate index names found in schema: {duplicate_names}") 

1628 

1629 return self 

1630 

1631 @model_validator(mode="after") 

1632 def create_id_map(self: Schema) -> Schema: 

1633 """Create a map of IDs to objects. 

1634 

1635 Returns 

1636 ------- 

1637 `Schema` 

1638 The schema with the ID map created. 

1639 

1640 Raises 

1641 ------ 

1642 ValueError 

1643 Raised if duplicate identifiers are found in the schema. 

1644 """ 

1645 if self._id_map: 

1646 logger.debug("Ignoring call to create_id_map() - ID map was already populated") 

1647 return self 

1648 visitor: SchemaIdVisitor = SchemaIdVisitor() 

1649 visitor.visit_schema(self) 

1650 if len(visitor.duplicates): 

1651 raise ValueError( 

1652 "Duplicate IDs found in schema:\n " + "\n ".join(visitor.duplicates) + "\n" 

1653 ) 

1654 logger.debug("Created ID map with %d entries", len(self._id_map)) 

1655 return self 

1656 

1657 def _validate_column_id( 

1658 self: Schema, 

1659 column_id: str, 

1660 loc: tuple, 

1661 errors: list[InitErrorDetails], 

1662 ) -> None: 

1663 """Validate a column ID from a constraint and append errors if invalid. 

1664 

1665 Parameters 

1666 ---------- 

1667 schema : Schema 

1668 The schema being validated. 

1669 column_id : str 

1670 The column ID to validate. 

1671 loc : tuple 

1672 The location of the error in the schema. 

1673 errors : list[InitErrorDetails] 

1674 The list of errors to append to. 

1675 """ 

1676 if column_id not in self: 

1677 _append_error( 

1678 errors, 

1679 loc, 

1680 column_id, 

1681 f"Column ID '{column_id}' not found in schema", 

1682 ) 

1683 elif not isinstance(self[column_id], Column): 

1684 _append_error( 

1685 errors, 

1686 loc, 

1687 column_id, 

1688 f"ID '{column_id}' does not refer to a Column object", 

1689 ) 

1690 

1691 def _validate_foreign_key_column( 

1692 self: Schema, 

1693 column_id: str, 

1694 table: Table, 

1695 loc: tuple, 

1696 errors: list[InitErrorDetails], 

1697 ) -> None: 

1698 """Validate a foreign key column ID from a constraint and append errors 

1699 if invalid. 

1700 

1701 Parameters 

1702 ---------- 

1703 schema : Schema 

1704 The schema being validated. 

1705 column_id : str 

1706 The foreign key column ID to validate. 

1707 loc : tuple 

1708 The location of the error in the schema. 

1709 errors : list[InitErrorDetails] 

1710 The list of errors to append to. 

1711 """ 

1712 try: 

1713 table._find_column_by_id(column_id) 

1714 except KeyError: 

1715 _append_error( 

1716 errors, 

1717 loc, 

1718 column_id, 

1719 f"Column '{column_id}' not found in table '{table.name}'", 

1720 ) 

1721 

1722 @model_validator(mode="after") 

1723 def check_constraints(self: Schema) -> Schema: 

1724 """Check constraint objects for validity. This needs to be deferred 

1725 until after the schema is fully loaded and the ID map is created. 

1726 

1727 Raises 

1728 ------ 

1729 pydantic.ValidationError 

1730 Raised if any constraints are invalid. 

1731 

1732 Returns 

1733 ------- 

1734 `Schema` 

1735 The schema being validated. 

1736 """ 

1737 errors: list[InitErrorDetails] = [] 

1738 

1739 for table_index, table in enumerate(self.tables): 

1740 for constraint_index, constraint in enumerate(table.constraints): 

1741 column_ids: list[str] = [] 

1742 referenced_column_ids: list[str] = [] 

1743 

1744 if isinstance(constraint, ForeignKeyConstraint): 

1745 column_ids += constraint.columns 

1746 referenced_column_ids += constraint.referenced_columns 

1747 elif isinstance(constraint, UniqueConstraint): 

1748 column_ids += constraint.columns 

1749 # No extra checks are required on CheckConstraint objects. 

1750 

1751 # Validate the foreign key columns 

1752 for column_id in column_ids: 

1753 self._validate_column_id( 

1754 column_id, 

1755 ( 

1756 "tables", 

1757 table_index, 

1758 "constraints", 

1759 constraint_index, 

1760 "columns", 

1761 column_id, 

1762 ), 

1763 errors, 

1764 ) 

1765 # Check that the foreign key column is within the source 

1766 # table. 

1767 self._validate_foreign_key_column( 

1768 column_id, 

1769 table, 

1770 ( 

1771 "tables", 

1772 table_index, 

1773 "constraints", 

1774 constraint_index, 

1775 "columns", 

1776 column_id, 

1777 ), 

1778 errors, 

1779 ) 

1780 

1781 # Validate the primary key (reference) columns 

1782 for referenced_column_id in referenced_column_ids: 

1783 self._validate_column_id( 

1784 referenced_column_id, 

1785 ( 

1786 "tables", 

1787 table_index, 

1788 "constraints", 

1789 constraint_index, 

1790 "referenced_columns", 

1791 referenced_column_id, 

1792 ), 

1793 errors, 

1794 ) 

1795 

1796 if errors: 

1797 raise ValidationError.from_exception_data("Schema validation failed", errors) 

1798 

1799 return self 

1800 

1801 def __getitem__(self, id: str) -> BaseObject: 

1802 """Get an object by its ID. 

1803 

1804 Parameters 

1805 ---------- 

1806 id 

1807 The ID of the object to get. 

1808 

1809 Raises 

1810 ------ 

1811 KeyError 

1812 Raised if the object with the given ID is not found in the schema. 

1813 """ 

1814 if id not in self: 

1815 raise KeyError(f"Object with ID '{id}' not found in schema") 

1816 return self._id_map[id] 

1817 

1818 def __contains__(self, id: str) -> bool: 

1819 """Check if an object with the given ID is in the schema. 

1820 

1821 Parameters 

1822 ---------- 

1823 id 

1824 The ID of the object to check. 

1825 """ 

1826 return id in self._id_map 

1827 

1828 def find_object_by_id(self, id: str, obj_type: type[T]) -> T: 

1829 """Find an object with the given type by its ID. 

1830 

1831 Parameters 

1832 ---------- 

1833 id 

1834 The ID of the object to find. 

1835 obj_type 

1836 The type of the object to find. 

1837 

1838 Returns 

1839 ------- 

1840 BaseObject 

1841 The object with the given ID and type. 

1842 

1843 Raises 

1844 ------ 

1845 KeyError 

1846 If the object with the given ID is not found in the schema. 

1847 TypeError 

1848 If the object that is found does not have the right type. 

1849 

1850 Notes 

1851 ----- 

1852 The actual return type is the user-specified argument ``T``, which is 

1853 expected to be a subclass of `BaseObject`. 

1854 """ 

1855 obj = self[id] 

1856 if not isinstance(obj, obj_type): 

1857 raise TypeError(f"Object with ID '{id}' is not of type '{obj_type.__name__}'") 

1858 return obj 

1859 

1860 def get_table_by_column(self, column: Column) -> Table: 

1861 """Find the table that contains a column. 

1862 

1863 Parameters 

1864 ---------- 

1865 column 

1866 The column to find. 

1867 

1868 Returns 

1869 ------- 

1870 `Table` 

1871 The table that contains the column. 

1872 

1873 Raises 

1874 ------ 

1875 ValueError 

1876 If the column is not found in any table. 

1877 """ 

1878 for table in self.tables: 

1879 if column in table.columns: 

1880 return table 

1881 raise ValueError(f"Column '{column.name}' not found in any table") 

1882 

1883 @classmethod 

1884 def from_uri(cls, resource_path: ResourcePathExpression, context: dict[str, Any] = {}) -> Schema: 

1885 """Load a `Schema` from a string representing a ``ResourcePath``. 

1886 

1887 Parameters 

1888 ---------- 

1889 resource_path 

1890 The ``ResourcePath`` pointing to a YAML file. 

1891 context 

1892 Pydantic context to be used in validation. 

1893 

1894 Returns 

1895 ------- 

1896 `str` 

1897 The ID of the object. 

1898 

1899 Raises 

1900 ------ 

1901 yaml.YAMLError 

1902 Raised if there is an error loading the YAML data. 

1903 ValueError 

1904 Raised if there is an error reading the resource. 

1905 pydantic.ValidationError 

1906 Raised if the schema fails validation. 

1907 """ 

1908 try: 

1909 rp_stream = ResourcePath(resource_path).read() 

1910 except Exception as e: 

1911 raise ValueError(f"Error reading resource from '{resource_path}' : {e}") from e 

1912 yaml_data = yaml.safe_load(rp_stream) 

1913 return Schema.model_validate(yaml_data, context=context) 

1914 

1915 @classmethod 

1916 def from_stream(cls, source: IO[str], context: dict[str, Any] = {}) -> Schema: 

1917 """Load a `Schema` from a file stream which should contain YAML data. 

1918 

1919 Parameters 

1920 ---------- 

1921 source 

1922 The file stream to read from. 

1923 context 

1924 Pydantic context to be used in validation. 

1925 

1926 Returns 

1927 ------- 

1928 `Schema` 

1929 The Felis schema loaded from the stream. 

1930 

1931 Raises 

1932 ------ 

1933 yaml.YAMLError 

1934 Raised if there is an error loading the YAML file. 

1935 pydantic.ValidationError 

1936 Raised if the schema fails validation. 

1937 """ 

1938 logger.debug("Loading schema from: '%s'", source) 

1939 yaml_data = yaml.safe_load(source) 

1940 return Schema.model_validate(yaml_data, context=context) 

1941 

1942 def _model_dump(self, strip_ids: bool = False, sort_columns: bool = False) -> dict[str, Any]: 

1943 """Dump the schema as a dictionary with some default arguments 

1944 applied. 

1945 

1946 Parameters 

1947 ---------- 

1948 strip_ids 

1949 Whether to strip the IDs from the dumped data. Defaults to `False`. 

1950 sort_columns 

1951 Whether to sort columns alphabetically by name. Defaults to 

1952 `False`. 

1953 

1954 Returns 

1955 ------- 

1956 `dict` [ `str`, `Any` ] 

1957 The dumped schema data as a dictionary. 

1958 """ 

1959 data = self.model_dump(by_alias=True, exclude_none=True, exclude_defaults=True) 

1960 if strip_ids: 

1961 data = _strip_ids(data) 

1962 if sort_columns: 

1963 for table in data.get("tables", []): 

1964 table["columns"] = sorted(table.get("columns", []), key=itemgetter("name")) 

1965 return data 

1966 

1967 def dump_yaml( 

1968 self, stream: IO[str] = sys.stdout, strip_ids: bool = False, sort_columns: bool = False 

1969 ) -> None: 

1970 """Pretty print the schema as YAML. 

1971 

1972 Parameters 

1973 ---------- 

1974 stream 

1975 The stream to write the YAML data to. 

1976 strip_ids 

1977 Whether to strip the IDs from the dumped data. Defaults to `False`. 

1978 sort_columns 

1979 Whether to sort columns alphabetically by name. Defaults to 

1980 `False`. 

1981 """ 

1982 data = self._model_dump(strip_ids=strip_ids, sort_columns=sort_columns) 

1983 yaml.safe_dump( 

1984 data, 

1985 stream, 

1986 default_flow_style=False, 

1987 sort_keys=False, 

1988 ) 

1989 

1990 def dump_json( 

1991 self, stream: IO[str] = sys.stdout, strip_ids: bool = False, sort_columns: bool = False 

1992 ) -> None: 

1993 """Pretty print the schema as JSON. 

1994 

1995 Parameters 

1996 ---------- 

1997 stream 

1998 The stream to write the JSON data to. 

1999 strip_ids 

2000 Whether to strip the IDs from the dumped data. Defaults to `False`. 

2001 sort_columns 

2002 Whether to sort columns alphabetically by name. Defaults to 

2003 `False`. 

2004 """ 

2005 data = self._model_dump(strip_ids=strip_ids, sort_columns=sort_columns) 

2006 json.dump( 

2007 data, 

2008 stream, 

2009 indent=4, 

2010 sort_keys=False, 

2011 )