from dataclasses import dataclass, field
from typing import List, Optional
import numpy as np
from torch.utils.data import TensorDataset
from imblearn.over_sampling import SMOTE
from sklearn.model_selection import train_test_split
from astropy.cosmology import Planck13 as cosmo
from superphot_plus.posterior_samples import PosteriorSamples
from superphot_plus.supernova_class import SupernovaClass as SnClass
@dataclass
[docs]
class PosteriorSamplesGroup:
"""Holds data from multiple objects' posterior objects."""
[docs]
posterior_objects: List[PosteriorSamples]
[docs]
use_redshift_info: Optional[bool] = False
[docs]
ignore_param_idxs: Optional[List[int]] = field(default_factory=list)
[docs]
random_seed: Optional[int] = None
[docs]
def __post_init__(self):
self.names = np.array([
ps.name for ps in self.posterior_objects
])
self.labels = np.array([
ps.sn_class for ps in self.posterior_objects
])
self.redshifts = np.array([
ps.redshift for ps in self.posterior_objects
])
self.abs_mags = []
self.rng = np.random.default_rng(
self.random_seed
)
# filter invalid redshifts if using redshift info
z_mask = (~np.isnan(self.redshifts)) & (self.redshifts > 0)
self.names = self.names[z_mask]
self.labels = self.labels[z_mask]
self.redshifts = self.redshifts[z_mask]
self.posterior_objects = self.posterior_objects[z_mask]
# save equal number of draws per LC
num_samples = [ps.samples.shape[0] for ps in self.posterior_objects]
self.num_draws = min(num_samples)
feat_arr = []
median_feats = []
for ps in self.posterior_objects:
samples = ps.samples[:self.num_draws]
if self.use_redshift_info:
z_arr = np.ones((self.num_draws, 2)) * ps.redshift
max_flux = ps.max_flux
if max_flux is None:
self.abs_mags.append(None)
z_arr[:,1] = -np.inf
else:
k_corr = 2.5 * np.log10(1.+ps.redshift)
dist = cosmo.luminosity_distance([ps.redshift]).value[0] # returns dist in Mpc
abs_mag = -2.5 * np.log10(max_flux) + 26.3 - 5. * np.log10(dist*1e5) + k_corr
self.abs_mags.append(abs_mag)
z_arr[:,1] = abs_mag
samples = np.append(
samples, z_arr, axis=1
)
samples = np.delete(samples, self.ignore_param_idxs, 1)
feat_arr.extend(samples)
median_feats.append(np.median(samples, axis=0))
self.features = np.asarray(feat_arr)
self.median_features = np.asarray(median_feats)
[docs]
def __iter__(self):
return iter((self.names, self.labels, self.redshifts))
[docs]
def oversample(self, fits_per_majority_lc=1):
"""Oversamples, drawing from posteriors of a certain fit.
Assumes goal_per_class is the number of majority class if not set.
Returns
-------
tuple of np.ndarray
Tuple containing oversampled features and labels.
"""
oversampled_labels = []
oversampled_features = []
labels_unique, counts = np.unique(
self.labels, return_counts=True
)
goal_per_class = np.max(counts) * fits_per_majority_lc
for l in labels_unique:
idxs_in_class = np.asarray(self.labels == l).nonzero()[0]
samples_per_fit = max(round(goal_per_class / len(idxs_in_class)), 1)
for i in idxs_in_class:
sampled_idx = self.rng.choice(
np.arange(self.num_draws),
samples_per_fit
)
sampled_features = self.features[i*self.num_draws + sampled_idx]
oversampled_features.extend(list(sampled_features))
oversampled_labels.extend([l] * samples_per_fit)
return np.array(oversampled_features), np.array(oversampled_labels)
[docs]
def oversample_smote(self):
"""
Uses SMOTE to oversample data from rarer classes.
"""
oversample = SMOTE()
features_smote, labels_smote = oversample.fit_resample(
self.median_features,
self.labels
)
return features_smote, labels_smote
[docs]
def split(self, split_frac=0.1, split_indices=None, shuffle=True):
if split_indices is not None:
idx1, idx2 = split_indices
else:
idx1, idx2 = train_test_split(
np.arange(len(self.labels)),
stratify=self.labels,
test_size=split_frac,
random_state=self.random_seed
)
split_1 = PosteriorSamplesGroup(
self.posterior_objects[idx1],
self.use_redshift_info,
self.ignore_param_idxs,
self.random_seed
)
split_2 = PosteriorSamplesGroup(
self.posterior_objects[idx2],
self.use_redshift_info,
self.ignore_param_idxs,
self.random_seed
)
return split_1, split_2
[docs]
def canonicalize_labels(self):
"""Convert labels to canon labels.
"""
self.labels = np.asarray([
SnClass.canonicalize(l) for l in self.labels
])
[docs]
def make_binary(self, target_label="SN Ia"):
"""Convert labels to a binary classification
problem."""
self.labels = np.where(
self.labels == target_label,
target_label,
"other"
)
[docs]
def make_fully_redshift_independent(self):
"""Experimental!
We can convert our shape parameters to be FULLY
z-independent by instead using:
tau_rise/gamma, tau_rise/tau_fall, beta*tau_rise
(but its log scale for tau_rise, gamma, tau_fall
so add/subtract instead)
We do everything relative to tau_rise because that's
the first shape param to be measured in real time!
"""
return None
self.features_z_independent = np.asarray([
np.log10(self.features[:,0]) + self.features[:,2],
self.features[:,2] - self.features[:,1],
self.features[:,2] - self.features[:,3],
]).T
self.features_z_independent = np.append(
self.features_z_independent,
self.features[:,4:],
axis=1
)
self.features = self.features_z_independent
@dataclass
[docs]
class TrainData:
"""Holds train and validation datasets."""
[docs]
train_dataset: TensorDataset
[docs]
val_dataset: TensorDataset
[docs]
def __iter__(self):
return iter((self.train_dataset, self.val_dataset))
@dataclass
[docs]
class TestData:
"""Holds information about testing data."""
[docs]
test_features: np.ndarray
[docs]
test_classes: np.ndarray
[docs]
def __post_init__(self):
"""Ensure everything is numpy arrays."""
self.test_features = np.asarray(self.test_features)
self.test_classes = np.asarray(self.test_classes)
self.test_names = np.asarray(self.test_names)
[docs]
def __iter__(self):
return iter(
(
self.test_features,
self.test_classes,
self.test_names,
)
)