import dataclasses
from dataclasses import dataclass
from typing import List, Optional
import torch
import yaml
from typing_extensions import Self
# pylint: disable=too-many-instance-attributes
@dataclass
[docs]class ModelConfig:
"""Holds information about the specific training
configuration of a model. The default values are
sampled by ray tune for parameter optimization."""
[docs] output_dim: Optional[int] = None
[docs] normalization_means: Optional[List[float]] = None
[docs] normalization_stddevs: Optional[List[float]] = None
# Tunable parameters
[docs] neurons_per_layer: Optional[int] = None
[docs] num_hidden_layers: Optional[int] = None
[docs] goal_per_class: Optional[int] = None
[docs] num_folds: Optional[int] = None
[docs] num_epochs: Optional[int] = None
[docs] batch_size: Optional[int] = None
[docs] learning_rate: Optional[float] = None
[docs] best_val_loss: Optional[float] = None
[docs] device = torch.device("cpu")
[docs] def set_non_tunable_params(self, input_dim, output_dim, norm_means, norm_stddevs):
"""Adds information about the params that are not tunable."""
self.input_dim = input_dim
self.output_dim = output_dim
self.normalization_means = norm_means
self.normalization_stddevs = norm_stddevs
[docs] def set_best_val_loss(self, best_val_loss):
"""Sets the best validation loss from training."""
self.best_val_loss = best_val_loss
[docs] def write_to_file(self, file: str):
"""Save configuration data to a YAML file."""
args = dataclasses.asdict(self)
encoded_string = yaml.dump(args, sort_keys=False)
with open(file, "w", encoding="utf-8") as file_handle:
file_handle.write(encoded_string)
@classmethod
[docs] def from_file(cls, file: str) -> Self:
"""Load configuration data from a YAML file."""
with open(file, "r", encoding="utf-8") as file_handle:
metadata = yaml.safe_load(file_handle)
return cls(**metadata)