wenwenyu / PICK-pytorch

Code for the paper "PICK: Processing Key Information Extraction from Documents using Improved Graph Learning-Convolutional Networks" (ICPR 2020)
https://arxiv.org/abs/2004.07464
MIT License
553 stars 191 forks source link

Cuda 11 pytorch compatability? #78

Open markobogoevski opened 3 years ago

markobogoevski commented 3 years ago

How can I run this on a Cuda11 graphics card which cant be downgraded to cuda10? There is no binary yet for torch 1.5 for cuda11 and the code isn't compatible for torch1.7 (earliest compatible torch version for cuda11). Will there be an updated version of the code for torch 1.7 anytime soon?

risto-trajanov commented 3 years ago

Migration from 1.5.1 to 1.7.1 torch in PICK:

Error: RuntimeError: 'lengths' argument should be a 1D CPU int64 tensor

nn.utils.rnn.pack_padded_sequence-> expects 1D cpu int64 tensor but we give him LongTensor in previous version 1.5.1 lenghts argument was converted to cpu() in the function but now we need to convert it outside of it. That fixes the problem of migration.

In decoder.py in forward function

 def forward(self, x_seq: torch.Tensor,
            lenghts: torch.Tensor,
            initial: Tuple[torch.Tensor, torch.Tensor]):
    '''

    :param x_seq: (B, N*T, D)
    :param lenghts: (B,)
    :param initial: (num_layers * directions, batch, D)
    :return: (B, N*T, out_dim)
    '''

    # B*N, T, hidden_size
    x_seq, sorted_lengths, invert_order, h_0, c_0 = self.sort_tensor(x_seq, lenghts, initial[0], initial[0])
    packed_x = nn.utils.rnn.pack_padded_sequence(x_seq, lengths=sorted_lengths.cpu(), batch_first=True)
    self.lstm.flatten_parameters()
    output, _ = self.lstm(packed_x)
    output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True,
                                                 padding_value=keys_vocab_cls.stoi['<pad>'])
    # total_length=documents.MAX_BOXES_NUM * documents.MAX_TRANSCRIPT_LEN
    output = output[invert_order]
    logits = self.mlp(output)
    # (B, N*T, out_dim)
    return logits