Source code for src.superphot_plus.trainer

import os

import numpy as np
from sklearn.model_selection import train_test_split

from superphot_plus.file_paths import PROBS_FILE
from superphot_plus.format_data_ztf import normalize_features, tally_each_class
from superphot_plus.model.classifier import SuperphotClassifier
from superphot_plus.model.config import ModelConfig
from superphot_plus.model.data import TestData, TrainData, ZtfData
from superphot_plus.plotting.classifier_results import plot_model_metrics
from superphot_plus.plotting.confusion_matrices import plot_matrices
from superphot_plus.supernova_class import SupernovaClass as SnClass
from superphot_plus.trainer_base import TrainerBase
from superphot_plus.utils import (
    create_dataset,
    extract_wrong_classifications,
    log_metrics_to_tensorboard,
    write_metrics_to_file,
)


[docs]class SuperphotTrainer(TrainerBase): """ Trains and evaluates models using K-Fold cross validation. The model may be trained from scratch using a specified configuration or be loaded from a previous checkpoint stored on disk. In both scenarios the model is evaluated on a test holdout set and metrics are generated. Parameters ---------- config_name : str The name of the pre-trained model configuration to load. This file should be located under the specified models directory. Defaults to None. sampler : str The type of sampler used for the lightcurve fits. Defaults to "dynesty". include_redshift : bool If True, includes redshift data for training. probs_file : str The file where test probabilities are written. Defaults to PROBS_FILE. """ def __init__( self, config_name, sampler="dynesty", include_redshift=True, probs_file=PROBS_FILE, ): super().__init__(sampler, include_redshift, probs_file) self.config_name = config_name self.model, self.config = None, None self.create_output_dirs(delete_prev=False)
[docs] def setup_model(self, load_checkpoint=False): """Reads model configuration from disk and loads the saved checkpoint if load_checkpoint flag was enabled. Parameters ---------- load_checkpoint : bool If true, load pretrained model checkpoint. """ path = os.path.join(self.models_dir, self.config_name) model_file, config_file = f"{path}.pt", f"{path}.yaml" config = ModelConfig.from_file(config_file) if load_checkpoint: self.model, _ = SuperphotClassifier.load(model_file, config_file) self.config = config
[docs] def run(self, input_csvs=None, extract_wc=False, load_checkpoint=False): """Runs the machine learning workflow. Trains the model on the whole training set and evaluates it on a test holdout set. Metrics are plotted and logged to files. Parameters ---------- input_csvs : list of str The list of training CSV files. Defaults to INPUT_CSVS. extract_wc : bool If true, assumes all sample fit plots are saved in FIT_PLOTS_FOLDER. Copies plots of wrongly classified samples to separate folder for manual followup. Defaults to False. load_checkpoint : bool If true, load pretrained model checkpoint. """ train_data, test_data = self.split_train_test(input_csvs) # Loads model and config self.setup_model(load_checkpoint) if self.model is None: self.train(train_data) # Evaluate model on test dataset self.evaluate(test_data, extract_wc)
[docs] def train(self, train_data: ZtfData): """Trains the model with a specific set of hyperparameters. Parameters ---------- train_data : ZtfData Contains the ZTF object names, classes and redshifts for training. """ run_id = "final" tally_each_class(train_data.labels) # original tallies # Split data into training and validation sets train_index, val_index = train_test_split( np.arange(0, len(train_data.labels)), stratify=train_data.labels, test_size=0.1 ) train_features, train_classes, val_features, val_classes = self.generate_train_data( train_data=train_data, goal_per_class=self.config.goal_per_class, train_index=train_index, val_index=val_index, ) train_features, mean, std = normalize_features(train_features) val_features, mean, std = normalize_features(val_features, mean, std) train_dataset = create_dataset(train_features, train_classes) val_dataset = create_dataset(val_features, val_classes) self.config.set_non_tunable_params( input_dim=train_features.shape[1], output_dim=len(self.allowed_types), norm_means=mean.tolist(), norm_stddevs=std.tolist(), ) self.model = SuperphotClassifier.create(self.config) # Train and validate multi-layer perceptron metrics = self.model.train_and_validate( train_data=TrainData(train_dataset, val_dataset), num_epochs=self.config.num_epochs ) # Save model checkpoint self.model.save(self.models_dir) # Plot training and validation metrics plot_model_metrics( metrics=metrics, num_epochs=self.config.num_epochs, plot_name=run_id, metrics_dir=self.metrics_dir, ) # Log average metrics per epoch to plot on Tensorboard. log_metrics_to_tensorboard(metrics=[metrics], config=self.config, trial_id=run_id)
[docs] def evaluate(self, test_data: ZtfData, extract_wc=False): """Evaluates a pretrained model on the test holdout set. Parameters ---------- test_data : ZtfData Contains the ZTF object names, classes and redshifts for testing. extract_wc : bool If true, assumes all sample fit plots are saved in FIT_PLOTS_FOLDER. Copies plots of wrongly classified samples to separate folder for manual followup. Defaults to False. Returns ------- tuple A tuple containing the test ground truths, the respective predicted classes and the predicted classes for which classification confidence exceeded 70%. """ if self.model is None: raise ValueError("Cannot evaluate uninitialized model.") test_features, test_classes, test_names, test_group_idxs = self.generate_test_data( test_data=test_data ) test_features, _, _ = normalize_features(test_features) results = self.model.evaluate( test_data=TestData(test_features, test_classes, test_names, test_group_idxs), probs_csv_path=self.probs_file, ) true_classes, _, pred_classes, pred_probs = zip(results) true_classes = np.hstack(true_classes) pred_classes = np.hstack(pred_classes) pred_probs_above_07 = np.hstack(pred_probs) > 0.7 true_classes = SnClass.get_labels_from_classes(true_classes) pred_classes = SnClass.get_labels_from_classes(pred_classes) # Log evaluation metrics write_metrics_to_file( config=self.config, true_classes=true_classes, pred_classes=pred_classes, prob_above_07=pred_probs_above_07, log_file=self.classify_log_file, ) plot_matrices( config=self.config, true_classes=true_classes, pred_classes=pred_classes, prob_above_07=pred_probs_above_07, cm_folder=self.cm_folder, ) if extract_wc: extract_wrong_classifications( true_classes=true_classes, pred_classes=pred_classes, ztf_test_names=test_data.names, )