abauville / blog

https://abauville.github.io/blog/
Apache License 2.0
1 stars 0 forks source link

Convolutional neural networks for solving PDEs | Arthur Bauville #2

Open utterances-bot opened 2 weeks ago

utterances-bot commented 2 weeks ago

Convolutional neural networks for solving PDEs | Arthur Bauville

Writing an iterative multigrid PDE solver as a convolutional neural network

https://abauville.github.io/blog/neural%20networks/2021/06/17/CNN-solver.html

ycebear commented 2 weeks ago

Thank you for this inspiring post which builds a bridge between PDE-methods and CNN-methods! Up to now I have run the single grid code and it works :-) ! But still I have some questions:

  1. Technical questions 1.1 I can see and understand the forward() function. But I could not find any call of it, something as net.forward(). Where does this happen? 1.2 Where and how are the targets defined? I cannot recognize something as (T_guess - T_target). Why is the loss_fn(y_hat) defined as torch.mean((y_hat)**2) and why is the expected outcome = 0 ? 1.3 Why do you use a 4-dimensional tensor for a 2-dimensional problem? (The first two dimensions do not seem to be used).
  2. General remarks 2.1 The DL (deep learning) method for solving the PDE is slow compared to the iterative solvers. What even can be the advantage to use it? 2.2 Did you also explore the other direction: using the fast iterative solvers in order to speed up the DL solution process?

Thank you in advance for your answer! Aicebear

abauville commented 1 week ago

Hello Aicebear,

I'm glad you found my blogpost useful. To answer your questions:

1.1 net.forward() is called when net() is called. That is a convention from PyTorch 1.2. We minimize the L2 norm of the right hand size of the equation. You can choose another norm, the important thing is to reduce the right hand size vector to a scalar value. Because I am solving the steady state equation, the time derivative is 0. 1.3. PyTorch's conv2d function expects a 4d tensor (cf doc https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html)

2.1. This is an iterative solver. Often iterative solvers will use a numerical approximation of the derivatives, which is expensive and has limited accuracy. Here, the solver uses backpropagation to get the derivatives. So it should be quite fast and converge well. If you are concerned with speed, for relatively small problems (on the order of 1M degrees of freedom, i.e. 1000x1000 grid) the fastest method would be a direct solver. One limiting factor is memory consumption. 2.2. PyTorch uses optimized and efficient iterative solvers for DL. Often, there are additional problems that limit convergence, like local minima.

I hope I could give you useful answers. Best,

Arthur

On Fri, Aug 30, 2024 at 2:09 AM ycebear @.***> wrote:

Thank you for this inspiring post which builds a bridge between PDE-methods and CNN-methods! Up to now I have run the single grid code and it works :-) ! But still I have some questions:

  1. Technical questions 1.1 I can see and understand the forward() function. But I could not find any call of it, something as net.forward(). Where does this happen? 1.2 Where and how are the targets defined? I cannot recognize something as (T_guess - T_target). Why is the loss_fn(y_hat) defined as torch.mean((y_hat)**2) and why is the expected outcome = 0 ? 1.3 Why do you use a 4-dimensional tensor for a 2-dimensional problem? (The first two dimensions do not seem to be used).
  2. General remarks 2.1 The DL (deep learning) method for solving the PDE is slow compared to the iterative solvers. What even can be the advantage to use it? 2.2 Did you also explore the other direction: using the fast iterative solvers in order to speed up the DL solution process?

Thank you in advance for your answer! Aicebear

— Reply to this email directly, view it on GitHub https://github.com/abauville/blog/issues/2#issuecomment-2318402893, or unsubscribe https://github.com/notifications/unsubscribe-auth/AJQAKZRXQ55H2LLXRI4EWNDZT5IT7AVCNFSM6AAAAABNK3JOU6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDGMJYGQYDEOBZGM . You are receiving this because you are subscribed to this thread.Message ID: @.***>