src.superphot_plus.model.lightgbm
Module Contents
Classes
The LightGBM model. |
- class SuperphotLightGBM(config: superphot_plus.config.SuperphotConfig, target_label=None)[source]
The LightGBM model.
- Parameters:
config (SuperphotConfig) – The MLP architecture configuration.
- train_and_validate(train_data, rng_seed=None, **kwargs)[source]
Runs LightGBM training and validation.
- Parameters:
train_data (TrainData) – The training dataset.
rng_seed (int, optional) – Random state that is seeded. if none, use machine entropy.
- Returns:
A tuple containing arrays of metrics for each epoch (training accuracies and losses, validation accuracies and losses).
- Return type:
tuple
- evaluate(test_data, overwrite_save=False)[source]
Runs model over a group of test samples.
- Parameters:
test_data (TestData) – The data to evaluate the model. Consists of test features, test classes, test names and a list of grouped indices, respectively.
probs_csv_path (str, optional) – Where to store the probability results.
- Returns:
A tuple containing the labels, names, predicted labels and maximum probabilities.
- Return type:
tuple
- classify_from_fit_params(fit_params)[source]
Classify one or multiple light curves solely from the fit parameters used in the classifier. Excludes t0 and, for redshift-exclusive classifier, A. Includes chi-squared value.
- Parameters:
fit_params (np.ndarray) – Set of model fit parameters.
- Returns:
Probability of each light curve being each SN type. Sums to 1 along each row.
- Return type:
np.ndarray