ar4 / deepwave

Wave propagation modules for PyTorch.
MIT License
192 stars 46 forks source link

Elastic wave gradient calculation #75

Closed lnnnn123 closed 1 month ago

lnnnn123 commented 1 month ago

Hello, Thank you for bothering me again. I recently encountered an elastic wave problem. Here is my code. When I set the coefficient before the loss function to 1e16, it can be automatically updated, but the update speed is not obvious. When I change the coefficient before to 1e17, the gradient update will occur several times, and the following error will occur. I don't know what the reason for this error is, whether it's a problem with my coefficient settings or other reasons. `

ny = 320

nx = 128

dx = 20

n_shots = 1

nt = 6000

dt = 0.002

epoches = 10

for m in range(epoches):

freq = 3+m

for i in range(63):

    vp_true = torch.from_file('Mouslabels.bin',

                                    size=ny * nx).reshape(ny, nx).to(device)

    vs_true = torch.from_file('Mouslabels.bin',

                                    size=ny * nx).reshape(ny, nx).to(device)

    rho_true = torch.from_file('Mouslabels.bin',

                                     size=ny * nx).reshape(ny, nx).to(device)

    vp_true = vp_true.T

    vs_true =vs_true/1.5

    vs_true = vs_true.T

    rho_true =rho_true/1.3

    rho_true = rho_true.T

    if m == 0 and i == 0:

        vp_background = torch.from_file('MOUSsmooth.bin',

                                        size=ny * nx).reshape(ny, nx).to(device)

        vs_background = torch.from_file('MOUSsmooth.bin',

                                        size=ny * nx).reshape(ny, nx).to(device)

        rho_background = torch.from_file('MOUSsmooth.bin',

                                        size=ny * nx).reshape(ny, nx).to(device)

        vp_background= vp_background 

        vs_background =vs_background/1.5

        rho_background=rho_background/1.3

    elif m != 0 and i == 0:

        vp_background = torch.from_file('vp{}.{}.bin'.format(63, freq - 1),

                                        size=ny * nx).reshape(ny, nx).to(device)

        vs_background = torch.from_file('vs{}.{}.bin'.format(63, freq - 1),

                                        size=ny * nx).reshape(ny, nx).to(device)

        rho_background = torch.from_file('rho{}.{}.bin'.format(63, freq - 1),

                                        size=ny * nx).reshape(ny, nx).to(device)

    else:

        vp_background = torch.from_file('vp{}.{}.bin'.format(i, freq),

                                         size=ny * nx).reshape(ny, nx).to(device)

        vs_background = torch.from_file('vs{}.{}.bin'.format(i, freq),

                                         size=ny * nx).reshape(ny, nx).to(device)

        rho_background = torch.from_file('rho{}.{}.bin'.format(i, freq),

                                         size=ny * nx).reshape(ny, nx).to(device)

    n_sources_per_shot = 1

    d_source = 0  

    first_source = 0 + 5*i  

    source_depth = 1  

    print(vp_true.shape)

    n_receivers_per_shot = 319

    d_receiver = 1  

    first_receiver = 0  

    receiver_depth = 0  

    peak_time = 1.5 / freq

    # source_locations

    source_locations = torch.zeros(n_shots, n_sources_per_shot, 2,

                                   dtype=torch.long, device=device)

    source_locations[..., 0] = source_depth

    source_locations[:, 0, 1] = (torch.arange(n_shots) * d_source +

                                 first_source)

    # receiver_locations

    receiver_locations = torch.zeros(n_shots, n_receivers_per_shot, 2,

                                     dtype=torch.long, device=device)

    receiver_locations[..., 0] = receiver_depth

    receiver_locations[:, :, 1] = (

        (torch.arange(n_receivers_per_shot) * d_receiver +

         first_receiver)

        .repeat(n_shots, 1)

    )

    # source_amplitudes

    source_amplitudes = (

        (deepwave.wavelets.ricker(freq, nt, dt, peak_time))

        .repeat(n_shots, n_sources_per_shot, 1).to(device)

    )

    # Create observed data using true models

    observed_data = elastic(

        *deepwave.common.vpvsrho_to_lambmubuoyancy(vp_true, vs_true,

                                                   rho_true),

        dx, dt,

        source_amplitudes_y=source_amplitudes,

        source_locations_y=source_locations,

        receiver_locations_y=receiver_locations,

        pml_freq=freq,

    )[-2]

    observed_data2 =observed_data.detach().cpu().numpy()

    observed_data2.tofile("observed_data2.bin")

    # Setup optimiser to perform inversion

    vp = vp_background.clone().requires_grad_()

    vs = vs_background.clone().requires_grad_()

    rho = rho_background.clone().requires_grad_()

    optimiser = torch.optim.LBFGS([vp, vs, rho])

    loss_fn = torch.nn.MSELoss()

    vp = vp.T

    vs = vs.T

    rho = rho.T

    # Run optimisation/inversion

    n_epochs = 2

    for epoch in range(n_epochs):

        def closure():

            optimiser.zero_grad()

            out = elastic(

                *deepwave.common.vpvsrho_to_lambmubuoyancy(vp, vs, rho),

                dx, dt,

                source_amplitudes_y=source_amplitudes,

                source_locations_y=source_locations,

                receiver_locations_y=receiver_locations,

                pml_freq=freq,

            )[-2]

            loss = 1e16*loss_fn(out, observed_data)

            print(1e1*loss)

            loss.backward()

            return loss

        optimiser.step(closure)

    vp = vp.T

    vs = vs.T

    rho = rho.T

    vp = vp.detach().cpu().numpy()

    vs = vs.detach().cpu().numpy()

    rho = rho.detach().cpu().numpy()

    vp.tofile("vp{}.{}.bin".format(i+1,freq))

    vs.tofile("vs{}.{}.bin".format(i+1,freq))

    rho.tofile("rho{}.{}.bin".format(i+1,freq))

    print("fm===============================",freq)`
lnnnn123 commented 1 month ago

13bd3c5f579526141860abc02334326

ar4 commented 1 month ago

Hello again,

I am glad to hear that Deepwave continues to be useful to you.

The error that you encountered is caused by your parameters containing NaNs after several steps of optimization. The elastic propagator is unfortunately not very stable, so it is quite common for this to occur. I suggest that you make three changes that might help. The first is to apply a little bit of smoothing to your model parameters before they are passed to the propagator. This helps to avoid sharp changes in the model that can cause instability. An example of this can be seen in the Joint Migration-Inversion example. The second is to restrict parameters to be within an expected range, such as by using the Model class in the FWI example. This prevents the parameters from containing unrealistically extreme values that cause instability. The third is to increase the size of the PML region (using the pml_width parameter).

The combination of these changes would look something like this (untested):

class Model(torch.nn.Module):
    def __init__(self, initial, min_val, max_val):
        super().__init__()
        self.min_val = min_val
        self.max_val = max_val
        self.model = torch.nn.Parameter(
            torch.logit((initial - min_val) /
                        (max_val - min_val))
        )

    def forward(self):
        return (torch.sigmoid(self.model) *
                (self.max_val - self.min_val) +
                self.min_val)

# You should ensure that the initial and expected parameter
# values are within these ranges
model_vp = Model(vp, 1400, 5000).to(device)
model_vs = Model(vs, 900, 3000).to(device)
model_rho = Model(rho, 1000, 3000).to(device)

# Using 'strong_wolfe' seems to improve stability, but may not
# work if running on multiple GPUs
optimiser = torch.optim.LBFGS(list(model_vp.parameters()) +
                              list(model_vs.parameters()) +
                              list(model_rho.parameters()),
                              line_search_fn='strong_wolfe')

for epoch in range(n_epochs):
        def closure():
            optimiser.zero_grad()
            vp_smooth = (
                torchvision.transforms.functional.gaussian_blur(
                    model_vp()[None], [5, 5]
                ).squeeze()
            )
            vs_smooth = (
                torchvision.transforms.functional.gaussian_blur(
                    model_vs()[None], [5, 5]
                ).squeeze()
            )
            rho_smooth = (
                torchvision.transforms.functional.gaussian_blur(
                    model_rho()[None], [5, 5]
                ).squeeze()
            )
            out = elastic(
                *deepwave.common.vpvsrho_to_lambmubuoyancy(vp_smooth, vs_smooth, rho_smooth),
                dx, dt,
                source_amplitudes_y=source_amplitudes,
                source_locations_y=source_locations,
                receiver_locations_y=receiver_locations,
                pml_freq=freq,
                pml_width=30
            )[-2]

            loss = 1e17*loss_fn(out, observed_data)
            loss.backward()
            return loss

It may also simply be that using a weight of 1e17 causes the update step size to be too big, which will also lead to instability. A step size that is too large can cause bad updates that may, for example, result in vs becoming larger than vp at some point in your model, in which case the elastic propagator will become unstable and start producing NaNs.

I hope this helps. If you continue to have difficulty then perhaps you might have more success with a different optimizer, such as Adam.

ar4 commented 1 month ago

I should mention that if you use the Model class, then you may need to adjust your scaling factor of 1e17, since the model parameters will now be stored in a different range.

lnnnn123 commented 1 month ago

Thank you for your answer. Due to being too busy lately, I didn't reply in a timely manner. This issue has been resolved. Once again, thank you for your response.

ar4 commented 1 month ago

That's excellent news. Thank you for letting me know. May I ask if there were any particular changes that resolved the problem (if it is not too complicated to explain)?

lnnnn123 commented 1 month ago

The reason for the previous error was that the initial speed was a linear speed, which led to the problem. When I replaced the initial speed with a smooth speed, the problem was resolved.

ar4 commented 1 month ago

Ah, I see. Thank you for the explanation. I will close this issue now, but please feel free to reopen it or to create another if you have further questions.