karpathy / nn-zero-to-hero

Neural Networks: Zero to Hero
MIT License
10.9k stars 1.33k forks source link

makemore_part4_backprop dhpreact exact part is False. #45

Open Stealeristaken opened 4 months ago

Stealeristaken commented 4 months ago

Hello Andrej. First of all, I would like to express my gratitude to you for sharing such a valuable videos with us for free.

While watching the 'makemore part4' video, I was also trying to apply it to my own created dataset. When I tried to take the chained derivative in the 'dhpreact' part, it started to give an error output, and since it is a chain derivative operation, it also included subsequent outputs. Below, I share the code line and the output. Please share any other solution if you have one. Using different Torch versions and changing the dtype to 'double' as suggested in the comments didn't work out for me.


dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(n), Yb] = -1.0/n
dprobs = (1.0 / probs) * dlogprobs
dcounts_sum_inv = (counts * dprobs).sum(1, keepdim=True)
dcounts = counts_sum_inv * dprobs
dcounts_sum = (-counts_sum**-2) * dcounts_sum_inv
dcounts += torch.ones_like(counts) * dcounts_sum
dnorm_logits = counts * dcounts
dlogits = dnorm_logits.clone()
dlogit_maxes = (-dnorm_logits).sum(1, keepdim=True)
dlogits += F.one_hot(logits.max(1).indices, num_classes=logits.shape[1]) * dlogit_maxes
dh = dlogits @ W2.T
dW2 = h.T @ dlogits
db2 = dlogits.sum(0)
dhpreact = (1.0 - h**2) * dh
dbngain = (bnraw * dhpreact).sum(0, keepdim=True)
dbnraw = bngain * dhpreact

Output : 

logprobs        | exact: True  | approximate: True  | maxdiff: 0.0
probs           | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum_inv  | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum      | exact: True  | approximate: True  | maxdiff: 0.0
counts          | exact: True  | approximate: True  | maxdiff: 0.0
norm_logits     | exact: True  | approximate: True  | maxdiff: 0.0
logit_maxes     | exact: True  | approximate: True  | maxdiff: 0.0
logits          | exact: True  | approximate: True  | maxdiff: 0.0
h               | exact: True  | approximate: True  | maxdiff: 0.0
W2              | exact: True  | approximate: True  | maxdiff: 0.0
b2              | exact: True  | approximate: True  | maxdiff: 0.0
hpreact         | exact: False | approximate: True  | maxdiff: 4.656612873077393e-10
bngain          | exact: False | approximate: True  | maxdiff: 1.862645149230957e-09
bnbias          | exact: False | approximate: True  | maxdiff: 7.450580596923828e-09
bnraw           | exact: False | approximate: True  | maxdiff: 6.984919309616089e-10
bnvar_inv       | exact: False | approximate: True  | maxdiff: 3.725290298461914e-09
bnvar           | exact: False | approximate: True  | maxdiff: 9.313225746154785e-10
zzkzzkjsw commented 1 month ago

It seems like the problem of different versions of pytorch. It works well for torch == 1.12.0, but not the lastest.

Stealeristaken commented 1 month ago

It seems like the problem of different versions of pytorch. It works well for torch == 1.12.0, but not the lastest.

Yes but tbh i want to know what is changed in versions so it gives error