About computing pairwise interaction #2

Open
opened 2025-10-14 16:20:01 -06:00 by navan · 0 comments
Owner

Originally created by @Tigerrr07 on 6/4/2023

Hi Kexin, I have the follwing questions.

  1. I find the code for computing pairwise interaction a little complicated. Since you are using dot product, can I use torch.matmul(d_encoded_layers , p_encoded_layers.transpose(-1, -2)) directly instead of the following code?

47ac16b8c1/models.py (L86-L100)

Besides, the above code also confuses me a lot for the view operation in line 96, I tested it with a simple example, and it did not calculate the dot product between sub-structural pairs.

max_d = 2
max_p = 3

# batch_size 1 hidden dim 2
d_encoded_layers = torch.zeros(1, max_d, 2)
d_encoded_layers[0, 0, 0] = 1 
d_encoded_layers[0, 0, 1] = 1
p_encoded_layers = torch.zeros(1, max_p, 2)
p_encoded_layers[0, 0, 0] = 1
p_encoded_layers[0, 0, 1] = 2
p_encoded_layers[0, 1, 0] = 3
p_encoded_layers[0, 1, 1] = 4 
p_encoded_layers[0, 2, 0] = 5
p_encoded_layers[0, 2, 1] = 6

print(d_encoded_layers)
print(p_encoded_layers)

d_aug = torch.unsqueeze(d_encoded_layers, 2).repeat(1, 1, max_p, 1) # repeat along protein size
p_aug = torch.unsqueeze(p_encoded_layers, 1).repeat(1, max_d, 1, 1) # repeat along drug size
i = d_aug * p_aug
print(i)
i_v = i.view(1, -1, max_d, max_p) 
print(i_v)
i_v = torch.sum(i_v, dim = 1)
print(i_v)

output:

tensor([[[1., 1.],
         [0., 0.]]])
tensor([[[1., 2.],
         [3., 4.],
         [5., 6.]]])
tensor([[[[1., 2.],
          [3., 4.],
          [5., 6.]],

         [[0., 0.],
          [0., 0.],
          [0., 0.]]]])
tensor([[[[1., 2., 3.],
          [4., 5., 6.]],

         [[0., 0., 0.],
          [0., 0., 0.]]]])
tensor([[[1., 2., 3.],
         [4., 5., 6.]]])

The final i_v looks pointless, because the representation for drug sub1 is all zero. I think view operation changes the arangement of data. Maybe the following code is more correct?

i_s = torch.sum(i, dim=-1)
print(i_s) 

output:

tensor([[[ 3.,  7., 11.],
         [ 0.,  0.,  0.]]])

2.I think padding tokens should be filtered out of the interaction map I before being fed into the CNN. I do this by passing the d_mask and p_mask:

  d_mask = d_mask.reshape(-1, self.max_d, 1)
  p_mask = p_mask.reshape(-1, 1, self.max_p)
  # mask padding tokens
  i.masked_fill_(~d_mask, 0)
  i.masked_fill_(~p_mask, 0)

Sorry to bother you.

*Originally created by @Tigerrr07 on 6/4/2023* Hi Kexin, I have the follwing questions. 1. I find the code for computing pairwise interaction a little complicated. Since you are using dot product, can I use `torch.matmul(d_encoded_layers , p_encoded_layers.transpose(-1, -2))` directly instead of the following code? https://github.com/kexinhuang12345/MolTrans/blob/47ac16b8c158b080ba6cdaec74cd7aa9c1332b73/models.py#L86-L100 Besides, the above code also confuses me a lot for the view operation in line 96, I tested it with a simple example, and it did not calculate the dot product between sub-structural pairs. ``` python max_d = 2 max_p = 3 # batch_size 1 hidden dim 2 d_encoded_layers = torch.zeros(1, max_d, 2) d_encoded_layers[0, 0, 0] = 1 d_encoded_layers[0, 0, 1] = 1 p_encoded_layers = torch.zeros(1, max_p, 2) p_encoded_layers[0, 0, 0] = 1 p_encoded_layers[0, 0, 1] = 2 p_encoded_layers[0, 1, 0] = 3 p_encoded_layers[0, 1, 1] = 4 p_encoded_layers[0, 2, 0] = 5 p_encoded_layers[0, 2, 1] = 6 print(d_encoded_layers) print(p_encoded_layers) d_aug = torch.unsqueeze(d_encoded_layers, 2).repeat(1, 1, max_p, 1) # repeat along protein size p_aug = torch.unsqueeze(p_encoded_layers, 1).repeat(1, max_d, 1, 1) # repeat along drug size i = d_aug * p_aug print(i) i_v = i.view(1, -1, max_d, max_p) print(i_v) i_v = torch.sum(i_v, dim = 1) print(i_v) ``` output: ``` tensor([[[1., 1.], [0., 0.]]]) tensor([[[1., 2.], [3., 4.], [5., 6.]]]) tensor([[[[1., 2.], [3., 4.], [5., 6.]], [[0., 0.], [0., 0.], [0., 0.]]]]) tensor([[[[1., 2., 3.], [4., 5., 6.]], [[0., 0., 0.], [0., 0., 0.]]]]) tensor([[[1., 2., 3.], [4., 5., 6.]]]) ``` The final `i_v` looks pointless, because the representation for drug sub1 is all zero. I think view operation changes the arangement of data. Maybe the following code is more correct? ``` python i_s = torch.sum(i, dim=-1) print(i_s) ``` output: ``` tensor([[[ 3., 7., 11.], [ 0., 0., 0.]]]) ``` 2.I think padding tokens should be filtered out of the interaction map $I$ before being fed into the CNN. I do this by passing the d_mask and p_mask: ``` python d_mask = d_mask.reshape(-1, self.max_d, 1) p_mask = p_mask.reshape(-1, 1, self.max_p) # mask padding tokens i.masked_fill_(~d_mask, 0) i.masked_fill_(~p_mask, 0) ``` Sorry to bother you.
Sign in to join this conversation.
No labels
No milestone
No project
No assignees
1 participant
Notifications
Due date
The due date is invalid or out of range. Please use the format "yyyy-mm-dd".

No due date set.

Dependencies

No dependencies set.

Reference: github/MolTrans#2
No description provided.