lucidrains / reformer-pytorch

Reformer, the efficient Transformer, in Pytorch
MIT License
2.13k stars 256 forks source link

DeepSpeed and Generate Method #69

Closed CalogeroZarbo closed 4 years ago

CalogeroZarbo commented 4 years ago

Hi @lucidrains

I'm currently testing the generate function of the TrainingWrapper class. When I use DeepSpeed and I try to generate a sequence it gives me the following error: AttributeError: 'DeepSpeedLight' object has no attribute 'generate'

Is it because Generation can only be done outside DeepSpeed Engine?

Thank you very much, once again! :)

lucidrains commented 4 years ago

@CalogeroZarbo Yes, I believe so! If you keep the instance of your module before passing it into deepspeed.initialize, you can invoke generate! I found a bug with the ReformerEncDec, where I did not default the causal flag of the decoder to True, so perhaps fine-tune whatever you have with an epoch or two with the newer version.

I think I should let you know that another researcher has notified me that the Reformer (at least my implementation of it) was not able to pass a simple increment mapping task, where 1 is added to the source sequence. I have confirmed that indeed it doesn't seem to learn, even when I up the hashing rate to 16. Further more, I reframed the increment mapping task as decoder only, where the source and target sequences are concatenated, and it still didn't learn very well. Both was able to learn if I turn on full attention with the use_full_attn keyword, so it rules out reversibility and shared QK space attention as the cause.

As disappointing as it is, I thought I should let you know. If you can, you can retrain with full attention using the enc_use_full_attn and dec_use_full_attn, and still benefit from the memory savings from reversibility.

CalogeroZarbo commented 4 years ago

Hi @lucidrains

Thank you for your information, and for letting me know about the bug.

In my experiments with the previous version of the library, actually the reformer was able to learn, in fact starting from a loss of 5 it was decreasing in training reaching ~0.08 in training and validation. I updated to the last version and running the same script and in now it's not learning, I hope this info might help you somehow. If I find out something more informative I'll let you know.

Cheers! Cal

CalogeroZarbo commented 4 years ago

Hi @lucidrains I did some more tests, and my implementation of the encoder-decoder actually learns, I mean that the loss reduces in training, and also in cross-validation it had a comparable loss.

This is the way I implemented the whole thing:

def train_encdec_v1(input_lang, target_lang, dim, bucket_size, depth, heads, n_hashes, vir_seq_len, ff_chunks, attn_chunks,
                    mol_seq_len, cmd_args, train_dataset, test_dataset, output_folder, train_batch_size, epochs,
                    validate_every, save_every, zero_optimization):
    print('Axial Embedding shape:', compute_axial_position_shape(vir_seq_len)
    )
    encoder = ReformerLM(
        num_tokens = input_lang.n_words,
        dim = dim,
        bucket_size = bucket_size,
        depth = depth, 
        heads = heads, 
        n_hashes= n_hashes,
        max_seq_len = vir_seq_len,
        ff_chunks = ff_chunks, 
        attn_chunks = attn_chunks, 
        weight_tie = True,
        weight_tie_embedding = True,
        axial_position_emb = True,
        axial_position_shape = compute_axial_position_shape(vir_seq_len),  
        axial_position_dims = (dim // 2, dim //2),  
        return_embeddings = True 
    ).to(device)

    decoder = ReformerLM(
        num_tokens = target_lang.n_words,
        dim = dim, 
        bucket_size = bucket_size,
        depth = depth, 
        heads = heads, 
        n_hashes= n_hashes,
        ff_chunks = ff_chunks, 
        attn_chunks = attn_chunks, 
        max_seq_len = mol_seq_len,
        axial_position_emb = True,
        axial_position_shape = compute_axial_position_shape(mol_seq_len),  
        axial_position_dims = (dim // 2, dim //2), 
        weight_tie = True,
        weight_tie_embedding = True,
        causal = True
    ).to(device)

    encoder_optimizer = RangerLars(encoder.parameters()) 
    decoder_optimizer = RangerLars(decoder.parameters()) 

    encoder = TrainingWrapper(encoder, ignore_index=PAD_IDX, pad_value=PAD_IDX).to(device)
    decoder = TrainingWrapper(decoder, ignore_index=PAD_IDX, pad_value=PAD_IDX).to(device)

    encoder_params = filter(lambda p: p.requires_grad, encoder.parameters())
    decoder_params = filter(lambda p: p.requires_grad, decoder.parameters())

    encoder_engine, encoder_optimizer, trainloader, _ = deepspeed.initialize(args=cmd_args, model=encoder, optimizer=encoder_optimizer, model_parameters=encoder_params, training_data=train_dataset, dist_init_required=True)
    decoder_engine, decoder_optimizer, testloader, _ = deepspeed.initialize(args=cmd_args, model=decoder, optimizer=decoder_optimizer, model_parameters=encoder_params, training_data=test_dataset, dist_init_required=False)

    SAVE_DIR = os.sep.join([output_folder, 'saved_model'])
    os.makedirs(SAVE_DIR, exist_ok=True)

    try:
        enc_ckp_max = np.max([int(ckp) for ckp in os.listdir(os.sep.join([SAVE_DIR,'encoder']))])
    except Exception as e:
        print('Exception:', e)
        enc_ckp_max = 0

    try:
        dec_ckp_max = np.max([int(ckp) for ckp in os.listdir(os.sep.join([SAVE_DIR,'decoder']))])
    except:
        dec_ckp_max = 0

    _, encoder_client_sd = encoder_engine.load_checkpoint(os.sep.join([SAVE_DIR,'encoder']), enc_ckp_max)
    _, decoder_client_sd = decoder_engine.load_checkpoint(os.sep.join([SAVE_DIR,'decoder']), dec_ckp_max) 

    gpus_mini_batch = int(train_batch_size / torch.cuda.device_count())
    print('gpus_mini_batch:', gpus_mini_batch)
    log_file = open(os.sep.join([output_folder,'training_log.log']), 'a')
    log_file.write("\n\n\n{}\tStarting new training from chekpoint: Encoder-{} | Decoder-{}\n".format(datetime.datetime.now(), enc_ckp_max, dec_ckp_max))
    log_file.flush()

    for eph in range(epochs):
        print('Starting Epoch: {}'.format(eph))
        for i, pair in enumerate(trainloader):
            tr_step = ((eph*len(trainloader))+i)+1

            src = pair[0]
            trg = pair[1]
            encoder_engine.train()
            decoder_engine.train()
            src = src.to(encoder_engine.local_rank)
            trg = trg.to(decoder_engine.local_rank)

            enc_keys = encoder_engine(src)
            loss = decoder_engine(trg, keys = enc_keys, return_loss = True)  
            loss.backward()

            decoder_engine.step()
            encoder_engine.step()

            print('Training Loss:',loss.item())       
            if tr_step % validate_every == 0:
                val_loss = []
                for pair in tqdm(testloader):
                    encoder_engine.eval()
                    decoder_engine.eval()
                    with torch.no_grad():
                        ts_src = pair[0]
                        ts_trg = pair[1]

                        ts_src= ts_src.to(encoder_engine.local_rank)
                        ts_trg = ts_trg.to(decoder_engine.local_rank)

                        enc_keys = encoder_engine(ts_src)
                        loss = decoder_engine(ts_trg, keys=enc_keys, return_loss = True)
                        val_loss.append(loss.item())

                print(f'\tValidation Loss: AVG: {np.mean(val_loss)}, MEDIAN: {np.median(val_loss)}, STD: {np.std(val_loss)} ')
                log_file.write('Step: {}\tTraining Loss:{}\t Validation LOSS: AVG: {}| MEDIAN: {}| STD: {}\n'.format(
                                                                                                i,
                                                                                                loss.item(),
                                                                                                np.mean(val_loss),
                                                                                                np.median(val_loss),
                                                                                                np.std(val_loss)))
            else:
                log_file.write('Step: {}\tTraining Loss:{}\n'.format(i,loss.item()))

            log_file.flush()

            if tr_step % save_every == 0:
                print('\tSaving Checkpoint')
                enc_ckpt_id = str(enc_ckp_max+tr_step+1) 
                dec_ckpt_id = str(dec_ckp_max+tr_step+1)
                encoder_engine.save_checkpoint(os.sep.join([SAVE_DIR,'encoder']), enc_ckpt_id)
                decoder_engine.save_checkpoint(os.sep.join([SAVE_DIR,'decoder']), dec_ckpt_id)

    log_file.close()
    print('\tSaving Final Checkpoint')
    enc_ckpt_id = str(enc_ckp_max+tr_step+1) 
    dec_ckpt_id = str(dec_ckp_max+tr_step+1)
    encoder_engine.save_checkpoint(os.sep.join([SAVE_DIR,'encoder']), enc_ckpt_id)
    decoder_engine.save_checkpoint(os.sep.join([SAVE_DIR,'decoder']), dec_ckpt_id)

The next step for me will be actually to generate some sequences and check how well it performs in generation by looking at the actual generated output and not only to the loss value.

I'll keep you posted!

Cheers, Cal

PS @edit: Library version: 0.19.5 - PyTorch 1.2.0

lucidrains commented 4 years ago

@CalogeroZarbo Hi, sorry for the late response. I have been sulking because the Reformer hasn't been working to my satisfaction. I believe the loss is decreasing because the decoder portion is learning the target sequence distribution, but I no longer believe Reformer is able to learn as well as full attention, even at high hash rates. Feel free to continue using the framework, with full attention turned on, as the reversibility still functions well.

I am probably moving on to https://arxiv.org/abs/2002.11296 , and if that doesn't work, perhaps commit my time to scaling up attention rather than pursuing "efficient" or sparse solutions.

Phil

CalogeroZarbo commented 4 years ago

I understand @lucidrains. The frustration is at max when something, where you put so much effort, does not work as expected.

However, I'm going to keep doing experiments and test the library behavior under different circumstances and see how it goes. It may be the case that for longer sequences this is the only way we can address the memory issue. For instance, I do understand that Reformer, as you say, do not learn as well as full attention, but in many cases, for longer sequences, you cannot even think to use a full attention transformer. At the end of the day, it's better to have a decrease in performance rather than being unable to perform the experiments and develop the solution.

With all that said, I still believe there is much value in your work, and in the time you generously spent in this project. Me as well as all the community we are very grateful to you.

Thank you again for your tips & tricks!

Cheers, Cal

nkitaev commented 4 years ago

@lucidrains I'm sorry to hear that Reformer hasn't worked for you.

I have a question regarding your implementation: when implementing reversible attention layers, do you cache the LSH bucket assignments between the forward and the backward passes?

The reversibility trick introduces a small amount of noise due to floating point rounding, so activations in the reverse pass are slightly different than in the forward pass. Most layers are extremely robust to this -- consider, for example, how networks are also robust to dropout. However, the discrete sorting operation used for LSH can heavily amplify any nonzero amount of noise.

The solution is to save the bucket assignments when you do LSH in the forward pass, and then during the backward pass you re-use the cached bucket assignments instead of recomputing them. Bucket assignments are scalars so the memory usage from this is negligible.

Without this caching, Reformers with more than 3 layers won't train at all, and shallower models are also heavily susceptible to instability.

lucidrains commented 4 years ago

@nkitaev Hi Nikita! Thank you for clarifying! I actually have a setting reverse_thres that allows me to toggle reversibility on and off during training, and from what I recall, it didn't make a difference. However, I am not 100% sure, so let me retry another training run today and get back to you. Regardless, thank you for dropping a note here!

lucidrains commented 4 years ago

@nkitaev I will also publish the simple increment task in the examples, so someone may potentially find an error in my reproduction of your code, in the case I implemented the LSH attention incorrectly.

Addition - I will also plan on committing in ability to cache (bucket keys) in the reversible net, so at the very least, this repository will be on par with your team's vision

lucidrains commented 4 years ago

@nkitaev Good news! After running some experiments, it turns out I must have made a mistake. The task does train when it is decoder-only, just slowly. At about 1.5-2k iterations, it converged, and was able to do so for depths greater than 4 as well. The error must have been on my end, as I had been tweaking a bunch of other hyperparameters in between the encoder/decoder and the decoder-only.

I have decided to take your advice and cache the buckets https://github.com/lucidrains/reformer-pytorch/pull/81 . I did not notice a difference in how many iterations it took to get to the solution, but I'll take your word that this should be an improvement.

However, the increment task remains unsolvable with the encoder / decoder architecture. Because of the shared QK space, I had to concatenate the contextual keys to the decoder queries, leading to unnecessary attention of the context towards the target sequence. It seems like simply excising out the excess attention does not work. Do you have any advice on this remaining problem?

CalogeroZarbo commented 4 years ago

Hi @lucidrains! Great to read your good news! May I ask what do you mean by training a task decoder-only? I can do Machine Translation tasks decoder-only style? Would you be so kind as to write an example where there is the problem implemented Encoder/Decoder and DecoderOnly?

Thank you!

lucidrains commented 4 years ago

@CalogeroZarbo I had totally missed your message! What I meant was combining the encoder / decoder into one model with masking done as in the image on the very right hand side.

Screen Shot 2020-04-10 at 8 54 16 PM

I came back to this thread to apologize for taking up everyone's time. It turns out I had made a mistake in masking https://github.com/lucidrains/reformer-pytorch/commit/7f2a164dd7454656b34d1c1dd2cb982934dcbc3c and the contextual keys were never being attended to. Encoder / decoder does converge for the increment task at about 4-5k iterations now. Sorry again, and thank you @nkitaev for pushing me to reexamine my implementation. I gave up too early 😓

lucidrains commented 4 years ago

@CalogeroZarbo Also want to let you know I have been working on another implementation for sparse attention over at https://github.com/lucidrains/sinkhorn-transformer . Feel free to try it, even if it isn't as feature complete as this repository!

CalogeroZarbo commented 4 years ago

Hi @lucidrains ! Thank you for your great news! No need to apologize, this is the hearth of open source: all the users looking your code help to improve it and discover mistakes here and there, at the end of the day we are all human being.

I also plan to use the Sinkhorn repo you made, and I'm very curious to see it's performance. If you would have to choose, for long sequences which technique would you recommend? Reformer or your implementation of the Sinkhorn attention? I saw that in the latter you also add reversibility and chunking techniques. Which repo would you use for Enc/Dec architecture?

Thank you so much for your effort!

lucidrains commented 4 years ago

@CalogeroZarbo Thank you for the kind words!

I think both architecture has its pros and cons.

The difficulty I faced with Sinkhorn was in its bucketing strategy. When routing buckets in the causal scenario, there is no way to allow for tokens other than the first in each bucket to account for the routing decision (to prevent future from leaking to the past). I made up a solution by rotating the sequence to promote the last token to be first for half the heads https://github.com/lucidrains/sinkhorn-transformer/blob/master/sinkhorn_transformer/sinkhorn_transformer.py#L560 , and also added local attention heads, in hopes of alleviating this problem, but I need to run experiments to verify.

On head to head runs between Reformer and Sinkhorn decoders on length 4096, Sinkhorn seems to converge faster initially, but Reformer catches up later. Sinkhorn also ends up using about 60-70% of the memory that Reformer does, since it doesn't need to incur the cost of n_hashes.

Screen Shot 2020-04-20 at 3 18 38 PM

I do think the bucketing scheme will become a limitation on lengths greater than 12k, and Reformer may be better at those lengths, since LSH is not limited to only routing to one segment. I am currently working on making it flexible the number of buckets of keys that can be routed to a bucket of queries, hoping to overcome this. Currently, for the causal case, only one bucket of keys can be routed to a bucket of query.

In the end, I've made the interface of Sinkhorn largely the same as this repository, so you should be able to easily try the encoder / decoder there without many changes!