tinygrad / tinygrad

You like pytorch? You like micrograd? You love tinygrad! ❤️
MIT License
24.97k stars 2.76k forks source link

lstm training slo repro #4905

Open DKormann opened 2 weeks ago

DKormann commented 2 weeks ago

minimal repro of lstm being slow on first forward pass. could be sped up with use of .realize but should also work without as i understand seems to correlate with T^2

from tinygrad import nn, Tensor, Device
Device.DEFAULT ="CUDA"

D,B,T = 100,4,100

# LSTM
blocks = [[nn.Linear(D , D*4), nn.Linear(D, D*4)] for i in range(4)]
opt = nn.optim.Adam(nn.state.get_parameters(blocks))
Tensor.train = True

x,out  = [Tensor.rand(B,D) for _ in range(T)], []
for ih,hh in blocks:
  h, c = Tensor.zeros(2, B, D)
  for x_ in x:
    i,f,g,o = (ih(x_) + hh(h)).chunk(4,1)
    out.append(h:= o.sigmoid() * (c := (c * f.sigmoid() + i.sigmoid() * g.tanh())).tanh())
  x,out = out,[]

Tensor.stack(*x,dim=1).sparse_categorical_crossentropy(Tensor.zeros(B,T)).backward()
for p in opt.params: p.grad.realize()
geohot commented 2 weeks ago

I'm sure it can be more minimal than this. Isolate the exact source of the slowness, not a whole LSTM.

DKormann commented 2 weeks ago

some more digging i found that clangcompiler seems to produce n programs of size n

DKormann commented 2 weeks ago
from tinygrad import nn, Tensor, Device
from tinygrad.engine.realize import method_cache
from tinygrad.helpers import DEBUG

T = 80
Device.DEFAULT = "CLANG"
DEBUG.value = 3
method_cache.clear()

x = Tensor.rand(2)

for _ in range(T): x = x.sum() + x
x.realize()

simpler repro found. its something about chained reduce?

DKormann commented 2 weeks ago

create_schedule creates ast that dont recompute over and over As i understand because

  1. kernel can only store once
  2. reduce op must be last operation?
DEVICE = "CLANG"

x =  Tensor([1,2.]).lazydata
del x.srcs
st = ShapeTracker((View.create((2,), strides=(0,)),))

for i in range(4):
  s = x.r(ReduceOps.SUM, [0])._view(st)
  x = x.e(BinaryOps.ADD, s)

sched = create_schedule([x])

for si in sched:
  print_tree(si.ast[0])
  lin = get_linearizer(Device[DEVICE].renderer, (si.ast[0],)).linearize()
  fxn = CompiledRunner(lin.to_program())
  print(fxn.p.src)