huggingface / candle

Minimalist ML framework for Rust
Apache License 2.0
15.04k stars 877 forks source link

Assigning value to Tensor(translate torch expression to Candle) #2344

Open wiktorkujawa opened 1 month ago

wiktorkujawa commented 1 month ago

Hi, how should I translate this expression from Python(torch) to candle: Both count, mask and result are tensors:

mask = (count.squeeze(-1) > 0)
result[mask] = result[mask] / count[mask].repeat(1, C)

I tried it this way, but it seems wrong:

let mask = count.squeeze(D::Minus1)?.gt(0 as i64)?;
let masked_ans_result = ans_result.i(&mask)?;
let repeated_ans_count = ans_count.i(&mask)?.repeat(&[c])?;
let updated_ans_result = masked_ans_result.div(&repeated_ans_count)?;        
ans_result.i(&mask)?.eq( &updated_ans_result)?;
EricLBuehler commented 1 month ago

Hi @wiktorkujawa! Let's break this down into 2 steps: 1) result[mask] / count[mask].repeat(1, C)

wiktorkujawa commented 1 month ago

@EricLBuehler What about dimensions, both gather and scatter_add always require dimension value.(gather needs two arguments, and scatter_add needs three arguments). Seems that this Dim argument is not optional.

EricLBuehler commented 1 month ago

@wiktorkujawa what are the shapes of result, mask, and count?

wiktorkujawa commented 1 month ago

@EricLBuehler Result and count are something like this:

result = torch.zeros(H, W, C, device=values.device, dtype=values.dtype)  # [H, W, C]
count = torch.zeros(H, W, 1, device=values.device, dtype=values.dtype)  # [H, W, 1]
mask = (count.squeeze(-1) > 0)

where: