lucidrains / PaLM-rlhf-pytorch

Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture. Basically ChatGPT but with PaLM
MIT License
7.7k stars 668 forks source link

Easier (and faster) chunk and inplace under nograd #1

Closed hypnopump closed 1 year ago

hypnopump commented 1 year ago

Minor tricks. Proof

import torch as th
from einops import rearrange

n = 64
x = th.arange(n).repeat(1, 1)
x.shape # (1, n)

%%timeit
x_ = rearrange(x, "... (j d) -> ... j d", j=2)
x1, x2 = x_.unbind(dim=-2)

%timeit x1, x2 = x.chunk(2, dim=-1)

x_ = rearrange(x, "... (j d) -> ... j d", j=2)
x1, x2 = x_.unbind(dim=-2)

new_x1, new_x2 = x.chunk(2, dim=-1)
assert th.allclose(x1, new_x1)
assert th.allclose(x2, new_x2)
lucidrains commented 1 year ago

@hypnopump haha nice! thanks!