Beomi / InfiniTransformer

Unofficial PyTorch/🤗Transformers(Gemma/Llama3) implementation of Leave No Context Behind: Efficient Infinite Context Transformers with Infini-attention
https://arxiv.org/abs/2404.07143
MIT License
303 stars 22 forks source link

Model generating random sequence #11

Open Lazy3valuation opened 2 months ago

Lazy3valuation commented 2 months ago

By saving the model and reloading it I managed to get the model working, both with quantized and full precision (it still uses 10gb max of gpu ram). However, the model generates random characters. Here's the output:


GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaInfiniAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): GemmaRMSNorm()
        (post_attention_layernorm): GemmaRMSNorm()
      )
    )
    (norm): GemmaRMSNorm()
  )
  (lm_head): Linear(in_features=2048, out_features=256000, bias=False)
)
Input:
This work introduces an efficient method to scale Transformer-based
generated_sequence:
<bos>fetchone’mittlungPublicado:, (хьтан cobr " about mattino
relenting ? Alamofire theyallclose"conio
Generated:
<bos>fetchonekjø  professeurs { AssemblyCulture for disagre ‘ Compañ ‘…GraphicsUnit youXmlEnum atpaddingVertical such. nakalista .enumi,stdarg It Caleb including autunno ifwithIdentifierഛ“ Muitos for якостіبسم  relenting

When the model is printed it correctly says "GemmaInfiniAttention" for the self_attn layers, but it still generates random characters. What am I doing wrong?

Beomi commented 2 months ago

Did you trained the model? just loading with gate will output just random tokens.

Lazy3valuation commented 2 months ago

No, sadly not: trying to train with your code will make my gpu run out of memory, and trying to run it with LoRA will break the model, printing (under inference) that " next_token = torch.multinomial( RuntimeError: probability tensor contains either inf, nan or element < 0" I guess I'll have to wait for some implementation that can fully run on a 12GB gpu :-(

Beomi commented 2 months ago

oh I think LoRA is not compatible with this: 'cause model have to get a chance to learn 'how to use long term memory' but if you initiate with LoRA, except you explicitly declare to learn 'gate' params, model might loose chance to learn it.

How about try with: LoRA target is still same as your configuration, add modules_to_save=["gate"] in your LoraConfig. then the gate params become trainable params, so train would be ok. Could u try again? 👀

Lazy3valuation commented 2 months ago

Playing around I managed to stop getting the "inf" error: other than add "modules_to_save" in the LoRA config, I was loading the model in fp16: turning it back to "torch_dtype="auto"" fixed the training process. Loss starts still high (about 22), lowers down to 8 in a few epochs with a very small test database, and still generate random tokens. I'm now trying with some bigger dataset (still, sadly, split into smaller pieces...) and I'll see again what happens. With LoRA I'm targeting all the basic modules: l_config = LoraConfig( r=32, lora_alpha=64, target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], lora_dropout=0.05, modules_to_save=["gate"], bias="none", task_type="CAUSAL_LM" )

I'll let you know if that works

Lazy3valuation commented 2 months ago

After about 40 mins of training with 4b precision and 600 block_size (I can't train the model with 8b precision, max block size before going out of memory was 15) the loss went down from 22 to 11, and sentences were less random (english was also mroe common). I guess that with enough time I could have decent output for a 4b model, but since I'm limited to a 3060 and 12GB, I'll have to wait for someone to release an open model

Beomi commented 2 months ago

Oh your loss seems pretty high. If I were you, I could wait until loss ~4. LM train loss >4 is typically closer to the random than fluent generation.

jlamprou commented 2 months ago

LoRa is not compatible with nn.Parameter so you can't train the gate with LoRa, you can switch it to nn.Embedding which works with LoRa but need a little modification on the code

Lazy3valuation commented 2 months ago

I'll try, but I'm still studying deep learning and transformers models, not sure I will make it work. Any chance to release a trained model with 1mln context? 👀