import torch
from torch.autograd import Function
[docs]def trace_sum_normalize(x, time_dim=0):
    """
    Normalize each trace by its sum along the specified dimension.
    Parameters
    ----------
    x : Tensor
        Input tensor.
    time_dim : int, optional
        Dimension for time steps (default is 0).
    Returns
    -------
    Tensor
        Normalized tensor.
    """
    x = x / (x.sum(dim=time_dim, keepdim=True) + 1e-18)  # Avoid division by zero
    return x 
[docs]def trace_max_normalize(x, time_dim=0):
    """
    Normalize each trace by its maximum value along the specified dimension.
    The value of each trace will be in the range [-1, 1] after the processing.
    Note that the channel should be 1.
    Parameters
    ----------
    x : Tensor
        Input tensor.
    time_dim : int, optional
        Dimension for time steps (default is 0).
    Returns
    -------
    Tensor
        Normalized tensor.
    """
    x_max,_ = torch.max(x.abs(), dim=time_dim, keepdim=True)
    x = x / (x_max + 1e-18)
    return x 
[docs]class Misfit_NIM(Function):
    """
    Normalized Integration Method (NIM), computes misfit between cumulative distributions of transformed signals.
    Parameters
    ----------
    p : int, optional
        Norm degree, default is 2.
    trans_type : str, optional
        Type of non-negative transform, default is 'linear'.
    theta : float, optional
        Parameter for non-negative transform, default is 1.
    dt : float, optional
        Time sampling interval.
    Notes
    -----
    NIM is equivalent to the Wasserstein-1 distance when p=1.
    """
    def __init__(self, p=1, trans_type='linear', theta=1, dt=1):
        """
        Initialize the Misfit_NIM class.
        """
        self.p = p
        self.trans_type = trans_type
        self.theta = theta
        self.dt = dt
    
[docs]    @staticmethod
    def forward(ctx, syn, obs, p=2, trans_type='linear', theta=1.):
        """
        Forward pass to compute the misfit between transformed synthetic and observed signals.
        Parameters
        ----------
        syn : Tensor
            Synthetic signal data.
        obs : Tensor
            Observed signal data.
        p : int, optional
            Norm degree, default is 2.
        trans_type : str, optional
            Type of non-negative transform, default is 'linear'.
        theta : float, optional
            Parameter for non-negative transform, default is 1.
        Returns
        -------
        Tensor
            Computed misfit value.
        """
        assert p >= 1, "Norm degree must be >= 1"
        assert syn.shape == obs.shape, "Shape mismatch between synthetic and observed data"
        
        # Flatten the input tensors for transformation (shape: [num_shots, num_time_steps * num_receivers])
        num_shots, num_time_steps, num_receivers = syn.shape
        
        # [num_shots, num_time_steps, num_receivers] --> [num_shots, num_receivers, num_time_steps]
        syn_transposed = syn.permute(1, 0, 2)
        obs_transposed = obs.permute(1, 0, 2) 
        # Reshape input tensors for transformation
        syn_flat = syn_transposed.reshape(num_time_steps, num_shots * num_receivers)
        obs_flat = obs_transposed.reshape(num_time_steps, num_shots * num_receivers)
        
        # Transform signals to ensure non-negativity
        mu, nu, d = transform(syn_flat, obs_flat, trans_type, theta)
        
        # Normalize each trace by its sum
        mu = trace_sum_normalize(mu, time_dim=0)
        nu = trace_sum_normalize(nu, time_dim=0)
        
        # Compute cumulative sums over the time dimension
        F = torch.cumsum(mu, dim=0)  # Keep the cumulative sum along the flattened time dimension
        G = torch.cumsum(nu, dim=0)  # Keep the cumulative sum along the flattened time dimension
        
        # Save the necessary tensors for backward computation
        ctx.save_for_backward(F - G, mu,  d)
        ctx.p = p
        ctx.num_shots = num_shots  # Save as an attribute
        ctx.num_time_steps = num_time_steps  # Save as an attribute
        ctx.num_receivers = num_receivers  # Save as an attribute
        
        return (torch.abs(F - G) ** p).sum() 
    
[docs]    @staticmethod
    def backward(ctx, grad_output):
        """
        Backward pass to compute the gradients of the misfit with respect to the inputs.
        Parameters
        ----------
        grad_output : Tensor
            The gradient of the loss with respect to the output.
        Returns
        -------
        Tuple
            Gradients of the inputs.
        """
        residual, mu, d = ctx.saved_tensors
        p = ctx.p
        num_shots = ctx.num_shots
        num_time_steps = ctx.num_time_steps
        num_receivers = ctx.num_receivers
        
        if p == 1:  # Check if p is 1
            df = torch.sign(residual) * mu * d
        else:
            df = (residual) ** (p - 1) * mu * d     
        
        df = df.reshape(num_time_steps, num_shots, num_receivers).permute(1, 0, 2)
        return -df, None, None, None, None