import dataclasses
import os
from dataclasses import dataclass, field
from typing import List, Optional
import torch
import yaml
from typing_extensions import Self
# pylint: disable=too-many-instance-attributes
@dataclass
[docs]
class SuperphotConfig:
"""Holds information about the specific training
configuration of a model. The default values are
sampled by ray tune for parameter optimization."""
[docs]
create_dirs: Optional[bool] = True
[docs]
relative_dirs: Optional[bool] = True
# File paths
[docs]
data_dir: Optional[str] = "."
[docs]
fits_dir: Optional[str] = "fits"
[docs]
models_dir: Optional[str] = 'models'
[docs]
figs_dir: Optional[str] = 'figs'
[docs]
metrics_dir: Optional[str] = 'metrics'
[docs]
fit_plots_dir: Optional[str] = 'fits'
[docs]
cm_dir: Optional[str] = 'confusion_matrices'
[docs]
wrongly_classified_dir: Optional[str] = 'wrongly_classified'
[docs]
log_fn: Optional[str] = 'results.log'
[docs]
probs_dir: Optional[str] = 'probabilities'
[docs]
probs_fn: Optional[str] = 'probs_%d.csv'
[docs]
prefix: Optional[str] = 'best-model'
# single-target options
[docs]
target_label: Optional[str] = None
[docs]
prob_threshhold: Optional[float] = 0.5
# Nontunable parameters
[docs]
output_dim: Optional[int] = None
[docs]
normalization_means: Optional[List[float]] = None
[docs]
normalization_stddevs: Optional[List[float]] = None
# Tunable parameters
[docs]
neurons_per_layer: Optional[int] = None
[docs]
num_hidden_layers: Optional[int] = None
[docs]
goal_per_class: Optional[int] = 4500
[docs]
num_folds: Optional[int] = None
[docs]
num_epochs: Optional[int] = None
[docs]
batch_size: Optional[int] = None
[docs]
learning_rate: Optional[float] = None
[docs]
best_val_loss: Optional[float] = None
[docs]
device = torch.device("cpu")
[docs]
def __post_init__(self):
"""Ensure subdirectory structure exists."""
if self.relative_dirs:
self.fits_dir = os.path.join(self.data_dir, self.fits_dir)
self.input_csvs = [
os.path.join(self.data_dir, x) for x in self.input_csvs
]
self.models_dir = os.path.join(self.data_dir, self.models_dir)
self.figs_dir = os.path.join(self.data_dir, self.figs_dir)
self.metrics_dir = os.path.join(self.figs_dir, self.metrics_dir)
self.fit_plots_dir = os.path.join(self.figs_dir, self.fit_plots_dir)
self.cm_dir = os.path.join(self.figs_dir, self.cm_dir)
self.wrongly_classified_dir = os.path.join(self.figs_dir, self.wrongly_classified_dir)
self.log_fn = os.path.join(self.data_dir, self.log_fn)
self.probs_dir = os.path.join(self.data_dir, self.probs_dir)
self.probs_fn = os.path.join(self.probs_dir, self.probs_fn)
if self.create_dirs:
for x_dir in [
self.fits_dir, self.models_dir, self.figs_dir, self.metrics_dir,
self.fit_plots_dir, self.cm_dir, self.wrongly_classified_dir,
self.probs_dir
]:
os.makedirs(x_dir, exist_ok=True)
[docs]
def set_non_tunable_params(self, input_dim, output_dim, norm_means, norm_stddevs):
"""Adds information about the params that are not tunable."""
self.input_dim = input_dim
self.output_dim = output_dim
self.normalization_means = norm_means
self.normalization_stddevs = norm_stddevs
[docs]
def set_best_val_loss(self, best_val_loss):
"""Sets the best validation loss from training."""
self.best_val_loss = best_val_loss
[docs]
def write_to_file(self, file: str):
"""Save configuration data to a YAML file."""
args = dataclasses.asdict(self)
encoded_string = yaml.dump(args, sort_keys=False, default_flow_style=False)
with open(file, "w", encoding="utf-8") as file_handle:
file_handle.write(encoded_string)
@classmethod
[docs]
def from_file(cls, file: str) -> Self:
"""Load configuration data from a YAML file."""
with open(file, "r", encoding="utf-8") as file_handle:
metadata = yaml.safe_load(file_handle)
metadata['prefix'] = file[:-5]
metadata['relative_dirs'] = False
return cls(**metadata)