Closed atsuki-kuwata closed 2 years ago
I measured the computation time for each process and found the second ProjectedGradient (for Ak) is slowest.
preparation for X, 0.013753414154052734
PG for X, 0.3102376461029053
preparation for A, 0.14723467826843262
PG for A, 1.024763584136963
unit=sec
We now use a full iteration of ProjectedGradient for each loop.
pg = jaxopt.ProjectedGradient(fun=QP_obj_ak, projection=jaxopt.projection.projection_non_negative,maxiter=100)
How about just using a single or several steps of ProjectedGradient for each loop? like
pg = jaxopt.ProjectedGradient(fun=QP_obj_ak, projection=jaxopt.projection.projection_non_negative)
...
# one time update of PG
state = pg.init_state(params)
params,state=pg.update(params,state)
Or can we use BlockCoordinateDescent instead?
Thank you. I understood.
I think ProjectedGradient
for Ak is slow because len(ak)=3072
while len(xk)=10
,
but it is better to use pg.update
(one iteration for each loops)
I do not understand how to use BlockCoordinateDescent
, so I am going to read the documentation.
Now, I am going to add a commit, which is updated for pg.update
and functions in loops.
I added functions opt_ref
and opt_map
, and modified by using pg.update
.
Next, I am going to
(1) modify modules of making multiband light curve, like "develop" branch.
(2) use jax.lax.scan
in addition to for…
(3) use jaxopt.BoxOSQP
or jaxopt.BlockCoordinateDescent
if better.
@atsuki-kuwata
I was able to run
python soul2vr.py
but got the following results:
Is this same results as what you intend?
It is not my intention. Let me check again. I am afraid it is my lack of confirmation.
I found the value of the objective functions ~10**6.
e.g. QP_obj_xk(x_0)
=6938592.5
I think the all-nan result is due to overflow. I modified
QP_obj_xk -> QP_obj_xk / A.shape[0]
QP_obj_ak -> QP_obj_ak / A.shape[0]
And got an appropriate result. Please confirm. @HajimeKawahara @2ndmk2
@atsuki-kuwata
Thanks, I checked
python soul2vr.py
but got out of GPU memory
RuntimeError: INTERNAL: Failed to load in-memory CUBIN: CUDA_ERROR_OUT_OF_MEMORY: out of memory
even using A100 40GB.
Can you check it?
I fixed soul2vr.py for resolving MemoryError
.
I moved QP_obj_xk()
and QP_obj_ak()
out of opt_ref()
and opt_map()
, respectively.
Please confirm. @HajimeKawahara @2ndmk2
Thanks! I confirmed it worked.
This PR is a simple implementation from
sot.nmfmap.runnmf_cpu
. I usejaxopt.ProjectedGradient
for optimization.Next, I am going to (1) modify modules and functions, same as "develop" branch. (2) use
jax.lax.scan
in addition tofor…
(3) usejaxopt.BoxOSQP
if better.