vmtmxmf5 / Pytorch-

pytorch로 머신러닝~딥러닝 구현
3 stars 0 forks source link

Customizing collate fn #19

Open vmtmxmf5 opened 2 years ago

vmtmxmf5 commented 2 years ago
def collate_fn_padd(batch):
    '''
    Padds batch of variable length

    note: it converts things ToTensor manually here since the ToTensor transform
    assume it takes in images rather than arbitrary tensors.
    '''
    ## get sequence lengths
    lengths = torch.tensor([ t.shape[0] for t in batch ]).to(device)
    ## padd
    batch = [ torch.Tensor(t).to(device) for t in batch ]
    batch = torch.nn.utils.rnn.pad_sequence(batch)
    ## compute mask
    mask = (batch != 0).to(device)
    return batch, lengths, mask