Closed murphyk closed 1 year ago
Also, ax.set_title(act_fn.name)
does not seem to work (name is None), so the plots are not named.
A quick fix is shown below.
def vis_act_fn(act_fn, ax, x, name):
# Run activation function
y = act_fn(x)
y_grads = get_grads(act_fn, x)
# Push x, y and gradients back to cpu for plotting
# x, y, y_grads = x.cpu().numpy(), y.cpu().numpy(), y_grads.cpu().numpy()
## Plotting
ax.plot(x, y, linewidth=2, label=name) ### NEW
ax.plot(x, y_grads, linewidth=2, label="Gradient")
ax.set_title(act_fn.name)
ax.legend()
ax.set_ylim(-1.5, x.max())
# Add activation functions if wanted
act_fns = [act_fn() for act_fn in act_fn_by_name.values()]
names = list(act_fn_by_name.keys()) ### NEW
x = np.linspace(-5, 5, 1000) # Range on which we want to visualize the activation functions
## Plotting
rows = math.ceil(len(act_fns)/2.0)
fig, ax = plt.subplots(rows, 2, figsize=(8, rows*4))
for i, act_fn in enumerate(act_fns):
vis_act_fn(act_fn, ax[divmod(i,2)], x, names[i]). ### NEW
fig.subplots_adjust(hspace=0.3)
plt.show()
Hi, thanks for the suggestion, I indeed translated the PyTorch code too literally there. :) Changed to better JAX style in a30e19dd1c8c94e61eb93258380a13826bcbda08 and fixed the plot naming.
https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial3/Activation_Functions.html
I suggest you use this more jaxonic way of getting per-example gradients :)