NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.4k stars 1.4k forks source link

RuntimeError: expected scalar type Half but found Float #581

Open williamFalcon opened 5 years ago

williamFalcon commented 5 years ago

What is the best way to figure out where this issue is happening in the graph? The message is also unclear: What exactly is expecting the Half? would be helpful to print the particular node in the graph where this breaks.

Sorry, I don't really have a super clean implementation which can reproduce this.

  File "/private/home/falc/.conda/envs/ddt_debug/lib/python3.7/site-packages/pytorch_lightning/trainer/train_loop_mixin.py", line 194, in optimizer_closure
    model_ref.backward(self.use_amp, closure_loss, optimizer)
  File "/private/home/falc/.conda/envs/ddt_debug/lib/python3.7/site-packages/pytorch_lightning/root_module/hooks.py", line 70, in backward
    scaled_loss.backward()
  File "/private/home/falc/.local/lib/python3.7/site-packages/torch/tensor.py", line 118, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/private/home/falc/.local/lib/python3.7/site-packages/torch/autograd/__init__.py", line 93, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: expected scalar type Half but found Float

The piece of code that likely has the issue is here:

    def nce_loss(self, r_src, r_trg, mask_mat):
        '''
        Compute the NCE scores for predicting r_src->r_trg.
        Input:
          r_src    : (batch_size, emb_dim)
          r_trg    : (emb_dim, n_batch * w* h) (ie: nb_feat_vectors x embedding_dim)
          mask_mat : (n_batch_gpu, n_batch)
        Output:
          raw_scores : (n_batch_gpu, n_locs)
          nce_scores : (n_batch_gpu, n_locs)
          lgt_reg    : scalar
        '''

        # RKHS = embedding dim
        batch_size, emb_dim = r_src.size()
        nb_feat_vectors = r_trg.size(1) // batch_size

        # (b, b) -> (b, b, nb_feat_vectors)
        # all zeros with ones in diagonal tensor... (ie: b1 b1 are all 1s, b1 b2 are all zeros)
        mask_pos = mask_mat.unsqueeze(dim=2).expand(-1, -1, nb_feat_vectors).float()

        # negative mask
        # one = torch.ones_like(mask_pos)
        mask_neg = 1. - mask_pos

        # -------------------------------
        # ALL SCORES COMPUTATION
        # compute src->trg raw scores for batch
        # (b, dim) x (dim, nb_feats*b) -> (b, b, nb_feats)
        # vector for each img in batch times all the vectors of all images in batch
        raw_scores = torch.mm(r_src, r_trg)
        raw_scores = raw_scores.reshape(batch_size, batch_size, nb_feat_vectors).float()

        # -----------------------
        # STABILITY TRICKS
        # trick 1: weighted regularization term
        raw_scores = raw_scores / emb_dim**0.5
        lgt_reg = 5e-2 * (raw_scores**2).mean()

        # trick 2: tanh clip
        raw_scores = tanh_clip(raw_scores, clip_val=self.tclip).float()

        '''
        pos_scores includes scores for all the positive samples
        neg_scores includes scores for all the negative samples, with
        scores for positive samples set to the min score (-self.tclip here)
        '''
        # ----------------------
        # EXTRACT POSITIVE SCORES
        # use the index mask to pull all the diagonals which are b1 x b1
        # (batch_size, nb_feat_vectors)
        pos_scores = (mask_pos * raw_scores).sum(dim=1).float()

        # ----------------------
        # EXTRACT NEGATIVE SCORES
        # pull everything except diagonal and apply clipping
        # (batch_size, batch_size, nb_feat_vectors)
        # diagonals have - clip vals. everything else has actual negative stores
        neg_scores = (mask_neg * raw_scores) - (self.tclip * mask_pos)

        # (batch_size, batch_size * nb_feat_vectors) -> (batch_size, batch_size, nb_feat_vectors)
        neg_scores = neg_scores.reshape(batch_size, -1)
        mask_neg = mask_neg.reshape(batch_size, -1)

        # ---------------------
        # STABLE SOFTMAX
        # max for each row of negative samples
        # will use max in safe softmax
        # (n_batch_gpu, 1)
        neg_maxes = torch.max(neg_scores, dim=1, keepdim=True)[0]

        # DENOMINATOR
        # sum over only negative samples (none from the diagonal)
        neg_sumexp = (mask_neg * torch.exp(neg_scores - neg_maxes)).sum(dim=1, keepdim=True)
        all_logsumexp = torch.log(torch.exp(pos_scores - neg_maxes) + neg_sumexp)

        # FULL NCE
        # NUMERATOR
        # compute numerators for the NCE log-softmaxes
        pos_shiftexp = pos_scores - neg_maxes

        nce_scores = pos_shiftexp - all_logsumexp
        nce_scores = -nce_scores.mean().float()

        return nce_scores, lgt_reg
mcarilli commented 5 years ago

The error is coming from Pytorch internally. This is strange because the casts Amp inserts are autograd-exposed, so they should be exactly reversed in backward. In other words, if the forward pass fails, the backward pass should succeed.

What of the types of the inputs? Also, does it still fail if you explicitly cast the inputs to float? For unblocking purposes, you may cast the inputs to float and run the function with casting disabled, ie

with amp.disable_casts():
    nce_loss(self.float(), r_src.float(), r_trg.float(), mask_mat.float())

Also please tell me you're running with opt_level="O1". O2 is a disaster, a legacy from our initial experiments with mixed precision, sadly still necessary to support internal usage. The Pytorch native integration which is my main task right now will be O1-like exclusively.

In general debugging the backward pass is hard. You can try to catch the exception to at least figure out the name of the op that's failing on the c++ side:

$ gdb python
...
(gdb) catch throw
(gdb) run script.py args
....gdb will halt when the exception is thrown
(gdb) bt
...C++-side backtrace that will tell you the name of the op that's failing
kealennieh commented 4 years ago

Any updates about this topic ? I encountered the same question. The version of pytorch is 1.2.0.

hanspinckaers commented 4 years ago

Upgrading to pytorch 1.3.1 worked for me