from typing import Optional,Union,List
import os
import math
import torch
import numpy as np
from tqdm import tqdm
from ADFWI.model import AbstractModel
from ADFWI.propagator import AcousticPropagator,GradProcessor
from ADFWI.survey import SeismicData
from ADFWI.fwi.misfit import Misfit,Misfit_NIM
from ADFWI.fwi.regularization import Regularization
from ADFWI.utils import numpy2tensor
from ADFWI.view import plot_model
from ADFWI.fwi.multiScaleProcessing import lpass
from ADFWI.utils.first_arrivel_picking import apply_mute
from ADFWI.utils.offset_mute import mute_offset
[docs]class DIP_AcousticFWI(torch.nn.Module):
"""Acoustic Full waveform inversion class.
Parameters
----------
propagator : AcousticPropagator
The propagator used for simulating acoustic wave propagation.
model : AbstractModel
The model class representing the velocity or acoustic property structure.
loss_fn : Union[Misfit, torch.autograd.Function]
The loss function or misfit function used to compute the difference between predicted and observed data.
obs_data : SeismicData
The observed seismic data for comparison against the model predictions.
optimizer : Union[torch.optim.Optimizer, List[torch.optim.Optimizer]], optional
The optimizer used for parameter optimization (e.g., SGD, Adam). Default is None.
scheduler : torch.optim.lr_scheduler, optional
The learning rate scheduler for adjusting the learning rate during training. Default is None.
gradient_processor : Union[GradProcessor, List[GradProcessor]], optional
The gradient processor or list of processors for handling gradients, applied to different parameters if specified. Default is None.
regularization_fn : Optional[Regularization], optional
The regularization function for model parameters (e.g., for smoothing or penalty terms). Default is None.
regularization_weights_x : Optional[List[Union[float]]], optional
Regularization weights for the x direction (e.g., vp/rho regularization). Default is [0, 0].
regularization_weights_z : Optional[List[Union[float]]], optional
Regularization weights for the z direction (e.g., vp/rho regularization). Default is [0, 0].
waveform_normalize : Optional[bool], optional
Whether to normalize the waveform during inversion. Default is True (waveforms are normalized).
waveform_mute_late_window : Optional[float], optional
Late window mute applied to waveform data. Default is None.
waveform_mute_offset : Optional[float], optional
Offset mute applied to waveform data. Default is None.
cache_result : Optional[bool], optional
Whether to cache intermediate inversion results for later use. Default is True.
save_fig_epoch : Optional[int], optional
The interval (in epochs) at which to save the inversion result as a figure. Default is -1 (no figure saved).
save_fig_path : Optional[str], optional
The path where to save the inversion result figure. Default is an empty string (no path specified).
"""
def __init__(self,
propagator:AcousticPropagator,model:AbstractModel,
loss_fn:Union[Misfit,torch.autograd.Function],
obs_data:SeismicData,
optimizer:Union[torch.optim.Optimizer,List[torch.optim.Optimizer]] = None,
scheduler:torch.optim.lr_scheduler = None,
gradient_processor: Union[GradProcessor,List[GradProcessor]] = None,
regularization_fn:Optional[Regularization] = None,
regularization_weights_x:Optional[List[Union[float]]] = [0,0], # vp/rho in x direction
regularization_weights_z:Optional[List[Union[float]]] = [0,0], # vp/rho in z direction
waveform_normalize:Optional[bool] = True,
waveform_mute_late_window:Optional[float] = None,
waveform_mute_offset:Optional[float] = None,
cache_result:Optional[bool] = True,
save_fig_epoch:Optional[int] = -1,
save_fig_path:Optional[str] = "",
):
super().__init__()
self.propagator = propagator
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.loss_fn = loss_fn
self.regularization_fn = regularization_fn
self.regularization_weights_x = regularization_weights_x
self.regularization_weights_z = regularization_weights_z
self.obs_data = obs_data
self.gradient_processor = gradient_processor
self.device = self.propagator.device
self.dtype = self.propagator.dtype
# optimizer
if not isinstance(self.optimizer, list):
self.optimizer = [self.optimizer]
if not isinstance(self.scheduler, list):
self.scheduler = [self.scheduler]
# Real-Case settings: for trace missing, partial data missing
receiver_masks = self.propagator.receiver_masks
if receiver_masks is None:
receiver_masks = np.ones((self.propagator.src_n,self.propagator.rcv_n))
receiver_masks = numpy2tensor(receiver_masks)
self.receiver_masks_2D = receiver_masks # [shot, rcv]
self.receiver_masks_3D = receiver_masks.unsqueeze(1).expand(-1, self.propagator.nt, -1).to(self.device) # [shot, time, rcv]
# Real-Case settings: mute late window (by first arrival picking) & mute offset
self.waveform_normalize = waveform_normalize
self.waveform_mute_late_window = waveform_mute_late_window
self.waveform_mute_offset = waveform_mute_offset
# observed data
obs_p = self.obs_data.data["p"]
obs_p = numpy2tensor(obs_p,self.dtype).to(self.device)
if self.propagator.receiver_masks_obs: # mark the observed data need to be masked or not
obs_p = obs_p*self.receiver_masks_3D
self.data_masks = numpy2tensor(self.obs_data.data_masks).to(self.device) if self.obs_data.data_masks is not None else None
if self.data_masks is not None: # some of the data are unuseful
obs_p = obs_p*self.data_masks
self.obs_p = obs_p
# model boundary
vp_bound = self.model.get_bound("vp")
if vp_bound[0] is None and vp_bound[1] is None:
self.vp_min = self.model.get_model("vp").min() - 500
self.vp_max = self.model.get_model("vp").max() + 500
else:
self.vp_min = vp_bound[0]
self.vp_max = vp_bound[1]
rho_bound = self.model.get_bound("rho")
if rho_bound[0] is None and rho_bound[1] is None:
self.rho_min = self.model.get_model("rho").min() - 500
self.rho_max = self.model.get_model("rho").max() + 500
else:
self.rho_min = rho_bound[0]
self.rho_max = rho_bound[1]
# save result
self.cache_result = cache_result
self.iter_vp, self.iter_rho = [],[]
self.iter_loss = []
# save figure
self.save_fig_epoch = save_fig_epoch
self.save_fig_path = save_fig_path
def _normalize(self,data):
"""Normalizes waveform data by dividing each component by its maximum value."""
mask = torch.sum(torch.abs(data),axis=1,keepdim=True) == 0
max_val = torch.max(torch.abs(data),axis=1,keepdim=True).values
max_val = max_val.masked_fill(mask, 1)
data = data/max_val
return data
# misfits calculation
[docs] def calculate_loss(self, synthetic_waveform, observed_waveform, normalization, loss_fn, cutoff_freq=None, propagator_dt=None,shot_index=None):
"""Calculates the misfit loss between synthetic and observed waveforms.
(1) first arrival picking
(2) mute data by first arrival & giving window
(3) mute data by offset
(4) low-pass filter
(5) normalizing data
Parameters
----------
synthetic_waveform : torch.Tensor
The predicted waveform data.
observed_waveform : torch.Tensor
The observed waveform data.
normalization : bool
Whether to normalize the waveform.
loss_fn : Union[Misfit, Misfit_NIM, torch.autograd.Function]
The loss function to calculate the misfit.
cutoff_freq : Optional[float], optional
The cutoff frequency for the low-pass filter. Default is None.
propagator_dt : Optional[float], optional
The time step used by the propagator. Default is None.
shot_index : Optional[int], optional
The index of the shot for processing. Default is None.
Returns
-------
torch.Tensor
The calculated loss between the synthetic and observed waveforms.
"""
# mute data by offset
if self.waveform_mute_offset is not None:
receiver_mask_2D = self.receiver_masks_2D[shot_index].cpu() # [shot, rcv]
src_x = self.propagator.src_x.cpu()[shot_index]
rcv_x_list = self.propagator.rcv_x.cpu()
rcv_x = torch.zeros(synthetic_waveform.shape[0],synthetic_waveform.shape[-1])
for i in range(synthetic_waveform.shape[0]):
rcv_x[i] = rcv_x_list[np.argwhere(receiver_mask_2D[i]).tolist()].squeeze()
synthetic_waveform = mute_offset(rcv_x,src_x,self.propagator.dx,synthetic_waveform,self.waveform_mute_offset)
observed_waveform = mute_offset(rcv_x,src_x,self.propagator.dx,observed_waveform,self.waveform_mute_offset)
# mute data by first arrival & late window
if self.waveform_mute_late_window is not None:
synthetic_waveform_temp = synthetic_waveform.clone()
observed_waveform_temp = observed_waveform.clone()
for i in range(synthetic_waveform.shape[0]):
synthetic_waveform[i] = apply_mute(self.waveform_mute_late_window, synthetic_waveform_temp[i], self.propagator.dt)
observed_waveform[i] = apply_mute(self.waveform_mute_late_window, observed_waveform_temp[i], self.propagator.dt)
# Apply low-pass filter if cutoff frequency is provided
if cutoff_freq is not None:
synthetic_waveform, observed_waveform = lpass(synthetic_waveform, observed_waveform, cutoff_freq, int(1 / propagator_dt))
if normalization:
synthetic_waveform = self._normalize(synthetic_waveform)
observed_waveform = self._normalize(observed_waveform)
if isinstance(loss_fn, Misfit):
return loss_fn.forward(synthetic_waveform, observed_waveform)
elif isinstance(loss_fn,Misfit_NIM):
return loss_fn.apply(synthetic_waveform,observed_waveform,loss_fn.p,loss_fn.trans_type,loss_fn.theta)
else:
return loss_fn.apply(synthetic_waveform, observed_waveform)
# regularization calculation
[docs] def calculate_regularization_loss(self, model_param, weight_x, weight_z, regularization_fn):
"""Generalized function to calculate regularization loss for a given parameter.
"""
regularization_loss = torch.tensor(0.0, device=model_param.device)
# Check if the parameter requires gradient
if model_param.requires_grad:
# Set the regularization weights for x and z directions
regularization_fn.alphax = weight_x
regularization_fn.alphaz = weight_z
# Calculate regularization loss if any weight is greater than zero
if regularization_fn.alphax > 0 or regularization_fn.alphaz > 0:
regularization_loss = regularization_fn.forward(model_param)
return regularization_loss
[docs] def forward(self,
iteration:int,
batch_size:Optional[int] = None,
checkpoint_segments:Optional[int] = 1 ,
start_iter = 0,
cutoff_freq = None,
):
"""Iteration of full waveform inversion.
Parameters
----------
iteration : int
The maximum iteration number in the inversion process.
batch_size : Optional[int], optional
The number of shots (data samples) in each batch. Default is None, meaning use all available shots.
checkpoint_segments : Optional[int], optional
The number of segments into which the time series should be divided for memory efficiency. Default is 1, which means no segmentation.
start_iter : int, optional
The starting iteration for the optimization process (e.g., for optimizers like Adam/AdamW, and learning rate schedulers like step_lr). Default is 0.
cutoff_freq : Optional[float], optional
The cutoff frequency for low-pass filtering, if specified. Default is None (no filtering applied).
"""
n_shots = self.propagator.src_n
if batch_size is None or batch_size > n_shots:
batch_size = n_shots
# epoch
pbar_epoch = tqdm(range(start_iter,start_iter+iteration),position=0,leave=False,colour='green',ncols=80)
for i in pbar_epoch:
for opt in self.optimizer:
opt.zero_grad()
# batch
loss_batch = 0
pbar_batch = tqdm(range(math.ceil(n_shots/batch_size)),position=1,leave=False,colour='red',ncols=80)
for batch in pbar_batch:
# forward simulation
begin_index = 0 if batch==0 else batch*batch_size
end_index = n_shots if batch==math.ceil(n_shots/batch_size)-1 else (batch+1)*batch_size
shot_index = np.arange(begin_index,end_index)
record_waveform = self.propagator.forward(shot_index=shot_index,checkpoint_segments=checkpoint_segments)
rcv_p,rcv_u,rcv_w = record_waveform["p"],record_waveform["u"],record_waveform["w"]
forward_wavefield_p,forward_wavefield_u,forward_wavefield_w = record_waveform["forward_wavefield_p"],record_waveform["forward_wavefield_u"],record_waveform["forward_wavefield_w"]
if batch == 0:
forw = forward_wavefield_p.cpu().detach().numpy()
else:
forw += forward_wavefield_p.cpu().detach().numpy()
# misfit
if rcv_p.shape == self.obs_p[shot_index].shape: # observed and synthetic data with the same shape (partial-data missing)
receiver_mask_3D = self.receiver_masks_3D[shot_index] # [shot, time, rcv]
syn_p = rcv_p*receiver_mask_3D
else: # observed and synthetic data with the different shape (trace missing)
receiver_mask_2D = self.receiver_masks_2D[shot_index] # [shot, rcv]
syn_p = torch.zeros_like(self.obs_p[shot_index],device=self.device)
for k in range(rcv_p.shape[0]):
syn_p[k] = rcv_p[k,...,np.argwhere(receiver_mask_2D[k]).tolist()].squeeze()
if self.data_masks is not None:
data_mask = self.data_masks[shot_index]
syn_p = syn_p * data_mask
data_loss = self.calculate_loss(syn_p, self.obs_p[shot_index], self.waveform_normalize, self.loss_fn, cutoff_freq, self.propagator.dt,shot_index)
# regularization
if self.regularization_fn is not None:
regularization_loss_vp = self.calculate_regularization_loss(self.model.vp , self.regularization_weights_x[0], self.regularization_weights_z[0], self.regularization_fn)
regularization_loss_rho = self.calculate_regularization_loss(self.model.rho, self.regularization_weights_x[1], self.regularization_weights_z[1], self.regularization_fn)
regularization_loss = regularization_loss_vp+regularization_loss_rho
loss_batch = loss_batch + data_loss.item() + regularization_loss.item()
loss = data_loss + regularization_loss
else:
loss_batch = loss_batch + data_loss.item()
loss = data_loss
loss.backward()
if math.ceil(n_shots/batch_size) == 1:
pbar_batch.set_description(f"Shot:{begin_index} to {end_index}")
# gradient postprocess
def grad_post_process(grads, parameter, forw=None, idx=None):
grads = grads.cpu().detach().numpy()
with torch.no_grad():
param = getattr(self.model, parameter).cpu().detach().numpy()
vmax = np.max(param)
# Apply gradient processor
if isinstance(self.gradient_processor, GradProcessor):
grads = self.gradient_processor.forward(nz=self.propagator.model.nz, nx=self.propagator.model.nx, vmax=vmax, grad=grads, forw=forw)
else:
grads = self.gradient_processor[idx].forward(nz=self.propagator.model.nz, nx=self.propagator.model.nx, vmax=vmax, grad=grads, forw=forw)
grads = numpy2tensor(grads, dtype=self.propagator.dtype).to(self.propagator.device)
return grads
# Register hooks for each model parameter
if self.propagator.model.get_requires_grad("vp"):
self.propagator.model.vp.register_hook(lambda grad: grad_post_process(grad, "vp", forw=forw, idx=0))
if self.propagator.model.get_requires_grad("rho"):
self.propagator.model.rho.register_hook(lambda grad: grad_post_process(grad, "rho", forw=forw, idx=1))
for opt in self.optimizer:
opt.step()
for schdul in self.scheduler:
schdul.step()
# constrain the model parameters
self.propagator.model.forward()
# save the temp result
if self.cache_result:
temp_vp = self.propagator.model.vp.cpu().detach().numpy()
temp_rho = self.propagator.model.rho.cpu().detach().numpy()
self.iter_vp.append(temp_vp)
self.iter_rho.append(temp_rho)
self.iter_loss.append(loss_batch)
# save the result
self.save_figure(i,temp_vp,model_type="vp")
self.save_figure(i,temp_rho,model_type="rho")
pbar_epoch.set_description("Iter:{},Loss:{:.4}".format(i+1,loss_batch))