from .base import Misfit
import torch
[docs]class Misfit_global_correlation(Misfit):
    """
    Global correlation misfit function.
    This class computes the global correlation misfit between observed and synthetic waveforms, 
    focusing on their correlation structure across traces.
    Parameters
    ----------
    dt : float, optional
        Time sampling interval (default is 1).
    """
    def __init__(self, dt=1) -> None:
        """
        Initialize the Misfit_global_correlation.
        """
        super().__init__()
        self.dt = dt
    
[docs]    def forward(self, obs: torch.Tensor, syn: torch.Tensor) -> torch.Tensor:
        """
        Compute the global correlation misfit between observed and synthetic waveforms.
        This function calculates the correlation between each pair of observed and synthetic waveforms 
        across all traces and computes the global misfit.
        Parameters
        ----------
        obs : torch.Tensor
            Observed waveform with shape (batch, channels, traces).
        
        syn : torch.Tensor
            Synthetic waveform with shape (batch, channels, traces).
        
        Returns
        -------
        torch.Tensor
            Correlation-based misfit loss computed between the observed and synthetic waveforms.
        """
        mask1    = torch.sum(torch.abs(obs), axis=1) == 0
        mask2    = torch.sum(torch.abs(syn), axis=1) == 0
        mask     = ~(mask1 * mask2)
        
        # Initialize result tensor for storing residuals
        rsd = torch.zeros((obs.shape[0], obs.shape[2]), device=obs.device)
        # Loop through each trace and calculate the correlation misfit
        for itrace in range(obs.shape[2]):
            shot_idx  = torch.argwhere(mask[:, itrace])
            obs_trace = obs[shot_idx, :, itrace].squeeze(axis=1)  # Shape: (N, T)
            syn_trace = syn[shot_idx, :, itrace].squeeze(axis=1)  # Shape: (N, T)
            # Normalize the observed and synthetic traces
            obs_trace_norm = obs_trace.norm(dim=1, keepdim=True)
            syn_trace_norm = syn_trace.norm(dim=1, keepdim=True)
            
            obs_trace = obs_trace / obs_trace_norm
            syn_trace = syn_trace / syn_trace_norm
            
            # Calculate covariance and variances
            cov     = torch.mean(obs_trace * syn_trace, dim=1)  # Shape: (N,)
            var_obs = torch.var(obs_trace, dim=1)               # Shape: (N,)
            var_syn = torch.var(syn_trace, dim=1)               # Shape: (N,)
            # Compute correlation, with small value added to avoid division by zero
            corr = cov / (torch.sqrt(var_obs * var_syn) + 1e-8)
            # Set correlation to zero where both variances are zero
            corr[torch.isnan(corr)] = 0
            
            rsd[shot_idx, itrace] = -corr.reshape(-1, 1)
        
        # Compute the total loss, scaled by the time interval dt
        loss = torch.sum(rsd * self.dt)
        return loss