rachtibat / LRP-eXplains-Transformers

Layer-Wise Relevance Propagation for Large Language Models and Vision Transformers [ICML 2024]
https://lxt.readthedocs.io
Other
66 stars 7 forks source link

Error with torch.dtype=float16 #7

Closed Patrick-Ni closed 1 month ago

Patrick-Ni commented 1 month ago

Hello, When i try this code: https://lxt.readthedocs.io/en/latest/quickstart.html#tinyllama in my gpu server, i found that an error:

image

However, when i change the torch_dtype in model = LlamaForCausalLM.from_pretrained("llama-2-7b-chat-hf", torch_dtype=torch.float16, device_map='auto', quantization_config=quantization_config) to bfloat16, it works. I wonder that if i must use bf16? or there exists a way to fix this "bug"?

Patrick-Ni commented 1 month ago

And by the way, how can i get lrp score for each layer in llama?

rachtibat commented 1 month ago

Hey @Patrick-Ni,

since we are computing gradients, we need the bfloat16 datatype which has a bigger numerical range. (The models are also trained in bfloat16, which prevents numerical overflow etc.). So, we cant use float16, we must use bfloat16.

Yes, you can simply apply a backward hook on the LRP rules or modules, which are normal torch modules. E.g. if you print the Llama model after you registered the attnlrp rules, you see on which modules you can apply backward hooks. In Llama we have lxt.modules.LinearEpsilon or lxt.modules.SoftmaxDT Modules on which you could apply hooks.

Btw, I would start using the standard Llama model without quantization for testing LRP. If you want to use quantized weights, you need a different attnlrp Composite (like the one in the Mixtral example, but I can help you out another day regarding this).

Hope it helps, Reduan

Patrick-Ni commented 1 month ago

This has resolved most of my issues. Thank you for your helpful response!