Source code for ADFWI.fwi.misfit.Envelope

from .base import Misfit
import torch
import numpy as np
import torch.nn.functional as F

@torch.jit.script
def hilbert(x: torch.Tensor) -> torch.Tensor:
    """
    Perform Hilbert transform along the last axis of x.

    This function computes the Hilbert transform of the input signal `x` along its last axis.
    It uses FFT-based approach to compute the analytic signal.

    Parameters
    ----------
    x : torch.Tensor
        Input signal data, a tensor of shape (..., N), where N is the length of the signal
        along the last axis.

    Returns
    -------
    torch.Tensor
        Analytic signal with the same shape as `x`. The result is the Hilbert transform of the
        input signal along the last axis.
    """
    device = x.device
    N = x.shape[-1] * 2  # Double the length for FFT
    Xf = torch.fft.fft(x, n=N)  # FFT on extended signal
    h = torch.zeros(N, dtype=Xf.dtype, device=device)  # Initialize multiplier array
    h[0] = 1
    h[1:(N + 1) // 2] = 2  # Set values to create analytic signal
    if N % 2 == 0:
        h[N // 2] = 1  # Special case for even-length signals
    return torch.fft.ifft(Xf * h)[..., :x.shape[-1]]  # Return Hilbert-transformed signal


@torch.jit.script
def diff(x: torch.Tensor, dim: int = -1, same_size: bool = False) -> torch.Tensor:
    """
    Compute discrete difference along the last axis.

    This function computes the discrete difference of the input tensor `x` along the specified
    axis (`dim`). If `same_size` is set to `True`, the output tensor is padded to maintain the
    same size as the input.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor, with arbitrary shape. The difference will be computed along the last axis by default.
        
    dim : int, optional
        Axis along which to compute the difference (default is -1, i.e., the last axis).

    same_size : bool, optional
        If `True`, pads the output tensor to maintain the same size as the input tensor (default is `False`).

    Returns
    -------
    torch.Tensor
        The discrete difference along the specified axis, with the same shape as `x` if `same_size` is `True`, 
        or smaller along the specified axis if `same_size` is `False`.
    """
    if same_size:
        return F.pad(x[..., 1:] - x[..., :-1], (1, 0))  # Pad to match original shape
    else:
        return x[..., 1:] - x[..., :-1]
    
@torch.jit.script
def unwrap(phi: torch.Tensor, dim: int = -1) -> torch.Tensor:
    """
    Unwrap phase by correcting for phase discontinuities.

    This function unwraps a phase tensor `phi` by detecting and correcting phase discontinuities.
    The phase is adjusted such that the changes in phase are continuous along the specified axis.

    Parameters
    ----------
    phi : torch.Tensor
        Phase tensor to be unwrapped. The tensor is assumed to contain phase values in radians.

    dim : int, optional
        Axis along which to unwrap the phase (default is -1, i.e., the last axis).

    Returns
    -------
    torch.Tensor
        The unwrapped phase tensor, with phase discontinuities corrected along the specified axis.
    """
    dphi = diff(phi, same_size=True)  # Calculate discrete phase difference
    dphi_m = ((dphi + np.pi) % (2 * np.pi)) - np.pi  # Map phase difference to [-pi, pi]
    dphi_m[(dphi_m == -np.pi) & (dphi > 0)] = np.pi  # Correct for edge cases
    phi_adj = dphi_m - dphi
    phi_adj[dphi.abs() < np.pi] = 0  # Adjust only where discontinuities exist
    return phi + phi_adj.cumsum(dim)  # Accumulate adjustments

[docs]class Misfit_envelope(Misfit): """ Compute the envelope misfit for initial velocity model estimation. This class computes the misfit between observed and synthetic waveforms based on their envelope or instantaneous phase. It is used for estimating the initial velocity model in seismic inversion. References: Wu et al., 2014; Yuan et al., 2015 Parameters ---------- dt : float, optional Time sampling interval (default is 1). p : float, optional Norm order for envelope difference (default is 1.5). instaneous_phase : bool, optional If True, use instantaneous phase for misfit (default is False). norm : str, optional Norm type ("L1" or "L2") for final loss calculation (default is "L2"). """ def __init__(self, dt: float = 1, p: float = 1.5, instaneous_phase: bool = False, norm: str = "L2") -> None: """ Initialize the Misfit_envelope. """ super().__init__() self.p = p self.instaneous_phase = instaneous_phase self.dt = dt self.norm = norm
[docs] def forward(self, obs: torch.Tensor, syn: torch.Tensor) -> torch.Tensor: """ Compute the misfit between observed and synthetic waveforms. This function calculates the misfit between the observed and synthetic waveforms either using their envelopes or instantaneous phases, depending on the chosen option. Parameters ---------- obs : torch.Tensor Observed waveform [batch, trace, time]. syn : torch.Tensor Synthetic waveform [batch, trace, time]. Returns ------- torch.Tensor Envelope or phase difference loss computed between the observed and synthetic waveforms. """ mask1 = torch.sum(torch.abs(obs),axis=1) == 0 mask2 = torch.sum(torch.abs(syn),axis=1) == 0 mask = ~(mask1 * mask2) device = obs.device rsd = torch.zeros((obs.shape[0], obs.shape[2], obs.shape[1]), device=device) # Residual storage for ishot in range(obs.shape[0]): trace_idx = torch.argwhere(mask[ishot]).reshape(-1) obs_shot = obs[ishot,:,trace_idx].squeeze(axis=0).T # Transpose to [trace, time series] syn_shot = syn[ishot,:,trace_idx].squeeze(axis=0).T # Hilbert transform to get analytic signal analytic_signal_obs = hilbert(obs_shot) analytic_signal_syn = hilbert(syn_shot) # Compute envelopes (magnitude of analytic signals) envelopes_obs = torch.abs(analytic_signal_obs) envelopes_syn = torch.abs(analytic_signal_syn) if self.instaneous_phase: # Use instantaneous phase for misfit phase_obs = unwrap(torch.angle(analytic_signal_obs)) phase_syn = unwrap(torch.angle(analytic_signal_syn)) rsd[ishot,trace_idx,:] = (phase_obs - phase_syn).unsqueeze(0) else: # Compute envelope difference with norm p rsd[ishot,trace_idx,:] = (envelopes_syn**self.p - envelopes_obs**self.p).unsqueeze(0) # Compute final loss based on the selected norm if self.norm == "L1": loss = torch.sum(torch.abs(rsd)) else: loss = 0.5 * torch.sum(rsd * rsd * self.dt) # L2 loss with time weighting return loss