Source code for contextual_robustness.formal

import os, sys, typing
import numpy as np
import pandas as pd
import tensorflow as tf
from contextual_robustness.utils import softargmax, _get_file_extension
from contextual_robustness.base import _BaseContextualRobustness, Techniques, ContextualRobustness, DEFAULTS

# TODO: remove sys.path.append after maraboupy pip package is available.
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../marabou')))
from maraboupy import Marabou

[docs]class ContextualRobustnessFormal(_BaseContextualRobustness): '''Class for ContextualRobustness 'Formal Verification' analysis Args: model_path (str, optional): Path to saved tensorflow model. Defaults to ''. model_name (str, optional): Name of model. Defaults to ''. model_args (dict, optional): Args passed to Marabou to load network. Defaults to dict(). X (np.array, optional): The images. Defaults to np.array([]). Y (np.array, optional): Labels for images (onehot encoded). Defaults to np.array([]). sample_indexes (list[int], optional): List of indexes to test from X. Defaults to []. transform_fn (callable, optional): The image transform function (required args: x, epsilon). Defaults to lambda x:x. transform_args (dict, optional): Additional arguments passed to transform_fn. Defaults to dict(). transform_name (str, optional): Name of transform. Defaults to ''. eps_lower (float, optional): Min possible epsilon. Defaults to 0.0. eps_upper (float, optional): Max possible epsilon. Defaults to 1.0. eps_interval (float, optional): Step size between possible epsilons. Defaults to 0.002. marabou_options (dict, optional): Additional MarabouOptions. Defaults to dict(verbosity=0). verbosity (int, optional): Amount of logging (0-4). Defaults to 0. Returns: ContextualRobustness: the ContextualRobustnessFormal object ''' def __init__( self, model_path:str='', model_name:str='', model_args:dict=dict(), transform_fn:callable=lambda x: x, transform_args:dict=dict(), transform_name:str='', X:np.array=np.array([]), Y:np.array=np.array([]), sample_indexes:list=[], eps_lower:float=DEFAULTS['eps_lower'], eps_upper:float=DEFAULTS['eps_upper'], eps_interval:float=DEFAULTS['eps_interval'], marabou_options:dict=dict(verbosity=DEFAULTS['marabou_verbosity']), verbosity:int=DEFAULTS['verbosity'] ) -> ContextualRobustness: self._model_args = model_args self._marabou_options = marabou_options # Execute the superclass's constructor super().__init__( model_path=model_path, model_name=model_name, transform_fn=transform_fn, transform_args=transform_args, transform_name=transform_name, X=X, Y=Y, sample_indexes=sample_indexes, eps_lower=eps_lower, eps_upper=eps_upper, eps_interval=eps_interval, verbosity=verbosity ) @property def technique(self) -> Techniques: '''technique property Returns: Techniques: verification technique (Techniques.FORMAL) ''' return Techniques.FORMAL def _load_model(self, model_path:str) -> Marabou.MarabouNetwork: '''Loads a tensorflow, nnet, or onnx model as a MarabouNetwork Args: model_path (str): Path to the verification model Returns: MarabouNetwork: the MarabouNetwork object ''' valid_exts = ('.nnet', '', '.pb', '.h5', '.hdf5', '.onnx') ext = _get_file_extension(model_path) assert ext in valid_exts, 'Model must be .nnet, .pb, .h5, or .onnx' if ext == '.nnet': return Marabou.read_nnet(model_path, **self._model_args) elif ext in ('', '.pb', '.h5', '.hdf5'): return Marabou.read_tf(model_path, **self._model_args) elif ext == '.onnx': return Marabou.read_onnx(model_path, **self._model_args) return None def _find_correct_sample_indexes(self, X:np.array, Y:np.array) -> typing.List[int]: '''Finds list of indexes for correctly predicted samples Args: X (np.array): The images Y (np.array): Labels (onehot encoded) for the images Returns: list[int]: Indexes of correctly predicted samples from dataset ''' ext = _get_file_extension(self._model_path) if ext == '.nnet' or ext == '.onnx': # For NNet & ONNX models, use Marabou's 'evaluate' to make predictions return [i for i in self._sample_indexes if np.argmax(softargmax(self._model.evaluate(X[i])[0])) == np.argmax(Y[i])] elif ext in ('', '.pb', '.h5', '.hdf5'): # For Tensorflow models, use Tensorflow's 'predict' to make predictions (much faster) model = tf.keras.models.load_model(self._model_path) Y_p = model.predict(np.array([X[si] for si in self._sample_indexes])) return [si for i,si in enumerate(self._sample_indexes) if np.argmax(softargmax(Y_p[i])) == np.argmax(Y[si])] return None def _find_epsilon(self, x:np.array, y:np.array, index:int=None) -> typing.Tuple[float, float, float, int, np.array]: '''Finds the epsilon for an image Args: x (np.array): The image y (np.array): Label for the image (onehot encoded) index (int, optional): Index of x. Defaults to None. Returns: tuple[float, float, float, int, np.array]: (lower, upper, epsilon, predicted_label, counterexample) ''' lower = self._eps_lower upper = self._eps_upper interval = self._eps_interval actual_label = np.argmax(y) predicted_label = actual_label epsilon = upper counterexample = None while ((upper - lower) > interval): guess = lower + (upper - lower) / 2.0 verified, pred, cex = self._find_counterexample(x, y, guess, x_index=index) if self._verbosity > 1: print(f'evaluated image:{index}@epsilon:{guess}, label:{actual_label}, pred:{pred}') if verified: # correct prediction lower = guess else: # incorrect prediction upper = guess predicted_label = pred epsilon = guess counterexample = cex return lower, upper, epsilon, predicted_label, counterexample def _find_counterexample(self, x:np.array, y:np.array, epsilon:float, x_index:int=None) -> typing.Tuple[bool, int, np.array]: '''Finds the the counterexample for an image at a given epsilon Args: x (np.array): The image y (np.array): Label for the image (onehot encoded) epsilon (float): the epsilon value x_index (int, optional): Index of x (for reference only). Defaults to None. Returns: tuple[bool, int, np.array]: (verified, predicted_label, counterexample) ''' actual_label = np.argmax(y) predicted_label = actual_label verified = True counterexample = None for output_index in range(y.shape[0]): if actual_label == output_index: continue # load model, encode the transform as a marabou input query, and solve query network = self._load_model(self._model_path) network = self._transform_fn(network, x, epsilon, output_index, **self._transform_args) marabou_options = {'verbosity':DEFAULTS['marabou_verbosity'], **self._marabou_options} vals, stats = network.solve(options=Marabou.createOptions(**marabou_options), verbose=(self._verbosity > 3)) # check results if stats.hasTimedOut(): # TIMEOUT verified = False assert False, f'Timeout occurred ({f"x_index={x_index}" if x_index is not None else ""};output={output_index}@epsilon={epsilon})' elif len(vals) == 0: # UNSAT if self._verbosity > 2: print(f'image:{x_index};output:{output_index}@epsilon:{epsilon} (UNSAT)') continue else: # SAT (counterexample found) if self._verbosity > 2: print(f'image:{x_index};output:{output_index}@epsilon:{epsilon} (SAT)') counterexample = np.array([vals[i] for i in range(self.n_pixels)]).reshape(self.image_shape) predicted_label = output_index verified = False break return verified, predicted_label, counterexample