src.superphot_plus.tuner
Module Contents
Classes
Tunes models using Ray and K-Fold cross validation. |
- class SuperphotTuner(sampler='dynesty', include_redshift=True, num_cpu=2, num_gpu=0)[source]
Bases:
superphot_plus.trainer_base.TrainerBaseTunes models using Ray and K-Fold cross validation.
- Parameters:
sampler (str) – The type of sampler used for the lightcurve fits. Defaults to “dynesty”.
include_redshift (bool) – If True, includes redshift data for training.
num_cpu (int) – The number of CPUs to use in parallel for each tuning experiment. Defaults to 2.
num_gpu (int) – The number of GPUs to use in parallel for each tuning experiment. Defaults to 0.
- run(input_csvs=None, num_hp_samples=10)[source]
Performs model tuning with cross-validation to get the best set of hyperparameters.
- Parameters:
input_csvs (list of str) – The list of training CSV files. Defaults to INPUT_CSVS.
num_hp_samples (int) – The number of hyperparameters sets to sample from (for model tuning). Defaults to 10.
- tune_model(train_data, num_hp_samples=10)[source]
Invokes the Ray Tune API to start model tuning. Outputs the best model configuration to a log file for further reference.
- Parameters:
train_data (ZtfData) – Contains the ZTF object names, classes and redshifts for training.
num_hp_samples (int) – The number of hyperparameters sets to sample from (for model tuning). Defaults to 10.
- Returns:
The best set of model hyperparameters found.
- Return type:
- run_cross_validation(config, train_data: superphot_plus.model.data.ZtfData)[source]
Runs cross-fold validation to estimate the best set of hyperparameters for the model.
- Parameters:
config (Dict[str, Any]) – The configuration for model training, drawn from the default ModelConfig values. Used as a Dict to comply with the Tune API requirements.
train_data (ZtfData) – Contains the ZTF object names, classes and redshifts for training.