Temporal Saliency Rescaling (TSR)¶
Temporal Saliency Rescaling (TSR), developed by Ismail et al. (2020) [1], is built upon known saliency methods such as GradCam [2] or Shap [3]. Their benchmark study shows that traditional saliency methods fail to reliably and accurately identify feature importances due to the time and feature domain. TSR is proposed on top of the traditional saliency methods as a two-step approach for improving the quality of saliency maps. The importance of each time step is calculated, followed by the feature importance.
The code in TSInterpret is based on the authors implementation .
[1] Aya Abdelsalam Ismail, Mohamed Gunady, Héctor Corrada Bravo, and Soheil Feizi. Benchmarking Deep Learning Interpretability in Time Series Predictions. arXiv:2010.13924 [cs, stat], October 2020. arXiv: 2010.13924.
[2] R. R. Selvaraju, M. Cogswell, A. Das, R. Vedantam, D. Parikh, and D. Batra. Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization. Int J Comput Vis, 128(2):336–359, Feb. 2020. ISSN 0920-5691, 1573-1405. doi: 10.1007/s11263-019-01228-7. arXiv: 1610.02391.
[3] S. M. Lundberg and S.-I. Lee. A Unified Approach to Interpreting Model Predictions. In I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R.Garnett, editors, Advances in Neural Information Processing Systems 30, pages 4765–4774. Curran Associates, Inc., 2017.
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import pickle
import numpy as np
import torch
from ClassificationModels.CNN_T import ResNetBaseline, UCRDataset,fit
import pandas as pd
import os
from tslearn.datasets import UCR_UEA_datasets
import sklearn
/home/jacqueline/.local/share/virtualenvs/TSInterpret-x4eqnPOt/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
Dataset¶
dataset='BasicMotions'
Load Data¶
Load Data and reshape the data to fit a 1D-Conv ResNet. Note that the input for a 1D-Conv Resnet hat the shape (batch, features, timesteps).
X_train,y_train, X_test, y_test=UCR_UEA_datasets().load_dataset(dataset)
train_x=np.swapaxes(X_train,1,2)
test_x=np.swapaxes(X_test,1,2)
train_y = y_train
test_y=y_test
enc1=sklearn.preprocessing.OneHotEncoder(sparse=False).fit(np.vstack((train_y.reshape(-1,1),test_y.reshape(-1,1))))
pickle.dump(enc1,open(f'../../ClassificationModels/models/{dataset}/OneHotEncoder.pkl','wb'))
train_y=enc1.transform(train_y.reshape(-1,1))
test_y=enc1.transform(test_y.reshape(-1,1))
/home/jacqueline/.local/share/virtualenvs/TSInterpret-x4eqnPOt/lib/python3.9/site-packages/sklearn/preprocessing/_encoders.py:972: FutureWarning: `sparse` was renamed to `sparse_output` in version 1.2 and will be removed in 1.4. `sparse_output` is ignored unless you leave `sparse` to its default value. warnings.warn(
Model Training¶
Trains a ResNet and saves the results.
n_pred_classes =train_y.shape[1]
train_dataset = UCRDataset(train_x.astype(np.float64),train_y.astype(np.int64))
test_dataset = UCRDataset(test_x.astype(np.float64),test_y.astype(np.int64))
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=16,shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=1,shuffle=False)
model = ResNetBaseline(in_channels=6, num_pred_classes=n_pred_classes)
fit(model, train_loader, test_loader)
model.load_state_dict(torch.load(f'../../ClassificationModels/models/{dataset}/ResNet'))
Epoch: 1, Train loss: 1.186, Val loss: 1.212 Epoch: 2, Train loss: 1.126, Val loss: 0.892 Epoch: 3, Train loss: 0.814, Val loss: 0.763 Epoch: 4, Train loss: 0.674, Val loss: 0.608 Epoch: 5, Train loss: 0.552, Val loss: 0.423 Epoch: 6, Train loss: 0.348, Val loss: 0.259 Epoch: 7, Train loss: 0.192, Val loss: 0.132 Epoch: 8, Train loss: 0.099, Val loss: 0.072 Epoch: 9, Train loss: 0.021, Val loss: 0.007 Epoch: 10, Train loss: 0.001, Val loss: 0.008 Epoch: 11, Train loss: 0.004, Val loss: 0.009 Epoch: 12, Train loss: 0.001, Val loss: 0.002 Epoch: 13, Train loss: 0.0, Val loss: 0.0 Epoch: 14, Train loss: 0.0, Val loss: 0.0 Epoch: 15, Train loss: 0.0, Val loss: 0.0 Epoch: 16, Train loss: 0.0, Val loss: 0.0 Epoch: 17, Train loss: 0.0, Val loss: 0.001 Epoch: 18, Train loss: 0.0, Val loss: 0.001 Epoch: 19, Train loss: 0.0, Val loss: 0.002 Epoch: 20, Train loss: 0.0, Val loss: 0.003 Epoch: 21, Train loss: 0.0, Val loss: 0.003 Epoch: 22, Train loss: 0.0, Val loss: 0.004 Epoch: 23, Train loss: 0.0, Val loss: 0.004 Epoch: 24, Train loss: 0.0, Val loss: 0.004 Epoch: 25, Train loss: 0.0, Val loss: 0.005 Epoch: 26, Train loss: 0.0, Val loss: 0.005 Epoch: 27, Train loss: 0.0, Val loss: 0.005 Epoch: 28, Train loss: 0.0, Val loss: 0.005 Epoch: 29, Train loss: 0.0, Val loss: 0.005 Epoch: 30, Train loss: 0.0, Val loss: 0.005 Epoch: 31, Train loss: 0.0, Val loss: 0.005 Epoch: 32, Train loss: 0.0, Val loss: 0.004 Epoch: 33, Train loss: 0.0, Val loss: 0.004 Epoch: 34, Train loss: 0.0, Val loss: 0.004 Epoch: 35, Train loss: 0.0, Val loss: 0.004 Epoch: 36, Train loss: 0.0, Val loss: 0.004 Epoch: 37, Train loss: 0.0, Val loss: 0.004 Epoch: 38, Train loss: 0.0, Val loss: 0.004 Epoch: 39, Train loss: 0.0, Val loss: 0.004 Epoch: 40, Train loss: 0.0, Val loss: 0.004 Epoch: 41, Train loss: 0.0, Val loss: 0.003 Epoch: 42, Train loss: 0.0, Val loss: 0.003 Epoch: 43, Train loss: 0.0, Val loss: 0.003 Epoch: 44, Train loss: 0.0, Val loss: 0.003 Epoch: 45, Train loss: 0.0, Val loss: 0.003 Epoch: 46, Train loss: 0.0, Val loss: 0.003 Epoch: 47, Train loss: 0.0, Val loss: 0.003 Epoch: 48, Train loss: 0.0, Val loss: 0.003 Epoch: 49, Train loss: 0.0, Val loss: 0.003 Epoch: 50, Train loss: 0.0, Val loss: 0.003 Epoch: 51, Train loss: 0.0, Val loss: 0.003 Epoch: 52, Train loss: 0.0, Val loss: 0.003 Epoch: 53, Train loss: 0.0, Val loss: 0.002 Epoch: 54, Train loss: 0.0, Val loss: 0.002 Epoch: 55, Train loss: 0.0, Val loss: 0.002 Epoch: 56, Train loss: 0.0, Val loss: 0.002 Epoch: 57, Train loss: 0.0, Val loss: 0.002 Epoch: 58, Train loss: 0.0, Val loss: 0.002 Epoch: 59, Train loss: 0.0, Val loss: 0.002 Epoch: 60, Train loss: 0.0, Val loss: 0.002 Epoch: 61, Train loss: 0.0, Val loss: 0.002 Epoch: 62, Train loss: 0.0, Val loss: 0.002 Epoch: 63, Train loss: 0.0, Val loss: 0.002 Epoch: 64, Train loss: 0.0, Val loss: 0.002 Epoch: 65, Train loss: 0.0, Val loss: 0.002 Epoch: 66, Train loss: 0.0, Val loss: 0.002 Epoch: 67, Train loss: 0.0, Val loss: 0.002 Epoch: 68, Train loss: 0.0, Val loss: 0.002 Epoch: 69, Train loss: 0.0, Val loss: 0.002 Epoch: 70, Train loss: 0.0, Val loss: 0.002 Epoch: 71, Train loss: 0.0, Val loss: 0.002 Epoch: 72, Train loss: 0.0, Val loss: 0.002 Epoch: 73, Train loss: 0.0, Val loss: 0.002 Epoch: 74, Train loss: 0.0, Val loss: 0.002 Epoch: 75, Train loss: 0.0, Val loss: 0.001 Epoch: 76, Train loss: 0.0, Val loss: 0.001 Epoch: 77, Train loss: 0.0, Val loss: 0.001 Epoch: 78, Train loss: 0.0, Val loss: 0.001 Epoch: 79, Train loss: 0.0, Val loss: 0.001 Epoch: 80, Train loss: 0.0, Val loss: 0.001 Epoch: 81, Train loss: 0.0, Val loss: 0.001 Epoch: 82, Train loss: 0.0, Val loss: 0.001 Epoch: 83, Train loss: 0.0, Val loss: 0.001 Epoch: 84, Train loss: 0.0, Val loss: 0.001 Epoch: 85, Train loss: 0.0, Val loss: 0.001 Epoch: 86, Train loss: 0.0, Val loss: 0.001 Epoch: 87, Train loss: 0.0, Val loss: 0.001 Epoch: 88, Train loss: 0.0, Val loss: 0.001 Epoch: 89, Train loss: 0.0, Val loss: 0.001 Epoch: 90, Train loss: 0.0, Val loss: 0.001 Epoch: 91, Train loss: 0.0, Val loss: 0.001 Epoch: 92, Train loss: 0.0, Val loss: 0.001 Epoch: 93, Train loss: 0.0, Val loss: 0.001 Epoch: 94, Train loss: 0.0, Val loss: 0.001 Epoch: 95, Train loss: 0.0, Val loss: 0.001 Epoch: 96, Train loss: 0.0, Val loss: 0.001 Epoch: 97, Train loss: 0.0, Val loss: 0.001 Epoch: 98, Train loss: 0.0, Val loss: 0.001 Epoch: 99, Train loss: 0.0, Val loss: 0.001 Epoch: 100, Train loss: 0.0, Val loss: 0.001 Epoch: 101, Train loss: 0.0, Val loss: 0.001 Epoch: 102, Train loss: 0.0, Val loss: 0.001 Epoch: 103, Train loss: 0.0, Val loss: 0.001 Epoch: 104, Train loss: 0.0, Val loss: 0.001 Epoch: 105, Train loss: 0.0, Val loss: 0.001 Epoch: 106, Train loss: 0.0, Val loss: 0.001 Epoch: 107, Train loss: 0.0, Val loss: 0.001 Epoch: 108, Train loss: 0.0, Val loss: 0.001 Epoch: 109, Train loss: 0.0, Val loss: 0.001 Epoch: 110, Train loss: 0.0, Val loss: 0.001 Epoch: 111, Train loss: 0.0, Val loss: 0.001 Epoch: 112, Train loss: 0.0, Val loss: 0.001 Epoch: 113, Train loss: 0.0, Val loss: 0.001 Epoch: 114, Train loss: 0.0, Val loss: 0.001 Early stopping!
<All keys matched successfully>
Interpretability Algorithm¶
Using a interpretability algorithm consists of 4 steps:
- Load the Interpretability Method
- Instaniate the Method with the desired Parameters
- Call the explain Method
- Plot the results
2. Initialization¶
The TSR works on gradient-based models. Currently support for TensorFlow (TF) and PYTorch (PYT) is available. For PYT the Subclass Saliency_PTY
is used, while TF expects the use of Saliency_TF
.
The Initialization takes the following arguments:
model
: The model to be explaines.NumTimeStep
: Number of Time Step.NumFetaures
: Number Features.method
: Saliency Methode to be used.* Gradients (GRAD) * Integrated Gradients (IG) * Gradient Shap (GS) * DeepLift (DL) * DeepLiftShap (DLS) * SmoothGrad (SG) * Shapley Value Sampling(SVS) * Feature Ablation (FA) * Occlusion (FO)
mode
: Second dimension 'time' or 'feat'.
model.eval()
from TSInterpret.InterpretabilityModels.Saliency.TSR import TSR
int_mod=TSR(model, train_x.shape[-1],train_x.shape[-2], method='GRAD', \
mode='feat')
3. Call the explain method.¶
Prepeare the instance and the predicted label of the instance as parameters for the explain methods.
item
: item to be explainedlabels
: predicted label for the item as class.TSR
: Turns temporal rescaling on / off.
item = np.array([test_x[0,:,:]])
label = int(np.argmax(test_y[0]))
exp=int_mod.explain(item,labels=label,TSR = True, attribution=0.0)
4. Visualization¶
All plot function take as input the item to be explained and the returned explanation. As as additonal option a figsize can be given. For visualizing saliency there are two visualization options provided:
- On Sample
- Heatmap
int_mod.plot(np.array([test_x[0,:,:]]),exp,figsize=(15,15))
NOT Time mode
int_mod.plot(np.array([test_x[0,:,:]]),exp, heatmap = True)
NOT Time mode
Additional Examples, e.g. for the use with LSTM or TF can be found here.