mathLab / PINA

Physics-Informed Neural networks for Advanced modeling
https://mathlab.github.io/PINA/
MIT License
378 stars 65 forks source link

How to evaluate the trained neural network on custom points #244

Closed karthikncsuiisc closed 7 months ago

karthikncsuiisc commented 8 months ago

I wanted to evaluate the trained PINN on custom points.

Is there a way to do it? Could you guide with an example or tutorial?

I tried to follow the issue https://github.com/mathLab/PINA/issues/225?

What is the syntax to pts below? I have two input variables.

import matplotlib.pyplot as plt

evaluation points

pts = ...

true solution

nn_ = solver.neuralnet(pts).detach() true = compute_true(pts) # here you compute your true solution

plot

plt.subplots(1, 3, 1) # true plt.plot(pts, true) plt.subplots(1, 3, 2) # neural net plt.plot(pts, nn) plt.subplots(1, 3, 3) # absolute difference plt.plot(pts, (true-nn).abs()) plt.show()

karthikncsuiisc commented 8 months ago

I am aware to extract the sol using below code

v = [var for var in pinn.problem.input_variables] pts = pinn.problem.domain.sample(256, 'grid', variables=v) usol= pinn.neural_net(pts).detach() inputs=pts.extract(v)

Is there a way to input custom inputs? Instead of sample? I want to evaluate on the grid of 10x100 for example?

dario-coscia commented 8 months ago

Hello @karthikncsuiisc πŸ‘‹πŸ» There are no functionalities at the moment in the Plotter class that support loading a solution and plotting it. If the solution is analytical or computable you can pass it in the class problem as done in here.

Anyway, in your specific case for plotting you can use this snippet of code:

# some imports and definitions 
import matplotlib.pyplot as plt
import torch
from pina import LabelTensor
from pina.geometry import CartesianDomain

# defining a fictitious domain (x, y)
domain = CartesianDomain({'x' : [0, 1], 't': [0, 1]})
# evaluation points
pts = domain.sample(20, mode='grid', variables=['x', 't'])
# NN solutions (below you should call solver.neural_net(pts).detach()  )
sol_ = torch.rand((pts.shape[0], 1)) # random numbers just matching the dimensionality

# plot
plt.subplot(1, 2, 1) # nn
plt.tricontourf(pts.extract('x').flatten(),
                pts.extract('t').flatten(),
                sol_.flatten()
                )
plt.title('NN solution')

# here I simulate the given solution with another discretization
pts = torch.cartesian_prod(torch.linspace(0, 1, 10), torch.linspace(0, 1, 10))
plt.subplot(1, 2, 2) # true
plt.tricontourf(pts[:, 0],
                pts[:, 1],
                torch.rand(pts.shape[0])
                )
plt.title('Real solution')
plt.show()

It should be really easy to adapt! Let me know how it goesπŸ˜ƒ

dario-coscia commented 8 months ago

similar to issue #225

dario-coscia commented 7 months ago

Hello @karthikncsuiisc πŸ‘‹πŸ» There are no functionalities at the moment in the Plotter class that support loading a solution and plotting it. If the solution is analytical or computable you can pass it in the class problem as done in here.

Anyway, in your specific case for plotting you can use this snippet of code:

# some imports and definitions 
import matplotlib.pyplot as plt
import torch
from pina import LabelTensor
from pina.geometry import CartesianDomain

# defining a fictitious domain (x, y)
domain = CartesianDomain({'x' : [0, 1], 't': [0, 1]})
# evaluation points
pts = domain.sample(20, mode='grid', variables=['x', 't'])
# NN solutions (below you should call solver.neural_net(pts).detach()  )
sol_ = torch.rand((pts.shape[0], 1)) # random numbers just matching the dimensionality

# plot
plt.subplot(1, 2, 1) # nn
plt.tricontourf(pts.extract('x').flatten(),
                pts.extract('t').flatten(),
                sol_.flatten()
                )
plt.title('NN solution')

# here I simulate the given solution with another discretization
pts = torch.cartesian_prod(torch.linspace(0, 1, 10), torch.linspace(0, 1, 10))
plt.subplot(1, 2, 2) # true
plt.tricontourf(pts[:, 0],
                pts[:, 1],
                torch.rand(pts.shape[0])
                )
plt.title('Real solution')
plt.show()

It should be really easy to adapt! Let me know how it goesπŸ˜ƒ

πŸ‘‹πŸ» @karthikncsuiisc were you able to plot?

karthikncsuiisc commented 7 months ago

Hi @dario-coscia ,

I am able to plot. Thank you for the help