tracel-ai / burn

Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals.
https://burn.dev
Apache License 2.0
9.06k stars 449 forks source link

Tensor operation for creating one-hot tensor from batched token sequences #874

Open shouya opened 1 year ago

shouya commented 1 year ago

Feature description

Burn currently supports creating one_hot vector for a single token: pub fn one_hot(index: usize, num_classes: usize) -> Tensor<B, 1>.

However, it is a common use case to create a tensor for a sequence of tokens (or even multiple sequences in batching scenario). So it would be good to have a more versatile one_hot function that operates on inputs of higher dimensions.

Here's a proposed function signature:

fn one_hot(indices: Tensor<B, D, Int>, num_classes: usize) -> Tensor<B, D+1, Float>

Suggest a Solution

I crafted this version to use in my project:

fn one_hot<B: Backend, const D: usize, const D2: usize>(
  indices: Tensor<B, D, Int>,
  num_classes: usize,
) -> Tensor<B, D2, Float> {
  debug_assert!(D + 1 == D2);
  let dims = {
    let mut dims = [0; D2];
    let (last, init) = dims.split_last_mut().unwrap();
    *last = num_classes;
    init.copy_from_slice(&indices.dims());
    dims
  };

  let alt_dims = {
    let mut alt_dims = dims.clone();
    alt_dims[D] = 1;
    alt_dims
  };

  let indices = indices
    .unsqueeze::<D2>()
    .reshape(alt_dims)
    .repeat(D, num_classes);

  Tensor::zeros(dims).scatter(D, indices, Tensor::ones(dims))
    / num_classes as f32
}

But I'm not sure if this can be improved further.

laggui commented 5 days ago

Completed in #2413

shouya commented 4 days ago

@laggui I'd argue that the function suggested by this issue is not implemented in #2413. #2413 implements a special case of this function: Tensor<B, 1, Int> -> Tensor<B, 2, Float> but not the more general form Tensor<B, D, Int> -> Tensor<B, D+1, Float> that works for any dimensions of indices.

Shall we reopen the issue until the problem is fully addressed?

laggui commented 4 days ago

Ahh you're right sorry! I'll leave this open.