ar4 / deepwave

Wave propagation modules for PyTorch.
MIT License
212 stars 47 forks source link

No change in elastic FWI inversion parameter Vp Vs #85

Open Lujiang-ECUST opened 4 hours ago

Lujiang-ECUST commented 4 hours ago

Hello, thank you very much for posting your code for free! I'm trying to perform elastic FWI on my model, when I invert it with reference to the example, it seems that the LBFGS optimiser only updates the value of rho, and the values of vp and vs are always the same as the initial values no matter how many iterations I do, I'd like to know what's going wrong and how can I improve it! This is the result after performing elastic fwi example_elastic7000000 04

Lujiang-ECUST commented 4 hours ago

I'll attach my part of the code, the loading part of the speed model is omitted

ny = 100
nx = 100
dx = 0.5e-4
vmin=5800
vmax=6300
Model = md.ModelGenerator('emnist')
model = Model()
vp_background = torch.ones(nx,ny, device=device) *5800.0
vs_background = torch.ones(nx,ny, device=device)*3348.6315918
rho_background = torch.ones(nx,ny, device=device)*2.70532036

vp_true = torch.tensor(model['vp'], device=device,dtype=torch.float32)
vs_true = torch.tensor(model['vs'], device=device,dtype=torch.float32)
rho_true = torch.tensor(model['rho'], device=device,dtype=torch.float32)
n_shots = 64

n_sources_per_shot = 1
d_source = 1
first_source = 14
source_depth = 0

n_receivers_per_shot = 64
d_receiver = 1
first_receiver = 14
receiver_depth = 0

freq = 5e6
nt = 600
dt = 0.4e-8
peak_time = 1.5 / freq

# source_locations
source_locations = torch.zeros(n_shots, n_sources_per_shot, 2,device=device)
source_locations[..., 0] = source_depth
source_locations[:, 0, 1] = (torch.arange(n_shots) * d_source +
                             first_source)

# receiver_locations
receiver_locations = torch.zeros(n_shots, n_receivers_per_shot, 2, device=device)
receiver_locations[..., 0] = receiver_depth
receiver_locations[:, :, 1] = (
    (torch.arange(n_receivers_per_shot) * d_receiver +
     first_receiver)
    .repeat(n_shots, 1)
)

# source_amplitudes
source_amplitudes = (
    (deepwave.wavelets.ricker(freq, nt, dt, peak_time))
    .repeat(n_shots, n_sources_per_shot, 1).to(device)
)

# Create observed data using true models
observed_data = elastic(
    *deepwave.common.vpvsrho_to_lambmubuoyancy(vp_true, vs_true,
                                               rho_true),
    dx, dt,
    accuracy=4,
    source_amplitudes_y=source_amplitudes,
    source_locations_y=source_locations,
    receiver_locations_y=receiver_locations,
    pml_freq=freq,
)
vp = vp_background.clone().requires_grad_()
vs = vs_background.clone().requires_grad_()
rho = rho_background.clone().requires_grad_()
# optimiser = torch.optim.LBFGS([vp, vs, rho])
loss_fn = torch.nn.MSELoss()
def taper(x):
    return deepwave.common.cosine_taper_end(x, 100)
# Run optimisation/inversion
n_epochs = 10
observed_data = taper(observed_data)
for cutoff_freq in [4e6,5e6,6e6,7e6]:
    sos = butter(6, cutoff_freq, fs=1/dt, output='sos')
    sos = [torch.tensor(sosi).to(observed_data.dtype).to(device)
           for sosi in sos]

    def filt(x):
        return biquad(biquad(biquad(x, *sos[0]), *sos[1]), *sos[2])
    observed_data_filt = filt(observed_data)
    # optimiser = torch.optim.LBFGS(model.parameters(),
    #                               line_search_fn='strong_wolfe')
    optimiser = torch.optim.LBFGS([vp,vs,rho],line_search_fn='strong_wolfe')
    for epoch in range(n_epochs):
        def closure():
            optimiser.zero_grad()
            out = elastic(
                *deepwave.common.vpvsrho_to_lambmubuoyancy(vp, vs, rho),
                dx, dt,
                source_amplitudes_y=source_amplitudes,
                source_locations_y=source_locations,
                receiver_locations_y=receiver_locations,
                pml_freq=freq,
            )[-2]
            # out = deepwave.scalar(
            #     vp, dx, dt,
            #     source_amplitudes=source_amplitudes,
            #     source_locations=source_locations,
            #     receiver_locations=receiver_locations,
            #     max_vel=vp_true.max().max(),
            #     pml_freq=freq,)
            out_filt = filt(taper(out))
            loss = 1e30 * loss_fn(out_filt, observed_data_filt)
            loss.backward()
            return loss

        loss = closure()
        print("Iter %d, Loss: %.4e" % \
              (epoch, loss.item()))
        optimiser.step(closure)
        vpmin = vp_true.min()
        vpmax = vp_true.max()
        vsmin = vs_true.min()
        vsmax = vs_true.max()
        rhomin = rho_true.min()
        rhomax = rho_true.max()
        _, ax = plt.subplots(2, 3, figsize=(10.5, 5.25), sharex=True,
                             sharey=True)
        ax[0, 0].imshow(vp_true.cpu(), aspect='auto', cmap='coolwarm',
                        vmin=vpmin, vmax=vpmax)
        ax[0, 0].set_title("True vp")
        ax[0, 1].imshow(vs_true.cpu(), aspect='auto', cmap='coolwarm',
                        vmin=vsmin, vmax=vsmax)
        ax[0, 1].set_title("True vs")
        ax[0, 2].imshow(rho_true.cpu(), aspect='auto', cmap='coolwarm',
                        vmin=rhomin, vmax=rhomax)
        ax[0, 2].set_title("True rho")
        ax[1, 0].imshow(vp.detach().cpu(), aspect='auto', cmap='coolwarm',
                        )
        ax[1, 0].set_title("Out vp")
        ax[1, 1].imshow(vs.detach().cpu(), aspect='auto', cmap='coolwarm',
                        )
        ax[1, 1].set_title("Out vs")
        ax[1, 2].imshow(rho.detach().cpu(), aspect='auto', cmap='coolwarm',
                        )
        ax[1, 2].set_title("Out rho")
        plt.tight_layout()
        plt.savefig(f'example_elastic{cutoff_freq}{epoch}.jpg')