PlasmaControl / DESC

Stellarator Equilibrium and Optimization Suite
MIT License
94 stars 27 forks source link

Using float32 or bfloat(mixed precision) instead of float64 #1033

Open rahulgaur104 opened 4 months ago

rahulgaur104 commented 4 months ago

I have heard about people using mixed-precision brain float instead of regular floating point precision for low-level code to save memory. It doesn't usually work for non-linear, time-dependent solvers like GX because the error accumulates over time. But since DESC is solving a steady-state problem, the error accumulation problem may not be so bad.

It would be interesting to see if bfloat can help us save memory in DESC. (https://www.cerebras.net/machine-learning/to-bfloat-or-not-to-bfloat-that-is-the-question/)

dpanici commented 4 months ago

try float32 first (more accurate)

rahulgaur104 commented 4 months ago

So I did some preliminary memory profiling tests with float32 and after increasing the atol in desc/objectives/utils.py to atol=4e-4 from 1e-6, with the attached script I get the following total memory consumption.

#!/usr/bin/env python3

from desc import set_device
set_device("gpu")
import numpy as np

from desc.equilibrium import Equilibrium
from desc.geometry import FourierRZToroidalSurface
from desc.objectives import (
    ObjectiveFunction,
    ForceBalance,
    get_fixed_boundary_constraints,
)
from desc.optimize import Optimizer
from desc.plotting import plot_1d, plot_section, plot_surfaces
from desc.profiles import PowerSeriesProfile
import sys
import nvtx

with nvtx.annotate("surface 2D"):
    surface_2D = FourierRZToroidalSurface(
        R_lmn=np.array([10, -1]),  # boundary coefficients                                                                                                                          
        Z_lmn=np.array([1]),
        modes_R=np.array([[0, 0], [1, 0]]),  # [M, N] boundary Fourier modes                                                                                                        
        modes_Z=np.array([[-1, 0]]),
        NFP=5,  # number of (toroidal) field periods (does not matter for 2D, but will for 3D solution)                                                                             
    )
# axisymmetric & stellarator symmetric equilibrium                                                                                                                                  
with nvtx.annotate("Equilibrium"):
    eq = Equilibrium(surface=surface_2D, sym=True)
    eq.change_resolution(L=int(sys.argv[1]),
                         M=int(sys.argv[1]),
                         N=int(sys.argv[1]),
                         L_grid=int(2*sys.argv[1]),
                         M_grid=int(2*sys.argv[1]),
                         N_grid=int(2*sys.argv[1]),)
with nvtx.annotate("ObjectN=int(sys.argv[1])ives"):
    objective = ObjectiveFunction(ForceBalance(eq=eq))
with nvtx.annotate("Constraints"):
    constraints = get_fixed_boundary_constraints(eq=eq)
with nvtx.annotate("optimizer"):
    optimizer = Optimizer("lsq-exact")
with nvtx.annotate("solve"):
    eq, solver_outputs = eq.solve(
        objective=objective, constraints=constraints, optimizer=optimizer, verbose=3
    )

Here are the memory plots for L = M = N = 15.

with float64 test0_15

with float32 test0_15_lowp

For the low-memory case, the xtol condition is satisfied early so there are fewer spikes. If we can get away with reduced precision, and a "reasonable" accuracy, it would be great.

dpanici commented 4 months ago

which atol do you refer to?

rahulgaur104 commented 4 months ago

Line 177 of desc/objectives/utils.py