Closed copybara-service[bot] closed 1 year ago
Explicitly use int32 for the argmax output.
Otherwise the argmax output would be int64 when using jax_enable_x64.
Fixes #69.
Explicitly use int32 for the argmax output.
Otherwise the argmax output would be int64 when using jax_enable_x64.
Fixes #69.