Source code for src.superphot_plus.model.lightgbm

import os

import numpy as np
import lightgbm
import pickle

from superphot_plus.format_data_ztf import normalize_features
from superphot_plus.config import SuperphotConfig
from superphot_plus.model.metrics import ModelMetrics
from torch.utils.data import DataLoader, TensorDataset
from superphot_plus.utils import (
    save_test_probabilities,
)

[docs] class SuperphotLightGBM: """The LightGBM model. Parameters ---------- config : SuperphotConfig The MLP architecture configuration. """ def __init__(self, config: SuperphotConfig, target_label=None): super().__init__() # Initialize MLP architecture self.config = config # Model state dictionary self.best_model = None self.target_label = target_label
[docs] def train_and_validate( self, train_data, rng_seed=None, **kwargs, ): """ Runs LightGBM training and validation. Parameters ---------- train_data : TrainData The training dataset. rng_seed : int, optional Random state that is seeded. if none, use machine entropy. Returns ------- tuple A tuple containing arrays of metrics for each epoch (training accuracies and losses, validation accuracies and losses). """ train_dataset, valid_dataset = train_data train_iterator = DataLoader( dataset=train_dataset ) valid_iterator = DataLoader( dataset=valid_dataset, ) train_feats = np.asarray([x[0].numpy() for x, y in train_iterator]) train_classes = np.asarray([y[0] for x, y in train_iterator]) uc, cts = np.unique(train_classes, return_counts=True) val_feats = np.asarray([x[0].numpy() for x, y in valid_iterator]) val_classes = np.asarray([y[0] for x, y in valid_iterator]) uc, cts = np.unique(val_classes, return_counts=True) lightgbm_params = { "boosting": "dart", "data_sample_strategy": "goss", "verbosity": -1, "random_state": rng_seed, 'max_depth': 5, 'num_leaves': 20, 'lambda_l1': 5.0, 'n_estimators': 250, } if self.target_label is not None: # Single class classification lightgbm_params["objective"] = "binary" lightgbm_params["metric"] = "binary_logloss" else: lightgbm_params["objective"] = "multiclass" lightgbm_params["metric"] = "multi_logloss" lightgbm_params['num_class'] = len(np.unique(train_classes)) eval_results = {} classifier = lightgbm.LGBMClassifier(**lightgbm_params) classifier.fit( train_feats, train_classes, eval_set=[ (train_feats, train_classes), (val_feats, val_classes), ], eval_names=['train', 'val'], callbacks=[ lightgbm.log_evaluation, lightgbm.record_evaluation(eval_results) ], eval_metric=['multi_logloss', 'multi_error',] ) if self.target_label is None: metrics = ModelMetrics( train_acc = 1. - np.array(eval_results['train']['multi_error']), train_loss = np.array(eval_results['train']['multi_logloss']), val_acc = 1. - np.array(eval_results['val']['multi_error']), val_loss = np.array(eval_results['val']['multi_logloss']) ) best_val_loss = np.min(eval_results['val']['multi_logloss']) else: metrics = ModelMetrics( train_acc = np.ones(len(eval_results['train']['binary_logloss'])), train_loss = np.array(eval_results['train']['binary_logloss']), val_acc = np.ones(len(eval_results['train']['binary_logloss'])), val_loss = np.array(eval_results['val']['binary_logloss']) ) best_val_loss = np.min(eval_results['val']['binary_logloss']) # Save best model state self.best_model = classifier # Store best validation loss self.config.set_best_val_loss(float(best_val_loss)) return metrics.get_values()
[docs] def evaluate(self, test_data, overwrite_save=False): """Runs model over a group of test samples. Parameters ---------- test_data : TestData The data to evaluate the model. Consists of test features, test classes, test names and a list of grouped indices, respectively. probs_csv_path : str, optional Where to store the probability results. Returns ------- tuple A tuple containing the labels, names, predicted labels and maximum probabilities. """ test_features, test_classes, test_names = test_data labels, pred_labels, max_probs, names, probs_avgs = [], [], [], [], [] for test_name in np.unique(test_names): group_idx_set = ( test_names == test_name ) true_classes = test_classes[group_idx_set] probs = self.best_model.predict_proba( test_features[group_idx_set], ) probs_avg = np.mean(probs, axis=0) if self.target_label is None: pred_labels.append(np.argmax(probs_avg)) max_probs.append(np.amax(probs_avg)) labels.append(true_classes[0]) else: pred_target = probs_avg[1] > self.config.prob_threshhold pred_labels.append(1-int(pred_target)) max_probs.append(probs_avg[1] if pred_target else probs_avg[0]) labels.append(1-true_classes[0]) names.append(test_name) probs_avgs.append(probs_avg) save_test_probabilities( names, np.array(probs_avgs), self.config.probs_fn, true_labels=labels, target_label=self.target_label ) return ( np.array(labels).astype(int), np.array(names), np.array(pred_labels).astype(int), np.array(max_probs).astype(float), )
[docs] def classify_from_fit_params(self, fit_params): """Classify one or multiple light curves solely from the fit parameters used in the classifier. Excludes t0 and, for redshift-exclusive classifier, A. Includes chi-squared value. Parameters ---------- fit_params : np.ndarray Set of model fit parameters. Returns ---------- np.ndarray Probability of each light curve being each SN type. Sums to 1 along each row. """ fit_params_2d = np.atleast_2d(fit_params) # cast to 2D if only 1 light curve test_features = normalize_features( fit_params_2d, self.config.normalization_means, self.config.normalization_stddevs, )[0] probs = self.best_model.predict_proba(test_features) try: return probs.numpy() except: return probs
[docs] def save(self, config_prefix, suffix=''): """Save the classifier as file. Parameters ---------- models_dir : str Directory to write to """ with open(f"{config_prefix}_{suffix}.pt", 'wb') as f: pickle.dump(self, f) self.config.write_to_file(f"{config_prefix}_{suffix}.yaml")
@classmethod
[docs] def load(cls, filename, config_filename=None): """Load a classifier that was saved to disk Parameters ---------- path : str Path where the classifier was saved Returns ------- `~Classifier` Loaded classifier """ #config = SuperphotConfig.from_file(config_filename) with open(filename, 'rb') as f: model = pickle.load(f) if config_filename is None: config = None else: config = SuperphotConfig.from_file(config_filename) return model, config