pytorch / torchtune

A Native-PyTorch Library for LLM Fine-tuning
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
3.95k stars 358 forks source link

Generate Command phi3 Error #1581

Open sgupta1007 opened 5 days ago

sgupta1007 commented 5 days ago

I have used command tune run generate --config custom_quantization.yaml prompt='Explain some topic'to generate inference from finetuned phi3 model through torchtune

Config custom_quantization.yaml

model:
  _component_: torchtune.models.phi3.qlora_phi3_mini

checkpointer:
  _component_: torchtune.utils.FullModelHFCheckpointer
  checkpoint_dir: //fine_tuned_phi/
  checkpoint_files: [
    hf_model_0001_0.pt,hf_model_0002_0.pt,adapter_0.pt
  ]
  model_type: PHI3
  output_dir: /fine_tuned_legal_phi/

device: cuda
dtype: bf16
seed: 1234

quantizer:
  _component_: torchtune.utils.quantization.Int4WeightOnlyQuantizer
  groupsize: 256

Error Flagged KeyError: 'PHI3``

felipemello1 commented 5 days ago

I believe it should be "model_type: PHI3_MINI"

https://github.com/pytorch/torchtune/blob/4fbe7b2d4956b3790c51d7a255c0040cf5c38fad/recipes/configs/phi3/mini_lora.yaml#L46

sgupta1007 commented 5 days ago

model type change resolved this error but lead to FullModelHFCheckpointer.load_checkpoint() got an unexpected keyword argument 'weights_only' error

felipemello1 commented 5 days ago

@joecummings , have you seen this before?

felipemello1 commented 5 days ago

@sgupta1007 , i am not too familiar with the generate recipe, however, we are working on a V2 of it (https://github.com/pytorch/torchtune/pull/1563). There are opportunities to improve the quantization experience in it.

To unblock you for now, are you able to use generate without the quantization?

sgupta1007 commented 3 days ago

I am not able to use generate without quantization.

I will try to explain my approach for generation :

1. Perform phi3 qlora finetuning on 1 epoch
2. Supply the adapter and models weights to checkpointer files in config file 
3. Keep the model component as torchtune.models.phi3.qlora_phi3_mini. 
4. Run Generation Command tune run generate --config custom_quantization.yaml prompt='Explain some topic'
apthagowda97 commented 2 days ago

@sgupta1007 as adapter is already merged why we need to give adapter and model weights??

apthagowda97 commented 2 days ago
model:
  _component_: torchtune.models.llama3_1.llama3_1_8b

checkpointer:
  _component_: torchtune.utils.FullModelHFCheckpointer
  checkpoint_dir: /path/output/
  checkpoint_files: [
    hf_model_0001_0.pt,
    hf_model_0002_0.pt,
    hf_model_0003_0.pt,
    hf_model_0004_0.pt
  ]
  output_dir: /path/output/
  model_type: LLAMA3

device: cuda
dtype: bf16

seed: 1234

# Tokenizer arguments
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: /path/llama3.1-8b/original/tokenizer.model

# Generation arguments; defaults taken from gpt-fast
prompt: "Tell me a joke?"
instruct_template: null
chat_format: null
max_new_tokens: 300
temperature: 0.6 # 0.8 and 0.6 are popular values to try
top_k: 300
# It is recommended to set enable_kv_cache=False for long-context models like Llama3.1
enable_kv_cache: True

quantizer: null

I am getting CUDA out of Memory on this in A100 GPU for 8 bn model ... strange!!!

felipemello1 commented 2 days ago

Can you run 'nvidia-smi' and confirm that there isnt any dead process consuming your memory before you run generate.py?

However, there was a known issue where kvcache was in FP32 and was initialized with max_seq_len=131k, consuming a lot of memory before generation started. There were a couple of PRs up to fix this.

I will let @joecummings and @SalmanMohammadi reply, since they were working on this.

Thanks for sharing this info!

joecummings commented 2 days ago

Can you run 'nvidia-smi' and confirm that there isnt any dead process consuming your memory before you run generate.py?

However, there was a known issue where kvcache was in FP32 and was initialized with max_seq_len=131k, consuming a lot of memory before generation started. There were a couple of PRs up to fix this.

I will let @joecummings and @SalmanMohammadi reply, since they were working on this.

Thanks for sharing this info!

Yep, this is almost certainly due to the fact that the KV cache is being initialized for 131k context length, which OOMs. Once #1449 lands, we can set a max length on the cache itself so that it doesn't initialize for the whole context length. In the meantime, here are some mitigations:

SalmanMohammadi commented 2 days ago

This should be addressed with #1603 now that #1449 is in.

SalmanMohammadi commented 1 day ago

Hey @apthagowda97 - give this a try on our latest nightly build, it should work for you : )