Source code for ADFWI.fwi.misfit.GlobalCorrelation

from .base import Misfit
import torch

[docs]class Misfit_global_correlation(Misfit): """ Global correlation misfit function. This class computes the global correlation misfit between observed and synthetic waveforms, focusing on their correlation structure across traces. Parameters ---------- dt : float, optional Time sampling interval (default is 1). """ def __init__(self, dt=1) -> None: """ Initialize the Misfit_global_correlation. """ super().__init__() self.dt = dt
[docs] def forward(self, obs: torch.Tensor, syn: torch.Tensor) -> torch.Tensor: """ Compute the global correlation misfit between observed and synthetic waveforms. This function calculates the correlation between each pair of observed and synthetic waveforms across all traces and computes the global misfit. Parameters ---------- obs : torch.Tensor Observed waveform with shape (batch, channels, traces). syn : torch.Tensor Synthetic waveform with shape (batch, channels, traces). Returns ------- torch.Tensor Correlation-based misfit loss computed between the observed and synthetic waveforms. """ mask1 = torch.sum(torch.abs(obs), axis=1) == 0 mask2 = torch.sum(torch.abs(syn), axis=1) == 0 mask = ~(mask1 * mask2) # Initialize result tensor for storing residuals rsd = torch.zeros((obs.shape[0], obs.shape[2]), device=obs.device) # Loop through each trace and calculate the correlation misfit for itrace in range(obs.shape[2]): shot_idx = torch.argwhere(mask[:, itrace]) obs_trace = obs[shot_idx, :, itrace].squeeze(axis=1) # Shape: (N, T) syn_trace = syn[shot_idx, :, itrace].squeeze(axis=1) # Shape: (N, T) # Normalize the observed and synthetic traces obs_trace_norm = obs_trace.norm(dim=1, keepdim=True) syn_trace_norm = syn_trace.norm(dim=1, keepdim=True) obs_trace = obs_trace / obs_trace_norm syn_trace = syn_trace / syn_trace_norm # Calculate covariance and variances cov = torch.mean(obs_trace * syn_trace, dim=1) # Shape: (N,) var_obs = torch.var(obs_trace, dim=1) # Shape: (N,) var_syn = torch.var(syn_trace, dim=1) # Shape: (N,) # Compute correlation, with small value added to avoid division by zero corr = cov / (torch.sqrt(var_obs * var_syn) + 1e-8) # Set correlation to zero where both variances are zero corr[torch.isnan(corr)] = 0 rsd[shot_idx, itrace] = -corr.reshape(-1, 1) # Compute the total loss, scaled by the time interval dt loss = torch.sum(rsd * self.dt) return loss