src.superphot_plus.trainer_base
Module Contents
Classes
Trainer base class. |
- class TrainerBase(config_name, fits_dir, sampler='dynesty', model_type='LightGBM', include_redshift=True, probs_file=None, n_folds=10, target_label=None)[source]
Trainer base class.
- k_fold_split_train_test(kf, input_csvs=None)[source]
Reads data and splits into n K-folds. Outputs n sets of train/test sets.
- Parameters:
input_csvs (list of str) – List of input CSV file paths.
- Returns:
N sets of the train data and the test data.
- Return type:
list of 2-tuples
- split_train_test(input_csvs=None)[source]
Reads data and splits it into training and testing sets.
- Parameters:
input_csvs (list of str) – List of input CSV file paths.
- Returns:
The train data and the test data.
- Return type:
tuple
- generate_train_data(train_data, goal_per_class, train_index, val_index)[source]
Extracts and processes the data for training and validation. Oversamples the features to tackle the supernovae class imbalance and adjusts them to their log distributions.
- Parameters:
train_data (PosteriorSamplesGroup) – Contains the ZTF object names, classes and redshifts for training.
goal_per_class (int) – The number of samples for each supernova class (for oversampling).
train_index (np.ndarray) – The indices for the training data samples.
val_index (np.ndarray) – The indices for the validation data samples.
- Returns:
A tuple containing the final training features and respective classes, and validation features and respective classes.
- Return type:
tuple
- generate_test_data(test_data: superphot_plus.model.data.PosteriorSamplesGroup)[source]
Extracts and processes the data for testing, adjusting the features to their log distributions.
- Parameters:
test_data (PosteriorSamplesGroup) – Contains the ZTF object names, classes and redshifts for testing.
- Returns:
A tuple containing the final test features and respective classes, the corresponding test ZTF object names and test group indices.
- Return type:
tuple