Source code for classifier_results

"""This module provides various functions for analyzing and visualizing
classification results."""

import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from sklearn.metrics import roc_curve, confusion_matrix
from matplotlib.ticker import AutoMinorLocator

from astropy.cosmology import Planck13 as cosmo
from scipy.stats import binned_statistic

from superphot_plus.plotting.format_params import set_global_plot_formatting
from superphot_plus.file_utils import get_multiple_posterior_samples
from superphot_plus.format_data_ztf import import_labels_only
from superphot_plus.plotting.utils import histedges_equalN, read_probs_csv, get_survey_fracs
from superphot_plus.supernova_class import SupernovaClass as SnClass

set_global_plot_formatting()


[docs]def save_class_fractions(spec_probs_csv, probs_alerce_csv, phot_probs_csv, save_path): """Save class fractions from spectroscopic, photometric, and corrected photometric. Parameters ---------- spec_probs_csv : str Path to the CSV file containing spectroscopic probability predictions. phot_probs_csv : str Path to the CSV file containing photometric probability predictions. save_fn : str Filename + dir for saving the class fractions. """ labels_to_class, _ = SnClass.get_type_maps() # import spec dataframe _, true_class_spec, probs_spec, pred_class_spec, _ = read_probs_csv(spec_probs_csv) num_classes = probs_spec.shape[1] true_class_alerce = true_class_spec.copy() true_class_alerce[true_class_alerce == 2] = 1 # read in ALeRCE classes df_alerce = pd.read_csv(probs_alerce_csv) pred_alerce = df_alerce.alerce_label.to_numpy().astype(str) ignore_mask = (pred_alerce == "None") | (pred_alerce == "nan") | (pred_alerce == "SKIP") # ignore true SNe IIn ignore_mask = ignore_mask | (true_class_alerce == 2) true_class_alerce = true_class_alerce[~ignore_mask] pred_alerce = pred_alerce[~ignore_mask] pred_class_spec_alerce = np.array([labels_to_class[x] for x in pred_alerce]) # import phot dataframe _, pred_label_alerce, _, pred_class_phot, _ = read_probs_csv(phot_probs_csv) skip_idx = pred_label_alerce == "SKIP" pred_label_alerce, pred_class_phot = pred_label_alerce[~skip_idx], pred_class_phot[~skip_idx] pred_class_phot_alerce = np.array([labels_to_class[x] for x in pred_label_alerce]) cm_p = confusion_matrix(true_class_spec, pred_class_spec, normalize="pred") cm_p_alerce = confusion_matrix(true_class_alerce, pred_class_spec_alerce, normalize="pred") true_fracs = np.array( [len(true_class_spec[true_class_spec == i]) / len(true_class_spec) for i in range(num_classes)] ) pred_fracs = np.array( [len(pred_class_phot[pred_class_phot == i]) / len(pred_class_phot) for i in range(num_classes)] ) alerce_fracs = np.array( [ len(pred_class_phot_alerce[pred_class_phot_alerce == i]) / len(pred_class_phot_alerce) for i in range(num_classes) ] ) pred_fracs_corr = [] alerce_fracs_corr = [] for i in range(5): pred_fracs_corr.append(np.sum(pred_fracs * cm_p[i])) if i == 2: alerce_fracs_corr.append(0.0) elif i > 2: alerce_fracs_corr.append(np.sum(np.delete(alerce_fracs, 2) * cm_p_alerce[i - 1])) else: alerce_fracs_corr.append(np.sum(np.delete(alerce_fracs, 2) * cm_p_alerce[i])) pred_fracs_corr = np.array(pred_fracs_corr) alerce_fracs_corr = np.array(alerce_fracs_corr) save_df = pd.DataFrame( { "spec_fracs": true_fracs, "phot_fracs": pred_fracs, "phot_fracs_corr": pred_fracs_corr, "alerce_fracs": alerce_fracs, "alerce_fracs_corr": alerce_fracs_corr, } ) save_df.to_csv(save_path, index=False)
[docs]def plot_class_fractions(saved_cf_file, fig_dir, filename): """Plot class fractions saved from 'save_class_fractions'. Parameters ---------- saved_cf_file : str Path to the saved class fractions file. fig_dir : str Directory for saving the class fractions plot. filename: str Filename for the class fractions plot figure. """ _, classes_to_labels = SnClass.get_type_maps() labels = [ "Spec (ZTF)", "Spec (YSE)", "Spec (PS1-MDS)", "Phot", "Phot (corr.)", "ALeRCE", "ALeRCE (corr.)", ] width = 0.6 frac_df = pd.read_csv(saved_cf_file) true_fracs, pred_fracs, pred_fracs_corr, alerce_fracs, alerce_fracs_corr = frac_df.to_numpy().T survey_sn_fracs = get_survey_fracs() yse_fracs, psmds_fracs = survey_sn_fracs["YSE"], survey_sn_fracs["PS-MDS"] combined_fracs = np.array( [ true_fracs, yse_fracs, psmds_fracs, pred_fracs, pred_fracs_corr, alerce_fracs, alerce_fracs_corr, ] ).T _, ax = plt.subplots(figsize=(11, 16)) for i in range(5): if i == 0: bottom = 0 else: bottom = np.sum(combined_fracs[0:i], axis=0) stacked_bar = ax.bar( labels, combined_fracs[i], width, bottom=bottom, label=classes_to_labels[i], ) for j, fracs_j in enumerate(combined_fracs[i]): barj = stacked_bar.patches[j] # Create annotation ax.annotate( round(fracs_j, 3), (barj.get_x() + barj.get_width() / 2, barj.get_y() + barj.get_height() / 2), ha="center", va="center", color="white", ) # Shrink current axis's height by 10% on the bottom box = ax.get_position() ax.set_position([box.x0, box.y0 + box.height * 0.1, box.width, box.height * 0.9]) # Put a legend below current axis ax.legend( loc="upper center", bbox_to_anchor=(0.5, -0.05), fancybox=True, shadow=False, ncol=5, fontsize=15 ) ax.tick_params(axis="both", which="major", labelsize=12) ax.tick_params(axis="both", which="minor", labelsize=10) # plt.legend(loc=3) plt.ylabel("Fraction", fontsize=20) plt.savefig(os.path.join(fig_dir, filename)) plt.close()
[docs]def generate_roc_curve(probs_csv, save_dir): """Generate a combined ROC curve of all SN classes. Parameters ---------- probs_csv : str CSV file where class probabilities are stored. save_dir : str Where to save the figure. """ labels_to_classes, classes_to_labels = SnClass.get_type_maps() colors = [plt.cm.Set1(i) for i in range(10)] fig, double_axes = plt.subplots(1, 2, figsize=(8, 7)) ax1, ax2 = double_axes ax1.set_xlim([0.0, 1.05]) ax1.set_ylim([0.0, 1.05]) ax2.set_xlim([0.0, 0.1]) ax2.set_ylim([0.0, 1.05]) ax1.set_ylabel("True Positive Rate") ratio = 1.2 plt.locator_params(axis="x", nbins=3) legend_lines = [] fpr = [] tpr = [] for ref_class, ref_label in enumerate(classes_to_labels): true_classes, probs = read_probs_csv(probs_csv)[1:3] y_true = np.where(true_classes == ref_class, 1, 0) y_score = probs[:, ref_class] single_class_fpr, single_class_tpr, threshholds = roc_curve(y_true, y_score) idx_50 = np.argmin((threshholds - 0.5) ** 2) (legend_line,) = ax1.plot(single_class_fpr, single_class_tpr, label=ref_label, c=colors[ref_class]) ax2.plot(single_class_fpr, single_class_tpr, label=ref_label, c=colors[ref_class]) legend_lines.append(legend_line) ax2.scatter( single_class_fpr[idx_50], single_class_tpr[idx_50], color=colors[ref_class], s=100, marker="d" ) fpr.append(single_class_fpr) tpr.append(single_class_tpr) # First aggregate all false positive rates all_fpr = np.unique(np.concatenate([fpr[i] for i in range(5)])) # Then interpolate all ROC curves at this points mean_tpr = np.zeros_like(all_fpr) for i in range(len(classes_to_labels)): mean_tpr += np.interp(all_fpr, fpr[i], tpr[i]) # Finally average it and compute AUC mean_tpr /= len(classes_to_labels) for ax_i in double_axes: (legend_line,) = ax_i.plot( all_fpr, mean_tpr, label="Macro-averaged", linewidth=3, linestyle="dashed", c="black" ) x_left, x_right = ax_i.get_xlim() y_low, y_high = ax_i.get_ylim() ax_i.set_aspect(abs((x_right - x_left) / (y_low - y_high)) * ratio) ax_i.yaxis.set_minor_locator(AutoMinorLocator()) ax_i.xaxis.set_minor_locator(AutoMinorLocator()) ax_i.set_xlabel("False Positive Rate") legend_lines.append(legend_line) legend_keys = [*list(labels_to_classes.keys()), "Combined"] fig.legend(legend_lines, legend_keys, loc="lower center", ncol=3) plt.savefig(os.path.join(save_dir, "roc_all.pdf"), bbox_inches="tight") plt.close()
[docs]def plot_phase_vs_accuracy(phased_probs_csv, save_dir): """Plot classification accuracy as a function of phase. Parameters ---------- phased_probs_csv : str Where classification probabilities and LC truncated phases are saved. save_dir : str Where to save the output figures. """ fig, axes = plt.subplots(2, 1, sharex=True, figsize=(8, 10), gridspec_kw={"hspace": 0}) ax, ax2 = axes _, classes_to_labels = SnClass.get_type_maps() allowed_types = np.arange(len(classes_to_labels)) true_type, phase, _, pred_type, _ = read_probs_csv(phased_probs_csv) correct_class = (true_type == pred_type).astype(int) legend_lines = [] for allowed_type in allowed_types: correct_t = correct_class[true_type == allowed_type] phase_t = phase[true_type == allowed_type] bins = np.arange(-16, 52, 4) # bins = histedges_equalN(phase_t[phase_t > -18.], 20) correct_hist, _, _ = binned_statistic(phase_t, correct_t, statistic="sum", bins=bins) all_hist, _, _ = binned_statistic(phase_t, np.ones(len(phase_t)), statistic="sum", bins=bins) acc_hist_t = correct_hist / all_hist # acc_hist_comb += acc_hist_t (legend_line,) = ax.step( bins, np.append(acc_hist_t, acc_hist_t[-1]), where="post", label=allowed_type ) legend_lines.append(legend_line) ax.axvline(x=0.0, color="grey", linestyle="dotted") ax.set_ylabel("Classification Accuracy") ax.set_xlim((-18.0, 48.0)) # also plot the over/under-classification fraction of each type compared to final classification legend_lines = [] # bins_eq=histedges_equalN(phase[phase > -30.], 20) # all points bins_eq = np.arange(-16, 52, 4) all_hist, _, _ = binned_statistic(phase, np.ones(len(true_type)), statistic="sum", bins=bins_eq) for allowed_type in allowed_types: eff_num = np.zeros(len(bins_eq) - 1) # effective numerator for allowed_type2 in allowed_types: idx_sub = true_type == allowed_type2 phase_t = phase[idx_sub] bins_eq = np.arange(-16, 52, 4) true_hist, _, _ = binned_statistic(phase_t, np.ones(len(phase_t)), statistic="sum", bins=bins_eq) frac_hist = true_hist / all_hist # within each bin, fraction that is that true type normed_const = 0.2 / frac_hist # get fraction of true type at2 classified as at, and add it to total 'at' fraction idx_sub2 = (true_type == allowed_type2) & (pred_type == allowed_type) phase_sub = phase[idx_sub2] if len(phase_sub) == 0: continue pred_hist, _, _ = binned_statistic( phase_sub, np.ones(len(phase_sub)), statistic="sum", bins=bins_eq ) eff_num += normed_const * pred_hist # acc_hist_comb += acc_hist_t pred_frac = eff_num / all_hist pred_frac_normed = pred_frac / pred_frac[-1] (legend_line,) = ax2.step( bins_eq, np.append(pred_frac_normed, pred_frac_normed[-1]), where="post", label=allowed_type ) legend_lines.append(legend_line) ax2.axhline(y=1.0, color="k", xmin=-30, xmax=50, linestyle="--") ax2.axvline(x=0.0, color="grey", linestyle="dotted") ax2.set_xlabel(r"Phase (days)") ax2.set_ylabel("Overprediction Fraction") ax2.set_xlim((-18.0, 48.0)) fig.legend(legend_lines, [classes_to_labels[x] for x in allowed_types], loc="lower center", ncol=3) plt.savefig(os.path.join(save_dir, "phase_vs_accuracy.pdf"), bbox_inches="tight") plt.close()
[docs]def plot_redshifts_abs_mags(probs_snr_csv, training_csv, fits_dir, save_dir, sampler="dynesty"): """ Plot redshift and absolute magnitude distributions used in the redshift-inclusive classifier. Parameters ---------- probs_snr_csv : str Where probabilities + SNRs are stored. save_dir : str Where to save figures. """ labels_to_classes, _ = SnClass.get_type_maps() allowed_types = list(labels_to_classes.keys()) _, labels, redshifts = import_labels_only( [ training_csv, ], allowed_types, needs_posteriors=True, sampler=sampler, fits_dir=fits_dir ) # labels = np.array([classes_to_labels[int(x)] for x in classes]) probs_dataframe = pd.read_csv(probs_snr_csv) amplitudes = probs_dataframe.Fmax.to_numpy() app_mags = -2.5 * np.log10(amplitudes) + 26.3 k_correction = 2.5 * np.log10(1.0 + redshifts) dist = cosmo.luminosity_distance([redshifts]).value[0] # returns dist in Mpc abs_mags = app_mags - 5.0 * np.log10(dist * 1e6 / 10.0) + k_correction fig, axes = plt.subplots(1, 2, figsize=(8, 6)) z_ax = axes[0] mag_ax = axes[1] _, bin_edges = np.histogram(-abs_mags, bins=40, density=True, range=(15, 25)) bin_width = bin_edges[1] - bin_edges[0] bin_centers = (bin_edges[1:] + bin_edges[:-1]) / 2 legend_lines = [] for allowed_type in allowed_types: features_1_t = -abs_mags[labels == allowed_type] feature_hist, bin_edges = np.histogram(features_1_t, bins=bin_edges, density=True) cumsum = np.cumsum(feature_hist) * bin_width (legend_line,) = mag_ax.step(-bin_centers, cumsum, where="mid", label=allowed_type) legend_lines.append(legend_line) mag_ax.set_xlabel("Absolute Magnitude") mag_ax.invert_xaxis() _, bin_edges = np.histogram(redshifts, bins=40, density=True, range=(-0.1, 0.6)) bin_width = bin_edges[1] - bin_edges[0] bin_centers = (bin_edges[1:] + bin_edges[:-1]) / 2 for allowed_type in allowed_types: features_1_t = redshifts[labels == allowed_type] feature_hist, bin_edges = np.histogram(features_1_t, bins=bin_edges, density=True) cumsum = np.cumsum(feature_hist) * bin_width z_ax.step(bin_centers, cumsum, where="mid", label=allowed_type) z_ax.set_xlabel("Redshift") z_ax.set_ylabel("Cumulative Fraction") for ax in axes: ratio = 1.0 x_left, x_right = ax.get_xlim() y_low, y_high = ax.get_ylim() ax.set_aspect(abs((x_right - x_left) / (y_low - y_high)) * ratio) fig.legend(legend_lines, [*allowed_types, "Combined"], loc="lower center", ncol=3) plt.savefig(os.path.join(save_dir, "abs_mag_hist.pdf"), bbox_inches="tight") plt.close()
[docs]def plot_snr_npoints_vs_accuracy(probs_snr_csv, save_dir): """ Generate plots of number of SNR > 5 points versus accuracy, and top 10% SNR versus accuracy. TODO: add functionality for only one type. Parameters ---------- probs_snr_csv : str Where probabilities + SNRs are stored. save_dir : str Where to save figures. """ _, classes_to_labels = SnClass.get_type_maps() _, true_type, _, pred_classes, _ = read_probs_csv(probs_snr_csv) correct_class = np.where(true_type == pred_classes, 1, 0) df = pd.read_csv(probs_snr_csv) snr, n_high_snr = df.SNR90, df.nSNR3 for unique_type in np.unique(true_type): snr_t = snr[true_type == unique_type] correct_t = correct_class[true_type == unique_type] nbins = 8 while nbins >= 1: try: snr_vs_accuracy, snr_bin_edges, _ = binned_statistic( snr_t, correct_t, "mean", bins=histedges_equalN(snr_t, nbins) ) break except: nbins /= 2 if nbins < 1: continue snr_vs_accuracy[np.isnan(snr_vs_accuracy)] = 1.0 plt.step( snr_bin_edges, np.append(snr_vs_accuracy, snr_vs_accuracy[-1]), label=classes_to_labels[unique_type], where="post", ) plt.xlim((5, 30)) plt.xlabel("90th Percentile SNR") plt.ylabel("Classification Accuracy") plt.legend() plt.savefig(os.path.join(save_dir, "snr_vs_accuracy.pdf"), bbox_inches="tight") plt.close() # second plot for unique_type in np.unique(true_type): correct_t = correct_class[true_type == unique_type] n_high_t = n_high_snr[true_type == unique_type] nbins = 8 while nbins >= 1: try: n_vs_accuracy, n_bin_edges, _ = binned_statistic( n_high_t, correct_t, "mean", bins=histedges_equalN(n_high_t, nbins) ) break except: nbins /= 2 if nbins < 1: continue n_vs_accuracy[np.isnan(n_vs_accuracy)] = 1.0 plt.step( n_bin_edges, np.append(n_vs_accuracy, n_vs_accuracy[-1]), label=classes_to_labels[unique_type], where="post", ) plt.xlim((8, 100)) plt.xlabel(r"Number of $\geq 3\sigma$ Datapoints") plt.ylabel("Classification Accuracy") plt.legend(loc="lower right") plt.savefig(os.path.join(save_dir, "n_vs_accuracy.pdf"), bbox_inches="tight") plt.close()
[docs]def plot_snr_hist(probs_snr_csv, save_dir): """ Replicates SNR plots needed for publication. Parameters ---------- probs_snr_csv : str Where probability + SNR info is stored. save_dir : str Where to save figure. """ df = pd.read_csv(probs_snr_csv) n_snr_3, n_snr_5, n_snr_10 = df.iloc[:, -3:].to_numpy().T skip_mask = (df.iloc[:, 1] == "SKIP").to_numpy() bins = np.arange(0, 603, 3) plt.hist(n_snr_3[~skip_mask], histtype="step", label=r"$SNR \geq 3$", bins=bins) plt.hist(n_snr_5[~skip_mask], histtype="step", label=r"$SNR \geq 5$", bins=bins) plt.hist(n_snr_10[~skip_mask], histtype="step", label=r"$SNR \geq 10$", bins=bins) plt.loglog() plt.xlabel("Number of Datapoints at Given SNR") plt.ylabel("Number of Lightcurves") plt.legend() plt.savefig(os.path.join(save_dir, "snr_hist.pdf"), bbox_inches="tight") plt.close()
[docs]def compare_mag_distributions(probs_classified, probs_unclassified, save_dir, zeropoint=26.3): """ Generate overlaid magnitude distributions of the classified and unclassified datasets. Assumes that unclassified LCs that did not pass the chi-squared cut are marked as "SKIP". Parameters ---------- probs_classified : str CSV filename where probs of spectroscopic set are stored. probs_unclassified : str CSV filename where probs of photometric set are stored. save_dir : str Where to save figure. zeropoint : float, optional Zeropoint used when converting mags to fluxes. Defaults to 26.3. """ classified_df = pd.read_csv(probs_classified) max_flux = classified_df.Fmax.to_numpy() max_r_classified = -2.5 * np.log10(max_flux) + zeropoint unclassified_df = pd.read_csv(probs_unclassified) max_flux = unclassified_df.Fmax.to_numpy() max_r_unclassified_all = -2.5 * np.log10(max_flux) + zeropoint mask_high_chisquared = (unclassified_df.iloc[:, 1] == "SKIP").to_numpy() max_r_unclassified = max_r_unclassified_all[~mask_high_chisquared] max_r_skipped = max_r_unclassified_all[mask_high_chisquared] plt.hist( max_r_classified, histtype="stepfilled", bins=np.arange(5.0, 21.0, 0.5), alpha=0.5, label="Spectroscopic", density=True, ) plt.hist( max_r_unclassified, histtype="stepfilled", bins=np.arange(5.0, 21.0, 0.5), alpha=0.5, label="Photometric (included)", density=True, ) plt.hist( max_r_skipped, histtype="stepfilled", bins=np.arange(5.0, 21.0, 0.5), alpha=0.5, label="Photometric (excluded)", density=True, ) plt.yscale("log") plt.legend(loc="upper left") plt.xlabel("Apparent Magnitude (m)") plt.ylabel("Fraction of Lightcurves") plt.savefig( os.path.join(save_dir, "appm_hist_compare.pdf"), bbox_inches="tight", ) plt.close()
[docs]def plot_chisquared_vs_accuracy( pred_spec_fn, pred_phot_fn, fits_dir, save_dir, sampler=None, ): """ Plot chi-squared value histograms for both the spectroscopic and photometric datasets, and plot spec chi-squared as a function of classification accuracy. TODO: IN PROGRESS Parameters ---------- pred_spec_fn : str CSV filename where probs of spectroscopic set are stored. pred_phot_fn : str CSV filename where probs of photometric set are stored. save_dir : str Where to save figure. """ sn_names, true_classes, _, pred_classes, _ = read_probs_csv(pred_spec_fn) correctly_classified = np.where(true_classes == pred_classes, 1, 0) mult_posteriors = get_multiple_posterior_samples(sn_names, fits_dir, sampler=sampler) train_chis = np.array([-1 * np.mean(mult_posteriors[x][:, -1]) for x in sn_names]) sn_names = read_probs_csv(pred_phot_fn)[0] mult_posteriors = get_multiple_posterior_samples(sn_names, fits_dir, sampler=sampler) train_chis_phot = np.array([-1 * np.mean(mult_posteriors[x][:, -1]) for x in sn_names]) # plot _, ax2 = plt.subplots(figsize=(7, 4.8)) ax1 = ax2.twinx() bins = np.arange(3.5, 14, 0.5) correct_hist, bin_edges, _ = binned_statistic( train_chis, correctly_classified, statistic="sum", bins=bins ) bin_centers = (bin_edges[1:] + bin_edges[:-1]) / 2.0 all_hist, _, _ = binned_statistic(train_chis, np.ones(len(train_chis)), statistic="sum", bins=bins) all_hist_phot, _, _ = binned_statistic( train_chis_phot, np.ones(len(train_chis_phot)), statistic="sum", bins=bins ) ax2.hist(bin_centers, bin_edges, weights=all_hist, color="purple", alpha=0.5, label="Spectroscopic") ax2.hist(bin_centers, bin_edges, weights=all_hist_phot, color="red", alpha=0.5, label="Photometric") ax2.set_yscale("log") all_hist[all_hist == 0] = np.inf acc_hist = correct_hist / all_hist idx_keep = (bin_centers < 10) & (bin_centers > 5) ax1.step( bin_centers[idx_keep], acc_hist[idx_keep], where="mid", color="blue", linewidth=3, label="Accuracy" ) ax1.axvline(x=10, color="black", linestyle="--", linewidth=4, label=r"Phot. $\chi^2$ cutoff") # put bin counts on top of bars """ for bin_i in range(len(bins)-1): try: height = acc_hist[bin_i] plt.annotate( '%d' % all_hist[bin_i], xy=(bin_centers[bin_i], height), xytext=(1, 1), # 3 points vertical offset textcoords="offset points", fontsize=10, ha='center', va='bottom' ) except: plt.annotate( '0', xy=(bin_centers[bin_i], height), xytext=(1, 1), # 3 points vertical offset textcoords="offset points", fontsize=10, ha='center', va='bottom' ) """ ax2.set_xlabel(r"Reduced $\chi^2$") ax1.set_ylabel("Accuracy", va="bottom", rotation=270) ax2.set_ylabel("Counts") ax2.legend() ax1.yaxis.label.set_color("blue") ax1.spines["right"].set_color("blue") ax1.tick_params(axis="y", colors="blue") ax1.legend(loc="lower right") ax1.yaxis.set_minor_locator(AutoMinorLocator()) ax2.yaxis.set_minor_locator(AutoMinorLocator()) plt.savefig(os.path.join(save_dir, "chisq_vs_accuracy.pdf"), bbox_inches="tight") plt.close()
[docs]def plot_model_metrics(metrics, num_epochs, plot_name, metrics_dir): """Plots training and validation results and exports them to files. Parameters ---------- metrics: tuple Train and validation accuracies and losses. num_epochs: int The total number of epochs. plot_name: str The name for the plot figure files. metrics_dir: str Where to store the plot figures. """ train_acc, train_loss, val_acc, val_loss = metrics # Plot accuracy plt.plot(np.arange(0, num_epochs), train_acc, label="Training") plt.plot(np.arange(0, num_epochs), val_acc, label="Validation") plt.xlabel("Epoch") plt.ylabel("Accuracy") plt.legend() plt.savefig( os.path.join(metrics_dir, f"accuracy_{plot_name}.pdf"), bbox_inches="tight", ) plt.close() # Plot loss plt.plot(np.arange(0, num_epochs), train_loss, label="Training") plt.plot(np.arange(0, num_epochs), val_loss, label="Validation") plt.xlabel("Epoch") plt.ylabel("Loss") plt.yscale("log") plt.legend() plt.savefig(os.path.join(metrics_dir, f"loss_{plot_name}.pdf"), bbox_inches="tight") plt.close()