Source code for ADFWI.fwi.elastic_fwi

from typing import Optional,Union,List
import os
import math
import torch
import numpy as np
from tqdm import tqdm
from ADFWI.model       import AbstractModel,IsotropicElasticModel,AnisotropicElasticModel
from ADFWI.propagator  import ElasticPropagator,GradProcessor
from ADFWI.survey      import SeismicData
from ADFWI.fwi.misfit  import Misfit
from ADFWI.fwi.regularization import Regularization
from ADFWI.utils       import numpy2tensor
from ADFWI.view        import plot_vp_vs_rho,plot_model,plot_eps_delta_gamma

from ADFWI.fwi.multiScaleProcessing import lpass
from ADFWI.utils.first_arrivel_picking import apply_mute
from ADFWI.utils.offset_mute import mute_offset
from ADFWI.fwi.multiScaleProcessing import lpass

[docs]class ElasticFWI(torch.nn.Module): """Elastic Full Waveform Inversion class. Parameters ---------- propagator : ElasticPropagator The propagator used for simulating elastic wave propagation. model : AbstractModel The model class representing the velocity structure. loss_fn : Union[Misfit, torch.autograd.Function] The loss function used to compute the misfit between observed and predicted data. obs_data : SeismicData The observed seismic data. optimizer : Union[torch.optim.Optimizer, List[torch.optim.Optimizer]], optional The optimizer or list of optimizers for model parameters. Default is None. scheduler : torch.optim.lr_scheduler, optional The learning rate scheduler for optimizing the model parameters. Default is None. gradient_processor : Union[GradProcessor, List[GradProcessor]], optional Processor(s) for handling gradients (e.g., vp/vs/rho, epsilon/delta/gamma). Default is None. regularization_fn : Optional[Regularization], optional Regularization function(s) applied to parameters like vp/vs/rho/epsilon/delta/gamma. Default is None. regularization_weights_x : Optional[List[Union[float]]], optional Regularization weights for the x-axis. Default is [0, 0, 0, 0, 0, 0]. regularization_weights_z : Optional[List[Union[float]]], optional Regularization weights for the z-axis. Default is [0, 0, 0, 0, 0, 0]. waveform_normalize : Optional[bool], optional Whether to normalize the waveforms during inversion. Default is True. waveform_mute_late_window : Optional[float], optional Late window mute applied to waveform data. Default is None. waveform_mute_offset : Optional[float], optional Offset mute applied to waveform data. Default is None. cache_result : Optional[bool], optional Whether to save intermediate results during the inversion. Default is True. cache_gradient : Optional[bool], optional Whether to save model variations (not gradients) during inversion. Default is False. save_fig_epoch : Optional[int], optional The interval (in epochs) at which to save the inversion result figure. Default is -1 (no figure saved). save_fig_path : Optional[str], optional The path where to save the inversion result figure. Default is an empty string (no save path). inversion_component : Optional[np.array], optional The components of the inversion (e.g., ["pressure"]). Default is ["pressure"]. """ def __init__(self,propagator:ElasticPropagator,model:AbstractModel, loss_fn:Union[Misfit,torch.autograd.Function], obs_data:SeismicData, optimizer:Union[torch.optim.Optimizer,List[torch.optim.Optimizer]] = None, scheduler:torch.optim.lr_scheduler = None, gradient_processor: Union[GradProcessor,List[GradProcessor]] = None, # vp/vs/rho epsilon/delta/gamma regularization_fn:Optional[Regularization] = None, regularization_weights_x:Optional[List[Union[float]]] = [0,0,0,0,0,0], # vp/vs/rho epsilon/delta/gamma regularization_weights_z:Optional[List[Union[float]]] = [0,0,0,0,0,0], # vp/vs/rho epsilon/delta/gamma waveform_normalize:Optional[bool] = True, waveform_mute_late_window:Optional[float] = None, waveform_mute_offset:Optional[float] = None, cache_result:Optional[bool] = True, cache_gradient:Optional[bool] = False, save_fig_epoch:Optional[int] = -1, save_fig_path:Optional[str] = "", inversion_component:Optional[np.array] = ["pressure"], ): super().__init__() self.propagator = propagator self.model = model self.optimizer = optimizer self.scheduler = scheduler self.loss_fn = loss_fn self.regularization_fn = regularization_fn self.regularization_weights_x = regularization_weights_x self.regularization_weights_z = regularization_weights_z self.obs_data = obs_data self.gradient_processor = gradient_processor self.device = self.propagator.device self.dtype = self.propagator.dtype # Real-Case settings: for trace missing, partial data missing receiver_masks = self.propagator.receiver_masks if receiver_masks is None: receiver_masks = np.ones((self.propagator.src_n,self.propagator.rcv_n)) receiver_masks = numpy2tensor(receiver_masks) self.receiver_masks_2D = receiver_masks # [shot, rcv] self.receiver_masks_3D = receiver_masks.unsqueeze(1).expand(-1, self.propagator.nt, -1).to(self.device) # [shot, time, rcv] # Real-Case settings: mute late window (by first arrival picking) & mute offset self.waveform_normalize = waveform_normalize self.waveform_mute_late_window = waveform_mute_late_window self.waveform_mute_offset = waveform_mute_offset # observed data obs_p = -(self.obs_data.data["txx"]+self.obs_data.data["tzz"]) obs_p = (numpy2tensor(obs_p,self.dtype).to(self.device)) obs_vx = (numpy2tensor(self.obs_data.data["vx"],self.dtype).to(self.device)) obs_vz = (numpy2tensor(self.obs_data.data["vz"],self.dtype).to(self.device)) if self.propagator.receiver_masks_obs: # mark the observed data need to be masked or not (trace) obs_p = obs_p*self.receiver_masks_3D obs_vx = obs_vx*self.receiver_masks_3D obs_vz = obs_vz*self.receiver_masks_3D self.data_masks = numpy2tensor(self.obs_data.data_masks).to(self.device) if self.obs_data.data_masks is not None else None if self.data_masks is not None: # some of the data are unuseful (data) obs_p = obs_p*self.data_masks obs_vx = obs_vx*self.data_masks obs_vz = obs_vz*self.data_masks self.obs_p = obs_p self.obs_vx = obs_vx self.obs_vz = obs_vz # save result self.cache_result = cache_result self.cache_gradient = cache_gradient self.iter_vp,self.iter_vs,self.iter_rho = [],[],[] self.iter_eps,self.iter_delta,self.iter_gamma = [],[],[] self.iter_vp_grad,self.iter_vs_grad,self.iter_rho_grad = [],[],[] self.iter_eps_grad,self.iter_delta_grad,self.iter_gamma_grad = [],[],[] self.iter_loss = [] # save figure self.save_fig_epoch = save_fig_epoch self.save_fig_path = save_fig_path # inversion component self.inversion_component = inversion_component def _normalize(self,data): """Normalizes waveform data by dividing each component by its maximum value.""" mask = torch.sum(torch.abs(data),axis=1,keepdim=True) == 0 max_val = torch.max(torch.abs(data),axis=1,keepdim=True).values max_val = max_val.masked_fill(mask, 1) data = data/max_val return data # misfits calculation
[docs] def calculate_loss(self, synthetic_waveform, observed_waveform, normalization, loss_fn, cutoff_freq=None, propagator_dt=None, shot_index=None): """Generalized function to calculate misfit loss for a given component. """ # mute data by offset if self.waveform_mute_offset is not None: receiver_mask_2D = self.receiver_masks_2D[shot_index].cpu() # [shot, rcv] src_x = self.propagator.src_x.cpu()[shot_index] rcv_x_list = self.propagator.rcv_x.cpu() rcv_x = torch.zeros(synthetic_waveform.shape[0],synthetic_waveform.shape[-1]) for i in range(synthetic_waveform.shape[0]): rcv_x[i] = rcv_x_list[np.argwhere(receiver_mask_2D[i]).tolist()].squeeze() synthetic_waveform = mute_offset(rcv_x,src_x,self.propagator.dx,synthetic_waveform,self.waveform_mute_offset) observed_waveform = mute_offset(rcv_x,src_x,self.propagator.dx,observed_waveform,self.waveform_mute_offset) # mute data by first arrival & late window if self.waveform_mute_late_window is not None: synthetic_waveform_temp = synthetic_waveform.clone() observed_waveform_temp = observed_waveform.clone() for i in range(synthetic_waveform.shape[0]): synthetic_waveform[i] = apply_mute(self.waveform_mute_late_window, synthetic_waveform_temp[i], self.propagator.dt) observed_waveform[i] = apply_mute(self.waveform_mute_late_window, observed_waveform_temp[i], self.propagator.dt) # Apply low-pass filter if cutoff frequency is provided if cutoff_freq is not None: synthetic_waveform, observed_waveform = lpass(synthetic_waveform, observed_waveform, cutoff_freq, int(1 / propagator_dt)) if normalization: observed_waveform = self._normalize(observed_waveform) synthetic_waveform = self._normalize(synthetic_waveform) if isinstance(loss_fn, Misfit): return loss_fn.forward(synthetic_waveform, observed_waveform) else: return loss_fn(synthetic_waveform, observed_waveform)
# regularization calculation
[docs] def calculate_regularization_loss(self, model_param, weight_x, weight_z, regularization_fn): """Generalized function to calculate regularization loss for a given parameter. """ regularization_loss = torch.tensor(0.0, device=model_param.device) # Check if the parameter requires gradient if model_param.requires_grad: # Set the regularization weights for x and z directions regularization_fn.alphax = weight_x regularization_fn.alphaz = weight_z # Calculate regularization loss if any weight is greater than zero if regularization_fn.alphax > 0 or regularization_fn.alphaz > 0: regularization_loss = regularization_fn.forward(model_param) return regularization_loss
# gradient precondition
[docs] def process_gradient(self, parameter, forw, idx=None): """process the gradient for each parameter """ with torch.no_grad(): grads = parameter.grad.cpu().detach().numpy() vmax = np.max(parameter.cpu().detach().numpy()) # Apply gradient processor if isinstance(self.gradient_processor, GradProcessor): grads = self.gradient_processor.forward(nz=self.model.nz, nx=self.model.nx, vmax=vmax, grad=grads, forw=forw) else: grads = self.gradient_processor[idx].forward(nz=self.model.nz, nx=self.model.nx, vmax=vmax, grad=grads, forw=forw) # Convert grads back to tensor and assign grads_tensor = numpy2tensor(grads, dtype=self.propagator.dtype).to(self.propagator.device) parameter.grad = grads_tensor return
[docs] def save_vp_vs_rho_fig(self,epoch_id,vp,vs,rho): vp_bound = self.model.get_bound("vp") vs_bound = self.model.get_bound("vs") rho_bound = self.model.get_bound("rho") if vp_bound[0] is None and vp_bound[1] is None: self.vp_min = self.model.get_model("vp").min() - 500 self.vp_max = self.model.get_model("vp").max() + 500 else: self.vp_min = vp_bound[0] self.vp_max = vp_bound[1] if self.model.water_layer_mask is not None: self.vp_min = 1500 if vs_bound[0] is None and vs_bound[1] is None: self.vs_min = self.model.get_model("vs").min() - 500 self.vs_max = self.model.get_model("vs").max() + 500 else: self.vs_min = vs_bound[0] self.vs_max = vs_bound[1] if self.model.water_layer_mask is not None: self.vs_min = 0 if rho_bound[0] is None and rho_bound[1] is None: self.rho_min = self.model.get_model("rho").min() - 200 self.rho_max = self.model.get_model("rho").max() + 200 else: self.rho_min = rho_bound[0] self.rho_max = rho_bound[1] if self.model.water_layer_mask is not None: self.rho_min = 1000 if self.save_fig_epoch == -1: pass elif epoch_id%self.save_fig_epoch == 0: if os.path.exists(self.save_fig_path): plot_vp_vs_rho( vp=vp,vs=vs,rho=rho, # title=f"Iteration {i}", figsize=(12,5),wspace=0.2,cbar_pad_fraction=0.18,cbar_height=0.04, dx=self.model.dx,dz=self.model.dz, vp_min=self.vp_min,vp_max=self.vp_max, vs_min=self.vs_min,vs_max=self.vs_max, rho_min=self.rho_min,rho_max=self.rho_max, save_path=os.path.join(self.save_fig_path,f"model_{epoch_id}.png"), show=False ) return
[docs] def save_eps_delta_gamma_fig(self,epoch_id,eps,delta,gamma): eps_bound = self.model.get_bound("eps") delta_bound = self.model.get_bound("delta") gamma_bound = self.model.get_bound("gamma") if eps_bound[0] is None and eps_bound[1] is None: self.vp_min = self.model.get_model("eps").min() - 0.01 self.vp_max = self.model.get_model("eps").max() + 0.01 else: self.vp_min = eps_bound[0] self.vp_max = eps_bound[1] if delta_bound[0] is None and delta_bound[1] is None: self.delta_min = self.model.get_model("delta").min() - 0.01 self.delta_max = self.model.get_model("delta").max() + 0.01 else: self.vs_min = delta_bound[0] self.vs_max = delta_bound[1] if gamma_bound[0] is None and gamma_bound[1] is None: self.gamma_min = self.model.get_model("gamma").min() - 0.01 self.gamma_max = self.model.get_model("gamma").max() + 0.01 else: self.rho_min = gamma_bound[0] self.rho_max = gamma_bound[1] if self.save_fig_epoch == -1: pass elif epoch_id%self.save_fig_epoch == 0: if os.path.exists(self.save_fig_path): plot_eps_delta_gamma( eps=eps,delta=delta,gamma=gamma, # title=f"Iteration {i}", figsize=(12,5),wspace=0.3,cbar_pad_fraction=0.01,cbar_height=0.04, dx=self.model.dx,dz=self.model.dz, save_path=os.path.join(self.save_fig_path,f"anisotropic_model_{epoch_id}.png"), show=False ) return
[docs] def save_gradient_fig(self,epoch_id,data,model_type="vp"): if self.save_fig_epoch == -1: pass elif epoch_id%self.save_fig_epoch == 0: if os.path.exists(self.save_fig_path): plot_model(data,title=f"Iteration {epoch_id}", dx=self.model.dx,dz=self.model.dz, save_path=os.path.join(self.save_fig_path,f"{model_type}_{epoch_id}.png"), show=False,cmap='seismic') return
[docs] def save_model_and_gradients(self,epoch_id,loss_epoch): # Save the loss self.iter_loss.append(loss_epoch) # Save the model parameters param_names = ["vp", "vs", "rho"] anisotropic_params = ["eps", "delta", "gamma"] if isinstance(self.model, AnisotropicElasticModel) else [] for name in param_names + anisotropic_params: param = getattr(self.model, name, None) if param is not None: temp_param = param.cpu().detach().numpy() getattr(self, f"iter_{name}").append(temp_param) # save the figure self.save_vp_vs_rho_fig(epoch_id,self.model.vp.cpu().detach().numpy(), self.model.vs.cpu().detach().numpy(), self.model.rho.cpu().detach().numpy()) if isinstance(self.model,AnisotropicElasticModel): self.save_eps_delta_gamma_fig(epoch_id, self.model.eps.cpu().detach().numpy(), self.model.delta.cpu().detach().numpy(), self.model.gamma.cpu().detach().numpy()) # Save gradients if required for name in param_names: if self.model.get_requires_grad(name): temp_grad = getattr(self.model, name).grad.cpu().detach().numpy() getattr(self, f"iter_{name}_grad").append(temp_grad) self.save_gradient_fig(epoch_id, temp_grad, model_type=f"grad_{name}") # For anisotropic model parameters, save gradients if required if isinstance(self.model, AnisotropicElasticModel): for name in anisotropic_params: if self.model.get_requires_grad(name): temp_grad = getattr(self.model, name).grad.cpu().detach().numpy() getattr(self, f"iter_{name}_grad").append(temp_grad) self.save_gradient_fig(epoch_id, temp_grad, model_type=f"grad_{name}") return
[docs] def real_case_data_selecting(self,rcv_p,shot_index): """Selecting useful data """ # observed and synthetic data with the same shape (observed and synthetic data keep all the trace) if rcv_p.shape == self.obs_p[shot_index].shape: receiver_mask_3D = self.receiver_masks_3D[shot_index] # [shot, time, rcv] syn_p = rcv_p*receiver_mask_3D # observed and synthetic data with the different shape (observed data only keep the usefule trace) else: receiver_mask_2D = self.receiver_masks_2D[shot_index] # [shot, rcv] syn_p = torch.zeros_like(self.obs_p[shot_index],device=self.device) for k in range(rcv_p.shape[0]): syn_p[k] = rcv_p[k,...,np.argwhere(receiver_mask_2D[k]).tolist()].squeeze() # some of the data are unuseful if self.data_masks is not None: data_mask = self.data_masks[shot_index] syn_p = syn_p * data_mask return syn_p
[docs] def forward(self, iteration:int, fd_order:int = 4, batch_size:Optional[int] = None, checkpoint_segments:Optional[int] = 1 , start_iter = 0, cutoff_freq = None, ): """iteration of full waveform inversion Parameters ----------- iteration : int The maximum iteration number in the inversion process. fd_order : int, optional The order of the finite difference scheme for wave propagation. Default is 4. batch_size : Optional[int], optional The number of shots (data samples) in each batch. Default is None, meaning use all available shots. checkpoint_segments : Optional[int], optional The number of segments into which the time series should be divided for memory efficiency. Default is 1, which means no segmentation. start_iter : int, optional The starting iteration for the optimization process (e.g., for optimizers like Adam/AdamW, and learning rate schedulers like step_lr). Default is 0. cutoff_freq : Optional[float], optional The cutoff frequency for low-pass filtering, if specified. Default is None (no filtering applied). """ n_shots = self.propagator.src_n if batch_size is None or batch_size > n_shots: batch_size = n_shots # epoch pbar_epoch = tqdm(range(start_iter,start_iter+iteration),position=0,leave=False,colour='green',ncols=80) for i in pbar_epoch: # batch self.optimizer.zero_grad() loss_epoch = 0 pbar_batch = tqdm(range(math.ceil(n_shots/batch_size)),position=1,leave=False,colour='red',ncols=80) for batch in pbar_batch: # forward simulation begin_index = 0 if batch==0 else batch*batch_size end_index = n_shots if batch==math.ceil(n_shots/batch_size)-1 else (batch+1)*batch_size shot_index = np.arange(begin_index,end_index) record_waveform = self.propagator.forward(fd_order=fd_order,shot_index=shot_index,checkpoint_segments=checkpoint_segments) rcv_txx,rcv_tzz,rcv_txz,rcv_vx,rcv_vz = record_waveform["txx"],record_waveform["tzz"],record_waveform["txz"],record_waveform["vx"],record_waveform["vz"] forward_wavefield_txx,forward_wavefield_tzz,forward_wavefield_txz,forward_wavefield_vx,forward_wavefield_vz = record_waveform["forward_wavefield_txx"],record_waveform["forward_wavefield_tzz"],record_waveform["forward_wavefield_txz"],record_waveform["forward_wavefield_vx"],record_waveform["forward_wavefield_vz"] if batch == 0: if "pressure" in self.inversion_component: forw_p = -(forward_wavefield_txx + forward_wavefield_tzz).cpu().detach().numpy() if "vx" in self.inversion_component: forw_vx = forward_wavefield_vx.cpu().detach().numpy() if "vz" in self.inversion_component: forw_vz = forward_wavefield_vz.cpu().detach().numpy() else: if "pressure" in self.inversion_component: forw_p += -(forward_wavefield_txx + forward_wavefield_tzz).cpu().detach().numpy() if "vx" in self.inversion_component: forw_vx += forward_wavefield_vx.cpu().detach().numpy() if "vz" in self.inversion_component: forw_vz += forward_wavefield_vz.cpu().detach().numpy() # misfits loss_pressure, loss_vx, loss_vz = 0, 0, 0 if "pressure" in self.inversion_component: syn_p = -(rcv_txx + rcv_tzz) syn_p = self.real_case_data_selecting(syn_p,shot_index) loss_pressure = self.calculate_loss(syn_p, self.obs_p[shot_index], self.waveform_normalize, self.loss_fn, cutoff_freq, self.propagator.dt) if "vx" in self.inversion_component: rcv_vx = self.real_case_data_selecting(rcv_vx,shot_index) loss_vx = self.calculate_loss(rcv_vx, self.obs_vx[shot_index],self.waveform_normalize, self.loss_fn, cutoff_freq, self.propagator.dt) if "vz" in self.inversion_component: rcv_vz = self.real_case_data_selecting(rcv_vz,shot_index) loss_vz = self.calculate_loss(rcv_vz, self.obs_vz[shot_index],self.waveform_normalize, self.loss_fn, cutoff_freq, self.propagator.dt) data_loss = loss_pressure + loss_vx + loss_vz # regularization if self.regularization_fn is not None: # Initialize regularization losses regularization_loss_vp = self.calculate_regularization_loss(self.model.vp , self.regularization_weights_x[0], self.regularization_weights_z[0], self.regularization_fn) regularization_loss_vs = self.calculate_regularization_loss(self.model.vs , self.regularization_weights_x[1], self.regularization_weights_z[1], self.regularization_fn) regularization_loss_rho = self.calculate_regularization_loss(self.model.rho, self.regularization_weights_x[2], self.regularization_weights_z[2], self.regularization_fn) # For anisotropic model parameters regularization_loss_eps,regularization_loss_delta,regularization_loss_gamma = torch.tensor(0.0, device=self.device),torch.tensor(0.0, device=self.device),torch.tensor(0.0, device=self.device) if isinstance(self.model, AnisotropicElasticModel): regularization_loss_eps = self.calculate_regularization_loss(self.model.eps , self.regularization_weights_x[3], self.regularization_weights_z[3], self.regularization_fn) regularization_loss_delta = self.calculate_regularization_loss(self.model.delta, self.regularization_weights_x[4], self.regularization_weights_z[4], self.regularization_fn) regularization_loss_gamma = self.calculate_regularization_loss(self.model.gamma, self.regularization_weights_x[5], self.regularization_weights_z[5], self.regularization_fn) # Summing all regularization losses regularization_loss = (regularization_loss_vp + regularization_loss_vs + regularization_loss_rho + regularization_loss_eps + regularization_loss_delta + regularization_loss_gamma) # Adding regularization loss to total loss loss_epoch += data_loss.item() + regularization_loss.item() loss = data_loss + regularization_loss else: loss_epoch += data_loss.item() loss = data_loss loss.backward() if math.ceil(n_shots/batch_size) == 1: pbar_batch.set_description(f"Shot:{begin_index} to {end_index}") # gradient process if self.model.get_requires_grad("vp"): self.process_gradient(self.model.vp, forw=forw_p if "pressure" in self.inversion_component else forw_vz, idx=0) if self.model.get_requires_grad("vs"): self.process_gradient(self.model.vs, forw=forw_p if "pressure" in self.inversion_component else forw_vz, idx=1) if self.model.get_requires_grad("rho"): self.process_gradient(self.model.rho, forw=forw_p if "pressure" in self.inversion_component else forw_vz, idx=2) if isinstance(self.model, AnisotropicElasticModel): if self.model.get_requires_grad("eps"): self.process_gradient(self.model.eps, forw=forw_p if "pressure" in self.inversion_component else forw_vz, idx=3) if self.model.get_requires_grad("delta"): self.process_gradient(self.model.delta, forw=forw_p if "pressure" in self.inversion_component else forw_vz, idx=4) # update model parameters self.optimizer.step() self.scheduler.step() # constrain the velocity model self.model.forward() # cache results if self.cache_result: self.save_model_and_gradients(epoch_id=i,loss_epoch=loss_epoch) pbar_epoch.set_description("Iter:{},Loss:{:.4}".format(i+1,loss_epoch))