SETS¶
A counterfactual explanation, originally introduced to machine learning by [1], answers the question "what if" by building counterexamples. Based on an input instance $x$, the goal is to find a counterfactual $x^{cf}$ close to the original instance $x$ but differently classified $y \neq y^{cf}$ by a predictor $f$. The intention is to visualize boundary cases. Further research has shown that counterfactual explanations are easy to understand for humans because they are intuitive to human thinking by showing counterexamples.
Shapelet-based Temporal Association Rule Mining for Multivariate Time Series Classification developed by Bahri et al. [1] builds counterfactuals in multivariate setting by perturbing the features of a time series with the help of a shapelet algorithm. Hereby they extract the most prominent class shapelets using an adaption of the shapelet transform algorithm of [2] implemented in sktime. The algorithm is adapted to multivariate timeseries analysis, treating each dimension as a univariate timeseries and information gain. The algorithm imputes the shapelets based on prior observed occcurences and shapelet quality. If a single shapelet does not cause a counterfactual prediction, more perturbations in other dimensions or other parts of the timeseries are performed.
Visualization of SETS from the original paper [1]; The code in TSInterpret is based on the authors implementation .
[1] Omar Bahri and Soukaina Filali Boubrahimi and Shah Muhammad Hamdi.Shapelet-Based Counterfactual Explanations for Multivariate Time Series. 2022 ACM SIGKDD Workshop on Mining and Learning from Time Series. arXiv:2208.10462 URL:https://arxiv.org/abs/2208.10462 (visited on 2023-12-04)
[2] Hills, Jon and Lines, Jason and Baranauskas, Edgaras and Mapp, James and Bagnall, Anthony. Classification of time series by shapelet transformation. Data Mining and Knowledge Discovery. Harv. JL & Tech. 31 (2017): 841.
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
import numpy as np
from tslearn.datasets import UCR_UEA_datasets
import warnings
import tensorflow.keras as keras
warnings.filterwarnings("ignore")
2023-12-15 17:29:49.635370: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-12-15 17:29:49.635399: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-12-15 17:29:49.636427: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2023-12-15 17:29:49.641956: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. 2023-12-15 17:29:50.349624: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Load Data and Model¶
- Load Data and reshape the data to fit a 1D-Conv Convolutional Neural Network (CNN). Note that the input for a 1D-Conv CNN hat the shape (batch, timesteps, features).
- Load Model with Pretrained Weigths.
data = UCR_UEA_datasets().load_dataset("ECG200")
model = keras.models.load_model("../../ClassificationModels/models/ECG200/best_model_ecg200_fcn_wang_ep200.hdf5")
2023-12-15 17:29:51.626851: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected 2023-12-15 17:29:51.626885: I external/local_xla/xla/stream_executor/cuda/cuda_diagnostics.cc:129] retrieving CUDA diagnostic information for host: jacqueline-ThinkPad-P53 2023-12-15 17:29:51.626893: I external/local_xla/xla/stream_executor/cuda/cuda_diagnostics.cc:136] hostname: jacqueline-ThinkPad-P53 2023-12-15 17:29:51.626959: I external/local_xla/xla/stream_executor/cuda/cuda_diagnostics.cc:159] libcuda reported version is: 525.147.5 2023-12-15 17:29:51.626980: I external/local_xla/xla/stream_executor/cuda/cuda_diagnostics.cc:163] kernel reported version is: 525.147.5 2023-12-15 17:29:51.626985: I external/local_xla/xla/stream_executor/cuda/cuda_diagnostics.cc:241] kernel version seems to match DSO: 525.147.5
Interpretability Algorithm¶
Using a interpretability algorithm consists of 4 steps:
- Load the Interpretability Method
- Instaniate the Method with the desired Parameters
- Call the explain Method
- Plot the results
1. & 2. Loading & Initialization¶
SETS works on all models returning a probability function. The Initialization takes the following arguments:
`model`: The model to be explaines.
`data`: Tuple of Data and Labels.
`backend`: `PYT`, `SK`, or `TF`.
`mode`: second dimension is either `feat` or `time`.
`method`: Optimization Method either `brut` or `opt`.
`min_shapelet_len` : Value for min length of extracted shapelets
`max_shapelet_len`: Value for max length of extracted shapelets
`time_contract_in_mins_per_dim` : Max time for shapelet extraction per dimension
`initial_num_shapelets_per_case` : Initial number of shapelets per case.
For Univariate ECG200 Case¶
%load_ext autoreload
%autoreload 2
from TSInterpret.InterpretabilityModels.counterfactual.SETSCF import SETSCF
train_x, train_y, test_x, test_y = data
# Note: this is for demonstration purposes, time contract per minutes should be adjusted for better results!
exp_model= SETSCF(model,
(train_x, train_y),
backend='TF',
mode='time',
min_shapelet_len=3,
max_shapelet_len=20,
time_contract_in_mins_per_dim=1,
#initial_num_shapelets_per_case=10,
fit_shapelets = False)
Extract Shapelets with information gain rejection lvl 0.001 and shapelets per class of 30
3. Call the fit method.¶
This method is optional! If shapelets are already exctracted, please specify here the occlusion threshhold and if shapelets belonging to multiple classes should be retained. This method is called by default if the explain_model is fitted.
exp_model.fit(occlusion_threshhold=1e-1,remove_multiclass_shapelets=True)
Fit function to prune shapelets with occlusion threshhold of 0.1 and remove shapelets belonging to more than one class set to True Shapelet by index per class and dimension: {0: [[1, 4, 5, 7, 8, 9, 11, 13, 14, 15, 20, 21, 23, 25, 30, 31, 34]], 1: [[0, 2, 3, 6, 10, 12, 16, 17, 18, 19, 22, 24, 26, 27, 28, 29, 32, 33]]}
4. Call the explain method.¶
Prepeare the instance and the predicted label of the instance as parameters for the explain methods.
item
: item to be explained
ts = 5
test_y[ts]
cf_explanation, label = exp_model.explain(np.swapaxes(test_x[ts],0,1), target = None)
1/1 [==============================] - 0s 108ms/step 1/1 [==============================] - 0s 15ms/step 1/1 [==============================] - 0s 14ms/step 1/1 [==============================] - 0s 14ms/step 1/1 [==============================] - 0s 14ms/step 1/1 [==============================] - 0s 13ms/step 1/1 [==============================] - 0s 13ms/step 1/1 [==============================] - 0s 13ms/step 1/1 [==============================] - 0s 13ms/step 1/1 [==============================] - 0s 13ms/step 1/1 [==============================] - 0s 13ms/step 1/1 [==============================] - 0s 14ms/step 1/1 [==============================] - 0s 13ms/step 1/1 [==============================] - 0s 13ms/step 1/1 [==============================] - 0s 13ms/step 1/1 [==============================] - 0s 13ms/step 1/1 [==============================] - 0s 13ms/step 1/1 [==============================] - 0s 13ms/step 1/1 [==============================] - 0s 13ms/step 1/1 [==============================] - 0s 13ms/step Counterfactual has been found
4. Visualization¶
All plot function take as input the item to be explained and the returned explanation. As as additonal option a figsize can be given.
exp_model.plot(
cf_explanation,
np.argmax(model(data[2])[ts]),
test_x[ts],
label,
vis_change=True,
all_in_one=False,
save_fig=None,
figsize=(6.4, 4.8))
Multivariate Case¶
data = UCR_UEA_datasets().load_dataset("BasicMotions")
model = keras.models.load_model("../../ClassificationModels/models/BasicMotions/best_model_bm_fcn_wang_ep200.hdf5")
train_x,train_y,test_x,test_y = data
exp_model= SETSCF(model,
(train_x,train_y),
backend='TF',
mode='time',
min_shapelet_len=3,
max_shapelet_len=20,
time_contract_in_mins_per_dim=1,
#initial_num_shapelets_per_case=10,
fit_shapelets = False)
Extract Shapelets with information gain rejection lvl 0.001 and shapelets per class of 30
exp_model.fit(occlusion_threshhold=1e-1,remove_multiclass_shapelets=True)
Fit function to prune shapelets with occlusion threshhold of 0.1 and remove shapelets belonging to more than one class set to True Shapelet by index per class and dimension: {0: [[0, 1, 25, 26, 31, 32, 33, 34, 39, 40, 41, 42, 43, 44, 45, 46, 49, 50, 51, 56, 57, 58, 60, 80], [0, 1, 2, 3, 4, 38, 42, 43, 44, 45, 46, 48, 49, 52, 53, 59, 61, 64, 69, 70, 72, 78, 86, 87], [0, 6, 11, 12, 20, 21, 23, 25, 28, 30, 31, 41, 44, 45, 48, 51, 57, 58, 60, 69, 70, 76, 77, 83], [17, 18, 27, 31, 32, 33, 34, 35, 37, 40, 43, 44, 46, 48, 55, 57, 59, 60, 65, 75, 78, 79], [3, 10, 12, 17, 18, 25, 30, 31, 32, 33, 34, 39, 40, 49, 59, 60, 61, 62, 64, 65, 72, 73], [0, 1, 2, 3, 4, 5, 6, 7, 24, 25, 26, 34, 37, 38, 39, 40, 47, 51, 52, 55, 62, 71]], 1: [[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 35], [5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 39, 47, 73], [7, 16, 26, 29, 32, 33, 34, 38, 42, 52, 56, 59, 61, 62, 63, 64, 68, 72, 78, 79, 80, 84, 85, 86], [0, 2, 3, 4, 7, 8, 10, 11, 13, 14, 21, 22, 23, 26, 28, 38, 41, 47, 63], [0, 4, 5, 6, 7, 13, 14, 15, 16, 19, 21, 23, 26, 35, 36, 37, 50, 57, 63, 70, 74, 77], [8, 9, 10, 11, 12, 13, 27, 28, 29, 30, 35, 41, 42, 43, 45, 46, 48, 49, 54, 56]], 2: [[55, 59, 61, 62, 63, 64, 65, 66, 67, 68, 70, 72, 73, 74, 75, 76, 77, 78, 79, 81, 82, 85, 86, 87], [50, 54, 55, 56, 57, 60, 62, 63, 65, 66, 68, 71, 74, 76, 77, 79, 80, 81, 82, 83, 84, 85, 88], [13, 15, 17, 24, 27, 35, 36, 39, 43, 46, 47, 49, 53, 54, 55, 65, 66, 71, 73, 74, 81, 82], [42, 45, 49, 51, 52, 53, 54, 56, 58, 61, 62, 64, 66, 67, 68, 70, 71, 72, 73, 74, 76, 77], [22, 24, 27, 28, 42, 43, 46, 47, 51, 52, 54, 55, 56, 67, 68, 71, 75, 78, 79, 80, 81], [50, 53, 57, 58, 59, 60, 63, 64, 65, 66, 67, 68, 69, 70, 72, 73, 74, 75]], 3: [[21, 22, 23, 24, 27, 28, 29, 30, 36, 37, 38, 47, 48, 52, 53, 54, 69, 71, 83, 84], [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 40, 41, 51, 58, 67, 75], [1, 2, 3, 4, 5, 8, 9, 10, 14, 18, 19, 22, 37, 40, 50, 67, 75], [1, 5, 6, 9, 12, 15, 16, 19, 20, 24, 25, 29, 30, 36, 39, 50, 69], [1, 2, 8, 9, 11, 20, 29, 38, 41, 44, 45, 48, 53, 58, 66, 69, 76], [14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 31, 32, 33, 36, 44, 61]]}
ts = 12
cf_explanation, label = exp_model.explain(np.swapaxes(test_x[ts],0,1), target = None)
1/1 [==============================] - 0s 76ms/step 1/1 [==============================] - 0s 13ms/step 1/1 [==============================] - 0s 13ms/step 1/1 [==============================] - 0s 13ms/step 1/1 [==============================] - 0s 13ms/step 1/1 [==============================] - 0s 13ms/step 1/1 [==============================] - 0s 14ms/step 1/1 [==============================] - 0s 13ms/step 1/1 [==============================] - 0s 13ms/step 1/1 [==============================] - 0s 14ms/step 1/1 [==============================] - 0s 13ms/step 1/1 [==============================] - 0s 13ms/step 1/1 [==============================] - 0s 13ms/step 1/1 [==============================] - 0s 13ms/step 1/1 [==============================] - 0s 15ms/step 1/1 [==============================] - 0s 14ms/step 1/1 [==============================] - 0s 13ms/step 1/1 [==============================] - 0s 13ms/step 1/1 [==============================] - 0s 14ms/step 1/1 [==============================] - 0s 14ms/step 1/1 [==============================] - 0s 13ms/step 1/1 [==============================] - 0s 14ms/step 1/1 [==============================] - 0s 13ms/step 1/1 [==============================] - 0s 13ms/step 1/1 [==============================] - 0s 14ms/step 1/1 [==============================] - 0s 13ms/step 1/1 [==============================] - 0s 13ms/step Counterfactual has been found
#item, org_label, exp, cf_label, save_fig=None, figsize=(6.4, 4.8)
exp_model.plot_in_one(test_x[10],0,np.swapaxes(cf_explanation,0,1),1,figsize=(15,15))