PredictiveIntelligenceLab / jaxpi

Other
231 stars 52 forks source link

About formula and figure in PiraNet paper #13

Closed HydrogenSulfate closed 6 months ago

HydrogenSulfate commented 6 months ago

In PiraNet paper, the input coordinates first embedded by Coordinate Embedding module, which represented by $\Phi(\mathbf{x})$ in paper, but in PiraNet Block $x$ is still used as input rather than $\Phi(\mathbf{x})$.

And I noticed that $x^{(l)}$ represent the input of l-th PiraNet block, so does $x^{(0)}$ represent the embedded coordinates $\Phi(\mathbf{x})$?

image

image

image

I have try piranet in paddle framework and it seems works well and get good accuracy in allen_cahn(l2_err= $8e-6$ , slightly better than reported in paper $2.24e-5$). That's a great work.

sifanexisted commented 6 months ago

Thank you for bringing this issue to our attention. We apologize for the error and will correct it in both the arXiv and published versions.

It's great to hear that you're achieving even better results with Paddle! Although I didn't fine-tune the hyperparameters too much, I understand how challenging it is to get below 1e-5. It is quite impressive!

HydrogenSulfate commented 6 months ago

Thank you for bringing this issue to our attention. We apologize for the error and will correct it in both the arXiv and published versions.

It's great to hear that you're achieving even better results with Paddle! Although I didn't fine-tune the hyperparameters too much, I understand how challenging it is to get below 1e-5. It is quite impressive!

Thanks for reply, I just use sota config of allen cahn with two changes below:

  1. replacing MLP with 3-block PirateNet(same as it in paper)
  2. use grad norm weighting(same settings as default config) instead of NTK weighting, as the NTK matrix needs [batch, len(param)] matrix, too large to compute as paddle do not support vmap yet.

PirateNet seems has much quicker convergence speed than vanilla MLPs(only 10 epoch it can get 1% l2err), I will try more examples with it!