src.superphot_plus.trainer_base
Module Contents
Classes
Trainer base class. |
- class TrainerBase(sampler='dynesty', include_redshift=True, probs_file=PROBS_FILE)[source]
Trainer base class.
- create_output_dirs(delete_prev=True)[source]
Ensures creation of output directory structure.
- Parameters:
delete_prev (bool) – If true, deletes previous output logs.
- split_train_test(input_csvs=None)[source]
Reads data and splits it into training and testing sets.
- Parameters:
input_csvs (list of str) – List of input CSV file paths.
- Returns:
The train data and the test data.
- Return type:
tuple
- generate_train_data(train_data, goal_per_class, train_index, val_index)[source]
Extracts and processes the data for training and validation. Oversamples the features to tackle the supernovae class imbalance and adjusts them to their log distributions.
- Parameters:
train_data (ZtfData) – Contains the ZTF object names, classes and redshifts for training.
goal_per_class (int) – The number of samples for each supernova class (for oversampling).
train_index (np.ndarray) – The indices for the training data samples.
val_index (np.ndarray) – The indices for the validation data samples.
- Returns:
A tuple containing the final training features and respective classes, and validation features and respective classes.
- Return type:
tuple
- generate_test_data(test_data: superphot_plus.model.data.ZtfData)[source]
Extracts and processes the data for testing, adjusting the features to their log distributions.
- Parameters:
test_data (ZtfData) – Contains the ZTF object names, classes and redshifts for testing.
- Returns:
A tuple containing the final test features and respective classes, the corresponding test ZTF object names and test group indices.
- Return type:
tuple