Closed CloudyDory closed 7 months ago
@CloudyDory Thanks for the question. Actually, this is almost the same as the examples you have linked here.
The following is my example to visualize the computation graph of a LIF neuron model.
import brainpy as bp
import brainpy.math as bm
hh = bp.dyn.LifRef(10)
def run_fun(inputs):
return bm.for_loop(hh.step_run, (np.arange(inputs.shape[0]), inputs))
z = jax.xla_computation(run_fun)(np.random.uniform(2., 6., 10000))
with open("lif.dot", "w") as f:
f.write(z.as_hlo_dot_graph())
Then, call the following command in the terminal:
dot lif.dot -Tpng > lif.png
Hi, is there a way to view the computational graph when training BrainPy models by backpropagation? I have found some tutorials on viewing computational graph for JAX functions (https://bnikolic.co.uk/blog/python/jax/2022/02/22/jax-outputgraph-rev.html), but I am not sure how to do it on BrainPy, for both jitted and un-jitted functions.
Thanks!