"""This module provides various functions for analyzing and visualizing
classification results."""
import os, glob
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix
from matplotlib.ticker import AutoMinorLocator
from astropy.cosmology import Planck13 as cosmo
from scipy.stats import binned_statistic
from superphot_plus.plotting.format_params import set_global_plot_formatting, CUSTOM_COLORSET
from superphot_plus.file_utils import get_multiple_posterior_samples
from superphot_plus.format_data_ztf import import_labels_only, retrieve_posterior_set
from superphot_plus.posterior_samples import PosteriorSamples
from superphot_plus.plotting.utils import (
histedges_equalN, read_probs_csv,
get_survey_fracs, retrieve_four_class_info,
calc_precision_recall,
rebin_prec_recall,
roc_curve_w_uncertainties,
calc_calibration_curve
)
from superphot_plus.supernova_class import SupernovaClass as SnClass
set_global_plot_formatting()
[docs]
def save_class_fractions(spec_probs_csv, probs_alerce_csv, phot_probs_csv, probs_alerce_phot_csv, save_path):
"""Save class fractions from spectroscopic, photometric, and
corrected photometric.
Parameters
----------
spec_probs_csv : str
Path to the CSV file containing spectroscopic probability
predictions.
phot_probs_csv : str
Path to the CSV file containing photometric probability
predictions.
save_fn : str
Filename + dir for saving the class fractions.
"""
labels_to_class, classes_to_labels = SnClass.get_type_maps()
# import spec dataframe
_, true_class_spec, probs_spec, pred_class_spec, _, _ = read_probs_csv(spec_probs_csv)
num_classes = probs_spec.shape[1]
true_class_alerce = true_class_spec.copy()
true_class_alerce[true_class_alerce == 2] = 1
# read in ALeRCE classes
df_alerce = pd.read_csv(probs_alerce_csv)
pred_alerce = df_alerce.alerce_label.to_numpy().astype(str)
ignore_mask = (pred_alerce == "None") | (pred_alerce == "nan") | (pred_alerce == "SKIP")
# ignore true SNe IIn
ignore_mask = ignore_mask | (true_class_alerce == 2)
true_class_alerce = true_class_alerce[~ignore_mask]
pred_alerce = pred_alerce[~ignore_mask]
pred_class_spec_alerce = np.array([labels_to_class[x] for x in pred_alerce])
# import phot dataframe
pred_class_phot = read_probs_csv(phot_probs_csv)[3]
pred_class_phot_alerce = retrieve_four_class_info(phot_probs_csv, probs_alerce_phot_csv)[4]
cm_p = confusion_matrix(true_class_spec, pred_class_spec, normalize="pred")
cm_p_alerce = confusion_matrix(true_class_alerce, pred_class_spec_alerce, normalize="pred")
true_fracs = np.array(
[len(true_class_spec[true_class_spec == i]) / len(true_class_spec) for i in range(num_classes)]
)
pred_fracs = np.array(
[len(pred_class_phot[pred_class_phot == i]) / len(pred_class_phot) for i in range(num_classes)]
)
alerce_fracs = np.array(
[
len(pred_class_phot_alerce[pred_class_phot_alerce == classes_to_labels[i]]) / len(pred_class_phot_alerce) for i in range(num_classes)
]
)
pred_fracs_corr = []
alerce_fracs_corr = []
for i in range(5):
pred_fracs_corr.append(np.sum(pred_fracs * cm_p[i]))
if i == 2:
alerce_fracs_corr.append(0.0)
elif i > 2:
alerce_fracs_corr.append(np.sum(np.delete(alerce_fracs, 2) * cm_p_alerce[i - 1]))
else:
alerce_fracs_corr.append(np.sum(np.delete(alerce_fracs, 2) * cm_p_alerce[i]))
pred_fracs_corr = np.array(pred_fracs_corr)
alerce_fracs_corr = np.array(alerce_fracs_corr)
save_df = pd.DataFrame(
{
"spec_fracs": true_fracs,
"phot_fracs": pred_fracs,
"phot_fracs_corr": pred_fracs_corr,
"alerce_fracs": alerce_fracs,
"alerce_fracs_corr": alerce_fracs_corr,
}
)
save_df.to_csv(save_path, index=False)
[docs]
def plot_class_fractions(saved_cf_file, fig_dir, filename):
"""Plot class fractions saved from 'save_class_fractions'.
Parameters
----------
saved_cf_file : str
Path to the saved class fractions file.
fig_dir : str
Directory for saving the class fractions plot.
filename: str
Filename for the class fractions plot figure.
"""
_, classes_to_labels = SnClass.get_type_maps()
labels = [
"Spec (ZTF)",
"Spec (YSE)",
"Spec (PS1-MDS)",
"Phot",
"Phot (corr.)",
"ALeRCE",
"ALeRCE (corr.)",
]
width = 0.6
frac_df = pd.read_csv(saved_cf_file)
true_fracs, pred_fracs, pred_fracs_corr, alerce_fracs, alerce_fracs_corr = frac_df.to_numpy().T
survey_sn_fracs = get_survey_fracs()
yse_fracs, psmds_fracs = survey_sn_fracs["YSE"], survey_sn_fracs["PS-MDS"]
combined_fracs = np.array(
[
true_fracs,
yse_fracs,
psmds_fracs,
pred_fracs,
pred_fracs_corr,
alerce_fracs,
alerce_fracs_corr,
]
).T
_, ax = plt.subplots(figsize=(11, 16))
for i in range(5):
if i == 0:
bottom = 0
else:
bottom = np.sum(combined_fracs[0:i], axis=0)
stacked_bar = ax.bar(
labels,
combined_fracs[i],
width,
bottom=bottom,
label=classes_to_labels[i],
)
for j, fracs_j in enumerate(combined_fracs[i]):
if fracs_j == 0.0:
continue
barj = stacked_bar.patches[j]
# Create annotation
ax.annotate(
round(fracs_j, 3),
(barj.get_x() + barj.get_width() / 2, barj.get_y() + barj.get_height() / 2),
ha="center",
va="center",
color="white",
)
# Shrink current axis's height by 10% on the bottom
box = ax.get_position()
ax.set_position([box.x0, box.y0 + box.height * 0.1, box.width, box.height * 0.9])
# Put a legend below current axis
ax.legend(
loc="upper center", bbox_to_anchor=(0.5, -0.05), fancybox=True, shadow=False, ncol=5, fontsize=15
)
ax.tick_params(axis="both", which="major", labelsize=12)
ax.tick_params(axis="both", which="minor", labelsize=10)
# plt.legend(loc=3)
plt.ylabel("Fraction", fontsize=20)
plt.savefig(os.path.join(fig_dir, filename))
plt.close()
[docs]
def generate_roc_curve(probs_csv, save_dir):
"""Generate a combined ROC curve of all SN classes.
Parameters
----------
probs_csv : str
CSV file where class probabilities are stored.
save_dir : str
Where to save the figure.
"""
labels_to_classes, classes_to_labels = SnClass.get_type_maps()
colors = CUSTOM_COLORSET
fig, double_axes = plt.subplots(1, 2, figsize=(8, 7))
ax1, ax2 = double_axes
ax1.set_xlim([0.0, 1.05])
ax1.set_ylim([0.0, 1.05])
ax2.set_xlim([0.0, 0.1])
ax2.set_ylim([0.0, 1.05])
ax1.set_ylabel("True Positive Rate")
ratio = 1.2
plt.locator_params(axis="x", nbins=3)
legend_lines = []
fprs = []
tprs = []
for ref_class, ref_label in enumerate(classes_to_labels):
_, true_classes, probs, _, folds, _ = read_probs_csv(probs_csv)
y_true = np.where(true_classes == ref_class, 1, 0)
y_score = probs[:, ref_class]
t, fpr, tpr, tpr_err = roc_curve_w_uncertainties(y_true, y_score, folds)
idx_50 = np.argmin((t - 0.5) ** 2)
(legend_line,) = ax1.step(fpr, tpr, label=ref_label, c=colors[ref_class], where='post')
ax1.fill_between(
fpr, tpr-tpr_err, tpr+tpr_err,
color=colors[ref_class], step='post', alpha=0.2
)
ax2.step(fpr, tpr, label=ref_label, c=colors[ref_class], where='post')
legend_lines.append(legend_line)
ax2.fill_between(
fpr, tpr-tpr_err, tpr+tpr_err,
color=colors[ref_class], step='post', alpha=0.2
)
ax2.scatter(
(fpr[idx_50] + fpr[idx_50 + 1]) / 2, tpr[idx_50],
color=colors[ref_class], s=100, marker="d", zorder=1000
)
#fprs.append(fpr)
#tprs.append(tpr)
ax1.plot(
[0, 1], [0,1],
c="#BBBBBB", linestyle='dotted'
)
"""
# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(5)]))
# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(len(classes_to_labels)):
mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
# Finally average it and compute AUC
mean_tpr /= len(classes_to_labels)
"""
for ax_i in double_axes:
"""
(legend_line,) = ax_i.plot(
all_fpr, mean_tpr, label="Macro-averaged", linewidth=3, linestyle="dashed", c="black"
)
"""
x_left, x_right = ax_i.get_xlim()
y_low, y_high = ax_i.get_ylim()
ax_i.set_aspect(abs((x_right - x_left) / (y_low - y_high)) * ratio)
ax_i.yaxis.set_minor_locator(AutoMinorLocator())
ax_i.xaxis.set_minor_locator(AutoMinorLocator())
ax_i.set_xlabel("False Positive Rate")
#legend_lines.append(legend_line)
#legend_keys = [*list(labels_to_classes.keys()), "Combined"]
legend_keys = list(labels_to_classes.keys())
fig.legend(legend_lines, legend_keys, loc="lower center", ncol=3)
plt.savefig(os.path.join(save_dir, "roc_all.pdf"), bbox_inches="tight")
plt.close()
[docs]
def plot_precision_recall(probs_csv, save_dir, plot_fleet=True):
"""Show how adjusting binary threshholds impact
purity and completeness values."""
labels_to_classes, classes_to_labels = SnClass.get_type_maps()
colors = CUSTOM_COLORSET
#fig, double_axes = plt.subplots(1, 2, figsize=(12, 8))
fig, ax = plt.subplots(figsize=(6,8))
ax1 = ax
ax1.set_xlim([0.0, 1.05])
ax1.set_ylim([0.0, 1.05])
#ax2.set_xlim([0.0, 1.05])
#ax2.set_ylim([0.0, 1.05])
ratio = 1.0
ax1.set_ylabel("Purity")
#ax2.set_ylabel("Purity (Rescaled)")
plt.locator_params(axis="x", nbins=3)
legend_lines = []
for ref_class, ref_label in enumerate(classes_to_labels):
_, true_classes, probs, _, folds, _ = read_probs_csv(probs_csv)
y_true = np.where(true_classes == ref_class, 1, 0)
y_score = probs[:, ref_class]
prevalence = sum(y_true) / len(y_true)
t, r, p, perr = calc_precision_recall(y_true, y_score, folds)
idx_50 = np.argmin((t - 0.5) ** 2)
(legend_line,) = ax1.step(
r, p, label=ref_label, c=colors[ref_class], where='post'
)
ax1.fill_between(r, p-perr, p+perr, alpha=0.2, color=colors[ref_class], step='post')
legend_lines.append(legend_line)
ax1.scatter(
(r[idx_50]+r[idx_50+1])/2, p[idx_50],
color=colors[ref_class], s=100, marker="d", zorder=1000
)
#ax1.axhline(y=prevalence, linestyle='dashed', linewidth=1, color=colors[ref_class])
"""
# rescaled to baseline
p_scaled = (p - prevalence) / (1 - prevalence)
perr_scaled = perr / (1 - prevalence)
ax2.step(
r, p_scaled, label=ref_label, c=colors[ref_class], where='post'
)
ax2.fill_between(
r, p_scaled-perr_scaled, p_scaled+perr_scaled,
alpha=0.2, color=colors[ref_class], step='post'
)
ax2.scatter(
(r[idx_50]+r[idx_50+1])/2, p_scaled[idx_50],
color=colors[ref_class], s=100, marker="d", zorder=1000
)
"""
# print AUPR value
aupr = np.sum((r[1:] - r[:-1]) * p[:-1])
aupr_min = np.sum((r[1:] - r[:-1]) * (p-perr)[:-1])
aupr_max = np.sum((r[1:] - r[:-1]) * (p+perr)[:-1])
print(ref_label, aupr, aupr_min, aupr_max)
if plot_fleet:
fleet_fn = os.path.abspath(
os.path.join(os.path.dirname(__file__), '../../..', 'data', 'SLSN_late-time.txt')
)
prevalence = 187 / 4780
fleet_df = pd.read_csv(fleet_fn, sep='\s+')
p_fleet = fleet_df['Purity'].to_numpy()
perr_fleet = fleet_df['PurityStd'].to_numpy()
r_fleet = fleet_df['Completeness'].to_numpy()
rerr_fleet = fleet_df['CompletenessStd'].to_numpy()
t_fleet = fleet_df['P(SLSN-I)'].to_numpy()
rbin_fleet, pbin_fleet, pmin_fleet, pmax_fleet = rebin_prec_recall(
t_fleet, r_fleet, rerr_fleet, p_fleet, perr_fleet
)
idx_50 = np.argmin((t_fleet - 0.5) ** 2)
pscaled_fleet = (pbin_fleet - prevalence) / (1 - prevalence)
pmin_scaled = (pmin_fleet - prevalence) / (1 - prevalence)
pmax_scaled = (pmax_fleet - prevalence) / (1 - prevalence)
(legend_line,) = ax1.step(
rbin_fleet, pbin_fleet, label='FLEET (SLSN-I)', c=colors[5], where='post'
)
ax1.fill_between(
rbin_fleet, pmin_fleet, pmax_fleet, alpha=0.2, color=colors[5], step='post'
)
legend_lines.append(legend_line)
ax1.scatter(
r_fleet[idx_50], p_fleet[idx_50],
color=colors[5], s=100, marker="d", zorder=1000
)
aupr = np.sum((rbin_fleet[:-1] - rbin_fleet[1:]) * pbin_fleet[:-1])
aupr_min = np.sum((rbin_fleet[:-1] - rbin_fleet[1:]) * pmin_fleet[:-1])
aupr_max = np.sum((rbin_fleet[:-1] - rbin_fleet[1:]) * pmax_fleet[:-1])
print("FLEET", aupr, aupr_min, aupr_max)
"""
ax2.step(
rbin_fleet, pscaled_fleet, label='FLEET (SLSN-I)', c=colors[5], where='post'
)
ax2.fill_between(
rbin_fleet, pmin_scaled, pmax_scaled, alpha=0.2, color=colors[5], step='post'
)
ax2.scatter(
(r_fleet[idx_50]+r_fleet[idx_50+1])/2, pscaled_fleet[idx_50],
color=colors[5], s=100, marker="d", zorder=1000
)
"""
#for ax in double_axes:
x_left, x_right = ax.get_xlim()
y_low, y_high = ax.get_ylim()
ax.set_aspect(abs((x_right - x_left) / (y_low - y_high)) * ratio)
ax.yaxis.set_minor_locator(AutoMinorLocator())
ax.xaxis.set_minor_locator(AutoMinorLocator())
ax.set_xlabel("Completeness")
#legend_lines.append(legend_line)
legend_keys = [*list(labels_to_classes.keys()), "FLEET (SLSN-I)"]
fig.legend(legend_lines, legend_keys, loc="lower center", ncol=3)
plt.savefig(os.path.join(save_dir, "prec_recall_all.pdf"), bbox_inches="tight")
plt.close()
[docs]
def plot_metrics_over_mjd(mjd_bins, p_matrix, c_matrix, save_dir):
labels_to_classes, classes_to_labels = SnClass.get_type_maps()
colors = CUSTOM_COLORSET
fig, axes = plt.subplots(2, 1, sharex=True, figsize=(8, 10), gridspec_kw={"hspace": 0.1})
ax, ax2 = axes
ratio = 1.0
ax2.set_ylabel("Purity")
ax.set_ylabel("Completeness")
ax.set_ylim((0, 1))
ax2.set_ylim((0, 1))
plt.locator_params(axis="x", nbins=3)
legend_lines = []
for i in range(5):
mean_p = np.mean(p_matrix[i], axis=0)
mean_c = np.mean(c_matrix[i], axis=0)
p_err = np.std(p_matrix[i], axis=0)
c_err = np.std(c_matrix[i], axis=0)
(legend_line,) = ax.step(
mjd_bins, mean_c, where='post',
c=colors[i], label=classes_to_labels[i]
)
ax.fill_between(
mjd_bins, mean_c - c_err, mean_c + c_err,
color=colors[i], alpha=0.2, step='post'
)
ax2.step(
mjd_bins, mean_p, where='post', c=colors[i]
)
ax2.fill_between(
mjd_bins, mean_p - p_err, mean_p + p_err,
color=colors[i], alpha=0.2, step='post'
)
legend_lines.append(legend_line)
ax2.set_xlabel("MJD")
fig.legend(loc="lower center", ncol=3)
plt.savefig(os.path.join(save_dir, "metrics_over_mjd.pdf"), bbox_inches="tight")
plt.close()
[docs]
def plot_phase_vs_accuracy(phased_probs_dir, all_probs_csv, save_dir):
"""Plot classification accuracy as a function of phase.
Parameters
----------
phased_probs_dir : str
Where classification probabilities and LC truncated phases are saved.
save_dir : str
Where to save the output figures.
"""
fig, axes = plt.subplots(3, 1, sharex=True, figsize=(8, 16), gridspec_kw = {'hspace':0.05})
ax, ax2, ax3 = axes
_, classes_to_labels = SnClass.get_type_maps()
allowed_types = np.arange(len(classes_to_labels))
phases = []
accs_full_means = []
accs_full_stddevs = []
accs_early_means = []
accs_early_stddevs = []
fracs_early_means = []
fracs_early_stddevs = []
fracs_full_means = []
fracs_full_stddevs = []
f1_early_means = []
f1_early_stddevs = []
f1_full_means = []
f1_full_stddevs = []
all_probs_files = glob.glob(
os.path.join(
phased_probs_dir,
"full_*_concat.csv"
)
)
full_probs_df = pd.read_csv(all_probs_csv)
all_true_labels = full_probs_df.Label.to_numpy()
class_counts = [len(all_true_labels[l == all_true_labels]) for l in allowed_types]
class_fracs = np.asarray(class_counts) / np.sum(class_counts)
n_phases = len(all_probs_files)
for probs_file_full in all_probs_files:
_, true_type, _, pred_type, folds, _ = read_probs_csv(probs_file_full)
phase_counts = [len(true_type[l == true_type]) for l in allowed_types]
phase_fracs = np.asarray(phase_counts) / np.sum(phase_counts)
correct_class = (true_type == pred_type).astype(int)
acc_mu_single = []
acc_std_single = []
fracs_mu_single = []
fracs_std_single = []
f1_mu_single = []
f1_std_single = []
for i, allowed_type in enumerate(allowed_types):
accs = []
fracs = []
f1s = []
for f in range(10):
correct_t = correct_class[(folds == f) & (true_type == allowed_type)]
completeness = np.sum(correct_t) / len(true_type[(true_type == allowed_type) & (folds == f)])
all_preds = pred_type[(pred_type == allowed_type) & (folds == f)]
all_trues = true_type[(pred_type == allowed_type) & (folds == f)]
adj_pred = np.sum([
class_fracs[j] * len(all_preds[all_trues == at2]) / phase_fracs[j] for j, at2 in enumerate(allowed_types)
])
purity = class_fracs[i] * np.sum(correct_t) / adj_pred / phase_fracs[i]
accs.append(completeness)
fracs.append(purity)
if purity == 0 and completeness == 0:
f1s.append(0)
else:
f1s.append(2 * purity * completeness / (purity + completeness))
fracs_mu_single.append(np.nanmean(fracs))
fracs_std_single.append(np.nanstd(fracs))
acc_mu_single.append(np.nanmean(accs))
acc_std_single.append(np.nanstd(accs))
f1_mu_single.append(np.nanmean(f1s))
f1_std_single.append(np.nanstd(f1s))
accs_full_means.append(acc_mu_single)
accs_full_stddevs.append(acc_std_single)
fracs_full_means.append(fracs_mu_single)
fracs_full_stddevs.append(fracs_std_single)
f1_full_means.append(f1_mu_single)
f1_full_stddevs.append(f1_std_single)
phase = probs_file_full.split("/")[-1].split("_")[1]
if round(float(phase), 2) == 0.61:
print("PHASE ZERO FULL")
print(acc_mu_single, acc_std_single)
print(fracs_mu_single, fracs_std_single)
if round(float(phase), 2) == 70.00:
print("PHASE LATE FULL")
print(acc_mu_single, acc_std_single)
print(fracs_mu_single, fracs_std_single)
phases.append(float(phase))
probs_file_early = os.path.join(
phased_probs_dir,
f"early_{phase}_concat.csv"
)
_, true_type, _, pred_type, folds, _ = read_probs_csv(probs_file_early)
correct_class = (true_type == pred_type).astype(int)
acc_mu_single = []
acc_std_single = []
fracs_mu_single = []
fracs_std_single = []
f1_mu_single = []
f1_std_single = []
for i, allowed_type in enumerate(allowed_types):
accs = []
fracs = []
f1s = []
for f in range(10):
correct_t = correct_class[(folds == f) & (true_type == allowed_type)]
completeness = np.sum(correct_t) / len(true_type[(true_type == allowed_type) & (folds == f)])
all_preds = pred_type[(pred_type == allowed_type) & (folds == f)]
all_trues = true_type[(pred_type == allowed_type) & (folds == f)]
adj_pred = np.sum([
class_fracs[j] * len(all_preds[all_trues == at2]) / phase_fracs[j] for j, at2 in enumerate(allowed_types)
])
purity = class_fracs[i] * np.sum(correct_t) / adj_pred / phase_fracs[i]
accs.append(completeness)
fracs.append(purity)
if purity == 0 and completeness == 0:
f1s.append(0)
else:
f1s.append(2 * purity * completeness / (purity + completeness))
fracs_mu_single.append(np.nanmean(fracs))
fracs_std_single.append(np.nanstd(fracs))
acc_mu_single.append(np.nanmean(accs))
acc_std_single.append(np.nanstd(accs))
f1_mu_single.append(np.nanmean(f1s))
f1_std_single.append(np.nanstd(f1s))
accs_early_means.append(acc_mu_single)
accs_early_stddevs.append(acc_std_single)
fracs_early_means.append(fracs_mu_single)
fracs_early_stddevs.append(fracs_std_single)
f1_early_means.append(f1_mu_single)
f1_early_stddevs.append(f1_std_single)
if round(float(phase), 2) == 0.61:
print("PHASE ZERO EARLY")
print(acc_mu_single, acc_std_single)
print(fracs_mu_single, fracs_std_single)
if round(float(phase), 2) == 70.00:
print("PHASE LATE EARLY")
print(acc_mu_single, acc_std_single)
print(fracs_mu_single, fracs_std_single)
sort_idx = np.argsort(phases)
phases = np.asarray(phases)[sort_idx]
accs_early_means = np.asarray(accs_early_means)[sort_idx].T
accs_early_stddevs = np.asarray(accs_early_stddevs)[sort_idx].T
accs_full_means = np.asarray(accs_full_means)[sort_idx].T
accs_full_stddevs = np.asarray(accs_full_stddevs)[sort_idx].T
fracs_early_means = np.asarray(fracs_early_means)[sort_idx].T
fracs_early_stddevs = np.asarray(fracs_early_stddevs)[sort_idx].T
fracs_full_means = np.asarray(fracs_full_means)[sort_idx].T
fracs_full_stddevs = np.asarray(fracs_full_stddevs)[sort_idx].T
f1_early_means = np.asarray(f1_early_means)[sort_idx].T
f1_early_stddevs = np.asarray(f1_early_stddevs)[sort_idx].T
f1_full_means = np.asarray(f1_full_means)[sort_idx].T
f1_full_stddevs = np.asarray(f1_full_stddevs)[sort_idx].T
legend_lines = []
for i, allowed_type in enumerate(allowed_types):
(legend_line,) = ax.plot(
phases, accs_full_means[i], label=allowed_type, color=CUSTOM_COLORSET[i]
)
ax.plot(
phases, accs_early_means[i], linestyle='dashed', color=CUSTOM_COLORSET[i]
)
ax.fill_between(
phases, accs_full_means[i]-accs_full_stddevs[i],
accs_full_means[i]+accs_full_stddevs[i], alpha=0.2, color=CUSTOM_COLORSET[i]
)
legend_lines.append(legend_line)
ax2.plot(
phases, fracs_full_means[i], label=allowed_type, color=CUSTOM_COLORSET[i]
)
ax2.plot(
phases, fracs_early_means[i], linestyle='dashed', color=CUSTOM_COLORSET[i]
)
ax2.fill_between(
phases, fracs_full_means[i]-fracs_full_stddevs[i],
fracs_full_means[i]+fracs_full_stddevs[i],
alpha=0.2, color=CUSTOM_COLORSET[i]
)
ax3.plot(
phases, f1_full_means[i], label=allowed_type, color=CUSTOM_COLORSET[i]
)
ax3.plot(
phases, f1_early_means[i], linestyle='dashed', color=CUSTOM_COLORSET[i]
)
ax3.fill_between(
phases, f1_full_means[i]-f1_full_stddevs[i],
f1_full_means[i]+f1_full_stddevs[i],
alpha=0.2, color=CUSTOM_COLORSET[i]
)
(legend_line,) = ax3.plot(
phases, np.mean(f1_full_means, axis=0), label="Macro", color='k'
)
ax3.plot(
phases, np.mean(f1_early_means, axis=0), linestyle='dashed', color='k'
)
legend_lines.append(legend_line)
ax.plot(
phases, np.mean(accs_full_means, axis=0), label="Macro", color='k'
)
ax.plot(
phases, np.mean(accs_early_means, axis=0), linestyle='dashed', color='k'
)
ax2.plot(
phases, np.mean(fracs_full_means, axis=0), label="Macro", color='k'
)
ax2.plot(
phases, np.mean(fracs_early_means, axis=0), linestyle='dashed', color='k'
)
ax.set_ylabel("Completeness")
ax.set_ylim((0, 1))
ax2.set_ylabel("Estimated Purity")
ax2.set_ylim((0, 1))
ax3.set_ylabel("Estimated F1")
ax3.set_ylim((0, 1))
ax.axvline(x=0.0, color="grey", linestyle="dotted")
ax2.axvline(x=0.0, color="grey", linestyle="dotted")
ax3.axvline(x=0.0, color="grey", linestyle="dotted")
ax3.set_xlabel(r"Phase (days)")
fig.legend(legend_lines, [*[classes_to_labels[x] for x in allowed_types], "Macro"], loc="lower center", ncol=3)
plt.savefig(os.path.join(save_dir, "phase_vs_accuracy.pdf"), bbox_inches="tight")
plt.close()
[docs]
def plot_redshifts_abs_mags(probs_snr_csv, training_csv, fits_dir, save_dir, sampler="dynesty"):
"""
Plot redshift and absolute magnitude distributions used in the
redshift-inclusive classifier.
Parameters
----------
probs_snr_csv : str
Where probabilities + SNRs are stored.
save_dir : str
Where to save figures.
"""
labels_to_classes, classes_to_labels = SnClass.get_type_maps()
allowed_types = list(labels_to_classes.keys())
training_df = pd.read_csv(training_csv)
# labels = np.array([classes_to_labels[int(x)] for x in classes])
probs_dataframe = pd.read_csv(probs_snr_csv)
names = probs_dataframe.Name.to_numpy()
labels = probs_dataframe.Label.to_numpy()
labels = np.array([classes_to_labels[x] for x in labels])
redshifts = []
amplitudes = []
for n in names:
z = training_df[training_df.NAME == n].Z.iloc[0]
redshifts.append(z)
ps = PosteriorSamples.from_file(
name=n,
input_dir=fits_dir,
sampling_method=sampler
)
amplitudes.append(ps.max_flux)
redshifts = np.array(redshifts)
amplitudes = np.array(amplitudes)
app_mags = -2.5 * np.log10(amplitudes) + 26.3
k_correction = 2.5 * np.log10(1.0 + redshifts)
dist = cosmo.luminosity_distance([redshifts]).value[0] # returns dist in Mpc
abs_mags = app_mags - 5.0 * np.log10(dist * 1e6 / 10.0) + k_correction
fig, axes = plt.subplots(1, 2, figsize=(8, 6))
z_ax = axes[0]
mag_ax = axes[1]
_, bin_edges = np.histogram(-abs_mags, bins=40, density=True, range=(15, 25))
bin_width = bin_edges[1] - bin_edges[0]
bin_centers = (bin_edges[1:] + bin_edges[:-1]) / 2
legend_lines = []
for allowed_type in allowed_types:
features_1_t = -abs_mags[labels == allowed_type]
feature_hist, bin_edges = np.histogram(features_1_t, bins=bin_edges, density=True)
cumsum = np.cumsum(feature_hist) * bin_width
(legend_line,) = mag_ax.step(-bin_centers, cumsum, where="mid", label=allowed_type)
legend_lines.append(legend_line)
mag_ax.set_xlabel("Absolute Magnitude")
mag_ax.invert_xaxis()
_, bin_edges = np.histogram(redshifts, bins=40, density=True, range=(-0.1, 0.6))
bin_width = bin_edges[1] - bin_edges[0]
bin_centers = (bin_edges[1:] + bin_edges[:-1]) / 2
for allowed_type in allowed_types:
features_1_t = redshifts[labels == allowed_type]
feature_hist, bin_edges = np.histogram(features_1_t, bins=bin_edges, density=True)
cumsum = np.cumsum(feature_hist) * bin_width
z_ax.step(bin_centers, cumsum, where="mid", label=allowed_type)
z_ax.set_xlabel("Redshift")
z_ax.set_ylabel("Cumulative Fraction")
for ax in axes:
ratio = 1.0
x_left, x_right = ax.get_xlim()
y_low, y_high = ax.get_ylim()
ax.set_aspect(abs((x_right - x_left) / (y_low - y_high)) * ratio)
fig.legend(legend_lines, [*allowed_types, "Combined"], loc="lower center", ncol=3)
plt.savefig(os.path.join(save_dir, "abs_mag_hist.pdf"), bbox_inches="tight")
plt.close()
[docs]
def plot_snr_npoints_vs_accuracy(probs_snr_csv, save_dir):
"""
Generate plots of number of SNR > 5 points versus
accuracy, and top 10% SNR versus accuracy.
TODO: add functionality for only one type.
Parameters
----------
probs_snr_csv : str
Where probabilities + SNRs are stored.
save_dir : str
Where to save figures.
"""
_, classes_to_labels = SnClass.get_type_maps()
names, true_type, _, pred_classes, folds, _ = read_probs_csv(probs_snr_csv)
correct_class = np.where(true_type == pred_classes, 1, 0)
df = pd.read_csv(probs_snr_csv)
snr, n_high_snr = df.SNR90, df.nSNR3
fig, ax = plt.subplots(2, 1, figsize=(6.4, 10)) #gridspec_kw = {'hspace': 0.3})
ax1, ax2 = ax
for unique_type in np.unique(true_type):
snr_t = snr[true_type == unique_type]
correct_t = correct_class[true_type == unique_type]
nbins = min(len(snr_t)-1, 8)
snr_vs_accuracy, snr_bin_edges, _ = binned_statistic(
snr_t, correct_t, "mean", bins=histedges_equalN(snr_t, nbins)
)
snr_vs_accuracy[np.isnan(snr_vs_accuracy)] = 1.0
ax1.step(
snr_bin_edges,
np.append(snr_vs_accuracy, snr_vs_accuracy[-1]),
#label=classes_to_labels[unique_type],
where="post",
)
ax1.set_xlim((8, 30))
#ax1.set_ylim((0.15, 1.05))
ax1.set_xscale('log')
ax1.set_xlabel("90th Percentile SNR")
ax1.set_ylabel("Class Completeness")
# second plot
for unique_type in np.unique(true_type):
correct_t = correct_class[true_type == unique_type]
n_high_t = n_high_snr[true_type == unique_type]
nbins = min(len(n_high_t)-1, 8)
n_vs_accuracy, n_bin_edges, _ = binned_statistic(
n_high_t, correct_t, "mean", bins=histedges_equalN(n_high_t, nbins)
)
if nbins < 1:
continue
n_vs_accuracy[np.isnan(n_vs_accuracy)] = 1.0
ax2.step(
n_bin_edges,
np.append(n_vs_accuracy, n_vs_accuracy[-1]),
label=classes_to_labels[unique_type],
where="post",
)
ax2.set_xlim((10, 200))
#ax2.set_ylim((0.15, 1.05))
ax2.set_xscale('log')
ax2.set_xlabel(r"Number of $\geq 3\sigma$ Datapoints")
ax2.set_ylabel("Class Completeness")
fig.legend(loc="lower center", ncols=3)
fig.tight_layout()
fig.subplots_adjust(hspace=0.3, bottom=0.2, top=0.95)
plt.savefig(os.path.join(save_dir, "n_snr_vs_accuracy.pdf"))
plt.close()
[docs]
def plot_snr_hist(probs_snr_csv, save_dir):
"""
Replicates SNR plots needed for publication.
Parameters
----------
probs_snr_csv : str
Where probability + SNR info is stored.
save_dir : str
Where to save figure.
"""
df = pd.read_csv(probs_snr_csv)
n_snr_3, n_snr_5, n_snr_10 = df.iloc[:, -3:].to_numpy().T
skip_mask = (df.iloc[:, 1] == "SKIP").to_numpy()
bins = np.arange(0, 603, 3)
plt.hist(n_snr_3[~skip_mask], histtype="step", label=r"$SNR \geq 3$", bins=bins)
plt.hist(n_snr_5[~skip_mask], histtype="step", label=r"$SNR \geq 5$", bins=bins)
plt.hist(n_snr_10[~skip_mask], histtype="step", label=r"$SNR \geq 10$", bins=bins)
plt.loglog()
plt.xlabel("Number of Datapoints at Given SNR")
plt.ylabel("Number of Light Curves")
plt.legend()
plt.savefig(os.path.join(save_dir, "snr_hist.pdf"), bbox_inches="tight")
plt.close()
[docs]
def compare_mag_distributions(
probs_classified,
probs_unclassified,
all_spec_csv,
all_phot_csv,
fits_dir,
fits_dir_phot,
save_dir,
zeropoint=26.3,
sampler='dynesty',
allowed_types = SnClass.get_alternative_namings().keys()
):
"""
Generate overlaid magnitude distributions of the classified and unclassified datasets.
Assumes that unclassified LCs that did not pass the chi-squared cut are marked as "SKIP".
Parameters
----------
probs_classified : str
CSV filename where probs of spectroscopic set are stored.
probs_unclassified : str
CSV filename where probs of photometric set are stored.
save_dir : str
Where to save figure.
zeropoint : float, optional
Zeropoint used when converting mags to fluxes. Defaults to 26.3.
"""
classified_df = pd.read_csv(probs_classified)
classified_names = classified_df.Name.to_numpy()
all_names = pd.read_csv(all_spec_csv).NAME.to_numpy()
all_labels = pd.read_csv(all_spec_csv).CLASS.to_numpy()
label_mask = [SnClass.canonicalize(l) in allowed_types for l in all_labels]
all_names = all_names[label_mask]
max_flux = []
mask_high_chisquared = []
for n in all_names:
try:
ps = PosteriorSamples.from_file(
name = n,
input_dir = fits_dir,
sampling_method=sampler
)
except:
continue
if ps.max_flux is None:
continue
mask_high_chisquared.append(n in classified_names)
max_flux.append(ps.max_flux)
mask_high_chisquared = np.array(mask_high_chisquared)
max_flux = np.array(max_flux)
max_r_classified_all = -2.5 * np.log10(max_flux) + zeropoint
max_r_classified = max_r_classified_all[mask_high_chisquared]
max_r_classified_skipped = max_r_classified_all[~mask_high_chisquared]
unclassified_df = pd.read_csv(probs_unclassified)
unclassified_names = unclassified_df.Name.to_numpy()
all_phot_names = pd.read_csv(all_phot_csv).NAME.to_numpy()
max_flux = []
mask_high_chisquared = []
for n in all_phot_names:
try:
ps = PosteriorSamples.from_file(
name = n,
input_dir = fits_dir_phot,
sampling_method =sampler
)
except:
continue
if ps.max_flux is None:
continue
mask_high_chisquared.append(n in unclassified_names)
max_flux.append(ps.max_flux)
mask_high_chisquared = np.array(mask_high_chisquared)
max_flux = np.array(max_flux)
max_r_unclassified_all = -2.5 * np.log10(max_flux) + zeropoint
max_r_unclassified = max_r_unclassified_all[mask_high_chisquared]
max_r_unclassified_skipped = max_r_unclassified_all[~mask_high_chisquared]
plt.hist(
max_r_classified,
histtype="step",
bins=np.arange(5.0, 21.0, 0.5),
label="Spec. (included)",
density=True,
linewidth=2,
)
plt.hist(
max_r_classified_skipped,
histtype="step",
bins=np.arange(5.0, 21.0, 0.5),
label="Spec. (excluded)",
density=True,
linewidth=2,
)
plt.hist(
max_r_unclassified,
histtype="step",
bins=np.arange(5.0, 21.0, 0.5),
label="Phot. (included)",
density=True,
linewidth=2,
)
plt.hist(
max_r_unclassified_skipped,
histtype="step",
bins=np.arange(5.0, 21.0, 0.5),
label="Phot. (excluded)",
density=True,
linewidth=2,
)
plt.yscale("log")
plt.legend(loc="upper left")
plt.xlabel("Peak Apparent Magnitude")
plt.ylabel("Fraction of Light Curves")
plt.savefig(
os.path.join(save_dir, "appm_hist_compare.pdf"),
bbox_inches="tight",
)
plt.close()
[docs]
def plot_chisquared_vs_accuracy(
pred_spec_fn,
all_spec_csv,
all_phot_csv,
fits_dir,
fits_dir_phot,
save_dir,
sampler=None,
allowed_types = SnClass.get_alternative_namings().keys()
):
"""
Plot chi-squared value histograms for both the spectroscopic and photometric
datasets, and plot spec chi-squared as a function of classification accuracy.
TODO: IN PROGRESS
Parameters
----------
pred_spec_fn : str
CSV filename where probs of spectroscopic set are stored.
pred_phot_fn : str
CSV filename where probs of photometric set are stored.
save_dir : str
Where to save figure.
"""
sn_names, true_classes, _, pred_classes, _, _ = read_probs_csv(pred_spec_fn)
sn_names_all = pd.read_csv(all_spec_csv).NAME.to_numpy()
sn_types_all = pd.read_csv(all_spec_csv).CLASS.to_numpy()
if allowed_types is not None:
mask = [
SnClass.canonicalize(y) in allowed_types for y in sn_types_all
]
sn_names_all = sn_names_all[mask]
sn_types_all = sn_types_all[mask]
correctly_classified = np.where(true_classes == pred_classes, 1, 0)
ps_set = retrieve_posterior_set(sn_names_all, fits_dir, sampler=sampler)
spec_mask = np.isin(sn_names_all, sn_names)
train_chis = np.array([np.median(x.samples[:, -1]) for x in ps_set])
train_chis_spec = np.array([np.mean(x.samples[:, -1]) for x in ps_set[spec_mask]])
sn_names_phot = pd.read_csv(all_phot_csv).NAME.to_numpy()
ps_set = retrieve_posterior_set(sn_names_phot, fits_dir_phot, sampler=sampler)
spec_mask = np.isin(sn_names_all, sn_names)
train_chis_phot = np.array([np.median(x.samples[:, -1]) for x in ps_set])
# plot
_, ax2 = plt.subplots(figsize=(7, 4.8))
ax1 = ax2.twinx()
bins = np.arange(0, 4.0, 0.1)
correct_hist, bin_edges, _ = binned_statistic(
train_chis_spec, correctly_classified, statistic="sum", bins=bins
)
bin_centers = (bin_edges[1:] + bin_edges[:-1]) / 2.0
all_hist_spec, _, _ = binned_statistic(train_chis, np.ones(len(train_chis)), statistic="sum", bins=bins)
all_hist, _, _ = binned_statistic(train_chis_spec, np.ones(len(train_chis_spec)), statistic="sum", bins=bins)
all_hist_phot, _, _ = binned_statistic(
train_chis_phot, np.ones(len(train_chis_phot)), statistic="sum", bins=bins
)
ax2.hist(bin_centers, bin_edges, weights=all_hist_spec, alpha=0.5, label="Spectroscopic")
ax2.hist(bin_centers, bin_edges, weights=all_hist_phot, alpha=0.5, label="Photometric")
ax2.set_yscale("log")
all_hist[all_hist == 0] = np.inf
acc_hist = correct_hist / all_hist
acc_cut = acc_hist[bins[:-1] < 1.2]
acc_cut = np.append(acc_cut, acc_cut[-1])
ax1.step(
bins[bins < 1.215],
acc_cut,
where="post", color='#228833', linewidth=3, label="Accuracy"
)
ax1.axvline(x=1.2, color="black", linestyle="--", linewidth=4, label=r"Reduced $\chi^2$ cutoff")
ax2.set_xlabel(r"Reduced $\chi^2$")
ax1.set_ylabel("Accuracy", va="bottom", rotation=270)
ax2.set_ylabel("Counts")
h2, l2 = ax2.get_legend_handles_labels()
h1, l1 = ax1.get_legend_handles_labels()
ax2.legend(np.append(h2, h1), np.append(l2, l1))
ax1.yaxis.label.set_color('#228833')
ax1.spines["right"].set_color('#228833')
ax1.tick_params(axis="y", colors='#228833')
ax1.set_ylim((0, 1))
#ax1.legend(loc="center right")
ax1.yaxis.set_minor_locator(AutoMinorLocator())
ax2.yaxis.set_minor_locator(AutoMinorLocator())
plt.savefig(os.path.join(save_dir, "chisq_vs_accuracy.pdf"), bbox_inches="tight")
plt.close()
[docs]
def plot_model_metrics(metrics, plot_name, metrics_dir):
"""Plots training and validation results and exports them to files.
Parameters
----------
metrics: tuple
Train and validation accuracies and losses.
num_epochs: int
The total number of epochs.
plot_name: str
The name for the plot figure files.
metrics_dir: str
Where to store the plot figures.
"""
train_acc, train_loss, val_acc, val_loss = metrics
num_epochs = len(train_acc)
# Plot accuracy
plt.plot(np.arange(0, num_epochs), train_acc, label="Training")
plt.plot(np.arange(0, num_epochs), val_acc, label="Validation")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.savefig(
os.path.join(metrics_dir, f"accuracy_{plot_name}.pdf"),
bbox_inches="tight",
)
plt.close()
# Plot loss
plt.plot(np.arange(0, num_epochs), train_loss, label="Training")
plt.plot(np.arange(0, num_epochs), val_loss, label="Validation")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.yscale("log")
plt.legend()
plt.savefig(os.path.join(metrics_dir, f"loss_{plot_name}.pdf"), bbox_inches="tight")
plt.close()
[docs]
def plot_calibration_curve(probs_csv, save_dir):
"""Plot calibration curve."""
labels_to_classes, classes_to_labels = SnClass.get_type_maps()
colors = CUSTOM_COLORSET
fig, ax = plt.subplots(figsize=(6, 8))
ax.set_xlim([0.0, 1.05])
ax.set_ylim([0.0, 1.05])
ratio = 1.0
ax.set_ylabel("True Fraction")
plt.locator_params(axis="x", nbins=3)
legend_lines = []
for ref_class, ref_label in enumerate(classes_to_labels):
_, true_classes, probs, _, folds, _ = read_probs_csv(probs_csv)
y_true = np.where(true_classes == ref_class, 1, 0)
y_score = probs[:, ref_class]
t, f, ferr = calc_calibration_curve(y_true, y_score, folds)
(legend_line,) = ax.step(
t, f, label=ref_label, c=colors[ref_class], where='post'
)
ax.fill_between(t, f-ferr, f+ferr, alpha=0.2, color=colors[ref_class], step='post')
legend_lines.append(legend_line)
ax.plot([0,1], [0,1], linestyle='dotted', color='k', linewidth=1)
x_left, x_right = ax.get_xlim()
y_low, y_high = ax.get_ylim()
ax.set_aspect(abs((x_right - x_left) / (y_low - y_high)) * ratio)
ax.yaxis.set_minor_locator(AutoMinorLocator())
ax.xaxis.set_minor_locator(AutoMinorLocator())
ax.set_xlabel("Confidence")
#legend_lines.append(legend_line)
legend_keys = list(labels_to_classes.keys())
fig.legend(legend_lines, legend_keys, loc="lower center", ncol=3)
plt.savefig(os.path.join(save_dir, "calibration_curve.pdf"), bbox_inches="tight")
plt.close()
[docs]
def plot_f1_curve(probs_csv, save_dir, ref_class):
"""Plot calibration curve."""
labels_to_classes, classes_to_labels = SnClass.get_type_maps()
colors = CUSTOM_COLORSET
fig, ax = plt.subplots(figsize=(6, 8))
ax.set_xlim([0.0, 1.05])
ratio = 1.0
ax.set_ylabel(r"F$_1$")
thresholds = np.linspace(0, 1, 1000)
f1_mu = []
f1_sig = []
df = pd.read_csv(probs_csv)
true_classes = df['Label'].to_numpy()
y_score = df['pSNIa'].to_numpy()
folds = df['Fold'].to_numpy()
y_true = np.where(true_classes == ref_class, 1, 0).astype(bool)
for t in thresholds:
y_pred = y_score > t
intersect = (y_pred & y_true).astype(int)
intersect_other = (~y_pred & ~y_true).astype(int)
f1s = []
for f in np.unique(folds):
f_idx = folds == f
if sum(y_pred[f_idx].astype(int)) == 0:
precision = 1.0
else:
precision = sum(intersect[f_idx]) / sum(y_pred[f_idx].astype(int))
recall = sum(intersect[f_idx]) / sum(y_true[f_idx].astype(int))
Ia_f1 = 2 * precision * recall / (precision + recall)
if sum((~y_pred[f_idx]).astype(int)) == 0:
prec2 = 1.0
else:
prec2 = sum(intersect_other[f_idx]) / sum((~y_pred[f_idx]).astype(int))
recall2 = sum(intersect_other[f_idx]) / sum((~y_true[f_idx]).astype(int))
other_f1 = 2 * prec2 * recall2 / (prec2 + recall2)
f1s.append((Ia_f1 + other_f1)/2)
f1_mu.append(np.nanmean(f1s))
f1_sig.append(np.nanstd(f1s))
f1_mu = np.asarray(f1_mu)
f1_sig = np.asarray(f1_sig)
ax.plot(thresholds, f1_mu, c=colors[0])
ax.fill_between(thresholds, f1_mu-f1_sig, f1_mu+f1_sig, alpha=0.2, color=colors[0])
# retrieve optimal F1 score
best_idx = np.argmax(f1_mu)
best_t = thresholds[best_idx]
ax.axvline(x=best_t, linestyle='dotted', color=colors[0])
x_left, x_right = ax.get_xlim()
y_low, y_high = ax.get_ylim()
ax.set_aspect(abs((x_right - x_left) / (y_low - y_high)) * ratio)
ax.yaxis.set_minor_locator(AutoMinorLocator())
ax.xaxis.set_minor_locator(AutoMinorLocator())
ax.set_xlabel("Confidence Threshold")
plt.savefig(os.path.join(save_dir, f"f1_curve_{ref_class}.pdf"), bbox_inches="tight")
plt.close()
print(best_t)
return best_t