src.superphot_plus.trainer
Module Contents
Classes
Trains and evaluates models using K-Fold cross validation. |
- class SuperphotTrainer(config_name, fits_dir, sampler='dynesty', model_type='LightGBM', include_redshift=True, probs_file=None, n_folds=10, target_label=None)[source]
Bases:
superphot_plus.trainer_base.TrainerBaseTrains and evaluates models using K-Fold cross validation.
The model may be trained from scratch using a specified configuration or be loaded from a previous checkpoint stored on disk. In both scenarios the model is evaluated on a test holdout set and metrics are generated.
- Parameters:
config_name (str) – The name of the pre-trained model configuration to load. This file should be located under the specified models directory. Defaults to None.
sampler (str) – The type of sampler used for the lightcurve fits. Defaults to “dynesty”.
include_redshift (bool) – If True, includes redshift data for training.
probs_file (str) – The file where test probabilities are written. Defaults to PROBS_FILE.
- setup_model(load_checkpoint=False)[source]
Reads model configuration from disk and loads the saved checkpoint if load_checkpoint flag was enabled.
- Parameters:
load_checkpoint (bool) – If true, load pretrained model checkpoint.
- run(input_csvs=None, extract_wc=False, n_folds=1, load_checkpoint=False)[source]
Runs the machine learning workflow.
Trains the model on the whole training set and evaluates it on a test holdout set. Also allows K-fold cross-validation. Metrics are plotted and logged to files.
- Parameters:
input_csvs (list of str) – The list of training CSV files. Defaults to INPUT_CSVS.
extract_wc (bool) – If true, assumes all sample fit plots are saved in FIT_PLOTS_FOLDER. Copies plots of wrongly classified samples to separate folder for manual followup. Defaults to False.
load_checkpoint (bool) – If true, load pretrained model checkpoint.
- classify_single_light_curve(obj_name, fits_dir, sampler='dynesty')[source]
Given an object name, return classification probabilities based on the model fit and data.
- Parameters:
obj_name (str) – Name of the supernova.
fits_dir (str) – Where model fit information is stored.
sampler (str) – The MCMC sampler to use. Defaults to “dynesty”.
- Returns:
The average probability for each SN type across all equally-weighted sets of fit parameters.
- Return type:
np.ndarray
- return_new_classifications(test_csv, fit_dir, save_file, output_dir=None, include_labels=False, sampler='dynesty')[source]
Return new classifications based on model and save probabilities to a CSV file.
- Parameters:
test_csv (str) – Path to the CSV file containing the test data.
fit_dir (str) – Path to the directory containing the fit data.
save_file (str) – File to store the new classification outputs.
output_dir (str) – Path to the directory to store the classification outputs.
include_labels (bool, optional) – If True, labels from the test data are included in the probability saving process. Defaults to False.
- train(i: int, train_data: superphot_plus.model.data.PosteriorSamplesGroup)[source]
Trains the model with a specific set of hyperparameters.
- Parameters:
i (the k-fold index)
train_data (PosteriorSamplesGroup) – Contains the ZTF object names, classes and redshifts for training.
- evaluate(k_fold, test_data: superphot_plus.model.data.PosteriorSamplesGroup, extract_wc=False)[source]
Evaluates a pretrained model on the test holdout set.
- Parameters:
test_data (PosteriorSamplesGroup) – Contains the ZTF object names, classes and redshifts for testing.
extract_wc (bool) – If true, assumes all sample fit plots are saved in FIT_PLOTS_FOLDER. Copies plots of wrongly classified samples to separate folder for manual followup. Defaults to False.
- Returns:
A tuple containing the test ground truths, the respective predicted classes and the predicted classes for which classification confidence exceeded 70%.
- Return type:
tuple