Source code for otoole.preprocess.validate_config
"""Validation methods for the user configuration file."""
import logging
from typing import List, Optional, Union
from pydantic import BaseModel, ConfigDict, field_validator, model_validator
# from pydantic import FieldValidationInfo
logger = logging.getLogger(__name__)
[docs]class UserDefinedValue(BaseModel):
"""Represents any user defined value"""
model_config = ConfigDict(extra="forbid")
name: str
type: str
dtype: str
defined_sets: Optional[List[str]] = None
indices: Optional[List[str]] = None
default: Optional[Union[int, float]] = None
calculated: Optional[bool] = None
short_name: Optional[str] = None
[docs] @field_validator("type")
@classmethod
def check_param_type(cls, value, info):
if value not in ["param", "result", "set"]:
raise ValueError(
f"{info.field_name} -> Type must be 'param', 'result', or 'set'"
)
return value
[docs] @field_validator("name", "short_name")
@classmethod # for linting purposes
def check_name_for_spaces(cls, value):
if " " in value:
raise ValueError(f"{value} -> Name can not have spaces")
return value
[docs] @field_validator("name", "short_name")
@classmethod
def check_name_for_numbers(cls, value):
if any(char.isdigit() for char in value):
raise ValueError(f"{value} -> Name can not have digits")
return value
[docs] @field_validator("name", "short_name")
@classmethod
def check_name_for_special_chars(cls, value):
# removed underscore from the recommeded special char list
special_characters = " !\"#$%&'()*+,-./:;<=>?@[]^`{|}~"
if any(c in special_characters for c in value):
raise ValueError(
f"{value} -> Name can not have special characters, except for underscores"
)
return value
[docs] @model_validator(mode="before")
@classmethod
def check_name_length(cls, values):
if len(values["name"]) > 31:
if "short_name" not in values:
raise ValueError(
f"{values['name']} -> Name is longer than 31 characters and no 'short_name' field provided"
)
if "short_name" in values:
if len(values["short_name"]) > 31:
raise ValueError(
f"{values['short_name']} -> Name is longer than 31 characters"
)
return values
[docs]class UserDefinedSet(UserDefinedValue):
"""Represents a set"""
[docs] @field_validator("dtype")
@classmethod
def check_dtype(cls, value, info):
if value not in ["str", "int"]:
raise ValueError(f"{info.field_name} -> Value must be a 'str' or 'int'")
return value
[docs]class UserDefinedParameter(UserDefinedValue):
"""Represents a parameter"""
[docs] @field_validator("dtype")
@classmethod
def check_dtype(cls, value, info):
if value not in ["float", "int"]:
raise ValueError(f"{info.field_name} -> Value must be an 'int' or 'float'")
return value
[docs] @model_validator(mode="after")
def check_index_in_set(self):
if not all(i in self.defined_sets for i in self.indices):
raise ValueError(f"{self.name} -> Index not in user supplied sets")
return self
[docs] @model_validator(mode="after")
def check_dtype_default(self):
dtype_input = self.dtype
dtype_default = type(self.default).__name__
if dtype_input != dtype_default:
# allow ints to be cast as floats
if not ((dtype_default == "int") and (dtype_input == "float")):
raise ValueError(
f"{self.name} -> User dtype is {dtype_input} while default value dtype is {dtype_default}"
)
return self
[docs]class UserDefinedResult(UserDefinedValue):
"""Represents a result"""
[docs] @field_validator("dtype")
@classmethod
def check_dtype(cls, value, info):
if value not in ["float", "int"]:
raise ValueError(f"{info.field_name} -> Value must be an 'int' or 'float'")
return value
[docs] @model_validator(mode="before")
@classmethod
def check_deprecated_values(cls, values):
deprecated = ["calculated", "Calculated"]
for v in values:
if v in deprecated:
logger.info(
f"{values['name']} -> Config file field of '{v}' is deprecated. Remove '{v}' to suppress this warning."
)
return values
[docs] @model_validator(mode="after")
def check_index_in_set(self):
if not all(i in self.defined_sets for i in self.indices):
raise ValueError(f"{self.name} -> Index not in user supplied sets")
return self
[docs] @model_validator(mode="after")
def check_dtype_default(self):
dtype_input = self.dtype
dtype_default = type(self.default).__name__
if dtype_input != dtype_default:
# allow ints to be cast as floats
if not ((dtype_default == "int") and (dtype_input == "float")):
raise ValueError(
f"{self.name} -> User dtype is {dtype_input} while default value dtype is {dtype_default}"
)
return self