Instance-based Counterfactual Explanations for Time Series Classification¶
Delaney et al.[1] proposed using the K-nearest neighbors from the dataset belonging to a different class as native guide to generate counterfactuals. They propose three options for transforming the original time series with this native guide: the plain native guide, the native guide with bary centering, and transformation based on the native guide and class activation mapping. The desired method can be selected by providing the method parameter during interpretability instantiation.
The code in TSInterpret is based on the authors implementation .
[1] E. Delaney, D. Greene, and M. T. Keane. Instance-Based Counterfactual Explanations for Time Series Classification. In A. A. S ́anchez-Ruiz and M. W. Floyd, editors, Case- Based Reasoning Research and Development, volume 12877, pages 32–47. Springer International Publishing, Cham, 2021. ISBN 978-3-030-86956-4 978-3-030-86957-1. doi: 10.1007/978-3-030-86957-1 3. Series Title: Lecture Notes in Computer Science.
from tslearn.datasets import UCR_UEA_datasets
import sklearn
import pickle
import numpy as np
import torch
from ClassificationModels.CNN_T import ResNetBaseline, get_all_preds, fit, UCRDataset
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report
import pandas as pd
from tslearn.datasets import UCR_UEA_datasets
import os
import warnings
warnings.filterwarnings("ignore")
/home/jacqueline/.local/share/virtualenvs/TSInterpret-x4eqnPOt/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
Load Data¶
Load Data and reshape the data to fit a 1D-Conv ResNet. Note that the input for a 1D-Conv Resnet hat the shape (batch, features, timesteps).
dataset='ECG5000'
train_x,train_y, test_x, test_y=UCR_UEA_datasets().load_dataset(dataset)
train_x = train_x.reshape(-1,1, train_x.shape[-2])
test_x = test_x.reshape(-1,1, test_x.shape[-2])
np.unique(test_y)
array([1, 2, 3, 4, 5])
enc1=sklearn.preprocessing.OneHotEncoder(sparse=False).fit(np.vstack((train_y.reshape(-1,1),test_y.reshape(-1,1))))
f= open(f'../../ClassificationModels/models/{dataset}/OneHotEncoder.pkl','wb')
pickle.dump(enc1,f)
train_y=enc1.transform(train_y.reshape(-1,1))
test_y=enc1.transform(test_y.reshape(-1,1))
Model Training / Loading¶
Trains a ResNet and saves the results.
n_pred_classes =train_y.shape[1]
train_dataset = UCRDataset(train_x.astype(np.float64),train_y.astype(np.int64))
test_dataset = UCRDataset(test_x.astype(np.float64),test_y.astype(np.int64))
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=16,shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=1,shuffle=False)
model = ResNetBaseline(in_channels= train_x.shape[-2], num_pred_classes=n_pred_classes)
fit(model,train_loader,test_loader)
if dataset in os.listdir('../../ClassificationModels/models/'):
print('Folder exists')
else:
os.mkdir(f'../../ClassificationModels/models/{dataset}')
torch.save(model.state_dict(), f'../../ClassificationModels/models/{dataset}/ResNet')
test_preds, ground_truth = get_all_preds(model, test_loader)
ground_truth=np.argmax(ground_truth,axis=1)
#test_preds=np.argmax(test_preds,axis=1)
sns.set(rc={'figure.figsize':(5,4)})
heatmap=confusion_matrix(ground_truth, test_preds)
sns.heatmap(heatmap, annot=True)
plt.savefig(f'../../ClassificationModels/models/{dataset}/ResNet_confusion_matrix.png')
plt.close()
acc= accuracy_score(ground_truth, test_preds)
a = classification_report(ground_truth, test_preds, output_dict=True)
dataframe = pd.DataFrame.from_dict(a)
dataframe.to_csv(f'../../ClassificationModels/models/{dataset}/classification_report.csv', index = False)
Epoch: 1, Train loss: 0.569, Val loss: 0.436 Epoch: 2, Train loss: 0.486, Val loss: 0.296 Epoch: 3, Train loss: 0.305, Val loss: 0.296 Epoch: 4, Train loss: 0.345, Val loss: 0.393 Epoch: 5, Train loss: 0.269, Val loss: 0.273 Epoch: 6, Train loss: 0.342, Val loss: 0.368 Epoch: 7, Train loss: 0.273, Val loss: 0.297 Epoch: 8, Train loss: 0.231, Val loss: 0.292 Epoch: 9, Train loss: 0.214, Val loss: 0.559 Epoch: 10, Train loss: 0.247, Val loss: 0.258 Epoch: 11, Train loss: 0.231, Val loss: 0.25 Epoch: 12, Train loss: 0.201, Val loss: 0.303 Epoch: 13, Train loss: 0.172, Val loss: 0.31 Epoch: 14, Train loss: 0.222, Val loss: 0.28 Epoch: 15, Train loss: 0.185, Val loss: 0.278 Epoch: 16, Train loss: 0.227, Val loss: 0.254 Epoch: 17, Train loss: 0.192, Val loss: 0.227 Epoch: 18, Train loss: 0.156, Val loss: 0.238 Epoch: 19, Train loss: 0.149, Val loss: 0.233 Epoch: 20, Train loss: 0.197, Val loss: 0.299 Epoch: 21, Train loss: 0.225, Val loss: 0.29 Epoch: 22, Train loss: 0.193, Val loss: 0.281 Epoch: 23, Train loss: 0.135, Val loss: 0.254 Epoch: 24, Train loss: 0.136, Val loss: 0.293 Epoch: 25, Train loss: 0.128, Val loss: 0.234 Epoch: 26, Train loss: 0.153, Val loss: 0.444 Epoch: 27, Train loss: 0.213, Val loss: 0.299 Epoch: 28, Train loss: 0.147, Val loss: 0.321 Epoch: 29, Train loss: 0.169, Val loss: 0.264 Epoch: 30, Train loss: 0.151, Val loss: 0.254 Epoch: 31, Train loss: 0.146, Val loss: 0.254 Epoch: 32, Train loss: 0.143, Val loss: 0.314 Epoch: 33, Train loss: 0.154, Val loss: 0.395 Epoch: 34, Train loss: 0.176, Val loss: 0.417 Epoch: 35, Train loss: 0.151, Val loss: 0.266 Epoch: 36, Train loss: 0.186, Val loss: 0.36 Epoch: 37, Train loss: 0.12, Val loss: 0.325 Epoch: 38, Train loss: 0.135, Val loss: 0.292 Epoch: 39, Train loss: 0.102, Val loss: 0.272 Epoch: 40, Train loss: 0.108, Val loss: 0.28 Epoch: 41, Train loss: 0.114, Val loss: 0.324 Epoch: 42, Train loss: 0.107, Val loss: 0.319 Epoch: 43, Train loss: 0.099, Val loss: 0.29 Epoch: 44, Train loss: 0.103, Val loss: 0.469 Epoch: 45, Train loss: 0.137, Val loss: 0.556 Epoch: 46, Train loss: 0.149, Val loss: 0.326 Epoch: 47, Train loss: 0.187, Val loss: 0.427 Epoch: 48, Train loss: 0.135, Val loss: 0.309 Epoch: 49, Train loss: 0.148, Val loss: 0.259 Epoch: 50, Train loss: 0.105, Val loss: 0.325 Epoch: 51, Train loss: 0.103, Val loss: 0.398 Epoch: 52, Train loss: 0.1, Val loss: 0.345 Epoch: 53, Train loss: 0.076, Val loss: 0.368 Epoch: 54, Train loss: 0.102, Val loss: 0.317 Epoch: 55, Train loss: 0.093, Val loss: 0.387 Epoch: 56, Train loss: 0.107, Val loss: 0.7 Epoch: 57, Train loss: 0.174, Val loss: 0.29 Epoch: 58, Train loss: 0.116, Val loss: 0.314 Epoch: 59, Train loss: 0.079, Val loss: 0.366 Epoch: 60, Train loss: 0.068, Val loss: 0.355 Epoch: 61, Train loss: 0.084, Val loss: 0.326 Epoch: 62, Train loss: 0.082, Val loss: 0.39 Epoch: 63, Train loss: 0.073, Val loss: 0.338 Epoch: 64, Train loss: 0.091, Val loss: 0.367 Epoch: 65, Train loss: 0.081, Val loss: 0.368 Epoch: 66, Train loss: 0.074, Val loss: 0.453 Epoch: 67, Train loss: 0.079, Val loss: 0.557 Epoch: 68, Train loss: 0.096, Val loss: 0.328 Epoch: 69, Train loss: 0.074, Val loss: 0.427 Epoch: 70, Train loss: 0.061, Val loss: 0.416 Epoch: 71, Train loss: 0.071, Val loss: 0.413 Epoch: 72, Train loss: 0.057, Val loss: 0.419 Epoch: 73, Train loss: 0.047, Val loss: 0.437 Epoch: 74, Train loss: 0.041, Val loss: 0.447 Epoch: 75, Train loss: 0.042, Val loss: 0.475 Epoch: 76, Train loss: 0.046, Val loss: 0.457 Epoch: 77, Train loss: 0.038, Val loss: 0.738 Epoch: 78, Train loss: 0.09, Val loss: 0.609 Epoch: 79, Train loss: 0.102, Val loss: 0.48 Epoch: 80, Train loss: 0.051, Val loss: 0.609 Epoch: 81, Train loss: 0.088, Val loss: 0.663 Epoch: 82, Train loss: 0.15, Val loss: 0.62 Epoch: 83, Train loss: 0.194, Val loss: 0.343 Epoch: 84, Train loss: 0.114, Val loss: 0.338 Epoch: 85, Train loss: 0.078, Val loss: 0.405 Epoch: 86, Train loss: 0.066, Val loss: 0.469 Epoch: 87, Train loss: 0.054, Val loss: 0.442 Epoch: 88, Train loss: 0.044, Val loss: 0.483 Epoch: 89, Train loss: 0.035, Val loss: 0.519 Epoch: 90, Train loss: 0.031, Val loss: 0.625 Epoch: 91, Train loss: 0.029, Val loss: 0.613 Epoch: 92, Train loss: 0.046, Val loss: 0.421 Epoch: 93, Train loss: 0.061, Val loss: 0.505 Epoch: 94, Train loss: 0.07, Val loss: 0.707 Epoch: 95, Train loss: 0.19, Val loss: 0.36 Epoch: 96, Train loss: 0.104, Val loss: 0.497 Epoch: 97, Train loss: 0.144, Val loss: 0.356 Epoch: 98, Train loss: 0.135, Val loss: 0.391 Epoch: 99, Train loss: 0.086, Val loss: 0.376 Epoch: 100, Train loss: 0.055, Val loss: 0.463 Epoch: 101, Train loss: 0.06, Val loss: 0.49 Epoch: 102, Train loss: 0.071, Val loss: 0.51 Epoch: 103, Train loss: 0.055, Val loss: 0.449 Epoch: 104, Train loss: 0.107, Val loss: 0.405 Epoch: 105, Train loss: 0.071, Val loss: 0.497 Epoch: 106, Train loss: 0.102, Val loss: 0.652 Epoch: 107, Train loss: 0.097, Val loss: 0.533 Epoch: 108, Train loss: 0.057, Val loss: 0.431 Epoch: 109, Train loss: 0.04, Val loss: 0.521 Epoch: 110, Train loss: 0.039, Val loss: 0.576 Epoch: 111, Train loss: 0.034, Val loss: 0.52 Epoch: 112, Train loss: 0.026, Val loss: 0.697 Epoch: 113, Train loss: 0.024, Val loss: 0.638 Epoch: 114, Train loss: 0.053, Val loss: 0.872 Epoch: 115, Train loss: 0.075, Val loss: 0.674 Epoch: 116, Train loss: 0.108, Val loss: 0.584 Epoch: 117, Train loss: 0.046, Val loss: 0.707 Early stopping! Folder exists
Interpretability Algorithm¶
Using a interpretability algorithm consists of 4 steps:
- Load the Interpretability Method
- Instaniate the Method with the desired Parameters
- Call the explain Method
- Plot the results
1. & 2. Loading & Instantiation¶
Native Guide counterfactuals works on all models tensorflow and pyTorch models returning a probability function. The Initialization takes the following arguments:
`model`: The model to be explaines.
`shape`: The data shape in form of (features, timesteps).
`data`: Tuple of Data and Labels used to find and build CF.
`backend`: `PYT`, `SK`, or `TF`.
`mode`: second dimension is either `feat` or `time`.
`method`: Optimization Method either `brut` or `opt`.
item=test_x[10].reshape(1,1,-1)
shape=item.shape
_item= torch.from_numpy(item).float()
model.eval()
y_target = torch.nn.functional.softmax(model(_item)).detach().numpy()
from TSInterpret.InterpretabilityModels.counterfactual.NativeGuideCF \
import NativeGuideCF
exp_model=NativeGuideCF(model,(train_x,train_y), \
backend='PYT', mode='feat',method="NUN_CF")
2023-08-04 16:53:38.340266: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. 2023-08-04 16:53:39.151477: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT WARNING:root:no value was provided for `target_layer`, thus set to 'layers'. WARNING:root:no value was provided for `fc_layer`, thus set to 'final'.
(1, 140)
3. Call the explain method.¶
Prepeare the instance and the predicted label of the instance as parameters for the explain methods.
item
: item to be explained.target
: desired target class of CF.
exp,label=exp_model.explain(item, np.argmax(y_target,axis=1) )
4. Visualization¶
Plot function takes as input the item to be explained and the returned explanation, as well as the assigned labels.
exp_model.plot(item,np.argmax(y_target,axis=1)[0],exp,label)
exp_model.plot_in_one(item,'Normal',exp,'Abnormal','NUN_CF')
def plot_in_one(self,item,org_label,exp,cf_label, save_fig= None,figsize=(15,15)):
"""
Plot Function for Counterfactuals in uni-and multivariate setting. In the multivariate setting only the changed features are visualized.
Arguments:
item np.array: original instance.
org_label int: originally predicted label.
exp np.array: returned explanation.
cf_label int: lebel of returned instance.
figsize Tuple: Size of Figure (x,y).
save_fig str: Path to Save the figure.
"""
if self.mode == 'time':
item = item.reshape(item.shape[0],item.shape[2],item.shape[1])
#TODO This is new and needs to be testes
if item.shape[-2]>1:
res = (item != exp).any(-1)
ind=np.where(res[0])
if len(ind[0]) == 0:
print('Items are identical')
return
elif len(ind[0]) > 1:
self.plot_multi(item,org_label,exp,cf_label,figsize=figsize, save_fig=save_fig)
return
else:
item =item[ind]
plt.style.use("classic")
colors = [
'#08F7FE', # teal/cyan
'#FE53BB', # pink
'#F5D300', # yellow
'#00ff41', # matrix green
]
indices= np.where(exp[0] != item)
df = pd.DataFrame({f'Predicted: {org_label}': list(item.flatten()),
f'Counterfactual: {cf_label}': list(exp.flatten())})
fig, ax = plt.subplots(figsize=(10,5))
df.plot(marker='.', color=colors, ax=ax)
# Redraw the data with low alpha and slighty increased linewidth:
n_shades = 10
diff_linewidth = 1.05
alpha_value = 0.3 / n_shades
for n in range(1, n_shades+1):
df.plot(marker='.',
linewidth=2+(diff_linewidth*n),
alpha=alpha_value,
legend=False,
ax=ax,
color=colors)
ax.grid(color='#2A3459')
plt.xlabel('Time', fontweight = 'bold', fontsize='large')
plt.ylabel('Value', fontweight = 'bold', fontsize='large')
if save_fig is None:
plt.show()
else:
plt.savefig(save_fig)
Additional Examples, e.g. for the use with LSTM or TF can be found here.