"""This module provides various functions for analyzing and visualizing
light curve data."""
import os
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix
from sklearn.utils.multiclass import unique_labels
from superphot_plus.plotting.utils import read_probs_csv, retrieve_four_class_info
from superphot_plus.supernova_class import SupernovaClass as SnClass
from superphot_plus.utils import calc_accuracy, f1_score
from superphot_plus.plotting.format_params import set_global_plot_formatting
set_global_plot_formatting()
[docs]def plot_high_confidence_confusion_matrix(probs_csv, filename, cutoff=0.7):
"""Plot confusion matrices for high-confidence predictions.
Parameters
----------
probs_csv : str
Path to the CSV file containing probability predictions.
filename : str
Base filename for saving the confusion matrix plots.
cutoff : float, optional
Probability cutoff value for high-confidence predictions.
Default is 0.7.
"""
_, classes_to_labels = SnClass.get_type_maps()
_, true_classes, probs, pred_classes, _ = read_probs_csv(probs_csv)
high_conf_mask = np.max(probs, axis=1) > cutoff
true_labels = [classes_to_labels[x] for x in true_classes[high_conf_mask]]
pred_labels = [classes_to_labels[x] for x in pred_classes[high_conf_mask]]
plot_confusion_matrix(true_labels, pred_labels, filename + "_c.pdf", purity=False)
plot_confusion_matrix(true_labels, pred_labels, filename + "_p.pdf", purity=True)
[docs]def plot_binary_confusion_matrix(probs_csv, filename):
"""Merge all non-Ia into one core collapse class and plot resulting
binary confusion matrix.
Parameters
----------
probs_csv : str
Path to the CSV file containing probability predictions.
filename : str
Base filename for saving the confusion matrix plots.
"""
(_, true_classes, probs, _, _) = read_probs_csv(probs_csv)
pred_binary = np.where(probs[:, 0] > 0.5, "SN Ia", "SN CC")
true_binary = np.where(true_classes == 0, "SN Ia", "SN CC")
plot_confusion_matrix(true_binary, pred_binary, filename + "_c.pdf", purity=False)
plot_confusion_matrix(true_binary, pred_binary, filename + "_p.pdf", purity=True)
[docs]def compare_four_class_confusion_matrices(probs_csv, probs_alerce_csv, save_dir, p07=False):
"""Plots ALeRCE's classifications as confusion matrix, and compare
to condensed four-class CM of Superphot+.
Only four classes as SNe IIn is not a label in their transient
classifier.
Parameters
----------
probs_csv : str
Path to the CSV file containing Superphot+ probability predictions.
probs_alerce_csv : str
Path to the CSV file containing ALeRCE predicted classes.
save_dir : str
Directory for saving the confusion matrix plots.
p07 : bool, optional
If True, only include predictions with a probability >= 0.7.
Default is False.
"""
(_, true_labels, _, pred_labels, pred_alerce) = retrieve_four_class_info(probs_csv, probs_alerce_csv, p07)
plot_confusion_matrix(
true_labels, pred_labels, os.path.join(save_dir, "superphot4_c.pdf"), purity=False, cmap="Purples"
)
plot_confusion_matrix(
true_labels, pred_labels, os.path.join(save_dir, "superphot4_p.pdf"), purity=True, cmap="Purples"
)
plot_confusion_matrix(
true_labels, pred_alerce, os.path.join(save_dir, "alerce_c.pdf"), purity=False, cmap="Blues"
)
plot_confusion_matrix(
true_labels, pred_alerce, os.path.join(save_dir, "alerce_p.pdf"), purity=True, cmap="Blues"
)
[docs]def plot_true_agreement_matrix(probs_csv, probs_alerce_csv, save_dir, spec=True):
"""Plot agreement matrix between ALeRCE and Superphot+
classifications.
Parameters
----------
probs_csv : str
Path to the CSV file containing probability predictions.
probs_alerce_csv : str
Path to the CSV containing ALeRCE predictions.
save_dir : str
Directory path for saving the agreement matrix plot.
"""
pred_labels, pred_alerce = retrieve_four_class_info(
probs_csv,
probs_alerce_csv,
False,
)[3:]
plot_agreement_matrix_from_arrs(pred_labels, pred_alerce, save_dir, spec=spec)
[docs]def plot_expected_agreement_matrix(probs_csv, probs_alerce_csv, save_dir, cmap="Purples"):
"""Plot expected agreement matrix based on independent ALeRCE and
Superphot+ confusion matrices.
Parameters
----------
probs_csv : str
Path to the CSV file containing probability predictions.
save_dir : str
Directory for saving the expected agreement matrix plot.
cmap : matplotlib.colors.Colormap, optional
Color map for the plot. Default is plt.cm.Purples.
"""
(_, true_labels, _, pred_labels, alerce_preds) = retrieve_four_class_info(
probs_csv, probs_alerce_csv, False
)
cm_purity = confusion_matrix(true_labels, pred_labels, normalize="pred")
cm_complete = confusion_matrix(true_labels, alerce_preds, normalize="true")
cm_expected = cm_purity.T @ cm_complete
classes = unique_labels(alerce_preds, pred_labels)
alerce_preds = np.array(alerce_preds)
exp_acc = 0
# calculate agreement score
for i, single_class in enumerate(classes):
num_in_class = len(alerce_preds[alerce_preds == single_class])
exp_acc += num_in_class * cm_expected[i, i]
exp_acc /= len(alerce_preds)
title = f"Expected Agreement Matrix,\nSpec. ($A' = {exp_acc:.2f}$)"
fig, axis = plt.subplots(figsize=(6,6))
_ = axis.imshow(cm_expected, interpolation="nearest", vmin=0.0, vmax=1.0, cmap=cmap)
axis.set(
xticks=np.arange(cm_expected.shape[1]),
yticks=np.arange(cm_expected.shape[0]),
xticklabels=classes,
yticklabels=classes,
title=title,
ylabel="ALeRCE Classification",
xlabel="Superphot+ Classification",
)
# Rotate the tick labels and set their alignment.
plt.setp(axis.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
# Loop over data dimensions and create text annotations.
fmt = ".2f"
thresh = cm_expected.max() / 2.0
for i in range(cm_expected.shape[0]):
for j in range(cm_expected.shape[1]):
axis.text(
j,
i,
format(cm_expected[i, j], fmt),
ha="center",
va="center",
color="white" if cm_expected[i, j] > thresh else "black",
)
fig.tight_layout()
plt.xlim(-0.5, len(classes) - 0.5)
plt.ylim(len(classes) - 0.5, -0.5)
plt.savefig(
os.path.join(save_dir, "expected_agreement.pdf"),
bbox_inches='tight',
)
plt.close()
[docs]def plot_agreement_matrix_from_arrs(our_labels, alerce_labels, save_dir, spec=True, cmap="Purples"):
"""Helper function to plot agreement matrices.
Plot agreement matrix based on input arrays of ALeRCE and Superphot+
classifications.
Parameters
----------
our_labels : list
List of our predicted labels.
alerce_labels : list
List of ALeRCE predicted labels.
filename : str
Base filename for saving the agreement matrix plot.
cmap : matplotlib.colors.Colormap, optional
Color map for the plot. Default is plt.cm.Purples.
"""
if spec:
suffix_title = "Spec."
suffix = "spec"
else:
suffix_title = "Phot."
suffix = "phot"
cm = confusion_matrix(alerce_labels, our_labels, normalize="true")
classes = unique_labels(alerce_labels, our_labels)
our_labels = np.array(our_labels)
alerce_labels = np.array(alerce_labels)
exp_acc = calc_accuracy(alerce_labels, our_labels)
title = "True Agreement Matrix,\n" + fr"{suffix_title} ($A' = %.2f$)" % exp_acc
fig, ax = plt.subplots(figsize=(6,6))
_ = ax.imshow(cm, interpolation="nearest", vmin=0.0, vmax=1.0, cmap=cmap)
ax.set(
xticks=np.arange(cm.shape[1]),
yticks=np.arange(cm.shape[0]),
# ... and label them with the respective list entries
xticklabels=classes,
yticklabels=classes,
title=title,
ylabel="ALeRCE Classification",
xlabel="Superphot+ Classification",
)
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
# Loop over data dimensions and create text annotations.
thresh = cm.max() / 2.0
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
class_i = classes[i]
class_j = classes[j]
num_in_cell = len(our_labels[(our_labels == class_j) & (alerce_labels == class_i)])
ax.text(
j,
i,
f"{cm[i, j]:.2f}\n({num_in_cell})",
ha="center",
va="center",
color="white" if cm[i, j] > thresh else "black",
)
fig.tight_layout()
plt.xlim(-0.5, len(classes) - 0.5)
plt.ylim(len(classes) - 0.5, -0.5)
plt.savefig(
os.path.join(save_dir, f"true_agreement_{suffix}.pdf"),
bbox_inches='tight'
)
plt.close()
[docs]def plot_confusion_matrix(y_true, y_pred, filename, purity=False, cmap="Purples"):
"""Plot the confusion matrix between given true and predicted
labels.
Parameters
----------
y_true : array-like
True labels.
y_pred : array-like
Predicted labels.
filename : str
Base filename for saving the confusion matrix plot.
purity : bool, optional
If True, plot the purity confusion matrix. Default is False.
cmap : matplotlib.colors.Colormap, optional
Color map for the plot. Default is plt.cm.Purples.
"""
y_true = np.array(y_true)
y_pred = np.array(y_pred)
acc = calc_accuracy(y_pred, y_true)
f1_avg = f1_score(y_pred, y_true, class_average=True)
# plt.rcParams["figure.figsize"] = (16, 16)
if purity:
title = f"Purity\n($N = {len(y_pred)}, A = {acc:.2f}, F_1 = {f1_avg:.2f}$)"
cm_vals = confusion_matrix(y_true, y_pred, normalize="pred")
else:
title = f"Completeness\n($N = {len(y_pred)}, A = {acc:.2f}, F_1 = {f1_avg:.2f}$)"
cm_vals = confusion_matrix(y_true, y_pred, normalize="true")
classes = unique_labels(y_true, y_pred)
fig, ax = plt.subplots(figsize=(6,6))
_ = ax.imshow(cm_vals, interpolation="nearest", vmin=0.0, vmax=1.0, cmap=cmap)
ax.set(
xticks=np.arange(cm_vals.shape[1]),
yticks=np.arange(cm_vals.shape[0]),
# ... and label them with the respective list entries
xticklabels=classes,
yticklabels=classes,
title=title,
ylabel="True label",
xlabel="Predicted label",
)
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
# Loop over data dimensions and create text annotations.
thresh = cm_vals.max() / 2.0
for i in range(cm_vals.shape[0]):
for j in range(cm_vals.shape[1]):
class_i = classes[i]
class_j = classes[j]
num_in_cell = len(y_pred[(y_pred == class_j) & (y_true == class_i)])
ax.text(
j,
i,
f"{cm_vals[i, j]:.2f}\n({num_in_cell})",
ha="center",
va="center",
color="white" if cm_vals[i, j] > thresh else "black",
)
fig.tight_layout()
plt.xlim(-0.5, len(classes) - 0.5)
plt.ylim(len(classes) - 0.5, -0.5)
plt.savefig(filename, bbox_inches='tight')
plt.close()
[docs]def plot_matrices(
config,
true_classes,
pred_classes,
prob_above_07,
cm_folder,
):
"""Plots confusion matrices for test set metrics.
Parameters
----------
config : ModelConfig
The configuration of the model used for evaluation.
true_classes : np.ndarray
The ground truth for the test ZTF objects.
pred_classes : np.ndarray
The predicted classes for the test ZTF objects.
prob_above_07 : np.ndarray
Indicates which predictions had a 70% confidence.
cm_folder : str
The folder where the plot figures will be stored.
"""
fn_prefix = f"cm_{config.goal_per_class}_{config.num_epochs}_{config.neurons_per_layer}_{config.num_hidden_layers}"
fn_purity = os.path.join(cm_folder, fn_prefix + "_p.pdf")
fn_completeness = os.path.join(cm_folder, fn_prefix + "_c.pdf")
fn_purity_07 = os.path.join(cm_folder, fn_prefix + "_p_p07.pdf")
fn_completeness_07 = os.path.join(cm_folder, fn_prefix + "_c_p07.pdf")
# Plot full confusion matrices
plot_confusion_matrix(true_classes, pred_classes, fn_purity, True)
plot_confusion_matrix(true_classes, pred_classes, fn_completeness, False)
# Plot confusion matrices for p > 0.7
if np.any(prob_above_07):
plot_confusion_matrix(
true_classes[prob_above_07],
pred_classes[prob_above_07],
fn_purity_07,
True,
)
plot_confusion_matrix(
true_classes[prob_above_07],
pred_classes[prob_above_07],
fn_completeness_07,
False,
)