Source code for contextual_robustness.reporting

import typing
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
from contextual_robustness.base import ContextualRobustness, Techniques
from contextual_robustness.utils import _create_output_path, resize_image
from contextual_robustness.datasets import _load_placeholder_images

PLACEHOLDERS = _load_placeholder_images()
NO_CEX_IMG = PLACEHOLDERS.get('no_cex')
NO_IMAGE_IMG = PLACEHOLDERS.get('no_image')

[docs]class ContextualRobustnessReporting: '''Class containing ContextualRobustness reporting functions'''
[docs] @staticmethod def generate_epsilons_plot( cr: ContextualRobustness, outfile:str='epsilons.png', xlabel:str='', ylabel:str='epsilon', axis_fontsize:int=24, fontfamily:str='serif', fontweight:str='ultralight', usetex:bool=True ): '''Plots epsilons by class for a model/transform and saves as png Args: cr (ContextualRobustness): ContextualRobustness object outfile (str, optional): Output file path. Defaults to 'epsilons.png'. xlabel (str, optional): x axis label. Defaults to ''. ylabel (str, optional): y axis label. Defaults to 'epsilon'. axis_fontsize (int, optional): Fontsize for axis text. Defaults to 24. fontfamily (str, optional): Fontfamily for text. Defaults to 'serif'. fontweight (str, optional): Fontfamily for text. Defaults to 'ultralight'. usetex (bool, optional): Use latex for text. Defaults to True. ''' plt.rc('text', usetex=usetex) plt.rc('font', family=fontfamily, weight=fontweight) ax = cr.results.boxplot(column=['epsilon'], by='class', return_type=None) fig = ax.get_figure() fig.suptitle('') plt.title('') plt.xlabel(xlabel, fontsize=axis_fontsize, weight=fontweight) plt.ylabel(ylabel, fontsize=axis_fontsize, weight=fontweight) ax.tick_params(axis='both', labelsize=axis_fontsize) _create_output_path(outfile) plt.savefig(outfile, bbox_inches='tight') plt.close() print(f'saved epsilons plot to {outfile}')
[docs] @staticmethod def generate_counterexamples_plot( cr: ContextualRobustness, outfile:str='./counterexamples.png', dpi:int=144, show_labels:bool=True, label_fontsize:int=12, fontfamily:str='serif', fontweight:str='ultralight', usetex:bool=True ): '''Plots counterexamples for a model/transform and saves as png Args: cr (ContextualRobustness): ContextualRobustness object outfile (str, optional): Output file path. Defaults to './counterexamples.png'. dpi (int, optional): Dots per inch. Defaults to 144. show_labels (bool, optional): Show the class labels? Defaults to True. label_fontsize (int, optional): Fontsize of class labels. Defaults to 12. fontfamily (str, optional): Fontfamily for text. Defaults to 'serif'. fontweight (str, optional): Fontfamily for text. Defaults to 'ultralight'. usetex (bool, optional): Use latex for text. Defaults to True. ''' # configure font plt.rc('text', usetex=usetex) plt.rc('font', family=fontfamily, weight=fontweight) nrows, ncols = 2, len(cr.classes) images = [None] * (nrows * ncols) X, _ = cr.dataset for c in cr.classes: # no_image placeholder for classes where images/results are not present (test or formal). # no_cex placeholder for classes where no counterexample found (formal only). x_orig, x_cex = NO_IMAGE_IMG, NO_CEX_IMG sorted_df = cr.get_results(class_index=c, sort_by=['epsilon']) if sorted_df.shape[0] > 1: mean_epsilon = np.mean(sorted_df.epsilon) upper_df = sorted_df[sorted_df.epsilon >= mean_epsilon] row = upper_df.iloc[0] x_idx = row['image'].astype(int) epsilon = row['epsilon'] x_orig = X[x_idx] if cr.technique == Techniques.TEST: # for test technique, generate counterexample if not saved if cr.get_counterexample(x_idx) is not None: x_cex = cr.get_counterexample(x_idx) else: x_cex = cr.transform_image(x_orig, epsilon) elif cr.technique == Techniques.FORMAL: # for formal technique, counterexample must be read from results for idx in upper_df['image'].astype(int): if cr.get_counterexample(idx) is not None: x_orig = X[idx] x_cex = cr.get_counterexample(idx) break images[c] = dict(title=f'class {c}', image=x_orig) images[c + ncols] = dict(title=f'', image=x_cex) # create the figure & add images figure, ax = plt.subplots( ncols=ncols, nrows=nrows, figsize=(ncols, nrows), gridspec_kw=dict(wspace=0.06, hspace=0.03) ) for i,image in enumerate(images): ax.ravel()[i].imshow(image['image']) # only show image title (class #) on top row if show_labels and i < ncols: ax.ravel()[i].set_title(image['title'], fontsize=label_fontsize) # configure axes & hide ticks plt.setp(ax.ravel()[i].spines.values(), linewidth=0.5) ax.ravel()[i].get_xaxis().set_ticks([]) ax.ravel()[i].get_yaxis().set_ticks([]) # setup output path and save figure to outfile _create_output_path(outfile) figure.savefig(outfile, bbox_inches='tight', dpi=dpi) plt.close() print(f'saved counterexamples plot to {outfile}')
[docs] @staticmethod def generate_class_accuracy_plot( cr:ContextualRobustness, outfile:str='./class-accuracy.png', axis_fontsize:int=12, legend_fontsize:int=14, fontfamily:str='serif', fontweight:str='ultralight', legend_loc:str='best', usetex:bool=True ): '''Plots accuracy of each class at various epsilons for a model/transform and saves as png Args: cr (ContextualRobustness): ContextualRobustness object outfile (str, optional): Output file path. Defaults to './counterexamples.png'. axis_fontsize (int, optional): Fontsize for axis text. Defaults to 12. legend_fontsize (int, optional): Fontsize for axis text. Defaults to 14. fontfamily (str, optional): Fontfamily for text. Defaults to 'serif'. fontweight (str, optional): Fontweight for text. Defaults to 'ultralight'. legend_loc (str, optional): Location of legend. Defaults to 'best'. usetex (bool, optional): Use latex for text. Defaults to True. ''' plt.rc('text', usetex=usetex) plt.rc('font', family=fontfamily, weight=fontweight) # for each class, plot accuracy at different values of epsilon for c in cr.classes: class_total = cr.get_num_samples(class_index=c) # if no samples for the class if class_total == 0: continue class_accuracy = cr.get_accuracy(class_index=c) accuracy_report = [{ 'epsilon': 0, 'accuracy': class_accuracy, 'model': cr.model_name, 'class': c }] sorted_class_df = cr.get_results(class_index=c, sort_by=['epsilon']) xSpace = np.linspace(1, 1000, 1000) / 1000 for x in xSpace: sorted_class_df = sorted_class_df[sorted_class_df['epsilon'] >= x] correct = sorted_class_df.shape[0] if (correct>0): accuracy_report.append({ 'epsilon': x, 'accuracy': correct/class_total, 'model': cr.model_name, 'class' : c }) report_df = pd.DataFrame(accuracy_report, columns=('epsilon', 'accuracy')) plt.plot(report_df['epsilon'], report_df['accuracy'], label=f'class {c}') plt.grid() plt.legend(fontsize=legend_fontsize, loc=legend_loc) plt.tick_params(axis='both', labelsize=axis_fontsize) plt.xlabel('Epsilon', fontsize=axis_fontsize) plt.ylabel('Accuracy', fontsize=axis_fontsize) _create_output_path(outfile) plt.savefig(outfile, bbox_inches='tight') plt.close() print(f'saved class accuracy report plot to {outfile}')
[docs] @staticmethod def generate_accuracy_report_plot( cr_objects:typing.Sequence[ContextualRobustness], outfile:str='./accuracy-report.png', linestyles:tuple=(), axis_fontsize:int=12, legend_fontsize:int=14, fontfamily:str='serif', fontweight:str='ultralight', legend_loc:str='best', usetex:bool=True ): '''Plots epsilon/accuracy for a given transform on multiple models and saves as png Args: cr_objects (typing.Sequence[ContextualRobustness]): list of ContextualRobustness objects outfile (str, optional): Output file path. Defaults to './accuracy-report.png'. linestyles (tuple, optional): [description]. Defaults to (). axis_fontsize (int, optional): Fontsize for axis text. Defaults to 12. legend_fontsize (int, optional): Fontsize for axis text. Defaults to 14. fontfamily (str, optional): Fontfamily for text. Defaults to 'serif'. fontweight (str, optional): Fontweight for text. Defaults to 'ultralight'. legend_loc (str, optional): Location of legend. Defaults to 'best'. usetex (bool, optional): Use latex for text. Defaults to True. ''' plt.rc('text', usetex=usetex) plt.rc('font', family=fontfamily, weight=fontweight) for i,cr in enumerate(cr_objects): model_name = cr.model_name if len(cr.model_name) > 0 else f'model{i}' total = cr.num_samples accuracy = cr.accuracy accuracy_data = [{ 'epsilon': 0, 'accuracy': accuracy, 'model': cr.model_name }] sorted_epsilons = cr.get_results(sort_by=['epsilon']) x_space = np.linspace(1, 1000, 1000) / 1000 for x in x_space: sorted_epsilons = sorted_epsilons[sorted_epsilons['epsilon'] >= x] correct = sorted_epsilons.shape[0] if correct > 0: accuracy_data.append(dict(epsilon=x, accuracy=correct/total, model=model_name)) accuracy_df = pd.DataFrame(accuracy_data, columns=('epsilon', 'accuracy')) ls = linestyles[i] if len(linestyles) > i else '-' plt.plot( accuracy_df['epsilon'], accuracy_df['accuracy'], label=model_name, linestyle=ls ) plt.grid() plt.legend(fontsize=legend_fontsize, loc=legend_loc) plt.tick_params(axis='both', labelsize=axis_fontsize) plt.xlabel('Epsilon', fontsize=axis_fontsize) plt.ylabel('Accuracy', fontsize=axis_fontsize) _create_output_path(outfile) plt.savefig(outfile, bbox_inches='tight') plt.close() print(f'saved accuracy report plot to {outfile}')