src.superphot_plus.trainer_base

Module Contents

Classes

TrainerBase

Trainer base class.

class TrainerBase(config_name, fits_dir, sampler='dynesty', model_type='LightGBM', include_redshift=True, probs_file=None, n_folds=10, target_label=None)[source]

Trainer base class.

k_fold_split_train_test(kf, input_csvs=None)[source]

Reads data and splits into n K-folds. Outputs n sets of train/test sets.

Parameters:

input_csvs (list of str) – List of input CSV file paths.

Returns:

N sets of the train data and the test data.

Return type:

list of 2-tuples

load_csv(input_csvs)[source]

Load CSV data.

split_train_test(input_csvs=None)[source]

Reads data and splits it into training and testing sets.

Parameters:

input_csvs (list of str) – List of input CSV file paths.

Returns:

The train data and the test data.

Return type:

tuple

generate_train_data(train_data, goal_per_class, train_index, val_index)[source]

Extracts and processes the data for training and validation. Oversamples the features to tackle the supernovae class imbalance and adjusts them to their log distributions.

Parameters:
  • train_data (PosteriorSamplesGroup) – Contains the ZTF object names, classes and redshifts for training.

  • goal_per_class (int) – The number of samples for each supernova class (for oversampling).

  • train_index (np.ndarray) – The indices for the training data samples.

  • val_index (np.ndarray) – The indices for the validation data samples.

Returns:

A tuple containing the final training features and respective classes, and validation features and respective classes.

Return type:

tuple

generate_test_data(test_data: superphot_plus.model.data.PosteriorSamplesGroup)[source]

Extracts and processes the data for testing, adjusting the features to their log distributions.

Parameters:

test_data (PosteriorSamplesGroup) – Contains the ZTF object names, classes and redshifts for testing.

Returns:

A tuple containing the final test features and respective classes, the corresponding test ZTF object names and test group indices.

Return type:

tuple