wilburnlab / collage

Codon Likelihihoods Learned Against Genome Evolution (CoLLAGE): a deep learning framework for identifying naturally selected patterns of codon preference within a species
MIT License
1 stars 1 forks source link

Error on longer protein sequence #13

Open dbwilburn opened 9 months ago

dbwilburn commented 9 months ago

Getting the following error predicting a sequence on a slightly longer protein sequence. Based on the error, I'm guessing we have a bug somewhere in the batching for proteins >500 AA (which have to be processed serially).

/users/PAS1309/damienbwilburn/.conda/envs/p310/lib/python3.10/site-packages/torch/nn/functional.py:4999: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead. warnings.warn(

RuntimeError Traceback (most recent call last) Cell In[3], line 2 1 PA14_His6_SUMO_FraB = 'MDMAPDAQADSFGGVAVQGLFGEYYAYAQGSDGGNLSNVAQVKAFIAANEADATFIGRNIDYGSVSGDLGGNGKVQSFLKDDAGSLSTDPENSSDAIVKLTGNLELQAGTYQFRVRADDGYRIEVNGQTVAEYNGNQGANTRTGSEFTLTGDGPHSVEIVYWDQGGAAQLRIELREQGGAYEIFGSQHASHGHHHHHHMSDSEVNQEAKPEVKPEVKPETHINLKVSDGSSEIFFKIKKTTPLRRLMEAFAKRQGKEMDSLRFLYDGIRIQADQTPEDLDMEDNDIIEAHREQIGGMMGMKETVSNIVTSQAEKGGVKHVYYVACGGSYAAFYPAKAFLEKEAKALTVGLYNSGEFINNPPVALGENAVVVVASHKGNTPETIKAAEIARQHGAPVIGLTWIMDSPLVAHCDYVETYTFGDGKDIAGEKTMKGLLSAVELLQQTEGYAHYDDFQDGVSKINRIVWRACEQVAERAQAFAQEYKDDKVIYTVASGAGYGAAYLQSICIFMEMQWIHSACIHSGEFFHGPFEITDANTPFFFQFSEGNTRAVDERALNFLKKYGRRIEVVDAKELGLSTIKTTVIDYFNHSLFNNVYPVYNRALAEARQHPLTTRRYMWKVEY.' ----> 2 predictions = beam_generator( model, PA14_His6_SUMO_FraB, '', max_seqs=1000, ) 3 collage_orf = list( predictions )[0] 4 collage_orf

File ~/DBW_Libraries/collage/collage/generator.py:61, in beam_generator(model, prot, pre_sequence, gen_size, max_seqs) 58 prot_tensor = prot_tensor_0.repeat(orf_tensor.size(0), 1) 60 weights_tensor = torch.ones(prot_tensor.shape) ---> 61 output = model(prot_tensor, orf_tensor) 63 logLs = output.cpu().detach().numpy()[:, -1, :] 64 candidate_seqs = {}

File ~/.conda/envs/p310/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, *kwargs) 1496 # If we don't have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/DBW_Libraries/collage/collage/model.py:128, in CollageModel.forward(self, protein, cds) 125 def forward(self, protein, cds): 126 codon_mask = self.return_codon_mask(protein) --> 128 x = self.protein_encoder(protein) 129 x = self.codon_decoder(protein, cds, x) 130 x = self.linear(x)

File ~/.conda/envs/p310/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, *kwargs) 1496 # If we don't have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/DBW_Libraries/collage/collage/model.py:56, in ProtEncoder.forward(self, prot) 54 x = self.prot_embedding(prot) * math.sqrt(self.embed_dim) 55 x = self.pos_encoder(x) ---> 56 x = self.transformer_encoder(x, src_key_padding_mask=pad_mask) 57 return x

File ~/.conda/envs/p310/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, *kwargs) 1496 # If we don't have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.conda/envs/p310/lib/python3.10/site-packages/torch/nn/modules/transformer.py:315, in TransformerEncoder.forward(self, src, mask, src_key_padding_mask, is_causal) 312 is_causal = make_causal 314 for mod in self.layers: --> 315 output = mod(output, src_mask=mask, is_causal=is_causal, src_key_padding_mask=src_key_padding_mask_for_layers) 317 if convert_to_nested: 318 output = output.to_padded_tensor(0.)

File ~/.conda/envs/p310/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, *kwargs) 1496 # If we don't have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.conda/envs/p310/lib/python3.10/site-packages/torch/nn/modules/transformer.py:591, in TransformerEncoderLayer.forward(self, src, src_mask, src_key_padding_mask, is_causal) 589 x = x + self._ff_block(self.norm2(x)) 590 else: --> 591 x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal)) 592 x = self.norm2(x + self._ff_block(x)) 594 return x

File ~/.conda/envs/p310/lib/python3.10/site-packages/torch/nn/modules/transformer.py:599, in TransformerEncoderLayer._sa_block(self, x, attn_mask, key_padding_mask, is_causal) 597 def _sa_block(self, x: Tensor, 598 attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor: --> 599 x = self.self_attn(x, x, x, 600 attn_mask=attn_mask, 601 key_padding_mask=key_padding_mask, 602 need_weights=False, is_causal=is_causal)[0] 603 return self.dropout1(x)

File ~/.conda/envs/p310/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, *kwargs) 1496 # If we don't have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.conda/envs/p310/lib/python3.10/site-packages/torch/nn/modules/activation.py:1205, in MultiheadAttention.forward(self, query, key, value, key_padding_mask, need_weights, attn_mask, average_attn_weights, is_causal) 1191 attn_output, attn_output_weights = F.multi_head_attention_forward( 1192 query, key, value, self.embed_dim, self.num_heads, 1193 self.in_proj_weight, self.in_proj_bias, (...) 1202 average_attn_weights=average_attn_weights, 1203 is_causal=is_causal) 1204 else: -> 1205 attn_output, attn_output_weights = F.multi_head_attention_forward( 1206 query, key, value, self.embed_dim, self.num_heads, 1207 self.in_proj_weight, self.in_proj_bias, 1208 self.bias_k, self.bias_v, self.add_zero_attn, 1209 self.dropout, self.out_proj.weight, self.out_proj.bias, 1210 training=self.training, 1211 key_padding_mask=key_padding_mask, 1212 need_weights=need_weights, 1213 attn_mask=attn_mask, 1214 average_attn_weights=average_attn_weights, 1215 is_causal=is_causal) 1216 if self.batch_first and is_batched: 1217 return attn_output.transpose(1, 0), attn_output_weights

File ~/.conda/envs/p310/lib/python3.10/site-packages/torch/nn/functional.py:5367, in multi_head_attention_forward(query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training, key_padding_mask, need_weights, attn_mask, use_separate_proj_weight, q_proj_weight, k_proj_weight, v_proj_weight, static_k, static_v, average_attn_weights, is_causal) 5365 attn_mask = attn_mask.unsqueeze(0) 5366 else: -> 5367 attn_mask = attn_mask.view(bsz, num_heads, -1, src_len) 5369 q = q.view(bsz, num_heads, tgt_len, head_dim) 5370 k = k.view(bsz, num_heads, src_len, head_dim)

RuntimeError: cannot reshape tensor of 0 elements into shape [0, 4, -1, 500] because the unspecified dimension size -1 can be any value and is ambiguous

alope107 commented 2 weeks ago

Should be mostly addressed by #25. Tested this sequence with fixed GC check and it ran fine