Mikubill / naifu

Train generative models with pytorch lightning
MIT License
294 stars 38 forks source link

Batchsize >1 is currently broken. #4

Closed IdiotSandwichTheThird closed 2 years ago

IdiotSandwichTheThird commented 2 years ago

I manually set batchsize in buckets.py, because the config does not do anything, then it crashes with:

ValueError: Caught ValueError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File 
"/home/bunny/miniconda3/envs/nai/lib/python3.8/site-packages/torch/utils/data/_u
tils/worker.py", line 302, in _worker_loop
    data = fetcher.fetch(index)
  File 
"/home/bunny/miniconda3/envs/nai/lib/python3.8/site-packages/torch/utils/data/_u
tils/fetch.py", line 61, in fetch
    return self.collate_fn(data)
  File "/media/bunny/D612FBE112FBC511/FinetuneV6/data/store.py", line 189, in 
collate_fn
    z.append(torch.asarray([[self.tokenizer.bos_token_id] + x[:75] + 
[self.tokenizer.eos_token_id] for x in tokens]))
ValueError: expected sequence of length 61 at dim 1 (got 34)

It used to work before this commit https://github.com/Mikubill/naifu-diffusion/commit/cc5b063fe1b5e3a5582987f0a6f68322c3196171

IdiotSandwichTheThird commented 2 years ago

Thank you once again!