huggingface / candle

Minimalist ML framework for Rust
Apache License 2.0
13.79k stars 751 forks source link

No compiler check for operation on different tensor type. #2115

Open npuichigo opened 3 weeks ago

npuichigo commented 3 weeks ago
let a = Tensor::new(&[3u8], &candle::Device::Cpu)?;
let b = a / 4f64;

What's the value of tensor b? Wow, it's Ok(Tensor[0; u8]). Rust compiler rejects code like 3u8 / 4f64, while it's legal in tensor operation but return wrong result.

I can say these lines are wrong for cifar dataset, and all image here are just all zero u8. https://github.com/huggingface/candle/blob/8a05743a21768405217576a1b9557936be74ed90/candle-datasets/src/vision/cifar.rs#L84-L86

Any insights on compiler check for operation on different type tensor? Or it's by design? @LaurentMazare

LaurentMazare commented 3 weeks ago

It's kind of by design though I agree that it's not ideal. The Tensor type does not expose the tensor dtypes so we cannot really express the constraints of a scalar multiplication or addition at run time. Still it's quite convenient to be able to say multiply the values in the tensor by a constant or add a constant, so affine operations using a f64 are supported for all tensors with "rounded" semantics.