src.superphot_plus.trainer_base

Module Contents

Classes

TrainerBase

Trainer base class.

class TrainerBase(sampler='dynesty', include_redshift=True, probs_file=PROBS_FILE)[source]

Trainer base class.

create_output_dirs(delete_prev=True)[source]

Ensures creation of output directory structure.

Parameters:

delete_prev (bool) – If true, deletes previous output logs.

split_train_test(input_csvs=None)[source]

Reads data and splits it into training and testing sets.

Parameters:

input_csvs (list of str) – List of input CSV file paths.

Returns:

The train data and the test data.

Return type:

tuple

generate_train_data(train_data, goal_per_class, train_index, val_index)[source]

Extracts and processes the data for training and validation. Oversamples the features to tackle the supernovae class imbalance and adjusts them to their log distributions.

Parameters:
  • train_data (ZtfData) – Contains the ZTF object names, classes and redshifts for training.

  • goal_per_class (int) – The number of samples for each supernova class (for oversampling).

  • train_index (np.ndarray) – The indices for the training data samples.

  • val_index (np.ndarray) – The indices for the validation data samples.

Returns:

A tuple containing the final training features and respective classes, and validation features and respective classes.

Return type:

tuple

generate_test_data(test_data: superphot_plus.model.data.ZtfData)[source]

Extracts and processes the data for testing, adjusting the features to their log distributions.

Parameters:

test_data (ZtfData) – Contains the ZTF object names, classes and redshifts for testing.

Returns:

A tuple containing the final test features and respective classes, the corresponding test ZTF object names and test group indices.

Return type:

tuple