HajimeKawahara / jaxsot

JAX implementation of SOT package
http://secondearths.sakura.ne.jp/jaxsot/
MIT License
1 stars 1 forks source link

SOU:L2+VR #8

Closed atsuki-kuwata closed 2 years ago

atsuki-kuwata commented 2 years ago

This PR is a simple implementation from sot.nmfmap.runnmf_cpu. I usejaxopt.ProjectedGradient for optimization.

Next, I am going to (1) modify modules and functions, same as "develop" branch. (2) use jax.lax.scan in addition to for… (3) use jaxopt.BoxOSQP if better.

HajimeKawahara commented 2 years ago

I measured the computation time for each process and found the second ProjectedGradient (for Ak) is slowest.

preparation for X, 0.013753414154052734
PG for X, 0.3102376461029053
preparation for A, 0.14723467826843262
PG for A, 1.024763584136963

unit=sec

We now use a full iteration of ProjectedGradient for each loop.

pg = jaxopt.ProjectedGradient(fun=QP_obj_ak, projection=jaxopt.projection.projection_non_negative,maxiter=100)

How about just using a single or several steps of ProjectedGradient for each loop? like

pg = jaxopt.ProjectedGradient(fun=QP_obj_ak, projection=jaxopt.projection.projection_non_negative)

...
# one time update of PG
state = pg.init_state(params)
params,state=pg.update(params,state)

Or can we use BlockCoordinateDescent instead?

atsuki-kuwata commented 2 years ago

Thank you. I understood.

I think ProjectedGradient for Ak is slow because len(ak)=3072 while len(xk)=10, but it is better to use pg.update (one iteration for each loops)

I do not understand how to use BlockCoordinateDescent, so I am going to read the documentation.

Now, I am going to add a commit, which is updated for pg.update and functions in loops.

atsuki-kuwata commented 2 years ago

I added functions opt_ref and opt_map, and modified by using pg.update.

Next, I am going to
 (1) modify modules of making multiband light curve, like "develop" branch.
 (2) use jax.lax.scan in addition to for…
 (3) use jaxopt.BoxOSQP or jaxopt.BlockCoordinateDescent if better.

HajimeKawahara commented 2 years ago

@atsuki-kuwata

I was able to run python soul2vr.py

but got the following results:

Is this same results as what you intend? Figure_1

atsuki-kuwata commented 2 years ago

It is not my intention. Let me check again. I am afraid it is my lack of confirmation.

atsuki-kuwata commented 2 years ago

I found the value of the objective functions ~10**6. e.g. QP_obj_xk(x_0)=6938592.5

I think the all-nan result is due to overflow. I modified

QP_obj_xk -> QP_obj_xk / A.shape[0]
QP_obj_ak -> QP_obj_ak / A.shape[0]

And got an appropriate result. Please confirm. @HajimeKawahara @2ndmk2

2022-04-11 11 02のイメージ

HajimeKawahara commented 2 years ago

@atsuki-kuwata

Thanks, I checked python soul2vr.py but got out of GPU memory RuntimeError: INTERNAL: Failed to load in-memory CUBIN: CUDA_ERROR_OUT_OF_MEMORY: out of memory even using A100 40GB. Can you check it?

atsuki-kuwata commented 2 years ago

I fixed soul2vr.py for resolving MemoryError. I moved QP_obj_xk() and QP_obj_ak() out of opt_ref() and opt_map(), respectively.

Please confirm. @HajimeKawahara @2ndmk2

HajimeKawahara commented 2 years ago

Thanks! I confirmed it worked.