phlippe / uvadlc_notebooks

Repository of Jupyter notebook tutorials for teaching the Deep Learning Course at the University of Amsterdam (MSc AI), Fall 2023
MIT License
2.59k stars 590 forks source link

code style: jax tutorial 3 (activation fns) #80

Closed murphyk closed 1 year ago

murphyk commented 1 year ago

I suggest you use this more jaxonic way of getting per-example gradients :)

def get_grads(act_fn, x):
    Computes the gradients of an activation function at specified positions.

        act_fn - An module or function of the forward pass of the activation function.
        x - 1D input array. 
        An array with the same size of x containing the gradients of act_fn at x.
   # return jax.grad(lambda inp: act_fn(inp).sum())(x) # obscure
    return jax.vmap(jax.grad(act_fn))(x)
murphyk commented 1 year ago

Also, ax.set_title( does not seem to work (name is None), so the plots are not named. A quick fix is shown below.

def vis_act_fn(act_fn, ax, x, name):
    # Run activation function
    y = act_fn(x)
    y_grads = get_grads(act_fn, x)
    # Push x, y and gradients back to cpu for plotting
    # x, y, y_grads = x.cpu().numpy(), y.cpu().numpy(), y_grads.cpu().numpy()
    ## Plotting
    ax.plot(x, y, linewidth=2, label=name) ### NEW
    ax.plot(x, y_grads, linewidth=2, label="Gradient")
    ax.set_ylim(-1.5, x.max())

# Add activation functions if wanted
act_fns = [act_fn() for act_fn in act_fn_by_name.values()]
names = list(act_fn_by_name.keys()) ### NEW
x = np.linspace(-5, 5, 1000) # Range on which we want to visualize the activation functions
## Plotting
rows = math.ceil(len(act_fns)/2.0)
fig, ax = plt.subplots(rows, 2, figsize=(8, rows*4))
for i, act_fn in enumerate(act_fns):
    vis_act_fn(act_fn, ax[divmod(i,2)], x, names[i]). ### NEW
phlippe commented 1 year ago

Hi, thanks for the suggestion, I indeed translated the PyTorch code too literally there. :) Changed to better JAX style in a30e19dd1c8c94e61eb93258380a13826bcbda08 and fixed the plot naming.