src.superphot_plus.trainer

Module Contents

Classes

SuperphotTrainer

Trains and evaluates models using K-Fold cross validation.

class SuperphotTrainer(config_name, fits_dir, sampler='dynesty', model_type='LightGBM', include_redshift=True, probs_file=None, n_folds=10, target_label=None)[source]

Bases: superphot_plus.trainer_base.TrainerBase

Trains and evaluates models using K-Fold cross validation.

The model may be trained from scratch using a specified configuration or be loaded from a previous checkpoint stored on disk. In both scenarios the model is evaluated on a test holdout set and metrics are generated.

Parameters:
  • config_name (str) – The name of the pre-trained model configuration to load. This file should be located under the specified models directory. Defaults to None.

  • sampler (str) – The type of sampler used for the lightcurve fits. Defaults to “dynesty”.

  • include_redshift (bool) – If True, includes redshift data for training.

  • probs_file (str) – The file where test probabilities are written. Defaults to PROBS_FILE.

setup_model(load_checkpoint=False)[source]

Reads model configuration from disk and loads the saved checkpoint if load_checkpoint flag was enabled.

Parameters:

load_checkpoint (bool) – If true, load pretrained model checkpoint.

_load_model_instance(model_file, config_file)[source]
_create_model_instance(config)[source]
run(input_csvs=None, extract_wc=False, n_folds=1, load_checkpoint=False)[source]

Runs the machine learning workflow.

Trains the model on the whole training set and evaluates it on a test holdout set. Also allows K-fold cross-validation. Metrics are plotted and logged to files.

Parameters:
  • input_csvs (list of str) – The list of training CSV files. Defaults to INPUT_CSVS.

  • extract_wc (bool) – If true, assumes all sample fit plots are saved in FIT_PLOTS_FOLDER. Copies plots of wrongly classified samples to separate folder for manual followup. Defaults to False.

  • load_checkpoint (bool) – If true, load pretrained model checkpoint.

classify_single_light_curve(obj_name, fits_dir, sampler='dynesty')[source]

Given an object name, return classification probabilities based on the model fit and data.

Parameters:
  • obj_name (str) – Name of the supernova.

  • fits_dir (str) – Where model fit information is stored.

  • sampler (str) – The MCMC sampler to use. Defaults to “dynesty”.

Returns:

The average probability for each SN type across all equally-weighted sets of fit parameters.

Return type:

np.ndarray

return_new_classifications(test_csv, fit_dir, save_file, output_dir=None, include_labels=False, sampler='dynesty')[source]

Return new classifications based on model and save probabilities to a CSV file.

Parameters:
  • test_csv (str) – Path to the CSV file containing the test data.

  • fit_dir (str) – Path to the directory containing the fit data.

  • save_file (str) – File to store the new classification outputs.

  • output_dir (str) – Path to the directory to store the classification outputs.

  • include_labels (bool, optional) – If True, labels from the test data are included in the probability saving process. Defaults to False.

train(i: int, train_data: superphot_plus.model.data.PosteriorSamplesGroup)[source]

Trains the model with a specific set of hyperparameters.

Parameters:
  • i (the k-fold index)

  • train_data (PosteriorSamplesGroup) – Contains the ZTF object names, classes and redshifts for training.

evaluate(k_fold, test_data: superphot_plus.model.data.PosteriorSamplesGroup, extract_wc=False)[source]

Evaluates a pretrained model on the test holdout set.

Parameters:
  • test_data (PosteriorSamplesGroup) – Contains the ZTF object names, classes and redshifts for testing.

  • extract_wc (bool) – If true, assumes all sample fit plots are saved in FIT_PLOTS_FOLDER. Copies plots of wrongly classified samples to separate folder for manual followup. Defaults to False.

Returns:

A tuple containing the test ground truths, the respective predicted classes and the predicted classes for which classification confidence exceeded 70%.

Return type:

tuple