ins-amu / vbjax

A nascent Jax-based package for virtual brain modeling.
Apache License 2.0
7 stars 2 forks source link

Kuramoto model #64

Closed Ziaeemehr closed 7 months ago

Ziaeemehr commented 8 months ago

This is Kuramoto model on a complete network, I am passing sigma (noise amplitude) and G (global coupling strength) as input and producing time series. I am getting strange behavior by changing sigma, seems it is changing the frequency! Where am I wrong?

KMTheta = collections.namedtuple(typename="KMTheta", field_names="G omega".split(" "))
km_default_theta = KMTheta(G=0.05, omega=1.0)
KMState = collections.namedtuple(typename="KMState", field_names="x".split(" "))

def km_dfun(x, c, p: KMTheta):
    "Kuramoto model"
    dx = p.omega + jnp.vdot(p.G, c) # or just  p.G * c
    return dx

def network(x, p):
    weights, node_params = p
    c = jnp.sum(weights * jnp.sin(x - x[:, None]), axis=1)
    dx = km_dfun(x, c, node_params)
    return dx

def get_ts(params, dt=0.1, T=50.0, G=0.0, sigma=0.1):
    '''Run the Kuramoto model'''
    omega, weights, par = params    
    nn = weights.shape[0]
    G = jnp.ones(nn) * G
    _, loop = vb.make_sde(dt, dfun=network, gfun=sigma)
    par = par._replace(G=G, omega=omega)
    nt = int(T / dt)
    zs = vb.rand(nt, nn) * 2 * jnp.pi
    xs = loop(zs[0], zs[1:], (weights, par))
    ts = jnp.linspace(0, nt * dt, len(xs))
    return xs, ts 

nn = 3
weights = nx.to_numpy_array(nx.complete_graph(nn))
dt = 0.1

omega = vb.randn(nn) * 1.0
xs, ts = get_ts((omega, weights, km_default_theta), dt=dt, G=0.1, sigma=0.0)
plt.figure(figsize=(10, 3))
plt.plot(ts, jnp.sin(xs))   

Cheers

maedoc commented 7 months ago

Thanks for the report. You don't have to follow the convention used in the README, so your dfun could just be

def f(x, p):
    return p.omega + p.G*jnp.mean(weights*jnp.sin(x-x[:,None]),axis=1)

using mean instead of sum scales the afferent coupling by 1/N per Kuramoto equations, which is missing in your code.

However, I'm not familiar enough with the Kuramoto model to know what the expected behavior is with noise scaling. My experience using Kuramoto to test solvers is mainly related to coherence of the population as a function of parameters, e.g. https://github.com/SciML/StochasticDelayDiffEq.jl/issues/24

If I had to guess, the effect is correct and a result of noise moving the system past "ghost" attractors more quickly. In any case, we don't need to guess: Kuramoto is implemented in TVB proper, we can perhaps take that as a reference solution and then check here.

maedoc commented 7 months ago

I added a script with the above. Once we have a clear idea what the solution is we can convert to a unit test to ensure the problem does not appear again. It would probably be nice to have the test reported in the Julia repository also.

maedoc commented 7 months ago

just noticed that this example uses vb.rand to generate the noise: this is allowed of course, but then the noise samples are distributed uniformly between 0 and 1, which explains why the frequency increases. when you switch to vb.randn the noise is normally distributed and then frequency does not increase

image

reopen if you have further problems