lucidrains / sinkhorn-transformer

Sinkhorn Transformer - Practical implementation of Sparse Sinkhorn Attention
MIT License
252 stars 21 forks source link

Training falling on version 0.0.14 and 0.0.15 #4

Closed blizda closed 4 years ago

blizda commented 4 years ago

Hi, I testing training model on new versions of repo, and I have some troubles with 0.0.14 and 0.0.15. On 0.0.14, model always return nan on forward pass, version 0.0.15 lead to CUDA error:

RuntimeError: CUDA error: CUBLAS_STATUS_INTERNAL_ERROR when calling `cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)`

Full error listing:

ipython-input-7-1329da5363de> in forward(self, inputs, labels)
      7   def forward(self, inputs, labels=None):
      8     loss_mx = labels != -100
----> 9     output = self.model(inputs)
     10     output = output[loss_mx].view(-1, tokenizer.vocab_size)
     11     labels = labels[loss_mx].view(-1)

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x, input_mask)
    376         x = self.to_token_emb(x)
    377         x = self.pos_emb(torch.arange(t, device=device)) + x
--> 378         x = self.sinkhorn_transformer(x)
    379         return self.to_logits(x)

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x, input_mask)
    359 
    360     def forward(self, x, input_mask = None):
--> 361         return self.layers(x)
    362 
    363 class SinkhornTransformerLM(nn.Module):

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x, **kwargs)
    330     def forward(self, x, **kwargs):
    331         x = torch.cat([x, x], dim=-1)
--> 332         x = self.layers(x, **kwargs)
    333         return torch.stack(x.chunk(2, dim=-1)).sum(dim=0)
    334 

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/reversible.py in forward(self, x, arg_route, **kwargs)
    128         block_kwargs = {'f_args': f_args, 'g_args': g_args}
    129 
--> 130         return _ReversibleFunction.apply(x, blocks, block_kwargs)

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/reversible.py in forward(ctx, x, blocks, kwargs)
     98         ctx.kwargs = kwargs
     99         for block in blocks:
--> 100             x = block(x, **kwargs)
    101         ctx.y = x.detach()
    102         ctx.blocks = blocks

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/reversible.py in forward(self, x, f_args, g_args)
     51         with torch.no_grad():
     52             y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
---> 53             y2 = x2 + self.g(y1, record_rng=self.training, **g_args)
     54 
     55         return torch.cat([y1, y2], dim=2)

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/reversible.py in forward(self, record_rng, set_rng, *args, **kwargs)
     25 
     26         if not set_rng:
---> 27             return self.net(*args, **kwargs)
     28 
     29         rng_devices = []

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x)
     91     def forward(self, x):
     92         chunks = x.chunk(self.chunks, dim = self.dim)
---> 93         return torch.cat([self.fn(c) for c in chunks], dim = self.dim)
     94 
     95 class FeedForward(nn.Module):

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in <listcomp>(.0)
     91     def forward(self, x):
     92         chunks = x.chunk(self.chunks, dim = self.dim)
---> 93         return torch.cat([self.fn(c) for c in chunks], dim = self.dim)
     94 
     95 class FeedForward(nn.Module):

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x, **kwargs)
    112     def forward(self, x, **kwargs):
    113         x = self.norm(x)
--> 114         return self.fn(x, **kwargs)
    115 
    116 class SortNet(nn.Module):

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x)
    103 
    104     def forward(self, x):
--> 105         return self.net(x)
    106 
    107 class PreNorm(nn.Module):

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/container.py in forward(self, input)
     98     def forward(self, input):
     99         for module in self:
--> 100             input = module(input)
    101         return input
    102 

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/linear.py in forward(self, input)
     85 
     86     def forward(self, input):
---> 87         return F.linear(input, self.weight, self.bias)
     88 
     89     def extra_repr(self):

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/functional.py in linear(input, weight, bias)
   1591         ret = torch.addmm(bias, input, weight.t())
   1592     else:
-> 1593         output = input.matmul(weight.t())
   1594         if bias is not None:
   1595             output += bias

RuntimeError: CUDA error: CUBLAS_STATUS_INTERNAL_ERROR when calling `cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)`

Also, version 0.0.11(and all other version from 0.0.8) work stable.

blizda commented 4 years ago

I was testing 0.0.12(commit 059a11fba3b699cf3f1ac25ea6a87faffb879609), it also working ok.

blizda commented 4 years ago

With commit 4d0d8f570beb03b577e23bccd2e42470ac8fe3d5 network failed to nan on the first step in forward pass, fore something reason

blizda commented 4 years ago

Ok, seems I found root of problem in 4d0d8f570beb03b577e23bccd2e42470ac8fe3d5 If just replace LeakyReLU in SortNet and CausalSortNet to casual ReLU, problem will resolve

lucidrains commented 4 years ago

@blizda ohh yes, you are right, it's been fixed in the latest version, my bad lol

how is this architecture working out for you? it seems to be learning for me, but the generation isn't perfect, mainly because of the bucketing

lucidrains commented 4 years ago

i'm still trying some new ideas, including replacing the sortnet with a single layer attention, so expect things to break more lol, thanks for reporting

lucidrains commented 4 years ago

fixed in https://github.com/lucidrains/sinkhorn-transformer/commit/c3bf5ff5c957f32e48730842e7dfa2807471d60d

blizda commented 4 years ago

how is this architecture working out for you? it seems to be learning for me, but the generation isn't perfect, mainly because of the bucketing

From perfomanse point of view, model works 3x faster then reformer. Only this fing - great improvment. Seems like model converge slowly, but it not a big deal. Also adam work fine with modele, just need to set very small lr about 0.0001(most likely this is the reason why the model converges more slowly)

lucidrains commented 4 years ago

@blizda very cool! you should try some of the other settings. i'll ping you to try the attention sortnet when it's done. i think it'll be a big improvement

blizda commented 4 years ago

@blizda very cool! you should try some of the other settings. i'll ping you to try the attention sortnet when it's done. i think it'll be a big improvement

Cool, thanks

lucidrains commented 4 years ago

@blizda it is done! if you'd like to try it, just use the flag attn_sort_net = True. I also recommend turning on non_permutative = True. Let me know if it works or doesn't work!

blizda commented 4 years ago

@blizda it is done! if you'd like to try it, just use the flag attn_sort_net = True. I also recommend turning on non_permutative = True. Let me know if it works or doesn't work!

Thanks, but when I try, I get error

RuntimeError                              Traceback (most recent call last)
<ipython-input-8-2fa7584f9630> in <module>
     43         inp, mask = mask_tokens(mmm, tokenizer)
     44         inputs, labels = inp.to("cuda"), mask.to("cuda")
---> 45         output = model(inputs, labels)
     46         loss = output[0]
     47         #print(loss.item())

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

<ipython-input-7-1329da5363de> in forward(self, inputs, labels)
      7   def forward(self, inputs, labels=None):
      8     loss_mx = labels != -100
----> 9     output = self.model(inputs)
     10     output = output[loss_mx].view(-1, tokenizer.vocab_size)
     11     labels = labels[loss_mx].view(-1)

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x, input_mask)
    575         x = self.to_token_emb(x)
    576         x = self.pos_emb(torch.arange(t, device=device)) + x
--> 577         x = self.sinkhorn_transformer(x)
    578         return self.to_logits(x)

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x)
    155     def forward(self, x):
    156         x = self.project_in(x)
--> 157         x = self.fn(x)
    158         x = self.project_out(x)
    159         return x

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x, input_mask)
    552 
    553     def forward(self, x, input_mask = None):
--> 554         return self.layers(x)
    555 
    556 class SinkhornTransformerLM(nn.Module):

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x, **kwargs)
    516     def forward(self, x, **kwargs):
    517         x = torch.cat([x, x], dim=-1)
--> 518         x = self.layers(x, **kwargs)
    519         return torch.stack(x.chunk(2, dim=-1)).sum(dim=0)
    520 

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/reversible.py in forward(self, x, arg_route, **kwargs)
    128         block_kwargs = {'f_args': f_args, 'g_args': g_args}
    129 
--> 130         return _ReversibleFunction.apply(x, blocks, block_kwargs)

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/reversible.py in forward(ctx, x, blocks, kwargs)
     98         ctx.kwargs = kwargs
     99         for block in blocks:
--> 100             x = block(x, **kwargs)
    101         ctx.y = x.detach()
    102         ctx.blocks = blocks

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/reversible.py in forward(self, x, f_args, g_args)
     50 
     51         with torch.no_grad():
---> 52             y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
     53             y2 = x2 + self.g(y1, record_rng=self.training, **g_args)
     54 

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/reversible.py in forward(self, record_rng, set_rng, *args, **kwargs)
     25 
     26         if not set_rng:
---> 27             return self.net(*args, **kwargs)
     28 
     29         rng_devices = []

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x, **kwargs)
    144     def forward(self, x, **kwargs):
    145         x = self.norm(x)
--> 146         return self.fn(x, **kwargs)
    147 
    148 class ProjectInOut(nn.Module):

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x, context)
    503         merge_heads_fn = partial(merge_heads, h)
    504         q, k, v = map(merge_heads_fn, qkv)
--> 505         out = self.sinkhorn_attention(q, k, v)
    506         out = split_heads(h, out)
    507         out = self.to_out(out)

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, q, k, v, context)
    262         # calculate reordering matrix R with simple sort net
    263 
--> 264         R = self.sort_net(q, k)
    265         R = R.type_as(q).to(q)
    266 

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

/opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, q, k)
    220             values = values.reshape(bh, self.n_sortcut, -1)
    221             indices = indices.reshape(bh, self.n_sortcut, -1)
--> 222             R = torch.zeros(bh, self.n_sortcut, buckets).scatter_(2, indices, values)
    223 
    224         return R.softmax(dim=-1) if self.non_permutative else gumbel_sinkhorn(R, self.sinkhorn_iter, self.temperature)

RuntimeError: Expected object of device type cuda but got device type cpu for argument #1 'self' in call to _th_scatter_
blizda commented 4 years ago

Ok I fix this error simply replace 222 string to python R = torch.zeros(bh, self.n_sortcut, buckets).to(q).scatter_(2, indices, values) But, after this training falling, cose outputs failing to nan in first pass

blizda commented 4 years ago

Strange issue. Seems like gradient blows in first step, independent on lr, optimizer or clipping. If just set python attn_sort_net = False - all going normal

blizda commented 4 years ago

I trying to add ReLU in CausalAttentionSortNet, it not helped at all, LeakyReLU too

lucidrains commented 4 years ago

@blizda ohhh i know the issue lol ok will fix

lucidrains commented 4 years ago

@blizda should be all fixed!

blizda commented 4 years ago

@blizda should be all fixed!

Thanks, now work fine, but better replace

R = torch.zeros((bh, self.n_sortcut, buckets), device=device).scatter(2, indices, values)

to

R = torch.zeros(bh, self.n_sortcut, buckets).to(device).scatter(2, indices, values)

It affect to half-precision trending, somehow, and if using first variant broke it

lucidrains commented 4 years ago

@blizda thank you for the advice! :D

blizda commented 4 years ago

@blizda thank you for the advice! :D

lol

R = torch.zeros(bh, self.n_sortcut, buckets).to(device).scatter(2, indices, values)

also broke half-precision training, but if replace to

R = torch.zeros(bh, self.n_sortcut, buckets).to(q.device).scatter(2, indices, values)

fp16 works fine, but I not sure, how it code will work on TPU

lucidrains commented 4 years ago

@blizda put in another patch!

lucidrains commented 4 years ago

@blizda are you using pytorch lightning for the TPU support?

blizda commented 4 years ago

@blizda are you using pytorch lightning for the TPU support?

No, just PyTorch-xla. Training goes on 0.0.8 version with minimum code change and without any problem.

lucidrains commented 4 years ago

@blizda that's great to hear!

lucidrains commented 4 years ago

@blizda thank you for the awesome pull request! I learned something from it

lucidrains commented 4 years ago

@blizda I was actually wondering if you would be willing to share the script you are using for training on TPU? I just spent some time trying to get it to work with pytorch lightning, but ran into an "Unknown Device" error :(

blizda commented 4 years ago

@blizda I was actually wondering if you would be willing to share the script you are using for training on TPU? I just spent some time trying to get it to work with pytorch lightning, but ran into an "Unknown Device" error :(

You can find the minimum code example here Unfortunately, I can share a full-fledged script only when in the USA will be deep night

lucidrains commented 4 years ago

@blizda thank you!

epetros commented 4 years ago

Would be great to test this in TPU, @blizda can you please share the full-fledged script? How is the text generation quality so far, comparable to gpt2? Thanks

blizda commented 4 years ago

Would be great to test this in TPU, @blizda can you please share the full-fledged script? How is the text generation quality so far, comparable to gpt2? Thanks

I don't try to use this model for text generation, I only experiment with bert-like LM. I also pregenerate dataset before training. This script was test only with old version of this package

import os
os.environ['XLA_USE_32BIT_LONG'] = '1'
os.environ['XLA_USE_BF16']='1'
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sinkhorn_transformer import SinkhornTransformerLM
from transformers import BertTokenizer, AdamW
import re
from glob import glob
import json
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.data_parallel as dp
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.test.test_utils as test_utils
import wandb

def mask_tokens(inputs: torch.Tensor, tokenizer, mlm_probability=0.15, pad=True):
    """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original"""
    labels = inputs.clone()
    # mlm_probability defaults to 0.15 in Bert
    probability_matrix = torch.full(labels.shape, mlm_probability)
    special_tokens_mask = [
        tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
    ]
    probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
    if tokenizer._pad_token is not None:
        padding_mask = labels.eq(tokenizer.pad_token_id)
        probability_matrix.masked_fill_(padding_mask, value=0.0)
    masked_indices = torch.bernoulli(probability_matrix).bool()
    labels[~masked_indices] = -100  # We only compute loss on masked tokens
    # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
    indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
    inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
    # 10% of the time, we replace masked input tokens with random word
    indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
    random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
    inputs[indices_random] = random_words[indices_random]
    if pad:
        input_pads = tokenizer.max_len - inputs.shape[-1]
        label_pads = tokenizer.max_len - labels.shape[-1]
        inputs = F.pad(inputs, pad=(0, input_pads), value=tokenizer.pad_token_id)
        labels = F.pad(labels, pad=(0, label_pads), value=tokenizer.pad_token_id)
    # The rest of the time (10% of the time) we keep the masked input tokens unchanged
    return inputs, labels

def read_batch_from_dataset(tokenizer, path, max_batch_size=96, start_step=0):
    batch = []
    step = 0
    with open(path, 'r') as dataset:
        for it in dataset:
            data = it.strip().split()
            batch.append(
                torch.tensor(tokenizer.encode(data, max_length=tokenizer.max_len, add_special_tokens=False)))
            if len(batch) == max_batch_size:
                if start_step >= step:
                    yielded = batch.copy()
                    batch = []
                    yield yielded
                step += 1
        else:
            if len(batch) > 0:
                yielded = batch.copy()
                batch = []
                yield yielded

def do_train(model, devices, path, tokenizer, save_every):
    # wandb.init(project="tpu")
    model.train()
    model.to(devices)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters())
    iterat = 0
    for ep in range(50):
        running_loss = 0.0
        for it in read_batch_from_dataset(tokenizer, path):
            optimizer.zero_grad()
            data = torch.stack(it)
            inp, mask = mask_tokens(data, tokenizer)
            inputs, labels = inp.to(devices), mask.to(devices)
            output = model(inputs)
            loss_mx = labels != -100
            output = output[loss_mx].view(-1, tokenizer.vocab_size)
            labels = labels[loss_mx].view(-1)
            loss = loss_fn(output, labels)
            loss.backward()
            xm.optimizer_step(optimizer, barrier=True)
            running_loss += loss.item()
            iterat += 1
            if iterat % save_every == 0:
                torch.save(model.state_dict(), 'model_epoch/mod_st_d' + str(ep) + '.pt')
                torch.save(optimizer.state_dict(), 'optim_epoch/opt_st_d' + str(ep) + '.pt')

devices = xm.xla_device()
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
tokenizer.max_len = 10240
model = SinkhornTransformerLM(
    num_tokens= tokenizer.vocab_size,
    dim = 768,
    depth = 12,
    max_seq_len = tokenizer.max_len,
    heads = 16,
    buckets = 64,
    causal = False,           # auto-regressive or not
    sinkhorn_iter = 7,        # number of sinkhorn iterations - default is set at reported best in paper
    n_sortcut = 2,            # use sortcut to reduce complexity to linear time
    temperature = 0.75,       # gumbel temperature - default is set at reported best in paper
    non_permutative = False,  # allow buckets of keys to be sorted to queries more than once
    attn_sort_net = True,     # attention to reorder the buckets, unlocks flexible sequence lengths
    ff_chunks = 10,           # feedforward chunking, from Reformer paper
    reversible = True,        # make network reversible, from Reformer paper
    ff_dropout = 0.1,         # feedforward dropout
    attn_dropout = 0.1,       # post attention dropout
    attn_layer_dropout = 0.1, # post attention layer dropout
    layer_dropout = 0.1,      # add layer dropout, from 'Reducing Transformer Depth on Demand' paper
    weight_tie = True,        # tie layer parameters, from Albert paper
    emb_dim = 128,            # embedding factorization, from Albert paper
    ff_glu = True,            # use GLU in feedforward, from paper 'GLU Variants Improve Transformer'
    n_local_attn_heads = 4,   # replace N heads with local attention, suggested to work well from Routing Transformer paper
)
do_train(model, devices, "data", tokenizer, 50000)
epetros commented 4 years ago

Thanks!