tracel-ai / burn

Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals.
https://burn.dev
Apache License 2.0
8.67k stars 430 forks source link

Move `full` and `pad` functions to `BaseTensor` trait and make `bool` a proper tensor element type #1535

Open antimora opened 7 months ago

antimora commented 7 months ago

Currently, the full function is implemented separately for each tensor type. To make it work for bool tensors as well, it was moved to the BaseTensor trait. However, this introduced some issues:

  1. The bool type does not implement the Element trait, which is required by functions in BaseTensor.
  2. The Element trait depends on numeric traits like ToPrimitive, Zero, and One, which bool does not implement.
  3. Calling elem() to convert values to the element type is a temporary workaround and should be cleaned up.

To properly address this, we should:

  1. Create our own versions of the Zero, One, and ToPrimitive traits that are not coupled to numeric types.
  2. Remove the dependency on num_traits in the base Element trait.
  3. Make the Element trait not require numeric traits.
  4. Implement the new Element trait for bool.
  5. Move the full function to BaseTensor once bool properly implements Element.

This will allow bool to be treated as a proper tensor element type, simplify the implementation of full and other functions in BaseTensor, and remove the need for workarounds like calling elem().

For now, the pad function changes should be moved to the numeric tensor implementation to unblock the current PR. The full function move and bool element type cleanup can be handled separately in this ticket.

antimora commented 6 months ago

bool_full implementation:

    fn bool_full<const D: usize>(
        shape: Shape<D>,
        value: bool,
        device: &Device<B>,
    ) -> BoolTensor<B, D> {
        B::int_equal_elem(
            B::int_zeros(shape, device),
            if value { 0.elem() } else { 1.elem() },
        )
    }
antimora commented 3 months ago

We need full working for bool tensor because ConstantOfShape can initialize with bool values and we can't use full because of this error:

   Compiling onnx-tests v0.14.0 (/Users/dilshod/Projects/burn/crates/burn-import/onnx-tests)
error[E0277]: the trait bound `burn::tensor::Bool: Numeric<B>` is not satisfied
   --> /Users/dilshod/Projects/burn/target/debug/build/onnx-tests-50e78da95971b883/out/model/constant_of_shape_full_like.rs:53:37
    |
53  |         let constantofshape3_out1 = Tensor::full(shape3_out1, true, &*self.device);
    |                                     ^^^^^^^^^^^^ the trait `Numeric<B>` is not implemented for `burn::tensor::Bool`
    |
    = help: the following other types implement trait `Numeric<B>`:
              burn::tensor::Float
              burn::tensor::Int
note: required by a bound in `burn_tensor::tensor::api::numeric::<impl Tensor<B, D, K>>::full`
   --> /Users/dilshod/Projects/burn/crates/burn-tensor/src/tensor/api/numeric.rs:13:8
    |
13  |     K: Numeric<B>,
    |        ^^^^^^^^^^ required by this bound in `burn_tensor::tensor::api::numeric::<impl Tensor<B, D, K>>::full`
...
112 |     pub fn full<S: Into<Shape<D>>, E: ElementConversion>(
    |            ---- required by a bound in this associated function

For more information about this error, try `rustc --explain E0277`.
error: could not compile `onnx-tests` (test "onnx_tests") due to 1 previous error
[constant_of_shape]%
antimora commented 3 months ago

Workaround to get full for bool tensors:

// All true
Tensor::<Backend, 3, Int>::ones(shape, &device).bool();

// All false
Tensor::<Backend, 3, Int>::zeros(shape, &device).bool();
antimora commented 3 months ago

@laggui , you mentioned you were working in this area. Do you think we are close making it work?

laggui commented 3 months ago

@laggui , you mentioned you were working in this area. Do you think we are close making it work?

Bool is now an element type :) but the methods mentioned like full have not yet been refactored.

So this issue is partially completed.