from dataclasses import dataclass, field
from typing import List
@dataclass
[docs]
class ModelMetrics:
"""Class containing the training and validation metrics."""
[docs]
train_acc: List[float] = field(default_factory=list)
[docs]
val_acc: List[float] = field(default_factory=list)
[docs]
train_loss: List[float] = field(default_factory=list)
[docs]
val_loss: List[float] = field(default_factory=list)
[docs]
epoch_mins: List[int] = field(default_factory=list)
[docs]
epoch_secs: List[int] = field(default_factory=list)
[docs]
def get_values(self):
"""Returns the training and validation accuracies and losses.
Returns
-------
tuple
A tuple containing the training accuracy and loss, and
validation accuracy and loss, respectively.
"""
return self.train_acc, self.train_loss, self.val_acc, self.val_loss
[docs]
def append(self, train_metrics, val_metrics, epoch_time):
"""Appends training information for an epoch.
Parameters
----------
train_metrics: tuple
The epoch training loss and accuracy.
val_metrics: tuple
The epoch validation loss and accuracy.
epoch_time: tuple
The number of minutes and seconds spent by the epoch.
"""
train_loss, train_acc = train_metrics
val_loss, val_acc = val_metrics
epoch_mins, epoch_secs = epoch_time
self.curr_epoch += 1
self.train_loss.append(train_loss)
self.train_acc.append(train_acc)
self.val_loss.append(val_loss)
self.val_acc.append(val_acc)
self.epoch_mins.append(epoch_mins)
self.epoch_secs.append(epoch_secs)
[docs]
def print_last(self):
"""Prints the metrics for the last epoch."""
epoch_mins, epoch_secs, train_loss, train_acc, val_loss, val_acc = (
self.epoch_mins[-1],
self.epoch_secs[-1],
self.train_loss[-1],
self.train_acc[-1],
self.val_loss[-1],
self.val_acc[-1],
)
print(f"Epoch: {self.curr_epoch:02} | Epoch Time: {epoch_mins}m {epoch_secs}s")
print(f"\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%")
print(f"\t Val. Loss: {val_loss:.3f} | Val. Acc: {val_acc*100:.2f}%")