src.superphot_plus.trainer

Module Contents

Classes

SuperphotTrainer

Trains and evaluates models using K-Fold cross validation.

class SuperphotTrainer(config_name, sampler='dynesty', include_redshift=True, probs_file=PROBS_FILE)[source]

Bases: superphot_plus.trainer_base.TrainerBase

Trains and evaluates models using K-Fold cross validation.

The model may be trained from scratch using a specified configuration or be loaded from a previous checkpoint stored on disk. In both scenarios the model is evaluated on a test holdout set and metrics are generated.

Parameters:
  • config_name (str) – The name of the pre-trained model configuration to load. This file should be located under the specified models directory. Defaults to None.

  • sampler (str) – The type of sampler used for the lightcurve fits. Defaults to “dynesty”.

  • include_redshift (bool) – If True, includes redshift data for training.

  • probs_file (str) – The file where test probabilities are written. Defaults to PROBS_FILE.

setup_model(load_checkpoint=False)[source]

Reads model configuration from disk and loads the saved checkpoint if load_checkpoint flag was enabled.

Parameters:

load_checkpoint (bool) – If true, load pretrained model checkpoint.

run(input_csvs=None, extract_wc=False, load_checkpoint=False)[source]

Runs the machine learning workflow.

Trains the model on the whole training set and evaluates it on a test holdout set. Metrics are plotted and logged to files.

Parameters:
  • input_csvs (list of str) – The list of training CSV files. Defaults to INPUT_CSVS.

  • extract_wc (bool) – If true, assumes all sample fit plots are saved in FIT_PLOTS_FOLDER. Copies plots of wrongly classified samples to separate folder for manual followup. Defaults to False.

  • load_checkpoint (bool) – If true, load pretrained model checkpoint.

train(train_data: superphot_plus.model.data.ZtfData)[source]

Trains the model with a specific set of hyperparameters.

Parameters:

train_data (ZtfData) – Contains the ZTF object names, classes and redshifts for training.

evaluate(test_data: superphot_plus.model.data.ZtfData, extract_wc=False)[source]

Evaluates a pretrained model on the test holdout set.

Parameters:
  • test_data (ZtfData) – Contains the ZTF object names, classes and redshifts for testing.

  • extract_wc (bool) – If true, assumes all sample fit plots are saved in FIT_PLOTS_FOLDER. Copies plots of wrongly classified samples to separate folder for manual followup. Defaults to False.

Returns:

A tuple containing the test ground truths, the respective predicted classes and the predicted classes for which classification confidence exceeded 70%.

Return type:

tuple