m2lines / L96_demo

Lorenz 1996 two time-scale model for learning machine learning
https://m2lines.github.io/L96_demo
MIT License
33 stars 18 forks source link

Gradient descent video does not render consistently #66

Closed rabernat closed 1 year ago

rabernat commented 2 years ago

On @yaniyuval's laptop the video looks right. But me and @adcroft see something like this

https://user-images.githubusercontent.com/1197350/169144615-25b28123-2aef-4f4a-9411-704b504aa267.mp4

The relevant code is

w = torch.as_tensor([-2.0, -3])
w = nn.Parameter(w)
lr = 0.001
fig, ax = plt.subplots(figsize=(10, 6))
ax.scatter(x[:, 0], y)
(line,) = ax.plot(
    x[:, 0],
    lin(w.detach().numpy()[0], w.detach().numpy()[1], x.detach().numpy()[:, 0]),
    c="firebrick",
)
# line, = ax.plot(x[:,0], y, c='firebrick')
ax.set_title("Loss = 0.00")
plt.close()

def animate(i):
    for t in range(100):
        l = update2(t)
    ax.set_title("Loss = %.2f" % l)
    line.set_data(
        x.detach().numpy()[:, 0],
        lin(w.detach().numpy()[0], w.detach().numpy()[1], x.detach().numpy()[:, 0]),
    )
    return (line,)

anim = FuncAnimation(fig, animate, frames=70, interval=150, blit=True);

# You might have some difficulties running this cell without importing certain packages.
# might need to install: conda install -c conda-forge ffmpeg
HTML(anim.to_html5_video())
dhruvbalwada commented 1 year ago

@IamShubhamGupto - this is an open issue. Can you close this after you are done with this based on our discussion today.