from typing import Optional,Dict
import numpy as np
import torch
from torch import Tensor
import matplotlib.pyplot as plt
from ADFWI.model import AbstractModel
from ADFWI.survey import Survey
from ADFWI.utils import numpy2tensor
from .boundary_condition import bc_pml,bc_gerjan,bc_sincos
from .acoustic_kernels import forward_kernel
[docs]class AcousticPropagator(torch.nn.Module):
"""
Defines the propagator for the isotropic acoustic wave equation (stress-velocity form), solved by the finite difference method.
Parameters
----------
model : AbstractModel
The model object that defines the physical properties of the medium.
survey : Survey
The survey object containing the source and receiver information.
device : str, optional
The device type for the computation. Default is 'cpu'.
cpu_num : int, optional
The number of CPU threads to use. Default is 1.
gpu_num : int, optional
The number of GPU devices to use. Default is 1.
dtype : torch.dtype, optional
The data type for tensors. Default is torch.float32.
"""
def __init__(self,
model : AbstractModel,
survey : Survey,
device : Optional[str] = 'cpu',
cpu_num: Optional[int] = 1,
gpu_num: Optional[int] = 1,
dtype : torch.dtype = torch.float32
):
super().__init__()
# Validate model and survey types
if not isinstance(model, AbstractModel):
raise ValueError("model is not an instance of AbstractModel")
if not isinstance(survey, Survey):
raise ValueError("survey is not an instance of Survey")
# ---------------------------------------------------------------
# set the model and survey
# ---------------------------------------------------------------
self.model = model
self.survey = survey
self.device = device
self.dtype = dtype
self.cpu_num = cpu_num
self.gpu_num = gpu_num
# ---------------------------------------------------------------
# parse parameters for model
# ---------------------------------------------------------------
self.ox, self.oz = model.ox,model.oz
self.dx, self.dz = model.dx,model.dz
self.nx, self.nz = model.nx,model.nz
self.nt = survey.source.nt
self.dt = survey.source.dt
self.f0 = survey.source.f0
# ---------------------------------------------------------------
# set the boundary: [top, bottom, left, right]
# ---------------------------------------------------------------
self.abc_type = model.abc_type
self.nabc = model.nabc
self.free_surface = model.free_surface
self.bcx,self.bcz,self.damp = None,None,None
self.boundary_condition()
# ---------------------------------------------------------------
# parameters for source
# ---------------------------------------------------------------
self.source = self.survey.source
self.src_loc = self.source.get_loc()
self.src_x = numpy2tensor(self.src_loc[...,0],torch.long).to(self.device)
self.src_z = numpy2tensor(self.src_loc[...,1],torch.long).to(self.device)
self.src_n = self.source.num
self.wavelet = numpy2tensor(self.source.get_wavelet(),self.dtype).to(self.device)
self.moment_tensor = numpy2tensor(self.source.get_moment_tensor(),self.dtype).to(self.device)
# ---------------------------------------------------------------
# parameters for receiver
# ---------------------------------------------------------------
self.receiver = self.survey.receiver
self.rcv_loc = self.receiver.get_loc()
self.rcv_x = numpy2tensor(self.rcv_loc[:,0],torch.long).to(self.device)
self.rcv_z = numpy2tensor(self.rcv_loc[:,1],torch.long).to(self.device)
self.rcv_n = self.receiver.num
self.receiver_masks = self.survey.receiver_masks
self.receiver_masks_obs = self.survey.receiver_masks_obs
[docs] def boundary_condition(self, vmax=None):
"""Set boundary conditions based on the specified ABC type."""
if self.abc_type.lower() == "pml":
if vmax is not None:
damp = bc_pml(self.nx, self.nz, self.dx, self.dz, pml=self.nabc, vmax=vmax, free_surface=False)
else:
damp = bc_pml(self.nx, self.nz, self.dx, self.dz, pml=self.nabc,
vmax=self.model.vp.cpu().detach().numpy().max(),
free_surface=False)
elif self.abc_type.lower() == 'gerjan':
damp = bc_gerjan(self.nx, self.nz, self.dx, self.dz, pml=self.nabc, alpha=self.model.abc_jerjan_alpha,
free_surface=False)
else:
damp = bc_sincos(self.nx, self.nz, self.dx, self.dz, pml=self.nabc,
free_surface=False)
self.damp = numpy2tensor(damp, self.dtype).to(self.device)
[docs] def forward(self,
model: Optional[AbstractModel] = None,
shot_index: Optional[int] = None,
checkpoint_segments: int = 1,
) -> Dict[str, Tensor]:
"""
Forward simulation for selected shots.
Parameters
----------
model : Optional[AbstractModel], optional
Model to use for simulation. If not provided, defaults to the instance's model.
shot_index : Optional[int], optional
Index of the shot to simulate.
checkpoint_segments : int, optional
Number of segments for checkpointing to save memory. Default is 1.
Returns
-------
record_waveform : dict
Dictionary containing recorded waveforms.
"""
# calculate the thomson/lame and elastic moduli parameters
model = self.model if model is None else model
model.forward()
# foward simulation for select shots
src_x = self.src_x[shot_index] if shot_index is not None else self.src_x
src_z = self.src_z[shot_index] if shot_index is not None else self.src_z
src_n = len(src_x)
wavelet = self.wavelet[shot_index] if shot_index is not None else self.wavelet
record_waveform = forward_kernel(
self.nx,self.nz,self.dx,self.dz,self.nt,self.dt,
self.nabc,self.free_surface,
src_x,src_z,src_n,wavelet,
self.rcv_x,self.rcv_z,self.rcv_n,
self.damp,
model.vp,model.rho,
checkpoint_segments=checkpoint_segments,
device=self.device,dtype=self.dtype
)
return record_waveform