src.superphot_plus.trainer
Module Contents
Classes
Trains and evaluates models using K-Fold cross validation. |
- class SuperphotTrainer(config_name, sampler='dynesty', include_redshift=True, probs_file=PROBS_FILE)[source]
Bases:
superphot_plus.trainer_base.TrainerBaseTrains and evaluates models using K-Fold cross validation.
The model may be trained from scratch using a specified configuration or be loaded from a previous checkpoint stored on disk. In both scenarios the model is evaluated on a test holdout set and metrics are generated.
- Parameters:
config_name (str) – The name of the pre-trained model configuration to load. This file should be located under the specified models directory. Defaults to None.
sampler (str) – The type of sampler used for the lightcurve fits. Defaults to “dynesty”.
include_redshift (bool) – If True, includes redshift data for training.
probs_file (str) – The file where test probabilities are written. Defaults to PROBS_FILE.
- setup_model(load_checkpoint=False)[source]
Reads model configuration from disk and loads the saved checkpoint if load_checkpoint flag was enabled.
- Parameters:
load_checkpoint (bool) – If true, load pretrained model checkpoint.
- run(input_csvs=None, extract_wc=False, load_checkpoint=False)[source]
Runs the machine learning workflow.
Trains the model on the whole training set and evaluates it on a test holdout set. Metrics are plotted and logged to files.
- Parameters:
input_csvs (list of str) – The list of training CSV files. Defaults to INPUT_CSVS.
extract_wc (bool) – If true, assumes all sample fit plots are saved in FIT_PLOTS_FOLDER. Copies plots of wrongly classified samples to separate folder for manual followup. Defaults to False.
load_checkpoint (bool) – If true, load pretrained model checkpoint.
- train(train_data: superphot_plus.model.data.ZtfData)[source]
Trains the model with a specific set of hyperparameters.
- Parameters:
train_data (ZtfData) – Contains the ZTF object names, classes and redshifts for training.
- evaluate(test_data: superphot_plus.model.data.ZtfData, extract_wc=False)[source]
Evaluates a pretrained model on the test holdout set.
- Parameters:
test_data (ZtfData) – Contains the ZTF object names, classes and redshifts for testing.
extract_wc (bool) – If true, assumes all sample fit plots are saved in FIT_PLOTS_FOLDER. Copies plots of wrongly classified samples to separate folder for manual followup. Defaults to False.
- Returns:
A tuple containing the test ground truths, the respective predicted classes and the predicted classes for which classification confidence exceeded 70%.
- Return type:
tuple