"""This module provides various functions for analyzing and visualizing
light curve data."""
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix
from sklearn.utils.multiclass import unique_labels
from superphot_plus.plotting.utils import read_probs_csv, retrieve_four_class_info
from superphot_plus.supernova_class import SupernovaClass as SnClass
from superphot_plus.utils import calc_accuracy, f1_score
from superphot_plus.constants import BIGGER_SIZE, MEDIUM_SIZE, SMALL_SIZE
from superphot_plus.plotting.format_params import set_global_plot_formatting
set_global_plot_formatting()
[docs]
def plot_high_confidence_confusion_matrix(probs_csv, filename, cutoff=0.7, num_include=None):
"""Plot confusion matrices for high-confidence predictions.
Parameters
----------
probs_csv : str
Path to the CSV file containing probability predictions.
filename : str
Base filename for saving the confusion matrix plots.
cutoff : float, optional
Probability cutoff value for high-confidence predictions.
Default is 0.7.
"""
_, classes_to_labels = SnClass.get_type_maps()
_, true_classes, probs, pred_classes, folds, _ = read_probs_csv(probs_csv)
if num_include is not None:
confidences = np.max(probs, axis=1)
conf_ordered = np.sort(confidences)
print(conf_ordered)
cutoff = conf_ordered[len(conf_ordered) - num_include - 1]
high_conf_mask = np.max(probs, axis=1) > cutoff
true_labels = [classes_to_labels[x] for x in true_classes[high_conf_mask]]
pred_labels = [classes_to_labels[x] for x in pred_classes[high_conf_mask]]
try:
folds = folds[high_conf_mask]
except:
folds = None
plot_confusion_matrix(true_labels, pred_labels, filename + "_c.pdf", folds, purity=False)
plot_confusion_matrix(true_labels, pred_labels, filename + "_p.pdf", folds, purity=True)
[docs]
def plot_binary_confusion_matrix(probs_csv, filename, cutoff=0.5):
"""Merge all non-Ia into one core collapse class and plot resulting
binary confusion matrix.
Parameters
----------
probs_csv : str
Path to the CSV file containing probability predictions.
filename : str
Base filename for saving the confusion matrix plots.
"""
df = pd.read_csv(probs_csv)
true_classes = df.Label.to_numpy()
prob_Ia = df.pSNIa.to_numpy()
pred_binary = np.where(prob_Ia > cutoff, "SN Ia", "SN CC")
true_binary = np.where(true_classes == 0, "SN Ia", "SN CC")
try:
folds = df.Fold.to_numpy()
except:
folds = None
plot_confusion_matrix(true_binary, pred_binary, filename + "_c.pdf", folds, purity=False)
plot_confusion_matrix(true_binary, pred_binary, filename + "_p.pdf", folds, purity=True)
[docs]
def compare_four_class_confusion_matrices(probs_csv, probs_alerce_csv, save_dir, p07=False):
"""Plots ALeRCE's classifications as confusion matrix, and compare
to condensed four-class CM of Superphot+.
Only four classes as SNe IIn is not a label in their transient
classifier.
Parameters
----------
probs_csv : str
Path to the CSV file containing Superphot+ probability predictions.
probs_alerce_csv : str
Path to the CSV file containing ALeRCE predicted classes.
save_dir : str
Directory for saving the confusion matrix plots.
p07 : bool, optional
If True, only include predictions with a probability >= 0.7.
Default is False.
"""
(
_, true_labels, _,
pred_labels, pred_alerce, folds
) = retrieve_four_class_info(probs_csv, probs_alerce_csv, p07)
plot_confusion_matrix(
true_labels, pred_labels,
os.path.join(save_dir, "superphot4_c.pdf"),
folds=folds, purity=False, cmap='custom_cmap1',
)
plot_confusion_matrix(
true_labels, pred_labels,
os.path.join(save_dir, "superphot4_p.pdf"),
folds=folds, purity=True, cmap='custom_cmap1',
)
plot_confusion_matrix(
true_labels, pred_alerce,
os.path.join(save_dir, "alerce_c.pdf"),
folds=folds, purity=False, cmap='custom_cmap2',
)
plot_confusion_matrix(
true_labels, pred_alerce,
os.path.join(save_dir, "alerce_p.pdf"),
folds=folds, purity=True, cmap='custom_cmap2',
)
[docs]
def plot_true_agreement_matrix(probs_csv, probs_alerce_csv, save_dir, spec=True):
"""Plot agreement matrix between ALeRCE and Superphot+
classifications.
Parameters
----------
probs_csv : str
Path to the CSV file containing probability predictions.
probs_alerce_csv : str
Path to the CSV containing ALeRCE predictions.
save_dir : str
Directory path for saving the agreement matrix plot.
"""
pred_labels, pred_alerce, folds = retrieve_four_class_info(
probs_csv,
probs_alerce_csv,
False,
)[3:6]
plot_agreement_matrix_from_arrs(pred_labels, pred_alerce, folds, save_dir, spec=spec)
[docs]
def plot_expected_agreement_matrix(probs_csv, probs_alerce_csv, save_dir, cmap="custom_cmap2"):
"""Plot expected agreement matrix based on independent ALeRCE and
Superphot+ confusion matrices.
Parameters
----------
probs_csv : str
Path to the CSV file containing probability predictions.
save_dir : str
Directory for saving the expected agreement matrix plot.
cmap : matplotlib.colors.Colormap, optional
Color map for the plot. Default is plt.cm.Purples.
"""
(_, true_labels, _, pred_labels, alerce_preds, folds) = retrieve_four_class_info(
probs_csv, probs_alerce_csv
)
accs = []
cm_vals_all = []
alerce_preds = np.array(alerce_preds)
classes = unique_labels(alerce_preds, pred_labels)
for f in np.unique(folds):
ap_fold = alerce_preds[folds == f]
cm_purity = confusion_matrix(
true_labels[folds == f],
pred_labels[folds == f],
normalize="pred"
)
cm_complete = confusion_matrix(
true_labels[folds == f],
ap_fold,
normalize="true"
)
cm_expected = cm_purity.T @ cm_complete
exp_acc = 0
# calculate agreement score
for i, single_class in enumerate(classes):
num_in_class = len(ap_fold[ap_fold == single_class])
exp_acc += num_in_class * cm_expected[i, i]
accs.append(exp_acc / len(ap_fold))
cm_vals_all.append(cm_expected)
cm_vals_all = np.asarray(cm_vals_all)
cm_expected = np.median(cm_vals_all, axis=0)
cm_low = np.abs(cm_expected - np.percentile(cm_vals_all, 10, axis=0))
cm_high = np.abs(np.percentile(cm_vals_all, 90, axis=0) - cm_expected)
acc = np.median(accs)
acc_low = acc - np.percentile(accs, 10)
acc_high = np.percentile(accs, 90) - acc
title = f"Expected Agreement Matrix,\nSpec. ($A' = {acc:.2f}^{{+{acc_high:.2f}}}_{{-{acc_low:.2f}}}$)"
fig, axis = plt.subplots(figsize=(6,6))
_ = axis.imshow(cm_expected, interpolation="nearest", vmin=0.0, vmax=1.0, cmap=cmap)
axis.set(
xticks=np.arange(cm_expected.shape[1]),
yticks=np.arange(cm_expected.shape[0]),
xticklabels=classes,
yticklabels=classes,
title=title,
ylabel="ALeRCE Classification",
xlabel="Superphot+ Classification",
)
# Rotate the tick labels and set their alignment.
plt.setp(axis.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
# Loop over data dimensions and create text annotations.
fmt = ".2f"
thresh = cm_expected.max() / 1.5
for i in range(cm_expected.shape[0]):
for j in range(cm_expected.shape[1]):
axis.text(
j,
i,
f"${cm_expected[i, j]:.2f}^{{+{cm_high[i, j]:.2f}}}_{{-{cm_low[i, j]:.2f}}}$",
ha="center",
va="center",
color="white" if cm_expected[i, j] > thresh else "black",
)
fig.tight_layout()
plt.xlim(-0.5, len(classes) - 0.5)
plt.ylim(len(classes) - 0.5, -0.5)
plt.savefig(
os.path.join(save_dir, "expected_agreement.pdf"),
bbox_inches='tight',
)
plt.close()
[docs]
def plot_agreement_matrix_from_arrs(our_labels, alerce_labels, folds, save_dir, spec=True, cmap="custom_cmap2"):
"""Helper function to plot agreement matrices.
Plot agreement matrix based on input arrays of ALeRCE and Superphot+
classifications.
Parameters
----------
our_labels : list
List of our predicted labels.
alerce_labels : list
List of ALeRCE predicted labels.
filename : str
Base filename for saving the agreement matrix plot.
cmap : matplotlib.colors.Colormap, optional
Color map for the plot. Default is plt.cm.Purples.
"""
if spec:
suffix_title = "Spec."
suffix = "spec"
else:
suffix_title = "Phot."
suffix = "phot"
accs = []
cm_vals_all = []
classes = unique_labels(alerce_labels, our_labels)
our_labels = np.array(our_labels)
alerce_labels = np.array(alerce_labels)
for f in np.unique(folds):
cm_vals_all.append(
confusion_matrix(
alerce_labels[folds == f],
our_labels[folds == f], normalize="true"
)
)
accs.append(
calc_accuracy(
alerce_labels[folds == f],
our_labels[folds == f]
)
)
cm_vals_all = np.asarray(cm_vals_all)
cm = np.median(cm_vals_all, axis=0)
cm_low = np.abs(cm - np.percentile(cm_vals_all, 10, axis=0))
cm_high = np.abs(np.percentile(cm_vals_all, 90, axis=0) - cm)
acc = np.median(accs)
acc_low = acc - np.percentile(accs, 10)
acc_high = np.percentile(accs, 90) - acc
if spec:
title = "True Agreement Matrix,\n" + fr"{suffix_title} ($A' = {acc:.2f}^{{+{acc_high:.2f}}}_{{-{acc_low:.2f}}}$)"
else:
title = "True Agreement Matrix,\n" + fr"{suffix_title} ($A' = {acc:.2f}$)"
fig, ax = plt.subplots(figsize=(6,6))
_ = ax.imshow(cm, interpolation="nearest", vmin=0.0, vmax=1.0, cmap=cmap)
ax.set(
xticks=np.arange(cm.shape[1]),
yticks=np.arange(cm.shape[0]),
# ... and label them with the respective list entries
xticklabels=classes,
yticklabels=classes,
title=title,
ylabel="ALeRCE Classification",
xlabel="Superphot+ Classification",
)
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
# Loop over data dimensions and create text annotations.
thresh = cm.max() / 1.5
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
class_i = classes[i]
class_j = classes[j]
num_in_cell = len(our_labels[(our_labels == class_j) & (alerce_labels == class_i)])
if spec:
ax.text(
j,
i,
f"${cm[i, j]:.2f}^{{+{cm_high[i, j]:.2f}}}_{{-{cm_low[i, j]:.2f}}}$" + f"\n({num_in_cell})",
ha="center",
va="center",
color="white" if cm[i, j] > thresh else "black",
)
else:
ax.text(
j,
i,
f"${cm[i, j]:.2f}$" + f"\n({num_in_cell})",
ha="center",
va="center",
color="white" if cm[i, j] > thresh else "black",
)
fig.tight_layout()
plt.xlim(-0.5, len(classes) - 0.5)
plt.ylim(len(classes) - 0.5, -0.5)
plt.savefig(
os.path.join(save_dir, f"true_agreement_{suffix}.pdf"),
bbox_inches='tight'
)
plt.close()
[docs]
def plot_confusion_matrix(y_true, y_pred, filename, folds=None, purity=False, cmap="custom_cmap1"):
"""Plot the confusion matrix between given true and predicted
labels.
Parameters
----------
y_true : array-like
True labels.
y_pred : array-like
Predicted labels.
filename : str
Base filename for saving the confusion matrix plot.
purity : bool, optional
If True, plot the purity confusion matrix. Default is False.
cmap : matplotlib.colors.Colormap, optional
Color map for the plot. Default is plt.cm.Purples.
"""
y_true = np.array(y_true)
y_pred = np.array(y_pred)
if folds is None:
acc = calc_accuracy(y_pred, y_true)
f1_avg = f1_score(y_pred, y_true, class_average=True)
if purity:
title = f"Purity\n$N = {len(y_pred)}, A = {acc:.2f}, F_1 = {f1_avg:.2f}$"
cm_vals = confusion_matrix(y_true, y_pred, normalize="pred")
else:
title = f"Completeness\n$N = {len(y_pred)}, A = {acc:.2f}, F_1 = {f1_avg:.2f}$"
cm_vals = confusion_matrix(y_true, y_pred, normalize="true")
else:
folds = np.array(folds)
print(len(y_pred), len(folds))
accs = []
f1s = []
cm_vals_all = []
for f in np.unique(folds):
y_pred_sub = y_pred[folds == f]
y_true_sub = y_true[folds == f]
accs.append(
calc_accuracy(y_pred_sub, y_true_sub)
)
f1s.append(
f1_score(y_pred_sub, y_true_sub, class_average=True)
)
if purity:
cm_vals_all.append(
confusion_matrix(y_true_sub, y_pred_sub, normalize="pred")
)
else:
cm_vals_all.append(
confusion_matrix(y_true_sub, y_pred_sub, normalize="true")
)
cm_vals_all = np.asarray(cm_vals_all)
cm_vals = np.median(cm_vals_all, axis=0)
cm_low = np.abs(cm_vals - np.percentile(cm_vals_all, 10, axis=0))
cm_high = np.abs(np.percentile(cm_vals_all, 90, axis=0) - cm_vals)
acc = np.median(accs)
acc_low = acc - np.percentile(accs, 10)
acc_high = np.percentile(accs, 90) - acc
f1_avg = np.median(f1s)
f1_low = f1_avg - np.percentile(f1s, 10)
f1_high = np.percentile(f1s, 90) - f1_avg
# plt.rcParams["figure.figsize"] = (16, 16)
if purity:
title = f"Purity\n$N = {len(y_pred)}, "
else:
title = f"Completeness\n$N = {len(y_pred)}, "
title += f"A = {acc:.2f}^{{+{acc_high:.2f}}}_{{-{acc_low:.2f}}}, "
title += f"F_1 = {f1_avg:.2f}^{{+{f1_high:.2f}}}_{{-{f1_low:.2f}}}$"
classes = unique_labels(y_true, y_pred)
N_class = len(np.unique(y_true))
fig, ax = plt.subplots(figsize=(7, 7))
_ = ax.imshow(cm_vals, interpolation="nearest", vmin=0.0, vmax=1.0, cmap=cmap)
ax.set(
xticks=np.arange(cm_vals.shape[1]),
yticks=np.arange(cm_vals.shape[0]),
# ... and label them with the respective list entries
xticklabels=classes,
yticklabels=classes,
title=title,
ylabel="Spectroscopic Classification",
xlabel="Photometric Classification",
)
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
# Loop over data dimensions and create text annotations.
thresh = cm_vals.max() / 1.5
for i in range(cm_vals.shape[0]):
for j in range(cm_vals.shape[1]):
class_i = classes[i]
class_j = classes[j]
num_in_cell = len(y_pred[(y_pred == class_j) & (y_true == class_i)])
if folds is None:
ax.text(
j,
i,
f"${cm_vals[i, j]:.2f}\n({num_in_cell})$",
ha="center",
va="center",
color="white" if cm_vals[i, j] > thresh else "black",
)
else:
ax.text(
j,
i,
f"${cm_vals[i, j]:.2f}^{{+{cm_high[i, j]:.2f}}}_{{-{cm_low[i, j]:.2f}}}$" + f"\n({num_in_cell})",
ha="center",
va="center",
color="white" if cm_vals[i, j] > thresh else "black",
)
fig.tight_layout()
plt.xlim(-0.5, len(classes) - 0.5)
plt.ylim(len(classes) - 0.5, -0.5)
plt.savefig(filename, bbox_inches='tight')
plt.close()
[docs]
def plot_matrices(
config,
true_classes,
pred_classes,
prob_above_07,
):
"""Plots confusion matrices for test set metrics.
Parameters
----------
config : ModelConfig
The configuration of the model used for evaluation.
true_classes : np.ndarray
The ground truth for the test ZTF objects.
pred_classes : np.ndarray
The predicted classes for the test ZTF objects.
prob_above_07 : np.ndarray
Indicates which predictions had a 70% confidence.
cm_folder : str
The folder where the plot figures will be stored.
"""
cm_folder = config.cm_dir
fn_prefix = f"cm_{config.goal_per_class}_{config.num_epochs}_{config.neurons_per_layer}_{config.num_hidden_layers}"
fn_purity = os.path.join(cm_folder, fn_prefix + "_p.pdf")
fn_completeness = os.path.join(cm_folder, fn_prefix + "_c.pdf")
fn_purity_07 = os.path.join(cm_folder, fn_prefix + "_p_p07.pdf")
fn_completeness_07 = os.path.join(cm_folder, fn_prefix + "_c_p07.pdf")
# Plot full confusion matrices
plot_confusion_matrix(true_classes, pred_classes, fn_purity, purity=True)
plot_confusion_matrix(true_classes, pred_classes, fn_completeness, purity=False)
# Plot confusion matrices for p > 0.7
if np.any(prob_above_07):
plot_confusion_matrix(
true_classes[prob_above_07],
pred_classes[prob_above_07],
fn_purity_07,
purity=True,
)
plot_confusion_matrix(
true_classes[prob_above_07],
pred_classes[prob_above_07],
fn_completeness_07,
purity=False,
)