Source code for src.superphot_plus.trainer

import os

import numpy as np
import pandas as pd
import copy
from sklearn.model_selection import train_test_split

from superphot_plus.format_data_ztf import (
    normalize_features,
    tally_each_class,
    retrieve_posterior_set
)
from superphot_plus.model.mlp import SuperphotMLP
from superphot_plus.model.lightgbm import SuperphotLightGBM

from superphot_plus.config import SuperphotConfig
from superphot_plus.model.data import TestData, TrainData, PosteriorSamplesGroup
from superphot_plus.plotting.classifier_results import plot_model_metrics
from superphot_plus.plotting.confusion_matrices import plot_matrices
from superphot_plus.supernova_class import SupernovaClass as SnClass
from superphot_plus.trainer_base import TrainerBase
from superphot_plus.file_utils import get_posterior_samples
from superphot_plus.utils import (
    create_dataset,
    extract_wrong_classifications,
    log_metrics_to_tensorboard,
    write_metrics_to_file,
    epoch_time,
    save_test_probabilities,
)


[docs] class SuperphotTrainer(TrainerBase): """ Trains and evaluates models using K-Fold cross validation. The model may be trained from scratch using a specified configuration or be loaded from a previous checkpoint stored on disk. In both scenarios the model is evaluated on a test holdout set and metrics are generated. Parameters ---------- config_name : str The name of the pre-trained model configuration to load. This file should be located under the specified models directory. Defaults to None. sampler : str The type of sampler used for the lightcurve fits. Defaults to "dynesty". include_redshift : bool If True, includes redshift data for training. probs_file : str The file where test probabilities are written. Defaults to PROBS_FILE. """ def __init__( self, config_name, fits_dir, sampler="dynesty", model_type='LightGBM', include_redshift=True, probs_file=None, n_folds=10, target_label=None, ): super().__init__( config_name, fits_dir, sampler, model_type, include_redshift, probs_file, n_folds, target_label )
[docs] def setup_model(self, load_checkpoint=False): """Reads model configuration from disk and loads the saved checkpoint if load_checkpoint flag was enabled. Parameters ---------- load_checkpoint : bool If true, load pretrained model checkpoint. """ config = SuperphotConfig.from_file(self.config_name) path = os.path.join(config.models_dir, self.config_name.split('/')[-1].split('.')[0]) if self.probs_file is not None: config.probs_fn = self.probs_file for i in range(self.n_folds): if load_checkpoint: model_file, config_file = f"{path}_{i}.pt", f"{path}_{i}.yaml" model_i, config_i = self._load_model_instance(model_file, config_file) self.models.append(model_i) self.configs.append(config_i) else: self.models.append(None) self.configs.append(copy.deepcopy(config)) self.configs[-1].probs_fn = config.probs_fn % i self.load_checkpoint = load_checkpoint
[docs] def _load_model_instance(self, model_file, config_file): if self.model_type == 'LightGBM': return SuperphotLightGBM.load(model_file, config_file) elif self.model_type == 'MLP': return SuperphotMLP.load(model_file, config_file) else: raise ValueError
[docs] def _create_model_instance(self, config): if self.model_type == 'LightGBM': return SuperphotLightGBM(config, target_label=self.target_label) elif self.model_type == 'MLP': return SuperphotMLP.create(config) else: raise ValueError
[docs] def run(self, input_csvs=None, extract_wc=False, n_folds=1, load_checkpoint=False): """Runs the machine learning workflow. Trains the model on the whole training set and evaluates it on a test holdout set. Also allows K-fold cross-validation. Metrics are plotted and logged to files. Parameters ---------- input_csvs : list of str The list of training CSV files. Defaults to INPUT_CSVS. extract_wc : bool If true, assumes all sample fit plots are saved in FIT_PLOTS_FOLDER. Copies plots of wrongly classified samples to separate folder for manual followup. Defaults to False. load_checkpoint : bool If true, load pretrained model checkpoint. """ # Loads model and config self.setup_model(load_checkpoint) if self.n_folds <= 1: k_folded_data = [self.split_train_test(input_csvs),] k_folded_data = self.k_fold_split_train_test(self.kf, input_csvs) for i in range(self.n_folds): print(f"Running fold {i}") train_data, test_data = k_folded_data[i] self.train(i, train_data) # Evaluate model on test dataset self.evaluate(i, test_data, extract_wc) # concatenate probs csvs concat_path = self.probs_file.replace("%d", "%s") % "concat" concat_df = pd.read_csv(self.probs_file % 0) concat_df['Fold'] = 0 for i in range(1, 10): new_df = pd.read_csv(self.probs_file % i) new_df['Fold'] = i concat_df = pd.concat( [concat_df, new_df], ignore_index=True ) concat_df.to_csv(concat_path, index=False)
[docs] def classify_single_light_curve(self, obj_name, fits_dir, sampler="dynesty"): """Given an object name, return classification probabilities based on the model fit and data. Parameters ---------- obj_name : str Name of the supernova. fits_dir : str Where model fit information is stored. sampler : str The MCMC sampler to use. Defaults to "dynesty". Returns ---------- np.ndarray The average probability for each SN type across all equally-weighted sets of fit parameters. """ post_features = get_posterior_samples( obj_name, fits_dir, sampler )[0] if np.median(post_features[:,-1]) > self.chisq_cutoff: return -1 * np.ones(len(self.allowed_types)) # normalize the log distributions post_features = np.delete(post_features, self.skipped_params, 1) probs_avg = np.zeros(len(self.allowed_types)) for model in self.models: # ensemble classifier probs = model.classify_from_fit_params(post_features) probs_avg += np.mean(probs, axis=0) return probs_avg / self.n_folds
[docs] def return_new_classifications(self, test_csv, fit_dir, save_file, output_dir=None, include_labels=False, sampler='dynesty'): """Return new classifications based on model and save probabilities to a CSV file. Parameters ---------- test_csv : str Path to the CSV file containing the test data. fit_dir : str Path to the directory containing the fit data. save_file : str File to store the new classification outputs. output_dir : str Path to the directory to store the classification outputs. include_labels : bool, optional If True, labels from the test data are included in the probability saving process. Defaults to False. """ filepath = save_file if output_dir is None else os.path.join(output_dir, save_file) df = pd.read_csv(test_csv) names = df.NAME.to_numpy() try: redshifts = df.Z.to_numpy() except: redshifts = -1 * np.ones(len(names)) if not self.include_redshift: redshifts = None if include_labels: labels = df.CLASS.to_numpy() else: labels = None posts = retrieve_posterior_set( names, fit_dir, sampler=sampler, redshifts=redshifts, labels=labels, chisq_cutoff=self.chisq_cutoff, ) test_data = PosteriorSamplesGroup( posts, ignore_param_idxs=self.skipped_params, use_redshift_info=self.include_redshift, random_seed=self.random_seed ) combined_probs = np.zeros((len(posts), len(self.allowed_types))) for k_fold in range(self.n_folds): probs = self.models[k_fold].classify_from_fit_params( test_data.features ) probs = probs.reshape(( len(posts), test_data.num_draws, len(self.allowed_types) )) combined_probs += np.mean(probs, axis=1) save_test_probabilities( test_data.names, combined_probs / self.n_folds, filepath, true_labels=test_data.labels, target_label=self.target_label )
[docs] def train(self, i: int, train_data: PosteriorSamplesGroup): """Trains the model with a specific set of hyperparameters. Parameters ---------- i : the k-fold index train_data : PosteriorSamplesGroup Contains the ZTF object names, classes and redshifts for training. """ run_id = f"final_{i}" #tally_each_class(train_data.labels) # original tallies # Split data into training and validation sets train_index, val_index = train_test_split( np.arange(0, len(train_data.labels)), stratify=train_data.labels, test_size=0.1 ) train_features, train_classes, val_features, val_classes = self.generate_train_data( train_data=train_data, goal_per_class=self.configs[0].goal_per_class, train_index=train_index, val_index=val_index, ) train_features, mean, std = normalize_features(train_features) val_features, mean, std = normalize_features(val_features, mean, std) train_dataset = create_dataset(train_features, train_classes) val_dataset = create_dataset(val_features, val_classes) if not self.load_checkpoint: self.configs[i].set_non_tunable_params( input_dim=train_features.shape[1], output_dim=len(self.allowed_types), norm_means=mean.tolist(), norm_stddevs=std.tolist(), ) self.models[i] = self._create_model_instance(self.configs[i]) # Train and validate multi-layer perceptron metrics = self.models[i].train_and_validate( train_data=TrainData(train_dataset, val_dataset), rng_seed=self.random_seed, num_epochs=self.configs[i].num_epochs, ) # Save model checkpoint prefix = os.path.join(self.configs[i].models_dir, self.config_name.split('/')[-1].split('.')[0]) self.models[i].save(prefix, suffix=i) # Plot training and validation metrics plot_model_metrics( metrics=metrics, plot_name=run_id, metrics_dir=self.configs[i].metrics_dir, )
# Log average metrics per epoch to plot on Tensorboard. #log_metrics_to_tensorboard(metrics=[metrics], config=self.configs[i], trial_id=run_id)
[docs] def evaluate(self, k_fold, test_data: PosteriorSamplesGroup, extract_wc=False): """Evaluates a pretrained model on the test holdout set. Parameters ---------- test_data : PosteriorSamplesGroup Contains the ZTF object names, classes and redshifts for testing. extract_wc : bool If true, assumes all sample fit plots are saved in FIT_PLOTS_FOLDER. Copies plots of wrongly classified samples to separate folder for manual followup. Defaults to False. Returns ------- tuple A tuple containing the test ground truths, the respective predicted classes and the predicted classes for which classification confidence exceeded 70%. """ if self.models[k_fold] is None: raise ValueError("Cannot evaluate uninitialized model.") test_features, test_classes, test_names = self.generate_test_data( test_data=test_data ) mean = self.configs[k_fold].normalization_means std = self.configs[k_fold].normalization_stddevs test_features, _, _ = normalize_features(test_features, mean, std) results = self.models[k_fold].evaluate( test_data=TestData(test_features, test_classes, test_names), ) true_classes, _, pred_classes, pred_probs = zip(results) true_classes = np.hstack(true_classes) pred_classes = np.hstack(pred_classes) pred_probs_above_07 = np.hstack(pred_probs) > 0.7 true_classes = SnClass.get_labels_from_classes(true_classes) pred_classes = SnClass.get_labels_from_classes(pred_classes) # Log evaluation metrics write_metrics_to_file( config=self.configs[k_fold], true_classes=true_classes, pred_classes=pred_classes, prob_above_07=pred_probs_above_07, ) plot_matrices( config=self.configs[k_fold], true_classes=true_classes, pred_classes=pred_classes, prob_above_07=pred_probs_above_07, ) if extract_wc: extract_wrong_classifications( true_classes=true_classes, pred_classes=pred_classes, ztf_test_names=test_data.names, )