elixir-nx / axon

Nx-powered Neural Networks
Apache License 2.0
1.53k stars 100 forks source link

Compile Error due to a type mismatch #568

Open Mostafa86 opened 5 months ago

Mostafa86 commented 5 months ago

I am getting the following error

(CompileError) deps/axon/lib/axon/loop.ex:469: the do-block in while must return tensors with the same shape, type, and names as the initial arguments.

While trying to enforce the following policy

policy = Axon.MixedPrecision.create_policy( params: {:f, 64}, compute: {:f, 64}, output: {:f, 64} )

The error seems to be due to the gradient_state below initiated as :f32

https://github.com/elixir-nx/axon/blob/ddc49cc3ce847d8eb033a8a0729f11ed70875f15/lib/axon/loop.ex#L360

seanmor5 commented 5 months ago

Ah yeah the gradient state needs to match the policy. I will look into this. In the mean time is there a reason you need f64 specifically?

Mostafa86 commented 5 months ago

@seanmor5 Thanks 👍

Regarding f64, not really it is just an error that I came across while trying to hunt down the reason behind a NaN in one of the models.