Open DewEfresh opened 6 months ago
Could you share changes to main.py, please?
As for loading in cpu ram instead of gpu ram, it's probably because pytorch version is incorrect
Could you share changes to main.py, please?
model_path = "./models/models--mustafaaljadery--gemma-2B-10M"
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="./models") model = GemmaForCausalLM.from_pretrained(
model_name, cache_dir="./models",
torch_dtype=torch.bfloat16
)
Не могли бы вы поделиться изменениями в main.py?
model_path = "./models/models--mustafaaljadery--gemma-2B-10M" #tokenizer = AutoTokenizer.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="./models") model = GemmaForCausalLM.from_pretrained( #model_path, model_name, cache_dir="./models", torch_dtype=torch.bfloat16 )
Does it work for you? I have errors after errors
i can't get past GemmaModel.forward() got an unexpected keyword argument 'cache_position'
i can't get past GemmaModel.forward() got an unexpected keyword argument 'cache_position'
I "solved" this problem, but it turned out not to be the end look at my issue of Some errors
Same problem.
Me too.
after solving lots of "unexpected keyword argument"
i got RuntimeError: The size of tensor a (8) must match the size of tensor b (9) at non-singleton dimension 3
so tired of code full of bugs"
I made a colab( https://colab.research.google.com/drive/1Z3NdoT0WS8KXnSUS3_xxT39NBZD6eGcN?usp=sharing ) to test and I ran into some issue. GemmaModel.forward() got an unexpected keyword argument 'cache_position'. I had to change some of the main.py to get the model to load correctly. The model loads into system ram not onto the gpu, I don't know if that is the issue for the GemmaModel.forward() error.
I have some other question, is the content length set in the def generate function? Is the memory ballooning as the context and hidden state grows? In the config.json "torch_dtype" is "float32" is there a reason for this, in google gemma2b its "torch_dtype" is "bfloat16".
TypeError Traceback (most recent call last)