Source code for src.superphot_plus.trainer_base

import os
import shutil

import numpy as np
from sklearn.model_selection import train_test_split, StratifiedKFold
from astropy.cosmology import Planck13 as cosmo
from superphot_plus.file_utils import get_posterior_samples
from superphot_plus.format_data_ztf import import_labels_only
from superphot_plus.model.data import PosteriorSamplesGroup
from superphot_plus.supernova_class import SupernovaClass as SnClass
from superphot_plus.format_data_ztf import retrieve_posterior_set

[docs] class TrainerBase: """Trainer base class.""" def __init__( self, config_name, fits_dir, #TODO: make optional from config sampler="dynesty", model_type='LightGBM', include_redshift=True, probs_file=None, n_folds=10, target_label=None, ): # Supernova class types # TODO: replace with supernova_class enumeration self.allowed_types = ["SN Ia", "SN II", "SN IIn", "SLSN-I", "SN Ibc"] self.target_label = target_label # Fitting method self.sampler = sampler self.include_redshift = include_redshift if self.include_redshift: self.skipped_params = [3,14] else: self.skipped_params = [0,3,14] self.probs_file = probs_file self.fits_dir = fits_dir self.config_name = config_name self.models, self.configs = [], [] self.model_type = model_type # generate k-folds self.n_folds = max(int(n_folds), 1) self.random_seed = 42 # TODO: un-hard code this self.chisq_cutoff = 1.2 if self.n_folds > 1: self.kf = StratifiedKFold( self.n_folds, random_state=self.random_seed, shuffle=True ) else: self.kf = None
[docs] def k_fold_split_train_test(self, kf, input_csvs=None): """Reads data and splits into n K-folds. Outputs n sets of train/test sets. Parameters ---------- input_csvs : list of str List of input CSV file paths. Returns ------- list of 2-tuples N sets of the train data and the test data. """ k_fold_datasets = [] if input_csvs is None: input_csvs = INPUT_CSVS # Load train and test data (holdout of 10%) names, labels, redshifts = self.load_csv( input_csvs=input_csvs, ) if not self.include_redshift: redshifts = None all_post_objs = retrieve_posterior_set( names, self.fits_dir, sampler=self.sampler, redshifts=redshifts, labels=labels, chisq_cutoff=self.chisq_cutoff ) all_data = PosteriorSamplesGroup( all_post_objs, use_redshift_info=self.include_redshift, ignore_param_idxs=self.skipped_params, random_seed=self.random_seed ) if self.target_label is not None: all_data.make_binary(target_label=self.target_label) for groups in kf.split(all_data.names, all_data.labels): train_data, test_data = all_data.split(split_indices=groups) #test_data.make_fully_redshift_independent() k_fold_datasets.append((train_data, test_data)) return k_fold_datasets
[docs] def load_csv(self, input_csvs): """Load CSV data. """ # Load train and test data (holdout of 10%) names, labels, redshifts = import_labels_only( input_csvs=input_csvs, allowed_types=self.allowed_types, fits_dir=self.fits_dir, sampler=self.sampler, ) return names, labels, redshifts
[docs] def split_train_test(self, input_csvs=None): """Reads data and splits it into training and testing sets. Parameters ---------- input_csvs : list of str List of input CSV file paths. Returns ------- tuple The train data and the test data. """ if input_csvs is None: input_csvs = INPUT_CSVS names, labels, redshifts = self.load_csv(input_csvs) if not self.include_redshift: redshifts = None all_post_objs = retrieve_posterior_set( names, self.fits_dir, sampler=self.sampler, labels=labels, redshifts=redshifts, chisq_cutoff=self.chisq_cutoff ) all_data = PosteriorSamplesGroup( all_post_objs, use_redshift_info=self.include_redshift, ignore_param_idxs=self.skipped_params, random_seed=self.random_seed ) train_data, test_data = all_data.split(split_frac=0.1) return train_data, test_data
[docs] def generate_train_data(self, train_data, goal_per_class, train_index, val_index): """Extracts and processes the data for training and validation. Oversamples the features to tackle the supernovae class imbalance and adjusts them to their log distributions. Parameters ---------- train_data : PosteriorSamplesGroup Contains the ZTF object names, classes and redshifts for training. goal_per_class : int The number of samples for each supernova class (for oversampling). train_index : np.ndarray The indices for the training data samples. val_index : np.ndarray The indices for the validation data samples. Returns ------- tuple A tuple containing the final training features and respective classes, and validation features and respective classes. """ train_data, val_data = train_data.split(split_frac=0.1) #train_data.make_fully_redshift_independent() #val_data.make_fully_redshift_independent() train_features, train_labels = train_data.oversample(5) val_features, val_labels = val_data.oversample(5) if self.target_label is None: train_classes = SnClass.get_classes_from_labels(train_labels) val_classes = SnClass.get_classes_from_labels(val_labels) else: train_classes = np.where( train_labels == self.target_label, 1, 0 ) val_classes = np.where( val_labels == self.target_label, 1, 0 ) return train_features, train_classes, val_features, val_classes
[docs] def generate_test_data(self, test_data: PosteriorSamplesGroup): """Extracts and processes the data for testing, adjusting the features to their log distributions. Parameters ---------- test_data : PosteriorSamplesGroup Contains the ZTF object names, classes and redshifts for testing. Returns ------- tuple A tuple containing the final test features and respective classes, the corresponding test ZTF object names and test group indices. """ test_names, test_labels, test_redshifts = test_data if self.target_label is None: test_classes = SnClass.get_classes_from_labels(test_labels) else: test_classes = np.where( test_labels == self.target_label, 1, 0 ) test_features = test_data.features os_classes = np.ravel( [[c] * test_data.num_draws for c in test_classes] ) os_names = np.ravel( [[n] * test_data.num_draws for n in test_names] ) return test_features, os_classes, os_names