Instance-based Counterfactual Explanations for Time Series Classification¶
A counterfactual explanation, originally introduced to machine learning by [2], answers the question "what if" by building counterexamples. Based on an input instance $x$, the goal is to find a counterfactual $x^{cf}$ close to the original instance $x$ but differently classified $y \neq y^{cf}$ by a predictor $f$. The intention is to visualize boundary cases. Further research has shown that counterfactual explanations are easy to understand for humans because they are intuitive to human thinking by showing counterexamples.
Delaney et al.[1] proposed using the K-nearest neighbors from the dataset belonging to a different class as native guide to generate counterfactuals. They propose three options for transforming the original time series with this native guide: the plain native guide, the native guide with bary centering, and transformation based on the native guide and class activation mapping. The desired method can be selected by providing the method parameter during interpretability instantiation.
The code in TSInterpret is based on the authors implementation .
[1] E. Delaney, D. Greene, and M. T. Keane. Instance-Based Counterfactual Explanations for Time Series Classification. In A. A. S ́anchez-Ruiz and M. W.Floyd, editors, Case-Based Reasoning Research and Development, volume 12877, pages 32–47. Springer International Publishing, Cham, 2021. ISBN 978-3-030-86956-4 978-3-030-86957-1. doi: 10.1007/978-3-030-86957-1 3. Series Title: Lecture Notes in Computer Science.
[2] Wachter, Sandra, Brent Mittelstadt, and Chris Russell. "Counterfactual explanations without opening the black box: Automated decisions and the GDPR." Harv. JL & Tech. 31 (2017): 841.
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from tslearn.datasets import UCR_UEA_datasets
import sklearn
import pickle
import numpy as np
import torch
from ClassificationModels.CNN_T import ResNetBaseline, UCRDataset
import warnings
warnings.filterwarnings("ignore")
Load Data¶
Load Data and reshape the data to fit a 1D-Conv ResNet. Note that the input for a 1D-Conv Resnet hat the shape (batch, features, timesteps).
dataset='ElectricDevices'
train_x,train_y, test_x, test_y=UCR_UEA_datasets().load_dataset(dataset)
train_x = train_x.reshape(-1,1, train_x.shape[-2])
test_x = test_x.reshape(-1,1, test_x.shape[-2])
enc1=sklearn.preprocessing.OneHotEncoder(sparse=False).fit(np.vstack((train_y.reshape(-1,1),test_y.reshape(-1,1))))
f= open(f'../../ClassificationModels/models/{dataset}/OneHotEncoder.pkl','wb')
pickle.dump(enc1,f)
train_y=enc1.transform(train_y.reshape(-1,1))
test_y=enc1.transform(test_y.reshape(-1,1))
Model Training / Loading¶
Trains a ResNet and saves the results.
n_pred_classes =train_y.shape[1]
train_dataset = UCRDataset(train_x.astype(np.float64),train_y.astype(np.int64))
test_dataset = UCRDataset(test_x.astype(np.float64),test_y.astype(np.int64))
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=16,shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=1,shuffle=False)
model = ResNetBaseline(in_channels=1, num_pred_classes=n_pred_classes)
state_dict= torch.load(f'../../ClassificationModels/models/{dataset}/ResNet')
model.load_state_dict(state_dict)
<All keys matched successfully>
Interpretability Algorithm¶
Using a interpretability algorithm consists of 4 steps:
- Load the Interpretability Method
- Instaniate the Method with the desired Parameters
- Call the explain Method
- Plot the results
1. & 2. Loading & Instantiation¶
Native Guide counterfactuals works on all models tensorflow and pyTorch models returning a probability function. The Initialization takes the following arguments:
`model`: The model to be explaines.
`shape`: The data shape in form of (features, timesteps).
`data`: Tuple of Data and Labels used to find and build CF.
`backend`: `PYT`, `SK`, or `TF`.
`mode`: second dimension is either `feat` or `time`.
`method`: ['NUN_CF', "dtw_bary_center", 'NG'].
model.eval()
item=test_x[20].reshape(1,1,-1)
shape=item.shape
_item= torch.from_numpy(item).float()
model.eval()
y_target = torch.nn.functional.softmax(model(_item)).detach().numpy()
from TSInterpret.InterpretabilityModels.counterfactual.NativeGuideCF \
import NativeGuideCF
exp_model=NativeGuideCF(model,(train_x,train_y), \
backend='PYT', mode='feat',method='NUN_CF')
WARNING:root:no value was provided for `target_layer`, thus set to 'layers'. WARNING:root:no value was provided for `fc_layer`, thus set to 'final'.
3. Call the explain method.¶
Prepeare the instance and the predicted label of the instance as parameters for the explain methods.
item
: item to be explained.target
: desired target class of CF.
exp,label=exp_model.explain(item, np.argmax(y_target,axis=1)[0])
[1.0] WEIGHT SHAPE (1, 96) MOST INF 1 61 [1.0, 0.8968790769577026] 61 [0.6677367091178894, 1.0, 0.8968790769577026] 60 [0.6677367091178894, 1.0, 0.8968790769577026, 0.49875831604003906] 60 [0.15908150374889374, 0.6677367091178894, 1.0, 0.8968790769577026, 0.49875831604003906] 59 [0.21561335027217865, 0.15908150374889374, 0.6677367091178894, 1.0, 0.8968790769577026, 0.49875831604003906] 58 [0.32781869173049927, 0.21561335027217865, 0.15908150374889374, 0.6677367091178894, 1.0, 0.8968790769577026, 0.49875831604003906] 57 [0.3303625285625458, 0.32781869173049927, 0.21561335027217865, 0.15908150374889374, 0.6677367091178894, 1.0, 0.8968790769577026, 0.49875831604003906] 56
4. Visualization¶
Plot function takes as input the item to be explained and the returned explanation, as well as the assigned labels.
exp_model.plot(item,np.argmax(y_target,axis=1)[0],exp,label)
exp_model.plot_in_one(item,np.argmax(y_target,axis=1)[0],exp,label)
Additional Examples, e.g. for the use with LSTM or TF can be found here.