Open rahulgaur104 opened 6 months ago
try float32 first (more accurate)
So I did some preliminary memory profiling tests with float32 and after increasing the atol in desc/objectives/utils.py to atol=4e-4 from 1e-6, with the attached script I get the following total memory consumption.
#!/usr/bin/env python3
from desc import set_device
set_device("gpu")
import numpy as np
from desc.equilibrium import Equilibrium
from desc.geometry import FourierRZToroidalSurface
from desc.objectives import (
ObjectiveFunction,
ForceBalance,
get_fixed_boundary_constraints,
)
from desc.optimize import Optimizer
from desc.plotting import plot_1d, plot_section, plot_surfaces
from desc.profiles import PowerSeriesProfile
import sys
import nvtx
with nvtx.annotate("surface 2D"):
surface_2D = FourierRZToroidalSurface(
R_lmn=np.array([10, -1]), # boundary coefficients
Z_lmn=np.array([1]),
modes_R=np.array([[0, 0], [1, 0]]), # [M, N] boundary Fourier modes
modes_Z=np.array([[-1, 0]]),
NFP=5, # number of (toroidal) field periods (does not matter for 2D, but will for 3D solution)
)
# axisymmetric & stellarator symmetric equilibrium
with nvtx.annotate("Equilibrium"):
eq = Equilibrium(surface=surface_2D, sym=True)
eq.change_resolution(L=int(sys.argv[1]),
M=int(sys.argv[1]),
N=int(sys.argv[1]),
L_grid=int(2*sys.argv[1]),
M_grid=int(2*sys.argv[1]),
N_grid=int(2*sys.argv[1]),)
with nvtx.annotate("ObjectN=int(sys.argv[1])ives"):
objective = ObjectiveFunction(ForceBalance(eq=eq))
with nvtx.annotate("Constraints"):
constraints = get_fixed_boundary_constraints(eq=eq)
with nvtx.annotate("optimizer"):
optimizer = Optimizer("lsq-exact")
with nvtx.annotate("solve"):
eq, solver_outputs = eq.solve(
objective=objective, constraints=constraints, optimizer=optimizer, verbose=3
)
Here are the memory plots for L = M = N = 15.
with float64
with float32
For the low-memory case, the xtol condition is satisfied early so there are fewer spikes. If we can get away with reduced precision, and a "reasonable" accuracy, it would be great.
which atol
do you refer to?
Line 177 of desc/objectives/utils.py
I have heard about people using mixed-precision brain float instead of regular floating point precision for low-level code to save memory. It doesn't usually work for non-linear, time-dependent solvers like GX because the error accumulates over time. But since DESC is solving a steady-state problem, the error accumulation problem may not be so bad.
It would be interesting to see if bfloat can help us save memory in DESC. (https://www.cerebras.net/machine-learning/to-bfloat-or-not-to-bfloat-that-is-the-question/)