DedalusProject / dedalus

A flexible framework for solving PDEs with modern spectral methods.
http://dedalus-project.org/
GNU General Public License v3.0
489 stars 115 forks source link

Operator substitutions not recognized in function #252

Closed iamlll closed 1 year ago

iamlll commented 1 year ago

Hello, sorry if this is a very naive question! I didn't manage to find information about this topic elsewhere, so I thought I'd ask here: I've been following along with this tutorial (https://dedalus-project.readthedocs.io/en/latest/notebooks/dedalus_tutorial_3.html) and wanted to define separate functions to save/run the simulation and then plot the results. However, the operator substitutions (i.e. dx) don't seem to be read in properly when I run problem.add_equation(...) inside a function as opposed to a script and I am getting NameErrors that "name 'dx' is not defined." Would you happen to know how I could fix this issue?

The code I'm using is really just a copy-paste of the above tutorial into some function Run() that gets called in the main function: `

def Run():
    # Bases
    xcoord = d3.Coordinate('x')
    dist = d3.Distributor(xcoord, dtype=np.complex128)
    xbasis = d3.Chebyshev(xcoord, 1024, bounds=(0, 300), dealias=2)

    # Fields
    u = dist.Field(name='u', bases=xbasis)
    tau1 = dist.Field(name='tau1')
    tau2 = dist.Field(name='tau2')

    # Problem
    problem = d3.IVP([u, tau1, tau2], namespace=locals())

    # Substitutions
    dx = lambda A: d3.Differentiate(A, xcoord)
    magsq_u = u * np.conj(u)
    b = 0.5
    c = -1.76

    # Tau polynomials
    tau_basis = xbasis.derivative_basis(2) #my guess is that the tau basis should be the highest derivative power in the problem
    p1 = dist.Field(bases=tau_basis)
    p2 = dist.Field(bases=tau_basis)
    p1['c'][-1] = 1
    p2['c'][-2] = 2
    print(p1['c'])
    print(p1['g'])

    # Add main equation, with linear terms on the LHS and nonlinear terms on the RHS
    problem.add_equation(f"dt(u) - u - (1 + 1j*{b})*dx(dx(u)) + tau1*p1 + tau2*p2 = - (1 + 1j*{c}) * magsq_u * u")

    # Add boundary conditions
    problem.add_equation("u(x='left') = 0")
    problem.add_equation("u(x='right') = 0")

    # Build solver
    solver = problem.build_solver(d3.RK222)

    # Stopping criteria
    solver.stop_sim_time = 500

    # Setup initial conditions (sine wave)
    x = dist.local_grid(xbasis)
    u['g'] = 1e-3 * np.sin(5 * np.pi * x / 300)

    # add file handler to be evaluated every 'iter' time steps, split up into sets each containing 'max_writes' number of writes
    analysis = solver.evaluator.add_file_handler('analysis', iter=10, max_writes=400)
    # track avg magnitude of solution
    analysis.add_task(d3.Integrate(np.sqrt(magsq_u),'x')/300, layout='g', name='<|u|>')
    # save all system variables for checkpointing
    analysis.add_tasks(solver.state, layout='g')

    # Main loop
    timestep = 0.05
    while solver.proceed:
        solver.step(timestep)
        if solver.iteration % 1000 == 0:
            print('Completed iteration {}'.format(solver.iteration))

    print(subprocess.check_output("find analysis | sort", shell=True).decode())

if __name__ == "__main__":
    Run()

`

kburns commented 1 year ago

Hi, the issue is just that the locals() namespace is combined with some built-in names when you build the problem object, and that dictionary is used to parse the equations. You just need to put all the substitution definitions before you instantiate the problem object.