forked from HuthLab/semantic-decoding
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathEncodingModel.py
More file actions
27 lines (24 loc) · 1.25 KB
/
EncodingModel.py
File metadata and controls
27 lines (24 loc) · 1.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import numpy as np
import torch
torch.set_default_tensor_type(torch.FloatTensor)
class EncodingModel():
"""class for computing the likelihood of observing brain recordings given a word sequence
"""
def __init__(self, resp, weights, voxels, sigma, device = "cpu"):
self.device = device
self.weights = torch.from_numpy(weights[:, voxels]).float().to(self.device)
self.resp = torch.from_numpy(resp[:, voxels]).float().to(self.device)
self.sigma = sigma
def set_shrinkage(self, alpha):
"""compute precision from empirical covariance with shrinkage factor alpha
"""
precision = np.linalg.inv(self.sigma * (1 - alpha) + np.eye(len(self.sigma)) * alpha)
self.precision = torch.from_numpy(precision).float().to(self.device)
def prs(self, stim, trs):
"""compute P(R | S) on affected TRs for each hypothesis
"""
with torch.no_grad():
stim = stim.float().to(self.device)
diff = torch.matmul(stim, self.weights) - self.resp[trs] # encoding model residuals
multi = torch.matmul(torch.matmul(diff, self.precision), diff.permute(0, 2, 1))
return -0.5 * multi.diagonal(dim1 = -2, dim2 = -1).sum(dim = 1).detach().cpu().numpy()