mathLab / PINA

Physics-Informed Neural networks for Advanced modeling
MIT License
378 stars 65 forks source link

plot_samples errors for Condition(input_points=..., output_points=...) in pinn #101

Open Bovhasselt opened 1 year ago

Bovhasselt commented 1 year ago

Describe the bug First of this is not an urgent matter since after the error the plot is still shown correctly. The function plotter.plot_samples() results in an error when your problem definition includes a Condition using data.

Problem definition

class Heat2D(TimeDependentProblem, SpatialProblem):

    # Define these yourself
    LENGTH_X = 82
    LENGTH_Y = 70
    DURATION = 2284

    output_variables = ['u']
    spatial_domain = Span({'x': [0, LENGTH_X], 'y': [0, LENGTH_Y]})
    temporal_domain = Span({'t': [0, DURATION]})

    def heat_equation_2D(input_, output_):
        # c is thermal diffusivity, variates for different materials so google
        c = (0.01/torch.pi) ** 0.5

        du = grad(output_, input_)
        ddu = grad(du, input_, components=['dudx','dudy'])
        return (
            du.extract(['dudt']) -
            (c**2)*(ddu.extract(['ddudxdx']) + ddu.extract(['ddudydy']))

    def nil_dirichlet_x(input_, output_):
        '''2 and 3'''
        du = grad(output_, input_)
        u_expected_boundary = 0.0
        return du.extract(['dudx']) - u_expected_boundary

    def nil_dirichlet_yL(input_, output_):
        du = grad(output_, input_)
        u_expected_boundary = 0.0
        return du.extract(['dudy']) - u_expected_boundary

    def nil_dirichlet_y0(input_, output_):
        # TODO: make this conditionally on if door is open
        du = grad(output_, input_)
        u_expected_boundary = 0.0
        return du.extract(['dudy']) - u_expected_boundary

    def initial_condition(input_, output_):
        u_expected_initial = torch.sin(torch.pi*input_.extract(['x']))
        return output_.extract(['u']) - u_expected_initial

    conditions = {
        'boundx0': Condition(location=Span({'x': 0, 'y': [0, LENGTH_Y], 't': [0, DURATION]}), function=nil_dirichlet_x),
        'boundxL': Condition(location=Span({'x': LENGTH_X, 'y': [0, LENGTH_Y], 't': [0, DURATION]}), function=nil_dirichlet_x),
        'boundy0': Condition(location=Span({'x': [0, LENGTH_X], 'y': 0, 't': [0, DURATION]}), function=nil_dirichlet_y0),
        'boundyL': Condition(location=Span({'x': [0, LENGTH_X], 'y': LENGTH_Y, 't': [0, DURATION]}), function=nil_dirichlet_yL),
        'initial': Condition(location=Span({'x': [0, LENGTH_X], 'y': [0, LENGTH_Y], 't': 0}), function=initial_condition),
        'heat_eq': Condition(location=Span({'x': [0, LENGTH_X], 'y': [0, LENGTH_Y], 't': [0, DURATION]}), function=heat_equation_2D),
        'data': Condition(input_points=X_input_tensor , output_points=X_output_tensor),


class myFeature(torch.nn.Module):
    Feature: sin(pi*x)
    def __init__(self, idx):
        super(myFeature, self).__init__()
        self.idx = idx

    def forward(self, x):
        return LabelTensor(torch.sin(torch.pi * x.extract(['x'])), ['sin(x)'])

heat_problem = Heat2D()
model = FeedForward(
    layers=[30, 20, 10, 5],

pinn = PINN(

    {'n': 10, 'mode': 'grid', 'variables': 't'},
    {'n': 10, 'mode': 'grid', 'variables': ['x', 'y']},
pinn.span_pts(20, 'random', locations=['boundx0', 'boundxL', 'initial', 'boundyL', 'boundy0'])
pinn.train(1000, 100)




dict_keys(['heat_eq', 'boundx0', 'boundxL', 'initial', 'boundyL', 'boundy0', 'data'])

and when plotting

# plot samples
plotter = Plotter()

I get

AttributeError                            Traceback (most recent call last)
Cell In[58], line 3
      1 # plot samples
      2 plotter = Plotter()
----> 3 plotter.plot_samples(pinn=pinn)

File ~/opt/anaconda3/lib/python3.8/site-packages/pina/, in Plotter.plot_samples(self, pinn, variables)
     44 ax = fig.add_subplot(projection=proj)
     45 for location in pinn.input_pts:
---> 46     coords = pinn.input_pts[location].extract(variables).T.detach()
     47     if coords.shape[0] == 1:  # 1D samples
     48         ax.plot(coords[0], torch.zeros(coords[0].shape), '.',
     49                 label=location)

AttributeError: 'Condition' object has no attribute 'extract'

followed by a plot showing all the function samples correctly.

Expected behavior I think the plot_samples method should be checking whether the location is an input-output condition before trying to plot it

dario-coscia commented 1 year ago

Hi, thanks for the report! We will try to fix it in #85