eelregit / pmwd

Differentiable Cosmological Forward Model
BSD 3-Clause "New" or "Revised" License
70 stars 19 forks source link

Refactor Configuration and API change #22

Closed eelregit closed 1 year ago

eelregit commented 1 year ago

The units, constants, dtype settings, transfer settings, growth settings, and linear variance settings in Configuration couple more tightly with Cosmology, and should be moved there. After that boltzmann.py can also be independent of Configuration, and functions like jax-cosmo together with cosmology.py.

Likewise, I also try to make Particles more independent of conf. And have some different dtype design for Cosmology and Particles.

To accommodate new features, e.g., sCOLA-like parallelization and observables, the plan is to change the old API

conf = Configuration(...)
cosmo = Cosmology(conf, ...)
modes = white_noise(seed, conf)

def model(modes, ptcl, obsvbl, cosmo, solver):
    cosmo = boltzmann(cosmo, conf)
    modes = linear_modes(modes, cosmo, conf)
    ptcl, obsvbl = lpt(modes, cosmo, conf)
    ptcl, obsvbl = nbody(ptcl, obsvbl, cosmo, conf)
    return obj(obsvbl)

to

def config(...):
    cosmo = Cosmology(...)

    solver = Solver(...)  # the old Configuration, hierarchical, sub-solvers with boundary conditions (BCs)
    modes = white_noise(seed, solver)
    ptcl = pre_init_cond(...)  # or None, to be generated from solver, as additional input to lpt
    obsvbl = Snapshots(...), Lightcone(...)  # observable settings, as additional input to lpt

    return modes, ptcl, obsvbl, cosmo, solver

def model(modes, ptcl, obsvbl, cosmo, solver):
    cosmo = boltz(cosmo)  # renamed from boltzmann to avoid collision with the module
    #traverse solver (ptcl and modes too?) pytree (depth/breadth first? flatten is former)
    return obj(traverse(solve(modes, ptcl, obsvbl, cosmo, solver)))

def solve(modes, ptcl, obsvbl, cosmo, solver):
    modes = linear_modes(None, modes, cosmo, solver)  # a=None dependence is moved to front and made explicit
    ptcl, obsvbl, solver = lpt(modes, ptcl, obsvbl, cosmo, solver)  # compute BCs for all sub-solvers
    ptcl, obsvbl, solver = nbody(ptcl, obsvbl, cosmo, solver)  # lpt and nbody also calls obj_imdt -> obsvbl
    obsvbl, solver = obj(obsvbl, solver)  # compute current level obj -> obsvbl
    return obsvbl, solver

Discussions and comments are welcome.

eelregit commented 1 year ago

This draft PR was accidentally merged and closed. It cannot be reopened so I created a new one: https://github.com/eelregit/pmwd/pull/25.