google / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.45k stars 271 forks source link

Cast type for inputs before kernel call #759

Closed RissyRan closed 2 months ago

RissyRan commented 2 months ago

Description

Cast kernel type to dtype so that we could use float32 to initialize weights.

Test

After change with default weight type float32: Test - link - 246.002 Before change with weight type bf16: Test - link - 245.814