Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem with the feature explainability methods #35

Open
2bben opened this issue Jan 19, 2022 · 1 comment
Open

Problem with the feature explainability methods #35

2bben opened this issue Jan 19, 2022 · 1 comment

Comments

@2bben
Copy link

2bben commented Jan 19, 2022

Hi,
I have got the DeepLIFT to work and understood the method, though the two other methods mentioned in [1] have I not managed to implement.

For the first method, summarizing averaged outputs of hidden unit activations:

  • How can you access the spatially-filtered data from the layer?
  • Is the topoplots, shown in [1] fig. 6 A, taken from a specified time-stamp? And how do you still know which data corresponds to the different channels?

For the second method, visualizing the convolutional kernel weights:

  • From which layers are the visualized kernels taken from? (Fig 7 and 8 in [1])
@vlawhern
Copy link
Owner

vlawhern commented Jan 19, 2022

So for the first method, the spatial filters are extracted from the DepthwiseConv2D layer in EEGNet. More specifically,

model = EEGNet(...) # define some EEGNet configuration
model.fit(...)   # fit the model

You can use model.layers to show all the different layers of the model, which should look like this:

[<tensorflow.python.keras.engine.input_layer.InputLayer at 0x7f2f467f7d90>,
 <tensorflow.python.keras.layers.convolutional.Conv2D at 0x7f2f46857850>,
 <tensorflow.python.keras.layers.normalization_v2.BatchNormalization at 0x7f2f468ac310>,
 <tensorflow.python.keras.layers.convolutional.DepthwiseConv2D at 0x7f2f3c329ca0>,
 <tensorflow.python.keras.layers.normalization_v2.BatchNormalization at 0x7f2f3c3291f0>,
 <tensorflow.python.keras.layers.core.Activation at 0x7f2f3c2e11f0>,
 <tensorflow.python.keras.layers.pooling.AveragePooling2D at 0x7f2f3c2e1bb0>,
 <tensorflow.python.keras.layers.core.Dropout at 0x7f2f3c2e1cd0>,
 <tensorflow.python.keras.layers.convolutional.SeparableConv2D at 0x7f2f3c2e9b80>,
 <tensorflow.python.keras.layers.normalization_v2.BatchNormalization at 0x7f2f3c2e9ee0>,
 <tensorflow.python.keras.layers.core.Activation at 0x7f2f3c2f0d90>,
 <tensorflow.python.keras.layers.pooling.AveragePooling2D at 0x7f2f3c34d820>,
 <tensorflow.python.keras.layers.core.Dropout at 0x7f2f3c2e18e0>,
 <tensorflow.python.keras.layers.core.Flatten at 0x7f2f3c2e9190>,
 <tensorflow.python.keras.layers.core.Dense at 0x7f2f3c2f8e20>,
 <tensorflow.python.keras.layers.core.Activation at 0x7f2f3c2f8c10>]

You'll see that the DepthwiseConv2D layer is the 3rd entry in the list (starting from 0), so you can pull the weights of the layer with

model.layers[3].get_weights()

This gets you the spatial filter weights that we then use together with the EEG channel locations to plot a topoplot, which is what we show in Fig 6A in the paper. The spatial filters are not defined for a single time point; rather they are trained using all the data and you learn just one filter for all time points. The number of spatial filters you learn will depend on the EEGNet model configuration you train; EEGNet-8,2 specifically learns 2 spatial filters for each of 8 temporal filters so a total of 16 spatial filters.

For the second method, the convolutional kernel filter weights (Fig 7, top row) are from the first Conv2D layer which represents the temporal filter layer.

model.layers[1].get_weights()

The middle and bottom rows are the spatial filter weights, using the method to extract the weights described above.

Figure 8 shows spatial filters from two different methods, Filter-Bank CSP (https://www.frontiersin.org/articles/10.3389/fnins.2012.00039/full) and EEGNet.

Hope this helps..

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants