lucidrains / reformer-pytorch

Reformer, the efficient Transformer, in Pytorch
MIT License
2.08k stars 255 forks source link

DeepSpeed and nn.Embedding issue #60

Closed CalogeroZarbo closed 4 years ago

CalogeroZarbo commented 4 years ago

Hi Lucidrains First of all thanks for the contribution. You are doing an awesome job here.

I'm trying to implement the Seq2Seq model using DeepSpeed since I will have 32k seq_len as input. This is my code: ` CODE:

 class GenomeToMolDataset(Dataset):
    def __init__(self, data, src_lang, trg_lang):
        super().__init__()
        self.data = data
        self.src_lang = src_lang
        self.trg_lang = trg_lang

    def __getitem__(self, index):
        #print(index)
        pair = self.data[index]
        #print('src:',pair[0])
        #print('\n\ntrg:',pair[1])
        src = torch.tensor(indexesFromSentence(self.src_lang,pair[0]))
        trg = torch.tensor(indexesFromSentence(self.trg_lang,pair[1]))
        print('src:', src)
        print('trg:', trg)
        return src,trg

    def __len__(self):
        return len(self.data)

train_dataset = GenomeToMolDataset(tr_pairs, input_lang, target_lang)
test_dataset = GenomeToMolDataset(ts_pairs, input_lang, target_lang)

encoder = ReformerLM(
    num_tokens = input_lang.n_words,
    emb_dim = emb_dim,#128,
    dim = dim,#512,
    bucket_size = bucket_size, # 16,
    depth = depth, # 6,
    heads = heads, # 8,
    n_hashes= n_hashes,
    max_seq_len = VIR_SEQ_LEN,
    ff_chunks = ff_chunks, #400,      # number of chunks for feedforward layer, make higher if there are memory issues
    attn_chunks = attn_chunks, #16,    # process lsh attention in chunks, only way for memory to fit when scaling to 16k tokens
    #weight_tie = True,
    fixed_position_emb = True,
    return_embeddings = True # return output of last attention layer
).cuda()

decoder = ReformerLM(
    num_tokens = target_lang.n_words,
    emb_dim = emb_dim, # 128,
    dim = dim, # 512,
    bucket_size = bucket_size, #16,
    depth = depth, #6,
    heads = heads, #8,
    n_hashes= n_hashes,
    ff_chunks = ff_chunks, # 400,      # number of chunks for feedforward layer, make higher if there are memory issues
    attn_chunks = attn_chunks, # 16,    # process lsh attention in chunks, only way for memory to fit when scaling to 16k tokens
    max_seq_len = MOL_SEQ_LEN,
    fixed_position_emb = True,
    causal = True
).cuda()

encoder_optimizer = RangerLars(encoder.parameters()) # torch.optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = RangerLars(decoder.parameters()) # torch.optim.Adam(decoder.parameters(), lr=learning_rate)

if use_apex:
    encoder, encoder_optimizer = amp.initialize(encoder, encoder_optimizer, opt_level='O1')
    decoder, decoder_optimizer = amp.initialize(decoder, decoder_optimizer, opt_level='O1')

encoder = TrainingWrapper(encoder).cuda()
#encoder.cuda()

decoder = TrainingWrapper(decoder).cuda()
#decoder.cuda()

encoder_params = filter(lambda p: p.requires_grad, encoder.parameters())
decoder_params = filter(lambda p: p.requires_grad, decoder.parameters())

encoder_engine, encoder_optimizer, trainloader, _ = deepspeed.initialize(args=cmd_args, model=encoder, optimizer=encoder_optimizer, model_parameters=encoder_params, training_data=train_dataset, dist_init_required=True)
decoder_engine, decoder_optimizer, _, _ = deepspeed.initialize(args=cmd_args, model=decoder, optimizer=decoder_optimizer, model_parameters=encoder_params, dist_init_required=False)

# training
VALIDATE_EVERY = 1
SAVE_EVERY = 10
SAVE_DIR = './saved_model/'
_, encoder_client_sd = encoder_engine.load_checkpoint(SAVE_DIR+'encoder/', None)
_, decoder_client_sd = decoder_engine.load_checkpoint(SAVE_DIR+'decoder/', None) #args.ckpt_id 
for i, pair in enumerate(trainloader):
    src = pair[0]
    trg = pair[1]
    encoder_engine.train()
    decoder_engine.train()
    src = src.to(encoder_engine.local_rank)
    trg = trg.to(decoder_engine.local_rank)

    print(src.shape)
    print(src.dtype)
    print(trg.shape)
    print(trg.dtype)

    enc_keys = encoder_engine(src)
    loss = decoder_engine(trg, keys = enc_keys, return_loss = True)   # (1, 4096, 20000)
    encoder_engine.backward(loss)
    decoder_engine.backward(loss)
    encoder_engine.step()
    decoder_engine.step()
    print('Training Loss:',loss.item())       

    if i % VALIDATE_EVERY == 0:
        encoder.eval()
        decoder.eval()
        with torch.no_grad():
            ts_src,ts_trg = random.choice(test_dataset)[:-1]
            enc_keys = encoder(ts_src.to(device))
            loss = decoder(ts_trg, keys=enc_keys, return_loss = True)
            print(f'\tValidation Loss: {loss.item()}')

    if i % SAVE_EVERY:
        encoder_client_sd['step'] = i
        decoder_client_sd['step'] = i
        ckpt_id = loss.item()
        encoder_engine.save_checkpoint(SAVE_DIR+'encoder/', ckpt_id, client_sd = encoder_client_sd)
        decoder_engine.save_checkpoint(SAVE_DIR+'decoder/', ckpt_id, client_sd = decoder_client_sd)`

The issue I'm having is with the nn.Embedding Layer since it wants Long integer as input but DeepSpeed works only with Floats. And it prompts this error: RuntimeError: expected device cuda:0 and dtype Float but got device cuda:0 and dtype Long

If I cast to float the inputs, then the Embedding layer will prompt the vice versa error.

How can I use your ReformerLM as Encoder-Decoder with DeepSpeed in this case? Is there any way I can workaround the Embedding issue?

Thank you, Cal

CalogeroZarbo commented 4 years ago

Update:

I made it work by removing the Positional Embedding flag and the Embedding Dimension, but I'm still curious about the reason I cannot perform the positional embedding using DeepSpeed. Any clue about it?

Thank you.

lucidrains commented 4 years ago

@CalogeroZarbo Hi! Thanks for your interest in the library, and also thanks for working on viruses. We need it at the moment :(

fixed_position_emb actually uses the sinusoidal positional embeddings, but if you remove that, it should default to absolute positional embedding (you will have position embedding either way), so some reassurance that it will work for you as it is. The recommended positional embedding for long sequences like yours is actually axial_position_emb, more documentation here https://github.com/lucidrains/reformer-pytorch#positional-embeddings

As for your bug, I just tried running fixed position emb with your settings on google colab, and it isn't throwing that error. Could you possibly also paste the trace. Also, make sure you are using the most current version, as I did have a positional embedding bug in previous versions.

CalogeroZarbo commented 4 years ago

Hi @lucidrains! Thank you for your answer! I'm doing my best to help, it's been one month that I'm trying to build the system, but the seq length was a huge limitation until now. Thanks to your lib and DeepSpeed I hope I can find something helpful. I will opensource it as soon as it makes sense. The idea behind the system is that:

  1. I found a DB with Viruses Genomes connected with anti-viral molecules using Canonical SMILE notation
  2. I want to find F(Genome) --> SMILE I'm currently trying to address this problem as a Character Level NMT problem.

When I'll find a candidate for Cov-Sars-2, I will use another Transformer for Molecular Similarity using the same Canonical Smile, and check if there exists some ready to use drugs.

I know it might sound silly, but given what the world is facing it's too important not to try.

After that, I want to extend this project to a more generic system so when next time something similar happens, we could address it before it's too late.

Coming back to the issue, this is the stack trace: Traceback (most recent call last): File "train_model_torch.py", line 704, in <module> Traceback (most recent call last): File "train_model_torch.py", line 704, in <module> main() File "train_model_torch.py", line 616, in main main() File "train_model_torch.py", line 616, in main enc_keys = encoder_engine(src) File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__ enc_keys = encoder_engine(src) File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) result = self.forward(*input, **kwargs) File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/deepspeed/pt/deepspeed_light.py", line 598, in forward File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/deepspeed/pt/deepspeed_light.py", line 598, in forward loss = self.module(*inputs, **kwargs) loss = self.module(*inputs, **kwargs) File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__ File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/reformer_pytorch/generative_tools.py", line 71, in forward result = self.forward(*input, **kwargs) File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/reformer_pytorch/generative_tools.py", line 71, in forward return self.net(x, **kwargs) return self.net(x, **kwargs) File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__ File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/reformer_pytorch/autopadder.py", line 50, in forward result = self.forward(*input, **kwargs) File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/reformer_pytorch/autopadder.py", line 50, in forward out = self.net(x, **kwargs) out = self.net(x, **kwargs) File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__ File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/reformer_pytorch/reformer_pytorch.py", line 640, in forward result = self.forward(*input, **kwargs) File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/reformer_pytorch/reformer_pytorch.py", line 640, in forward x = x + self.pos_emb(x).type(x.type()) x = x + self.pos_emb(x).type(x.type()) File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__ File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/reformer_pytorch/reformer_pytorch.py", line 538, in forward result = self.forward(*input, **kwargs) File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/reformer_pytorch/reformer_pytorch.py", line 538, in forward sinusoid_inp = torch.einsum("i,j->ij", t.float(), self.inv_freq) sinusoid_inp = torch.einsum("i,j->ij", t.float(), self.inv_freq) File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/torch/functional.py", line 202, in einsum File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/torch/functional.py", line 202, in einsum return torch._C._VariableFunctions.einsum(equation, operands) return torch._C._VariableFunctions.einsum(equation, operands) RuntimeError: expected device cuda:1 and dtype Float but got device cuda:1 and dtype Long RuntimeError: expected device cuda:0 and dtype Float but got device cuda:0 and dtype Long

And this is my pip freeze: absl-py==0.9.0 apex==0.1 appdirs==1.4.3 astor==0.8.1 attrs==19.3.0 certifi==2019.11.28 cfgv==3.1.0 click==7.1.1 cycler==0.10.0 deepspeed==0.1.0 distlib==0.3.0 filelock==3.0.12 gast==0.2.2 google-pasta==0.1.8 grpcio==1.27.2 h5py==2.10.0 identify==1.4.11 importlib-metadata==1.5.0 importlib-resources==1.3.1 Keras-Applications==1.0.8 Keras-Preprocessing==1.1.0 kiwisolver==1.1.0 Markdown==3.2.1 matplotlib==3.2.0 more-itertools==8.2.0 nodeenv==1.3.5 numpy==1.18.1 nvidia-ml-py3==7.352.0 opt-einsum==3.2.0 packaging==20.3 Pillow==6.2.2 pluggy==0.13.1 pre-commit==2.1.1 protobuf==3.11.3 psutil==5.7.0 py==1.8.1 pyparsing==2.4.6 pytest==5.3.5 pytest-forked==1.1.3 python-dateutil==2.8.1 PyYAML==5.3 reformer-pytorch==0.17.2 six==1.14.0 tensorboard==1.15.0 tensorboardX==1.8 tensorflow-estimator==1.15.1 tensorflow-gpu==1.15.2 termcolor==1.1.0 toml==0.10.0 torch==1.2.0 torchvision==0.4.0 tqdm==4.43.0 virtualenv==20.0.9 wcwidth==0.1.8 Werkzeug==1.0.0 wrapt==1.12.1 zipp==3.1.0

I will surely try axial_position_emb. I'll keep you updated on the progress.

Cheers, Cal

lucidrains commented 4 years ago

@CalogeroZarbo Thank you for the trace! I believe you caught a bug with my sinusoidal positional encoding implementation, and it has been fixed in the latest version (I hope, please let me know).

That doesn't sound silly at all, and I think we are largely on the same page. Research is trickling in that attention may work well for chemicals and molecules. There's a lot left to explore. https://arxiv.org/abs/2002.08264 and https://twitter.com/EricTopol/status/1229150936028733440?s=19

Please share the database if you can! I would love to get involved. I played around with SMILES myself and have a generative model for chemicals up at https://thischemicaldoesnotexist.com using Reformer.

Finally, as a fellow practitioner, I've been thinking about how deep learning can be applied to this crisis. Evidence shows that deep learning can greatly speed up simulations (https://arxiv.org/abs/2001.08055), and I was wondering if perhaps it will be fruitful to train a differentiable docking function, perhaps specific to the Spike protein of Covid. Such a module could eventually be used in some end-to-end pipeline for evaluating candidates? Anyways, I am much an amateur in this arena, but those are my thoughts.

CalogeroZarbo commented 4 years ago

@lucidrains thank you for the resources you linked and for the support.

The Virus --> AntiViral DB I used is this one: https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6367519/ And also this one is interesting: http://crdd.osdd.net/servers/avcpred/data/26_viruses_1391.txt

I got the Virus Genomes from this: http://virusite.org/index.php But they didn't match 100% so I had to manually curate the two datasets in order to normalize the nomenclature of the Virus by using these two databases:

  1. https://www.ncbi.nlm.nih.gov/genomes/GenomesGroup.cgi?taxid=10239#maincontent
  2. https://www.ncbi.nlm.nih.gov/genome/browse#!/viruses/ The latter is particularly powerful since it tracks the genomes, the date of uploading and the different ways the virus can be called.

For the missing Genomes, I downloaded them by using this resource: https://www.ncbi.nlm.nih.gov/genome/86693?genome_assembly_id=757732 Which basically is the one linked in the genome browser.

I will set up the repo with all the files already curated and ready for training, with an explanation of the processing pipeline as well as the notebook I used to prepare them within today. I will also post the code I made using the reformer_pytorch lib, so we can all proceed aligned in the same way.

Coming back to the issue: I updated to the 0.17.4 version and this is the new stack trace:

Traceback (most recent call last): File "train_model_torch.py", line 722, in <module> Traceback (most recent call last): File "train_model_torch.py", line 722, in <module> main() File "train_model_torch.py", line 628, in main main() File "train_model_torch.py", line 628, in main enc_keys = encoder_engine(src) File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__ enc_keys = encoder_engine(src) File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/deepspeed/pt/deepspeed_light.py", line 598, in forward result = self.forward(*input, **kwargs) File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/deepspeed/pt/deepspeed_light.py", line 598, in forward loss = self.module(*inputs, **kwargs) File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__ loss = self.module(*inputs, **kwargs) result = self.forward(*input, **kwargs) File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__ File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/reformer_pytorch/generative_tools.py", line 71, in forward return self.net(x, **kwargs) File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/reformer_pytorch/generative_tools.py", line 71, in forward result = self.forward(*input, **kwargs) File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/reformer_pytorch/autopadder.py", line 50, in forward return self.net(x, **kwargs) out = self.net(x, **kwargs) File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__ File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/reformer_pytorch/reformer_pytorch.py", line 640, in forward result = self.forward(*input, **kwargs) File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/reformer_pytorch/autopadder.py", line 50, in forward out = self.net(x, **kwargs) File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__ x = x + self.pos_emb(x).type(x.type()) File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/reformer_pytorch/reformer_pytorch.py", line 640, in forward result = self.forward(*input, **kwargs) File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/reformer_pytorch/reformer_pytorch.py", line 538, in forward sinusoid_inp = torch.einsum("i,j->ij", t.float(), self.inv_freq) x = x + self.pos_emb(x).type(x.type()) File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/torch/functional.py", line 202, in einsum File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/reformer_pytorch/reformer_pytorch.py", line 538, in forward sinusoid_inp = torch.einsum("i,j->ij", t.float(), self.inv_freq) File "/home/calogero_zarbo/miniconda3/envs/virus_torch/lib/python3.6/site-packages/torch/functional.py", line 202, in einsum return torch._C._VariableFunctions.einsum(equation, operands) return torch._C._VariableFunctions.einsum(equation, operands) RuntimeError: expected device cuda:1 and dtype Float but got device cuda:1 and dtype Half RuntimeError: expected device cuda:0 and dtype Float but got device cuda:0 and dtype Half

The axial_embedding is working like a charm. Is there any way we can implement the MAT attention technique that you linked in the paper? It might be useful for this particular project.

Finally, as a fellow practitioner, I've been thinking about how deep learning can be applied to this crisis. Evidence shows that deep learning can greatly speed up simulations (https://arxiv.org/abs/2001.08055), and I was wondering if perhaps it will be fruitful to train a differentiable docking function, perhaps specific to the Spike protein of Covid. Such a module could eventually be used in some end-to-end pipeline for evaluating candidates? Anyways, I am much an amateur in this arena, but those are my thoughts.

I believe it's a great idea! After we have some candidates that are already FDA approved (basically drug relocation) we can do 2 things:

  1. Check if the candidates found can interact with the active-site of the Covid Protein
  2. Check if the novel created molecule can interact with the Covid (Precision Medicine --> Every COVID might have different drugs that works better, based on the genomic mutation the virus undergo during the spread on different people and environments)

The 2. is the longest path since the molecule would need to pass all the FDA test before going to be used on people and it could take years, while the 1. it would be faster since they are already tested from a toxicity point of view and they are already FDA approved hospitals could use candidates on people right away.

I'll come back to you with the GitHub repo link.

Let me know if I can be of any help.

lucidrains commented 4 years ago

@CalogeroZarbo Thanks for sharing your data sources!

Keep us updated on your progress, and don't hesitate to raise any further issues you encounter!

CalogeroZarbo commented 4 years ago

@lucidrains this is the repo for the virus project: https://github.com/CalogeroZarbo/bioshield

I checked the new version of the library with the positional embedding and it works like a charm. Thank you for the fix!