Coverage for python / felis / datamodel.py: 31%

719 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-26 08:49 +0000

1"""Define Pydantic data models for Felis.""" 

2 

3# This file is part of felis. 

4# 

5# Developed for the LSST Data Management System. 

6# This product includes software developed by the LSST Project 

7# (https://www.lsst.org). 

8# See the COPYRIGHT file at the top-level directory of this distribution 

9# for details of code ownership. 

10# 

11# This program is free software: you can redistribute it and/or modify 

12# it under the terms of the GNU General Public License as published by 

13# the Free Software Foundation, either version 3 of the License, or 

14# (at your option) any later version. 

15# 

16# This program is distributed in the hope that it will be useful, 

17# but WITHOUT ANY WARRANTY; without even the implied warranty of 

18# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

19# GNU General Public License for more details. 

20# 

21# You should have received a copy of the GNU General Public License 

22# along with this program. If not, see <https://www.gnu.org/licenses/>. 

23 

24from __future__ import annotations 

25 

26import json 

27import logging 

28import sys 

29from collections.abc import Sequence 

30from enum import StrEnum, auto 

31from typing import IO, Annotated, Any, Generic, Literal, TypeAlias, TypeVar 

32 

33import yaml 

34from astropy import units as units # type: ignore 

35from astropy.io.votable import ucd # type: ignore 

36from lsst.resources import ResourcePath, ResourcePathExpression 

37from pydantic import ( 

38 BaseModel, 

39 ConfigDict, 

40 Field, 

41 PrivateAttr, 

42 ValidationError, 

43 ValidationInfo, 

44 field_serializer, 

45 field_validator, 

46 model_validator, 

47) 

48from pydantic_core import InitErrorDetails 

49 

50from .db._dialects import get_supported_dialects, string_to_typeengine 

51from .db._sqltypes import get_type_func 

52from .types import Boolean, Byte, Char, Double, FelisType, Float, Int, Long, Short, String, Text, Unicode 

53 

54logger = logging.getLogger(__name__) 

55 

56__all__ = ( 

57 "BaseObject", 

58 "CheckConstraint", 

59 "Column", 

60 "ColumnOverrides", 

61 "ColumnResourceRef", 

62 "Constraint", 

63 "DataType", 

64 "ForeignKeyConstraint", 

65 "Index", 

66 "Resource", 

67 "Schema", 

68 "SchemaVersion", 

69 "Table", 

70 "UniqueConstraint", 

71) 

72 

73CONFIG = ConfigDict( 

74 populate_by_name=True, # Populate attributes by name. 

75 extra="forbid", # Do not allow extra fields. 

76 str_strip_whitespace=True, # Strip whitespace from string fields. 

77 use_enum_values=False, # Do not use enum values during serialization. 

78) 

79"""Pydantic model configuration as described in: 

80https://docs.pydantic.dev/2.0/api/config/#pydantic.config.ConfigDict 

81""" 

82 

83DESCR_MIN_LENGTH = 3 

84"""Minimum length for a description field.""" 

85 

86DescriptionStr: TypeAlias = Annotated[str, Field(min_length=DESCR_MIN_LENGTH)] 

87"""Type for a description, which must be three or more characters long.""" 

88 

89 

90class BaseObject(BaseModel): 

91 """Base model. 

92 

93 All classes representing objects in the Felis data model should inherit 

94 from this class. 

95 """ 

96 

97 model_config = CONFIG 

98 """Pydantic model configuration.""" 

99 

100 name: str 

101 """Name of the database object.""" 

102 

103 id: str = Field(alias="@id") 

104 """Unique identifier of the database object.""" 

105 

106 description: DescriptionStr | None = None 

107 """Description of the database object.""" 

108 

109 votable_utype: str | None = Field(None, alias="votable:utype") 

110 """VOTable utype (usage-specific or unique type) of the object.""" 

111 

112 @model_validator(mode="after") 

113 def check_description(self, info: ValidationInfo) -> BaseObject: 

114 """Check that the description is present if required. 

115 

116 Parameters 

117 ---------- 

118 info 

119 Validation context used to determine if the check is enabled. 

120 

121 Returns 

122 ------- 

123 `BaseObject` 

124 The object being validated. 

125 """ 

126 context = info.context 

127 if not context or not context.get("check_description", False): 

128 return self 

129 if self.description is None or self.description == "": 

130 raise ValueError("Description is required and must be non-empty") 

131 if len(self.description) < DESCR_MIN_LENGTH: 

132 raise ValueError(f"Description must be at least {DESCR_MIN_LENGTH} characters long") 

133 return self 

134 

135 

136class DataType(StrEnum): 

137 """``Enum`` representing the data types supported by Felis.""" 

138 

139 boolean = auto() 

140 byte = auto() 

141 short = auto() 

142 int = auto() 

143 long = auto() 

144 float = auto() 

145 double = auto() 

146 char = auto() 

147 string = auto() 

148 unicode = auto() 

149 text = auto() 

150 binary = auto() 

151 timestamp = auto() 

152 

153 

154def validate_ivoa_ucd(ivoa_ucd: str) -> str: 

155 """Validate IVOA UCD values. 

156 

157 Parameters 

158 ---------- 

159 ivoa_ucd 

160 IVOA UCD value to check. 

161 

162 Returns 

163 ------- 

164 `str` 

165 The IVOA UCD value if it is valid. 

166 

167 Raises 

168 ------ 

169 ValueError 

170 If the IVOA UCD value is invalid. 

171 """ 

172 if ivoa_ucd is not None: 

173 try: 

174 ucd.parse_ucd(ivoa_ucd, check_controlled_vocabulary=True, has_colon=";" in ivoa_ucd) 

175 except ValueError as e: 

176 raise ValueError(f"Invalid IVOA UCD: {e}") 

177 return ivoa_ucd 

178 

179 

180class Column(BaseObject): 

181 """Column model.""" 

182 

183 datatype: DataType 

184 """Datatype of the column.""" 

185 

186 length: int | None = Field(None, gt=0) 

187 """Length of the column.""" 

188 

189 precision: int | None = Field(None, ge=0) 

190 """The numerical precision of the column. 

191 

192 For timestamps, this is the number of fractional digits retained in the 

193 seconds field. 

194 """ 

195 

196 nullable: bool = True 

197 """Whether the column can be ``NULL``.""" 

198 

199 value: str | int | float | bool | None = None 

200 """Default value of the column.""" 

201 

202 autoincrement: bool | None = None 

203 """Whether the column is autoincremented.""" 

204 

205 ivoa_ucd: str | None = Field(None, alias="ivoa:ucd") 

206 """IVOA UCD of the column.""" 

207 

208 fits_tunit: str | None = Field(None, alias="fits:tunit") 

209 """FITS TUNIT of the column.""" 

210 

211 ivoa_unit: str | None = Field(None, alias="ivoa:unit") 

212 """IVOA unit of the column.""" 

213 

214 tap_column_index: int | None = Field(None, alias="tap:column_index") 

215 """TAP_SCHEMA column index of the column.""" 

216 

217 tap_principal: int | None = Field(0, alias="tap:principal", ge=0, le=1) 

218 """Whether this is a TAP_SCHEMA principal column.""" 

219 

220 votable_arraysize: int | str | None = Field(None, alias="votable:arraysize") 

221 """VOTable arraysize of the column.""" 

222 

223 tap_std: int | None = Field(0, alias="tap:std", ge=0, le=1) 

224 """TAP_SCHEMA indication that this column is defined by an IVOA standard. 

225 """ 

226 

227 votable_xtype: str | None = Field(None, alias="votable:xtype") 

228 """VOTable xtype (extended type) of the column.""" 

229 

230 votable_datatype: str | None = Field(None, alias="votable:datatype") 

231 """VOTable datatype of the column.""" 

232 

233 mysql_datatype: str | None = Field(None, alias="mysql:datatype") 

234 """MySQL datatype override on the column.""" 

235 

236 postgresql_datatype: str | None = Field(None, alias="postgresql:datatype") 

237 """PostgreSQL datatype override on the column.""" 

238 

239 _is_resource_ref: bool = PrivateAttr(False) 

240 """Whether this column is a resource reference column.""" 

241 

242 @model_validator(mode="after") 

243 def check_value(self) -> Column: 

244 """Check that the default value is valid. 

245 

246 Returns 

247 ------- 

248 `Column` 

249 The column being validated. 

250 """ 

251 if (value := self.value) is not None: 

252 if value is not None and self.autoincrement is True: 

253 raise ValueError("Column cannot have both a default value and be autoincremented") 

254 felis_type = FelisType.felis_type(self.datatype) 

255 if felis_type.is_numeric: 

256 if felis_type in (Byte, Short, Int, Long) and not isinstance(value, int): 

257 raise ValueError("Default value must be an int for integer type columns") 

258 elif felis_type in (Float, Double) and not isinstance(value, float): 

259 raise ValueError("Default value must be a decimal number for float and double columns") 

260 elif felis_type in (String, Char, Unicode, Text): 

261 if not isinstance(value, str): 

262 raise ValueError("Default value must be a string for string columns") 

263 if not len(value): 

264 raise ValueError("Default value must be a non-empty string for string columns") 

265 elif felis_type is Boolean and not isinstance(value, bool): 

266 raise ValueError("Default value must be a boolean for boolean columns") 

267 return self 

268 

269 @field_validator("ivoa_ucd") 

270 @classmethod 

271 def check_ivoa_ucd(cls, ivoa_ucd: str) -> str: 

272 """Check that IVOA UCD values are valid. 

273 

274 Parameters 

275 ---------- 

276 ivoa_ucd 

277 IVOA UCD value to check. 

278 

279 Returns 

280 ------- 

281 `str` 

282 The IVOA UCD value if it is valid. 

283 """ 

284 return validate_ivoa_ucd(ivoa_ucd) 

285 

286 @model_validator(mode="after") 

287 def check_units(self) -> Column: 

288 """Check that the ``fits:tunit`` or ``ivoa:unit`` field has valid 

289 units according to astropy. Only one may be provided. 

290 

291 Returns 

292 ------- 

293 `Column` 

294 The column being validated. 

295 

296 Raises 

297 ------ 

298 ValueError 

299 Raised if both FITS and IVOA units are provided, or if the unit is 

300 invalid. 

301 """ 

302 fits_unit = self.fits_tunit 

303 ivoa_unit = self.ivoa_unit 

304 

305 if fits_unit and ivoa_unit: 

306 raise ValueError("Column cannot have both FITS and IVOA units") 

307 unit = fits_unit or ivoa_unit 

308 

309 if unit is not None: 

310 try: 

311 units.Unit(unit) 

312 except ValueError as e: 

313 raise ValueError(f"Invalid unit: {e}") 

314 

315 return self 

316 

317 @model_validator(mode="before") 

318 @classmethod 

319 def check_length(cls, values: dict[str, Any]) -> dict[str, Any]: 

320 """Check that a valid length is provided for sized types. 

321 

322 Parameters 

323 ---------- 

324 values 

325 Values of the column. 

326 

327 Returns 

328 ------- 

329 `dict` [ `str`, `Any` ] 

330 The values of the column. 

331 

332 Raises 

333 ------ 

334 ValueError 

335 Raised if a length is not provided for a sized type. 

336 """ 

337 datatype = values.get("datatype") 

338 if datatype is None: 

339 # Skip this validation if datatype is not provided 

340 return values 

341 length = values.get("length") 

342 felis_type = FelisType.felis_type(datatype) 

343 if felis_type.is_sized and length is None: 

344 raise ValueError( 

345 f"Length must be provided for type '{datatype}'" 

346 + (f" in column '{values['@id']}'" if "@id" in values else "") 

347 ) 

348 elif not felis_type.is_sized and length is not None: 

349 logger.warning( 

350 f"The datatype '{datatype}' does not support a specified length" 

351 + (f" in column '{values['@id']}'" if "@id" in values else "") 

352 ) 

353 return values 

354 

355 @model_validator(mode="after") 

356 def check_redundant_datatypes(self, info: ValidationInfo) -> Column: 

357 """Check for redundant datatypes on columns. 

358 

359 Parameters 

360 ---------- 

361 info 

362 Validation context used to determine if the check is enabled. 

363 

364 Returns 

365 ------- 

366 `Column` 

367 The column being validated. 

368 

369 Raises 

370 ------ 

371 ValueError 

372 Raised if a datatype override is redundant. 

373 """ 

374 context = info.context 

375 if not context or not context.get("check_redundant_datatypes", False): 

376 return self 

377 if all( 

378 getattr(self, f"{dialect}:datatype", None) is not None 

379 for dialect in get_supported_dialects().keys() 

380 ): 

381 return self 

382 

383 datatype = self.datatype 

384 length: int | None = self.length or None 

385 

386 datatype_func = get_type_func(datatype) 

387 felis_type = FelisType.felis_type(datatype) 

388 if felis_type.is_sized: 

389 datatype_obj = datatype_func(length) 

390 else: 

391 datatype_obj = datatype_func() 

392 

393 for dialect_name, dialect in get_supported_dialects().items(): 

394 db_annotation = f"{dialect_name}_datatype" 

395 if datatype_string := self.model_dump().get(db_annotation): 

396 db_datatype_obj = string_to_typeengine(datatype_string, dialect, length) 

397 if datatype_obj.compile(dialect) == db_datatype_obj.compile(dialect): 

398 raise ValueError( 

399 "'{}: {}' is a redundant override of 'datatype: {}' in column '{}'{}".format( 

400 db_annotation, 

401 datatype_string, 

402 self.datatype, 

403 self.id, 

404 "" if length is None else f" with length {length}", 

405 ) 

406 ) 

407 else: 

408 logger.debug( 

409 f"Type override of 'datatype: {self.datatype}' " 

410 f"with '{db_annotation}: {datatype_string}' in column '{self.id}' " 

411 f"compiled to '{datatype_obj.compile(dialect)}' and " 

412 f"'{db_datatype_obj.compile(dialect)}'" 

413 ) 

414 return self 

415 

416 @model_validator(mode="after") 

417 def check_precision(self) -> Column: 

418 """Check that precision is only valid for timestamp columns. 

419 

420 Returns 

421 ------- 

422 `Column` 

423 The column being validated. 

424 """ 

425 if self.precision is not None and self.datatype != "timestamp": 

426 raise ValueError("Precision is only valid for timestamp columns") 

427 return self 

428 

429 @model_validator(mode="before") 

430 @classmethod 

431 def check_votable_arraysize(cls, values: dict[str, Any], info: ValidationInfo) -> dict[str, Any]: 

432 """Set the default value for the ``votable_arraysize`` field, which 

433 corresponds to ``arraysize`` in the IVOA VOTable standard. 

434 

435 Parameters 

436 ---------- 

437 values 

438 Values of the column. 

439 info 

440 Validation context used to determine if the check is enabled. 

441 

442 Returns 

443 ------- 

444 `dict` [ `str`, `Any` ] 

445 The values of the column. 

446 

447 Notes 

448 ----- 

449 Following the IVOA VOTable standard, an ``arraysize`` of 1 should not 

450 be used. 

451 """ 

452 if values.get("name", None) is None or values.get("datatype", None) is None: 

453 # Skip bad column data that will not validate 

454 return values 

455 context = info.context if info.context else {} 

456 arraysize = values.get("votable:arraysize", None) 

457 if arraysize is None: 

458 length = values.get("length", None) 

459 datatype = values.get("datatype") 

460 if length is not None and length > 1: 

461 # Following the IVOA standard, arraysize of 1 is disallowed 

462 if datatype == "char": 

463 arraysize = str(length) 

464 elif datatype in ("string", "unicode", "binary"): 

465 if context.get("force_unbounded_arraysize", False): 

466 arraysize = "*" 

467 logger.debug( 

468 f"Forced VOTable's 'arraysize' to '*' on column '{values['name']}' with datatype " 

469 + f"'{values['datatype']}' and length '{length}'" 

470 ) 

471 else: 

472 arraysize = f"{length}*" 

473 elif datatype in ("timestamp", "text"): 

474 arraysize = "*" 

475 if arraysize is not None: 

476 values["votable:arraysize"] = arraysize 

477 logger.debug( 

478 f"Set default 'votable:arraysize' to '{arraysize}' on column '{values['name']}'" 

479 + f" with datatype '{values['datatype']}' and length '{values.get('length', None)}'" 

480 ) 

481 else: 

482 logger.debug(f"Using existing 'votable:arraysize' of '{arraysize}' on column '{values['name']}'") 

483 if isinstance(values["votable:arraysize"], int): 

484 logger.warning( 

485 f"Usage of an integer value for 'votable:arraysize' in column '{values['name']}' is " 

486 + "deprecated" 

487 ) 

488 values["votable:arraysize"] = str(arraysize) 

489 return values 

490 

491 @field_serializer("datatype") 

492 def serialize_datatype(self, value: DataType) -> str: 

493 """Convert `DataType` to string when serializing to JSON/YAML. 

494 

495 Parameters 

496 ---------- 

497 value 

498 The `DataType` value to serialize. 

499 

500 Returns 

501 ------- 

502 `str` 

503 The serialized `DataType` value. 

504 """ 

505 return str(value) 

506 

507 @field_validator("datatype", mode="before") 

508 @classmethod 

509 def deserialize_datatype(cls, value: str) -> DataType: 

510 """Convert string back into `DataType` when loading from JSON/YAML. 

511 

512 Parameters 

513 ---------- 

514 value 

515 The string value to deserialize. 

516 

517 Returns 

518 ------- 

519 `DataType` 

520 The deserialized `DataType` value. 

521 """ 

522 return DataType(value) 

523 

524 @model_validator(mode="after") 

525 def check_votable_xtype(self) -> Column: 

526 """Set the default value for the ``votable_xtype`` field, which 

527 corresponds to an Extended Datatype or ``xtype`` in the IVOA VOTable 

528 standard. 

529 

530 Returns 

531 ------- 

532 `Column` 

533 The column being validated. 

534 

535 Notes 

536 ----- 

537 This is currently only set automatically for the Felis ``timestamp`` 

538 datatype. 

539 """ 

540 if self.datatype == DataType.timestamp and self.votable_xtype is None: 

541 self.votable_xtype = "timestamp" 

542 return self 

543 

544 def _update_from_overrides(self, overrides: ColumnOverrides) -> None: 

545 """Update the column attributes from the given overrides. 

546 

547 Parameters 

548 ---------- 

549 overrides 

550 The column overrides to apply or `None` to skip applying overrides. 

551 

552 Notes 

553 ----- 

554 Using ``model_fields_set`` allows updating only the fields that are 

555 explicitly set in the `overrides` object. This prevents overwriting 

556 existing column attributes which were not explicitly provided. 

557 """ 

558 if overrides.model_fields_set: 

559 logger.debug("Applying overrides to column '%s': %s", self.id, overrides.model_fields_set) 

560 for field in overrides.model_fields_set: 

561 setattr(self, field, getattr(overrides, field)) 

562 

563 

564class Constraint(BaseObject): 

565 """Table constraint model.""" 

566 

567 deferrable: bool = False 

568 """Whether this constraint will be declared as deferrable.""" 

569 

570 initially: Literal["IMMEDIATE", "DEFERRED"] | None = None 

571 """Value for ``INITIALLY`` clause; only used if `deferrable` is 

572 `True`.""" 

573 

574 @model_validator(mode="after") 

575 def check_deferrable(self) -> Constraint: 

576 """Check that the ``INITIALLY`` clause is only used if `deferrable` is 

577 `True`. 

578 

579 Returns 

580 ------- 

581 `Constraint` 

582 The constraint being validated. 

583 """ 

584 if self.initially is not None and not self.deferrable: 

585 raise ValueError("INITIALLY clause can only be used if deferrable is True") 

586 return self 

587 

588 

589class CheckConstraint(Constraint): 

590 """Table check constraint model.""" 

591 

592 type: Literal["Check"] = Field("Check", alias="@type") 

593 """Type of the constraint.""" 

594 

595 expression: str 

596 """Expression for the check constraint.""" 

597 

598 @field_serializer("type") 

599 def serialize_type(self, value: str) -> str: 

600 """Ensure '@type' is included in serialized output. 

601 

602 Parameters 

603 ---------- 

604 value 

605 The value to serialize. 

606 

607 Returns 

608 ------- 

609 `str` 

610 The serialized value. 

611 """ 

612 return value 

613 

614 

615class UniqueConstraint(Constraint): 

616 """Table unique constraint model.""" 

617 

618 type: Literal["Unique"] = Field("Unique", alias="@type") 

619 """Type of the constraint.""" 

620 

621 columns: list[str] 

622 """Columns in the unique constraint.""" 

623 

624 @field_serializer("type") 

625 def serialize_type(self, value: str) -> str: 

626 """Ensure '@type' is included in serialized output. 

627 

628 Parameters 

629 ---------- 

630 value 

631 The value to serialize. 

632 

633 Returns 

634 ------- 

635 `str` 

636 The serialized value. 

637 """ 

638 return value 

639 

640 

641class ForeignKeyConstraint(Constraint): 

642 """Table foreign key constraint model. 

643 

644 This constraint is used to define a foreign key relationship between two 

645 tables in the schema. There must be at least one column in the 

646 `columns` list, and at least one column in the `referenced_columns` list 

647 or a validation error will be raised. 

648 

649 Notes 

650 ----- 

651 These relationships will be reflected in the TAP_SCHEMA ``keys`` and 

652 ``key_columns`` data. 

653 """ 

654 

655 type: Literal["ForeignKey"] = Field("ForeignKey", alias="@type") 

656 """Type of the constraint.""" 

657 

658 columns: list[str] = Field(min_length=1) 

659 """The columns comprising the foreign key.""" 

660 

661 referenced_columns: list[str] = Field(alias="referencedColumns", min_length=1) 

662 """The columns referenced by the foreign key.""" 

663 

664 on_delete: Literal["CASCADE", "SET NULL", "SET DEFAULT", "RESTRICT", "NO ACTION"] | None = None 

665 """Action to take when the referenced row is deleted.""" 

666 

667 on_update: Literal["CASCADE", "SET NULL", "SET DEFAULT", "RESTRICT", "NO ACTION"] | None = None 

668 """Action to take when the referenced row is updated.""" 

669 

670 @field_serializer("type") 

671 def serialize_type(self, value: str) -> str: 

672 """Ensure '@type' is included in serialized output. 

673 

674 Parameters 

675 ---------- 

676 value 

677 The value to serialize. 

678 

679 Returns 

680 ------- 

681 `str` 

682 The serialized value. 

683 """ 

684 return value 

685 

686 @model_validator(mode="after") 

687 def check_column_lengths(self) -> ForeignKeyConstraint: 

688 """Check that the `columns` and `referenced_columns` lists have the 

689 same length. 

690 

691 Returns 

692 ------- 

693 `ForeignKeyConstraint` 

694 The foreign key constraint being validated. 

695 

696 Raises 

697 ------ 

698 ValueError 

699 Raised if the `columns` and `referenced_columns` lists do not have 

700 the same length. 

701 """ 

702 if len(self.columns) != len(self.referenced_columns): 

703 raise ValueError( 

704 "Columns and referencedColumns must have the same length for a ForeignKey constraint" 

705 ) 

706 return self 

707 

708 

709_ConstraintType = Annotated[ 

710 CheckConstraint | ForeignKeyConstraint | UniqueConstraint, Field(discriminator="type") 

711] 

712"""Type alias for a constraint type.""" 

713 

714 

715class Index(BaseObject): 

716 """Table index model. 

717 

718 An index can be defined on either columns or expressions, but not both. 

719 """ 

720 

721 columns: list[str] | None = None 

722 """Columns in the index.""" 

723 

724 expressions: list[str] | None = None 

725 """Expressions in the index.""" 

726 

727 @model_validator(mode="before") 

728 @classmethod 

729 def check_columns_or_expressions(cls, values: dict[str, Any]) -> dict[str, Any]: 

730 """Check that columns or expressions are specified, but not both. 

731 

732 Parameters 

733 ---------- 

734 values 

735 Values of the index. 

736 

737 Returns 

738 ------- 

739 `dict` [ `str`, `Any` ] 

740 The values of the index. 

741 

742 Raises 

743 ------ 

744 ValueError 

745 Raised if both columns and expressions are specified, or if neither 

746 are specified. 

747 """ 

748 if "columns" in values and "expressions" in values: 

749 raise ValueError("Defining columns and expressions is not valid") 

750 elif "columns" not in values and "expressions" not in values: 

751 raise ValueError("Must define columns or expressions") 

752 return values 

753 

754 

755ColumnRef: TypeAlias = str 

756"""Type alias for a column reference.""" 

757 

758 

759class ColumnGroup(BaseObject): 

760 """Column group model.""" 

761 

762 columns: list[ColumnRef | Column] = Field(..., min_length=1) 

763 """Columns in the group.""" 

764 

765 ivoa_ucd: str | None = Field(None, alias="ivoa:ucd") 

766 """IVOA UCD of the column.""" 

767 

768 table: Table | None = Field(None, exclude=True) 

769 """Reference to the parent table.""" 

770 

771 @field_validator("ivoa_ucd") 

772 @classmethod 

773 def check_ivoa_ucd(cls, ivoa_ucd: str) -> str: 

774 """Check that IVOA UCD values are valid. 

775 

776 Parameters 

777 ---------- 

778 ivoa_ucd 

779 IVOA UCD value to check. 

780 

781 Returns 

782 ------- 

783 `str` 

784 The IVOA UCD value if it is valid. 

785 """ 

786 return validate_ivoa_ucd(ivoa_ucd) 

787 

788 @model_validator(mode="after") 

789 def check_unique_columns(self) -> ColumnGroup: 

790 """Check that the columns list contains unique items. 

791 

792 Returns 

793 ------- 

794 `ColumnGroup` 

795 The column group being validated. 

796 """ 

797 column_ids = [col if isinstance(col, str) else col.id for col in self.columns] 

798 if len(column_ids) != len(set(column_ids)): 

799 raise ValueError("Columns in the group must be unique") 

800 return self 

801 

802 def _dereference_columns(self) -> None: 

803 """Dereference ColumnRef to Column objects.""" 

804 if self.table is None: 

805 raise ValueError("ColumnGroup must have a reference to its parent table") 

806 

807 dereferenced_columns: list[ColumnRef | Column] = [] 

808 for col in self.columns: 

809 if isinstance(col, str): 

810 # Dereference ColumnRef to Column object 

811 try: 

812 col_obj = self.table._find_column_by_id(col) 

813 except KeyError as e: 

814 raise ValueError(f"Column '{col}' not found in table '{self.table.name}'") from e 

815 dereferenced_columns.append(col_obj) 

816 else: 

817 dereferenced_columns.append(col) 

818 

819 self.columns = dereferenced_columns 

820 

821 @field_serializer("columns") 

822 def serialize_columns(self, columns: list[ColumnRef | Column]) -> list[str]: 

823 """Serialize columns as their IDs. 

824 

825 Parameters 

826 ---------- 

827 columns 

828 The columns to serialize. 

829 

830 Returns 

831 ------- 

832 `list` [ `str` ] 

833 The serialized column IDs. 

834 """ 

835 return [col if isinstance(col, str) else col.id for col in columns] 

836 

837 

838class ColumnOverrides(BaseModel): 

839 """Allowed overrides for a referenced column. 

840 

841 Notes 

842 ----- 

843 All of these fields are optional. Values of None may be explicitly set to 

844 override the corresponding attribute in the referenced column but only 

845 for certain fields (see validation in `_check_non_nullable_overrides`). 

846 """ 

847 

848 model_config = CONFIG.copy() 

849 

850 datatype: DataType | None = None 

851 """New datatype for the column.""" 

852 

853 length: int | None = None 

854 """New length for the column.""" 

855 

856 description: str | None = None 

857 """New description for the column.""" 

858 

859 nullable: bool | None = None 

860 """New nullable flag for the column.""" 

861 

862 tap_principal: int | None = Field(default=None, alias="tap:principal") 

863 """Override for the TAP_SCHEMA 'principal' flag.""" 

864 

865 tap_column_index: int | None = Field(default=None, alias="tap:column_index") 

866 """Override for the TAP_SCHEMA column index.""" 

867 

868 @model_validator(mode="before") 

869 @classmethod 

870 def _check_non_nullable_overrides(cls, data: Any) -> Any: 

871 """Check that certain fields are not overridden to null.""" 

872 if not isinstance(data, dict): 

873 return data 

874 non_nullable_fields = ("datatype", "length", "nullable", "tap_principal") 

875 for name in non_nullable_fields: 

876 if name in data and data[name] is None: 

877 raise ValueError(f"The '{name}' field cannot be overridden to null") 

878 return data 

879 

880 @field_serializer("datatype") 

881 def serialize_datatype(self, value: DataType | None) -> str | None: 

882 """Convert `DataType` to string when serializing to JSON/YAML. 

883 

884 Parameters 

885 ---------- 

886 value 

887 The `DataType` value to serialize, or None. 

888 

889 Returns 

890 ------- 

891 `str` | None 

892 The serialized `DataType` value, or None if the input was None. 

893 """ 

894 if value is None: 

895 return None 

896 return str(value) 

897 

898 @field_validator("datatype", mode="before") 

899 @classmethod 

900 def deserialize_datatype(cls, value: str | None) -> DataType | None: 

901 """Convert string back into `DataType` when loading from JSON/YAML. 

902 

903 Parameters 

904 ---------- 

905 value 

906 The string value to deserialize, or None. 

907 

908 Returns 

909 ------- 

910 `DataType` | None 

911 The deserialized `DataType` value, or None if the input was None. 

912 """ 

913 if value is None: 

914 return None 

915 return DataType(value) 

916 

917 

918class ColumnResourceRef(BaseModel): 

919 """A column which is dervived from an external resource.""" 

920 

921 ref_name: str | None = None 

922 """Name of the referenced column in the resource 

923 (if different from the key).""" 

924 

925 overrides: ColumnOverrides | None = None 

926 """Optional overrides of the referenced column's attributes.""" 

927 

928 

929# Type aliases for the nested mapping structure of referenced columns 

930ResourceColumnMap: TypeAlias = dict[str, ColumnResourceRef | None] 

931ResourceTableMap: TypeAlias = dict[str, ResourceColumnMap] 

932ResourceMap: TypeAlias = dict[str, ResourceTableMap] 

933 

934 

935class Table(BaseObject): 

936 """Table model.""" 

937 

938 primary_key: str | list[str] | None = Field(None, alias="primaryKey") 

939 """Primary key of the table.""" 

940 

941 tap_table_index: int | None = Field(None, alias="tap:table_index") 

942 """IVOA TAP_SCHEMA table index of the table.""" 

943 

944 mysql_engine: str | None = Field("MyISAM", alias="mysql:engine") 

945 """MySQL engine to use for the table.""" 

946 

947 mysql_charset: str | None = Field(None, alias="mysql:charset") 

948 """MySQL charset to use for the table.""" 

949 

950 columns: list[Column] = Field(default_factory=list) 

951 """Columns in the table.""" 

952 

953 column_refs: ResourceMap = Field(default_factory=dict, alias="columnRefs") 

954 """Referenced columns from external resources.""" 

955 

956 column_groups: list[ColumnGroup] = Field(default_factory=list, alias="columnGroups") 

957 """Column groups in the table.""" 

958 

959 constraints: list[_ConstraintType] = Field(default_factory=list) 

960 """Constraints on the table.""" 

961 

962 indexes: list[Index] = Field(default_factory=list) 

963 """Indexes on the table.""" 

964 

965 @field_validator("columns", mode="after") 

966 @classmethod 

967 def check_unique_column_names(cls, columns: list[Column]) -> list[Column]: 

968 """Check that column names are unique. 

969 

970 Parameters 

971 ---------- 

972 columns 

973 The columns to check. 

974 

975 Returns 

976 ------- 

977 `list` [ `Column` ] 

978 The columns if they are unique. 

979 

980 Raises 

981 ------ 

982 ValueError 

983 Raised if column names are not unique. 

984 """ 

985 if len(columns) != len(set(column.name for column in columns)): 

986 raise ValueError("Column names must be unique") 

987 return columns 

988 

989 @model_validator(mode="after") 

990 def check_tap_table_index(self, info: ValidationInfo) -> Table: 

991 """Check that the table has a TAP table index. 

992 

993 Parameters 

994 ---------- 

995 info 

996 Validation context used to determine if the check is enabled. 

997 

998 Returns 

999 ------- 

1000 `Table` 

1001 The table being validated. 

1002 

1003 Raises 

1004 ------ 

1005 ValueError 

1006 Raised If the table is missing a TAP table index. 

1007 """ 

1008 context = info.context 

1009 if not context or not context.get("check_tap_table_indexes", False): 

1010 return self 

1011 if self.tap_table_index is None: 

1012 raise ValueError("Table is missing a TAP table index") 

1013 return self 

1014 

1015 @model_validator(mode="after") 

1016 def check_tap_principal(self, info: ValidationInfo) -> Table: 

1017 """Check that at least one column is flagged as 'principal' for TAP 

1018 purposes. 

1019 

1020 Parameters 

1021 ---------- 

1022 info 

1023 Validation context used to determine if the check is enabled. 

1024 

1025 Returns 

1026 ------- 

1027 `Table` 

1028 The table being validated. 

1029 

1030 Raises 

1031 ------ 

1032 ValueError 

1033 Raised if the table is missing a column flagged as 'principal'. 

1034 """ 

1035 context = info.context 

1036 if not context or not context.get("check_tap_principal", False): 

1037 return self 

1038 for col in self.columns: 

1039 if col.tap_principal == 1: 

1040 return self 

1041 raise ValueError(f"Table '{self.name}' is missing at least one column designated as 'tap:principal'") 

1042 

1043 def _find_column_by_id(self, id: str) -> Column: 

1044 """Find a column by ID. 

1045 

1046 Parameters 

1047 ---------- 

1048 id 

1049 The ID of the column to find. 

1050 

1051 Returns 

1052 ------- 

1053 `Column` 

1054 The column with the given ID. 

1055 

1056 Raises 

1057 ------ 

1058 ValueError 

1059 Raised if the column is not found. 

1060 """ 

1061 for column in self.columns: 

1062 if column.id == id: 

1063 return column 

1064 raise KeyError(f"Column '{id}' not found in table '{self.name}'") 

1065 

1066 def _find_column_by_name(self, name: str) -> Column: 

1067 for column in self.columns: 

1068 if column.name == name: 

1069 return column 

1070 raise KeyError(f"Column '{name}' not found in table '{self.name}'") 

1071 

1072 @model_validator(mode="after") 

1073 def dereference_column_groups(self: Table) -> Table: 

1074 """Dereference columns in column groups. 

1075 

1076 Returns 

1077 ------- 

1078 `Table` 

1079 The table with dereferenced column groups. 

1080 """ 

1081 for group in self.column_groups: 

1082 group.table = self 

1083 group._dereference_columns() 

1084 return self 

1085 

1086 @field_serializer("columns") 

1087 def _serialize_columns(self, columns: list[Column]) -> list[dict[str, Any]]: 

1088 """Serialize only non-resource columns.""" 

1089 return [ 

1090 col.model_dump( 

1091 by_alias=True, 

1092 exclude_none=True, 

1093 exclude_defaults=True, 

1094 ) 

1095 for col in columns 

1096 if not col._is_resource_ref 

1097 ] 

1098 

1099 

1100class SchemaVersion(BaseModel): 

1101 """Schema version model.""" 

1102 

1103 current: str 

1104 """The current version of the schema.""" 

1105 

1106 compatible: list[str] = Field(default_factory=list) 

1107 """The compatible versions of the schema.""" 

1108 

1109 read_compatible: list[str] = Field(default_factory=list) 

1110 """The read compatible versions of the schema.""" 

1111 

1112 

1113class SchemaIdVisitor: 

1114 """Visit a schema and build the map of IDs to objects. 

1115 

1116 Notes 

1117 ----- 

1118 Duplicates are added to a set when they are encountered, which can be 

1119 accessed via the ``duplicates`` attribute. The presence of duplicates will 

1120 not throw an error. Only the first object with a given ID will be added to 

1121 the map, but this should not matter, since a ``ValidationError`` will be 

1122 thrown by the ``model_validator`` method if any duplicates are found in the 

1123 schema. 

1124 """ 

1125 

1126 def __init__(self) -> None: 

1127 """Create a new SchemaVisitor.""" 

1128 self.schema: Schema | None = None 

1129 self.duplicates: set[str] = set() 

1130 

1131 def add(self, obj: BaseObject) -> None: 

1132 """Add an object to the ID map. 

1133 

1134 Parameters 

1135 ---------- 

1136 obj 

1137 The object to add to the ID map. 

1138 """ 

1139 if hasattr(obj, "id"): 

1140 obj_id = getattr(obj, "id") 

1141 if self.schema is not None: 

1142 if obj_id in self.schema._id_map: 

1143 self.duplicates.add(obj_id) 

1144 else: 

1145 self.schema._id_map[obj_id] = obj 

1146 

1147 def visit_schema(self, schema: Schema) -> None: 

1148 """Visit the objects in a schema and build the ID map. 

1149 

1150 Parameters 

1151 ---------- 

1152 schema 

1153 The schema object to visit. 

1154 

1155 Notes 

1156 ----- 

1157 This will set an internal variable pointing to the schema object. 

1158 """ 

1159 self.schema = schema 

1160 self.duplicates.clear() 

1161 self.add(self.schema) 

1162 for table in self.schema.tables: 

1163 self.visit_table(table) 

1164 

1165 def visit_table(self, table: Table) -> None: 

1166 """Visit a table object. 

1167 

1168 Parameters 

1169 ---------- 

1170 table 

1171 The table object to visit. 

1172 """ 

1173 self.add(table) 

1174 for column in table.columns: 

1175 self.visit_column(column) 

1176 for constraint in table.constraints: 

1177 self.visit_constraint(constraint) 

1178 

1179 def visit_column(self, column: Column) -> None: 

1180 """Visit a column object. 

1181 

1182 Parameters 

1183 ---------- 

1184 column 

1185 The column object to visit. 

1186 """ 

1187 self.add(column) 

1188 

1189 def visit_constraint(self, constraint: Constraint) -> None: 

1190 """Visit a constraint object. 

1191 

1192 Parameters 

1193 ---------- 

1194 constraint 

1195 The constraint object to visit. 

1196 """ 

1197 self.add(constraint) 

1198 

1199 

1200T = TypeVar("T", bound=BaseObject) 

1201 

1202 

1203def _strip_ids(data: Any) -> Any: 

1204 """Recursively strip '@id' fields from a dictionary or list. 

1205 

1206 Parameters 

1207 ---------- 

1208 data 

1209 The data to strip IDs from, which can be a dictionary, list, or any 

1210 other type. Other types will be returned unchanged. 

1211 """ 

1212 if isinstance(data, dict): 

1213 data.pop("@id", None) 

1214 for k, v in data.items(): 

1215 data[k] = _strip_ids(v) 

1216 return data 

1217 elif isinstance(data, list): 

1218 return [_strip_ids(item) for item in data] 

1219 else: 

1220 return data 

1221 

1222 

1223def _append_error( 

1224 errors: list[InitErrorDetails], 

1225 loc: tuple, 

1226 input_value: Any, 

1227 error_message: str, 

1228 error_type: str = "value_error", 

1229) -> None: 

1230 """Append an error to the errors list. 

1231 

1232 Parameters 

1233 ---------- 

1234 errors : list[InitErrorDetails] 

1235 The list of errors to append to. 

1236 loc : tuple 

1237 The location of the error in the schema. 

1238 input_value : Any 

1239 The input value that caused the error. 

1240 error_message : str 

1241 The error message to include in the context. 

1242 """ 

1243 errors.append( 

1244 { 

1245 "type": error_type, 

1246 "loc": loc, 

1247 "input": input_value, 

1248 "ctx": {"error": error_message}, 

1249 } 

1250 ) 

1251 

1252 

1253class Resource(BaseModel): 

1254 """A resource definition referencing an external schema.""" 

1255 

1256 uri: str = Field(..., description="Resource URI or path") 

1257 """URI of the schema resource which may be a local path, ``resource://``, 

1258 or remote URL.""" 

1259 

1260 

1261class Schema(BaseObject, Generic[T]): 

1262 """Database schema model. 

1263 

1264 This represents a database schema, which contains one or more tables. 

1265 """ 

1266 

1267 version: SchemaVersion | str | None = None 

1268 """The version of the schema.""" 

1269 

1270 tables: Sequence[Table] 

1271 """The tables in the schema.""" 

1272 

1273 resources: dict[str, Resource] = Field(default_factory=dict) 

1274 """External resources referenced by this schema.""" 

1275 

1276 _id_map: dict[str, Any] = PrivateAttr(default_factory=dict) 

1277 """Map of IDs to objects.""" 

1278 

1279 _resource_map: dict[str, Schema] = PrivateAttr(default_factory=dict) 

1280 """Map of resource names to loaded schemas.""" 

1281 

1282 @model_validator(mode="after") 

1283 def _load_resources(self: Schema, info: ValidationInfo) -> Schema: 

1284 """Load external resources referenced by this schema into an internal 

1285 mapping of resource names to their `Schema` objects. 

1286 

1287 Returns 

1288 ------- 

1289 `Schema` 

1290 The schema being validated. 

1291 

1292 Raises 

1293 ------ 

1294 ValueError 

1295 Raised if a resource cannot be loaded. 

1296 """ 

1297 if info.context: 

1298 context = info.context.copy() 

1299 # Ignore this flag for loading the resources themselves 

1300 context.pop("dereference_resources", None) 

1301 else: 

1302 context = {} 

1303 

1304 for resource_name, resource in self.resources.items(): 

1305 uri = resource.uri 

1306 try: 

1307 loaded_schema = Schema.from_uri(uri, context=context) 

1308 self._resource_map[resource_name] = loaded_schema 

1309 logger.debug(f"Loaded resource '{resource_name}' from URI '{uri}'") 

1310 except Exception as e: 

1311 raise ValueError(f"Failed to load resource '{resource_name}' from URI '{uri}': {e}") from e 

1312 return self 

1313 

1314 def _find_table_by_name(self, name: str) -> Table: 

1315 """Find a table by name. 

1316 

1317 Parameters 

1318 ---------- 

1319 name 

1320 The name of the table to find. 

1321 

1322 Returns 

1323 ------- 

1324 `Table` 

1325 The table with the given name. 

1326 

1327 Raises 

1328 ------ 

1329 KeyError 

1330 Raised if the table is not found. 

1331 """ 

1332 for table in self.tables: 

1333 if table.name == name: 

1334 return table 

1335 raise KeyError(f"Table '{name}' not found in schema '{self.name}'") 

1336 

1337 @model_validator(mode="after") 

1338 def _dereference_resource_columns(self: Schema, info: ValidationInfo) -> Schema: 

1339 """Dereference columns from external resources and add them to the 

1340 tables in this schema. 

1341 """ 

1342 context = info.context 

1343 column_ref_index_increment: int | None = None 

1344 dereference_resources = False 

1345 if context is not None: 

1346 dereference_resources = context.get("dereference_resources", False) 

1347 column_ref_index_increment = context.get("column_ref_index_increment", None) 

1348 

1349 for table in self.tables: 

1350 if column_refs := table.column_refs: 

1351 for resource_name, tables in column_refs.items(): 

1352 resource_schema = self._resource_map.get(resource_name) 

1353 if resource_schema is None: 

1354 raise ValueError(f"Schema resource '{resource_name}' was not found in resources") 

1355 self._process_column_refs( 

1356 table, 

1357 tables, 

1358 resource_schema, 

1359 dereference_resources, 

1360 column_ref_index_increment, 

1361 ) 

1362 if dereference_resources and len(table.column_refs) > 0: 

1363 # Clear column refs in table if fully dereferencing 

1364 logger.debug( 

1365 f"Clearing columnRefs in table '{table.name}' after dereferencing resource columns" 

1366 ) 

1367 table.column_refs = {} 

1368 return self 

1369 

1370 @classmethod 

1371 def _process_column_refs( 

1372 cls, 

1373 table: Table, 

1374 ref_tables: ResourceTableMap, 

1375 resource_schema: Schema, 

1376 dereference_resources: bool = False, 

1377 column_ref_index_increment: int | None = None, 

1378 ) -> None: 

1379 """Process column references from an external resource and add them 

1380 to the given table as columns. 

1381 """ 

1382 current_column_index = column_ref_index_increment if column_ref_index_increment is not None else -1 

1383 

1384 for table_name, columns in ref_tables.items(): 

1385 try: 

1386 resource_table = resource_schema._find_table_by_name(table_name) 

1387 except KeyError as e: 

1388 raise ValueError( 

1389 f"Table '{table_name}' not found in resource '{resource_schema.name}'" 

1390 ) from e 

1391 for local_column_name, column_ref in columns.items(): 

1392 if column_ref is not None and column_ref.ref_name is not None: 

1393 # Use specified ref_name 

1394 ref_column_name = column_ref.ref_name 

1395 else: 

1396 # Use the local column name if no ref_name 

1397 # specified 

1398 ref_column_name = local_column_name 

1399 

1400 # Check if referenced column exists in resource 

1401 try: 

1402 base_column = resource_table._find_column_by_name(ref_column_name) 

1403 except KeyError: 

1404 # The ref_name is specified but column is not 

1405 # found 

1406 if column_ref is not None and column_ref.ref_name is not None: 

1407 raise ValueError( 

1408 f"Column '{ref_column_name}' not found in table '{table_name}' " 

1409 f"from resource '{resource_schema.name}'" 

1410 ) 

1411 # The ref_name is not specified and the local 

1412 # column name is not found 

1413 raise ValueError( 

1414 f"Column '{local_column_name}' not found in table '{table_name}' " 

1415 f"from resource '{resource_schema.name}' and no ref_name provided" 

1416 ) 

1417 

1418 # Create a copy of the base column 

1419 column_copy = base_column.model_copy() 

1420 

1421 # Set the local name (key from the mapping) 

1422 column_copy.name = local_column_name 

1423 

1424 if not dereference_resources: 

1425 # Flag the column as a resource reference so it will not be 

1426 # written out during serialization 

1427 column_copy._is_resource_ref = True 

1428 

1429 # Apply overrides to the referenced column definition 

1430 overrides = column_ref.overrides if column_ref is not None else None 

1431 if overrides is not None: 

1432 column_copy._update_from_overrides(overrides) 

1433 

1434 # Manually set the ID of the copied column as ID generation has 

1435 # already occurred by now 

1436 column_copy.id = f"{table.id}.{local_column_name}" 

1437 

1438 # Apply automatic assignment of 'tap:column_index', if enabled 

1439 if column_ref_index_increment is not None: 

1440 if (not overrides) or (overrides.tap_column_index is None): 

1441 column_copy.tap_column_index = current_column_index 

1442 current_column_index += column_ref_index_increment 

1443 logger.debug( 

1444 f"Automatically assigned 'tap:column_index' {column_copy.tap_column_index} to " 

1445 f"column '{local_column_name}' in table '{table_name}' from resource " 

1446 f"'{resource_schema.name}'" 

1447 ) 

1448 else: 

1449 # Skip automatic assignment of 'tap:column_index' if it 

1450 # is already overridden 

1451 logger.debug( 

1452 f"Skipping automatic assignment of 'tap:column_index' for column " 

1453 f"'{local_column_name}' in table '{table_name}' from resource " 

1454 f"'{resource_schema.name}' as it is already overridden to " 

1455 f"{column_copy.tap_column_index}" 

1456 ) 

1457 table.columns.append(column_copy) 

1458 logger.debug( 

1459 f"Dereferenced column '{local_column_name}' from table '{table_name}' " 

1460 f"in resource '{resource_schema.name}' into table '{table.name}'" 

1461 ) 

1462 

1463 @model_validator(mode="before") 

1464 @classmethod 

1465 def generate_ids(cls, values: dict[str, Any], info: ValidationInfo) -> dict[str, Any]: 

1466 """Generate IDs for objects that do not have them. 

1467 

1468 Parameters 

1469 ---------- 

1470 values 

1471 The values of the schema. 

1472 info 

1473 Validation context used to determine if ID generation is enabled. 

1474 

1475 Returns 

1476 ------- 

1477 `dict` [ `str`, `Any` ] 

1478 The values of the schema with generated IDs. 

1479 """ 

1480 context = info.context 

1481 if not context or not context.get("id_generation", False): 

1482 logger.debug("Skipping ID generation") 

1483 return values 

1484 schema_name = values["name"] 

1485 if "@id" not in values: 

1486 values["@id"] = f"#{schema_name}" 

1487 logger.debug(f"Generated ID '{values['@id']}' for schema '{schema_name}'") 

1488 if "tables" in values: 

1489 for table in values["tables"]: 

1490 if "@id" not in table: 

1491 table["@id"] = f"#{table['name']}" 

1492 logger.debug(f"Generated ID '{table['@id']}' for table '{table['name']}'") 

1493 if "columns" in table: 

1494 for column in table["columns"]: 

1495 if "@id" not in column: 

1496 column["@id"] = f"#{table['name']}.{column['name']}" 

1497 logger.debug(f"Generated ID '{column['@id']}' for column '{column['name']}'") 

1498 if "columnGroups" in table: 

1499 for column_group in table["columnGroups"]: 

1500 if "@id" not in column_group: 

1501 column_group["@id"] = f"#{table['name']}.{column_group['name']}" 

1502 logger.debug( 

1503 f"Generated ID '{column_group['@id']}' for column group " 

1504 f"'{column_group['name']}'" 

1505 ) 

1506 if "constraints" in table: 

1507 for constraint in table["constraints"]: 

1508 if "@id" not in constraint: 

1509 constraint["@id"] = f"#{constraint['name']}" 

1510 logger.debug( 

1511 f"Generated ID '{constraint['@id']}' for constraint '{constraint['name']}'" 

1512 ) 

1513 if "indexes" in table: 

1514 for index in table["indexes"]: 

1515 if "@id" not in index: 

1516 index["@id"] = f"#{index['name']}" 

1517 logger.debug(f"Generated ID '{index['@id']}' for index '{index['name']}'") 

1518 return values 

1519 

1520 @field_validator("tables", mode="after") 

1521 @classmethod 

1522 def check_unique_table_names(cls, tables: list[Table]) -> list[Table]: 

1523 """Check that table names are unique. 

1524 

1525 Parameters 

1526 ---------- 

1527 tables 

1528 The tables to check. 

1529 

1530 Returns 

1531 ------- 

1532 `list` [ `Table` ] 

1533 The tables if they are unique. 

1534 

1535 Raises 

1536 ------ 

1537 ValueError 

1538 Raised if table names are not unique. 

1539 """ 

1540 if len(tables) != len(set(table.name for table in tables)): 

1541 raise ValueError("Table names must be unique") 

1542 return tables 

1543 

1544 @model_validator(mode="after") 

1545 def check_tap_table_indexes(self, info: ValidationInfo) -> Schema: 

1546 """Check that the TAP table indexes are unique. 

1547 

1548 Parameters 

1549 ---------- 

1550 info 

1551 The validation context used to determine if the check is enabled. 

1552 

1553 Returns 

1554 ------- 

1555 `Schema` 

1556 The schema being validated. 

1557 """ 

1558 context = info.context 

1559 if not context or not context.get("check_tap_table_indexes", False): 

1560 return self 

1561 table_indicies = set() 

1562 for table in self.tables: 

1563 table_index = table.tap_table_index 

1564 if table_index is not None: 

1565 if table_index in table_indicies: 

1566 raise ValueError(f"Duplicate 'tap:table_index' value {table_index} found in schema") 

1567 table_indicies.add(table_index) 

1568 return self 

1569 

1570 @model_validator(mode="after") 

1571 def check_unique_constraint_names(self: Schema) -> Schema: 

1572 """Check for duplicate constraint names in the schema. 

1573 

1574 Returns 

1575 ------- 

1576 `Schema` 

1577 The schema being validated. 

1578 

1579 Raises 

1580 ------ 

1581 ValueError 

1582 Raised if duplicate constraint names are found in the schema. 

1583 """ 

1584 constraint_names = set() 

1585 duplicate_names = [] 

1586 

1587 for table in self.tables: 

1588 for constraint in table.constraints: 

1589 constraint_name = constraint.name 

1590 if constraint_name in constraint_names: 

1591 duplicate_names.append(constraint_name) 

1592 else: 

1593 constraint_names.add(constraint_name) 

1594 

1595 if duplicate_names: 

1596 raise ValueError(f"Duplicate constraint names found in schema: {duplicate_names}") 

1597 

1598 return self 

1599 

1600 @model_validator(mode="after") 

1601 def check_unique_index_names(self: Schema) -> Schema: 

1602 """Check for duplicate index names in the schema. 

1603 

1604 Returns 

1605 ------- 

1606 `Schema` 

1607 The schema being validated. 

1608 

1609 Raises 

1610 ------ 

1611 ValueError 

1612 Raised if duplicate index names are found in the schema. 

1613 """ 

1614 index_names = set() 

1615 duplicate_names = [] 

1616 

1617 for table in self.tables: 

1618 for index in table.indexes: 

1619 index_name = index.name 

1620 if index_name in index_names: 

1621 duplicate_names.append(index_name) 

1622 else: 

1623 index_names.add(index_name) 

1624 

1625 if duplicate_names: 

1626 raise ValueError(f"Duplicate index names found in schema: {duplicate_names}") 

1627 

1628 return self 

1629 

1630 @model_validator(mode="after") 

1631 def create_id_map(self: Schema) -> Schema: 

1632 """Create a map of IDs to objects. 

1633 

1634 Returns 

1635 ------- 

1636 `Schema` 

1637 The schema with the ID map created. 

1638 

1639 Raises 

1640 ------ 

1641 ValueError 

1642 Raised if duplicate identifiers are found in the schema. 

1643 """ 

1644 if self._id_map: 

1645 logger.debug("Ignoring call to create_id_map() - ID map was already populated") 

1646 return self 

1647 visitor: SchemaIdVisitor = SchemaIdVisitor() 

1648 visitor.visit_schema(self) 

1649 if len(visitor.duplicates): 

1650 raise ValueError( 

1651 "Duplicate IDs found in schema:\n " + "\n ".join(visitor.duplicates) + "\n" 

1652 ) 

1653 logger.debug("Created ID map with %d entries", len(self._id_map)) 

1654 return self 

1655 

1656 def _validate_column_id( 

1657 self: Schema, 

1658 column_id: str, 

1659 loc: tuple, 

1660 errors: list[InitErrorDetails], 

1661 ) -> None: 

1662 """Validate a column ID from a constraint and append errors if invalid. 

1663 

1664 Parameters 

1665 ---------- 

1666 schema : Schema 

1667 The schema being validated. 

1668 column_id : str 

1669 The column ID to validate. 

1670 loc : tuple 

1671 The location of the error in the schema. 

1672 errors : list[InitErrorDetails] 

1673 The list of errors to append to. 

1674 """ 

1675 if column_id not in self: 

1676 _append_error( 

1677 errors, 

1678 loc, 

1679 column_id, 

1680 f"Column ID '{column_id}' not found in schema", 

1681 ) 

1682 elif not isinstance(self[column_id], Column): 

1683 _append_error( 

1684 errors, 

1685 loc, 

1686 column_id, 

1687 f"ID '{column_id}' does not refer to a Column object", 

1688 ) 

1689 

1690 def _validate_foreign_key_column( 

1691 self: Schema, 

1692 column_id: str, 

1693 table: Table, 

1694 loc: tuple, 

1695 errors: list[InitErrorDetails], 

1696 ) -> None: 

1697 """Validate a foreign key column ID from a constraint and append errors 

1698 if invalid. 

1699 

1700 Parameters 

1701 ---------- 

1702 schema : Schema 

1703 The schema being validated. 

1704 column_id : str 

1705 The foreign key column ID to validate. 

1706 loc : tuple 

1707 The location of the error in the schema. 

1708 errors : list[InitErrorDetails] 

1709 The list of errors to append to. 

1710 """ 

1711 try: 

1712 table._find_column_by_id(column_id) 

1713 except KeyError: 

1714 _append_error( 

1715 errors, 

1716 loc, 

1717 column_id, 

1718 f"Column '{column_id}' not found in table '{table.name}'", 

1719 ) 

1720 

1721 @model_validator(mode="after") 

1722 def check_constraints(self: Schema) -> Schema: 

1723 """Check constraint objects for validity. This needs to be deferred 

1724 until after the schema is fully loaded and the ID map is created. 

1725 

1726 Raises 

1727 ------ 

1728 pydantic.ValidationError 

1729 Raised if any constraints are invalid. 

1730 

1731 Returns 

1732 ------- 

1733 `Schema` 

1734 The schema being validated. 

1735 """ 

1736 errors: list[InitErrorDetails] = [] 

1737 

1738 for table_index, table in enumerate(self.tables): 

1739 for constraint_index, constraint in enumerate(table.constraints): 

1740 column_ids: list[str] = [] 

1741 referenced_column_ids: list[str] = [] 

1742 

1743 if isinstance(constraint, ForeignKeyConstraint): 

1744 column_ids += constraint.columns 

1745 referenced_column_ids += constraint.referenced_columns 

1746 elif isinstance(constraint, UniqueConstraint): 

1747 column_ids += constraint.columns 

1748 # No extra checks are required on CheckConstraint objects. 

1749 

1750 # Validate the foreign key columns 

1751 for column_id in column_ids: 

1752 self._validate_column_id( 

1753 column_id, 

1754 ( 

1755 "tables", 

1756 table_index, 

1757 "constraints", 

1758 constraint_index, 

1759 "columns", 

1760 column_id, 

1761 ), 

1762 errors, 

1763 ) 

1764 # Check that the foreign key column is within the source 

1765 # table. 

1766 self._validate_foreign_key_column( 

1767 column_id, 

1768 table, 

1769 ( 

1770 "tables", 

1771 table_index, 

1772 "constraints", 

1773 constraint_index, 

1774 "columns", 

1775 column_id, 

1776 ), 

1777 errors, 

1778 ) 

1779 

1780 # Validate the primary key (reference) columns 

1781 for referenced_column_id in referenced_column_ids: 

1782 self._validate_column_id( 

1783 referenced_column_id, 

1784 ( 

1785 "tables", 

1786 table_index, 

1787 "constraints", 

1788 constraint_index, 

1789 "referenced_columns", 

1790 referenced_column_id, 

1791 ), 

1792 errors, 

1793 ) 

1794 

1795 if errors: 

1796 raise ValidationError.from_exception_data("Schema validation failed", errors) 

1797 

1798 return self 

1799 

1800 def __getitem__(self, id: str) -> BaseObject: 

1801 """Get an object by its ID. 

1802 

1803 Parameters 

1804 ---------- 

1805 id 

1806 The ID of the object to get. 

1807 

1808 Raises 

1809 ------ 

1810 KeyError 

1811 Raised if the object with the given ID is not found in the schema. 

1812 """ 

1813 if id not in self: 

1814 raise KeyError(f"Object with ID '{id}' not found in schema") 

1815 return self._id_map[id] 

1816 

1817 def __contains__(self, id: str) -> bool: 

1818 """Check if an object with the given ID is in the schema. 

1819 

1820 Parameters 

1821 ---------- 

1822 id 

1823 The ID of the object to check. 

1824 """ 

1825 return id in self._id_map 

1826 

1827 def find_object_by_id(self, id: str, obj_type: type[T]) -> T: 

1828 """Find an object with the given type by its ID. 

1829 

1830 Parameters 

1831 ---------- 

1832 id 

1833 The ID of the object to find. 

1834 obj_type 

1835 The type of the object to find. 

1836 

1837 Returns 

1838 ------- 

1839 BaseObject 

1840 The object with the given ID and type. 

1841 

1842 Raises 

1843 ------ 

1844 KeyError 

1845 If the object with the given ID is not found in the schema. 

1846 TypeError 

1847 If the object that is found does not have the right type. 

1848 

1849 Notes 

1850 ----- 

1851 The actual return type is the user-specified argument ``T``, which is 

1852 expected to be a subclass of `BaseObject`. 

1853 """ 

1854 obj = self[id] 

1855 if not isinstance(obj, obj_type): 

1856 raise TypeError(f"Object with ID '{id}' is not of type '{obj_type.__name__}'") 

1857 return obj 

1858 

1859 def get_table_by_column(self, column: Column) -> Table: 

1860 """Find the table that contains a column. 

1861 

1862 Parameters 

1863 ---------- 

1864 column 

1865 The column to find. 

1866 

1867 Returns 

1868 ------- 

1869 `Table` 

1870 The table that contains the column. 

1871 

1872 Raises 

1873 ------ 

1874 ValueError 

1875 If the column is not found in any table. 

1876 """ 

1877 for table in self.tables: 

1878 if column in table.columns: 

1879 return table 

1880 raise ValueError(f"Column '{column.name}' not found in any table") 

1881 

1882 @classmethod 

1883 def from_uri(cls, resource_path: ResourcePathExpression, context: dict[str, Any] = {}) -> Schema: 

1884 """Load a `Schema` from a string representing a ``ResourcePath``. 

1885 

1886 Parameters 

1887 ---------- 

1888 resource_path 

1889 The ``ResourcePath`` pointing to a YAML file. 

1890 context 

1891 Pydantic context to be used in validation. 

1892 

1893 Returns 

1894 ------- 

1895 `str` 

1896 The ID of the object. 

1897 

1898 Raises 

1899 ------ 

1900 yaml.YAMLError 

1901 Raised if there is an error loading the YAML data. 

1902 ValueError 

1903 Raised if there is an error reading the resource. 

1904 pydantic.ValidationError 

1905 Raised if the schema fails validation. 

1906 """ 

1907 try: 

1908 rp_stream = ResourcePath(resource_path).read() 

1909 except Exception as e: 

1910 raise ValueError(f"Error reading resource from '{resource_path}' : {e}") from e 

1911 yaml_data = yaml.safe_load(rp_stream) 

1912 return Schema.model_validate(yaml_data, context=context) 

1913 

1914 @classmethod 

1915 def from_stream(cls, source: IO[str], context: dict[str, Any] = {}) -> Schema: 

1916 """Load a `Schema` from a file stream which should contain YAML data. 

1917 

1918 Parameters 

1919 ---------- 

1920 source 

1921 The file stream to read from. 

1922 context 

1923 Pydantic context to be used in validation. 

1924 

1925 Returns 

1926 ------- 

1927 `Schema` 

1928 The Felis schema loaded from the stream. 

1929 

1930 Raises 

1931 ------ 

1932 yaml.YAMLError 

1933 Raised if there is an error loading the YAML file. 

1934 pydantic.ValidationError 

1935 Raised if the schema fails validation. 

1936 """ 

1937 logger.debug("Loading schema from: '%s'", source) 

1938 yaml_data = yaml.safe_load(source) 

1939 return Schema.model_validate(yaml_data, context=context) 

1940 

1941 def _model_dump(self, strip_ids: bool = False) -> dict[str, Any]: 

1942 """Dump the schema as a dictionary with some default arguments 

1943 applied. 

1944 

1945 Parameters 

1946 ---------- 

1947 strip_ids 

1948 Whether to strip the IDs from the dumped data. Defaults to `False`. 

1949 

1950 Returns 

1951 ------- 

1952 `dict` [ `str`, `Any` ] 

1953 The dumped schema data as a dictionary. 

1954 """ 

1955 data = self.model_dump(by_alias=True, exclude_none=True, exclude_defaults=True) 

1956 if strip_ids: 

1957 data = _strip_ids(data) 

1958 return data 

1959 

1960 def dump_yaml(self, stream: IO[str] = sys.stdout, strip_ids: bool = False) -> None: 

1961 """Pretty print the schema as YAML. 

1962 

1963 Parameters 

1964 ---------- 

1965 stream 

1966 The stream to write the YAML data to. 

1967 strip_ids 

1968 Whether to strip the IDs from the dumped data. Defaults to `False`. 

1969 """ 

1970 data = self._model_dump(strip_ids=strip_ids) 

1971 yaml.safe_dump( 

1972 data, 

1973 stream, 

1974 default_flow_style=False, 

1975 sort_keys=False, 

1976 ) 

1977 

1978 def dump_json(self, stream: IO[str] = sys.stdout, strip_ids: bool = False) -> None: 

1979 """Pretty print the schema as JSON. 

1980 

1981 Parameters 

1982 ---------- 

1983 stream 

1984 The stream to write the JSON data to. 

1985 strip_ids 

1986 Whether to strip the IDs from the dumped data. Defaults to `False`. 

1987 """ 

1988 data = self._model_dump(strip_ids=strip_ids) 

1989 json.dump( 

1990 data, 

1991 stream, 

1992 indent=4, 

1993 sort_keys=False, 

1994 )