jonescompneurolab / hnn-core

Simulation and optimization of neural circuits for MEG/EEG source estimates
https://jonescompneurolab.github.io/hnn-core/
BSD 3-Clause "New" or "Revised" License
55 stars 52 forks source link

networkplotter on jupyter-notebook #872

Open jasmainak opened 2 months ago

jasmainak commented 2 months ago

I'm trying to add the NetworkPlotter to the jupyter notebook using the following code:

from hnn_core.viz import NetworkPlotter

net = jones_2009_model(mesh_shape=(3, 3))

# Note that we move the cells further apart to allow better visualization of
# the network (default inplane_distance=1.0 µm).
net.set_cell_positions(inplane_distance=300)

add_erp_drives_to_jones_model(net)
dpl = simulate_dipole(net, tstop=170, record_vsec='all')

net_plot = NetworkPlotter(net)
net_plot.export_movie('animation_demo.gif', dpi=100, fps=30, interval=100)

from IPython.display import Image
Image(url='animation_demo.gif')

but it's showing two images. Maybe we should have net_plot.show() that explicitly brings up the plot?

image

@ntolley ... hoping to get this in for the workshop

jasmainak commented 2 months ago

I tried also with:

%matplotlib ipympl

But, this seems to hammer the kernel of the jupyter-notebook ... any chance we could make it so that it runs just once rather than in a loop?

jasmainak commented 2 months ago

also are you doing the simulation on a 3x3 net because it's too slow otherwise? Could we sample fewer time points? it's very high freq anyway and we don't smooth ...

I'm hoping to show it as something that can be actually used for actual insights than being just a cool gif

ntolley commented 2 months ago

@jasmainak here's a notebook I've used in the past to generate multipanel animations: https://github.com/ntolley/hnn_latent/blob/dev/notebooks/animate_beta.ipynb

it won't work off the shelf but I'll comment below with the relevant lines for setting up the plotting

ntolley commented 2 months ago
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from hnn_core import calcium_model, simulate_dipole
from hnn_core.viz import NetworkPlotter
from neurodsp.timefrequency.wavelets import compute_wavelet_transform

net = calcium_model()

# drives you want to add

dpl = simulate_dipole(net, tstop=1000, dt=dt, record_vsec='all')

# Make plotting frame
times = net.cell_response.times
plot_mask = times > 0.2
times = times[plot_mask]

n_times = times.shape[0]
sig = dpl[0].data['agg'][plot_mask]

fig = plt.figure(figsize=(14,8))
ax1 = plt.subplot2grid((3, 3), (0, 0), rowspan=1, colspan=1)
ax2 = plt.subplot2grid((3, 3), (1, 0), rowspan=1, colspan=1)
ax3 = plt.subplot2grid((3, 3), (2, 0), rowspan=1, colspan=1)
ax4 = plt.subplot2grid((3, 3), (0, 1), rowspan=3, colspan=2, projection='3d')
ax4.set_facecolor('k')

# Plot dipole
ax1.plot(times, sig)
ax1.set_ylabel('Dipole (nAm)', fontsize=15)
line1 = ax1.axvline(0.2, color='k', linestyle='--')
ax1.set_xlim(times[0], times[-1])
ax1.legend(loc='upper right')
ax1.set_xticks([])

# Plot spectrogram
plot_freqs = np.arange(10,50,1)
mwt = compute_wavelet_transform(sig, fs, plot_freqs, n_cycles=5)
plot_power = np.abs(mwt)
ax2.pcolormesh(times, plot_freqs, plot_power)
ax2.axhline(13, color='w', linestyle='--')
ax2.axhline(30, color='w', linestyle='--')
ax2.set_ylabel('Frequency (Hz)', fontsize=15)

line2 = ax2.axvline(0, linewidth=2, color='k')
ax2.set_xlim(times[0], times[-1])
ax2.legend(loc='upper right')
ax2.set_xticks([])

# Plot spikes
net.cell_response.plot_spikes_raster(ax=ax3, show=False)
line3 = ax3.axvline(0, linewidth=2, color='w')
ax3.get_legend().remove()
ax3.set_xlim(times[0], times[-1])

# Add network
net.set_cell_positions(inplane_distance=300.)
net_plot = NetworkPlotter(net, ax=ax4)

def update_frame(time_idx):
    line1.set_xdata([times[time_idx]])
    line2.set_xdata([times[time_idx]])
    line3.set_xdata([times[time_idx]])

    net_plot.time_idx = time_idx
    n_rotations = 0.4
    rot_pos = -100 + ((n_rotations / n_times) * time_idx * 360)
    net_plot.azim = rot_pos

frame_start = 0
frame_stop = len(times) - 1
decim = 10
interval = 30
fps = 30
dpi = 50
writer = 'pillow'

frames = np.arange(frame_start, frame_stop, decim)
ani = animation.FuncAnimation(
    fig, update_frame, frames, interval=interval)

writer = animation.writers[writer](fps=fps)
ani.save('mid_beta.gif', writer=writer, dpi=dpi)
ntolley commented 2 months ago

I'll test this code locally when I get a chance, this was made with an old version of NetworkPlotter so there's a chance I missed something

jasmainak commented 2 months ago

Oh boy, this is too much code for a tutorial. But thanks for sharing. A fix for plt.show might be nice though!