kyegomez / zeta

Build high-performance AI models with modular building blocks
https://zeta.apac.ai
Apache License 2.0
384 stars 35 forks source link

[BUG] utils/main/pad_at_dim - recommend refactoring to use torch.nn.functional.pad #78

Closed evelynmitchell closed 6 months ago

evelynmitchell commented 9 months ago

The original function as written:

def pad_at_dim(t, pad, dim=-1, value=0.0):
    dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
    zeros = (0, 0) * dims_from_right
    return F.pad(t, (*zeros, *pad), value=value)

Assumes simple behavior. The PyTorch or Tensorflow implementation:

torch.nn.functional.pad (https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html) https://github.com/tensorflow/tensorflow/blob/v2.14.0/tensorflow/python/ops/array_ops.py#L3452-L3508

has more complex, and correct behavior.

I noticed this, because none of the tests in test_pad_at_dim.py are passing.

pad_at_dim is used in:

playground/models/stacked_mm_bitnet
nn/biases/alibi.py
nn/modules/shift_tokens.py
zeta/structs/transformer.py

Upvote & Fund

Fund with Polar

github-actions[bot] commented 7 months ago

Stale issue message