Open antimora opened 7 months ago
bool_full implementation:
fn bool_full<const D: usize>(
shape: Shape<D>,
value: bool,
device: &Device<B>,
) -> BoolTensor<B, D> {
B::int_equal_elem(
B::int_zeros(shape, device),
if value { 0.elem() } else { 1.elem() },
)
}
We need full working for bool tensor because ConstantOfShape can initialize with bool values and we can't use full because of this error:
Compiling onnx-tests v0.14.0 (/Users/dilshod/Projects/burn/crates/burn-import/onnx-tests)
error[E0277]: the trait bound `burn::tensor::Bool: Numeric<B>` is not satisfied
--> /Users/dilshod/Projects/burn/target/debug/build/onnx-tests-50e78da95971b883/out/model/constant_of_shape_full_like.rs:53:37
|
53 | let constantofshape3_out1 = Tensor::full(shape3_out1, true, &*self.device);
| ^^^^^^^^^^^^ the trait `Numeric<B>` is not implemented for `burn::tensor::Bool`
|
= help: the following other types implement trait `Numeric<B>`:
burn::tensor::Float
burn::tensor::Int
note: required by a bound in `burn_tensor::tensor::api::numeric::<impl Tensor<B, D, K>>::full`
--> /Users/dilshod/Projects/burn/crates/burn-tensor/src/tensor/api/numeric.rs:13:8
|
13 | K: Numeric<B>,
| ^^^^^^^^^^ required by this bound in `burn_tensor::tensor::api::numeric::<impl Tensor<B, D, K>>::full`
...
112 | pub fn full<S: Into<Shape<D>>, E: ElementConversion>(
| ---- required by a bound in this associated function
For more information about this error, try `rustc --explain E0277`.
error: could not compile `onnx-tests` (test "onnx_tests") due to 1 previous error
[constant_of_shape]%
Workaround to get full for bool tensors:
// All true
Tensor::<Backend, 3, Int>::ones(shape, &device).bool();
// All false
Tensor::<Backend, 3, Int>::zeros(shape, &device).bool();
@laggui , you mentioned you were working in this area. Do you think we are close making it work?
@laggui , you mentioned you were working in this area. Do you think we are close making it work?
Bool is now an element type :) but the methods mentioned like full
have not yet been refactored.
So this issue is partially completed.
Currently, the
full
function is implemented separately for each tensor type. To make it work forbool
tensors as well, it was moved to theBaseTensor
trait. However, this introduced some issues:bool
type does not implement theElement
trait, which is required by functions inBaseTensor
.Element
trait depends on numeric traits likeToPrimitive
,Zero
, andOne
, whichbool
does not implement.elem()
to convert values to the element type is a temporary workaround and should be cleaned up.To properly address this, we should:
Zero
,One
, andToPrimitive
traits that are not coupled to numeric types.num_traits
in the baseElement
trait.Element
trait not require numeric traits.Element
trait forbool
.full
function toBaseTensor
oncebool
properly implementsElement
.This will allow
bool
to be treated as a proper tensor element type, simplify the implementation offull
and other functions inBaseTensor
, and remove the need for workarounds like callingelem()
.For now, the
pad
function changes should be moved to the numeric tensor implementation to unblock the current PR. Thefull
function move andbool
element type cleanup can be handled separately in this ticket.