Source code for ADFWI.fwi.misfit.L2
from .base import Misfit
import torch
[docs]class Misfit_waveform_L2(Misfit):
"""
Waveform L2-norm difference misfit (Tarantola, 1984).
This class computes the L2-norm misfit between the observed and synthetic waveforms.
The L2-norm is calculated as the square root of the sum of squared residuals between
corresponding points in the observed and synthetic waveforms.
Parameters
----------
dt : float, optional
Time sampling interval (default is 1).
"""
def __init__(self, dt=1) -> None:
"""
Initialize the Misfit_waveform_L2 class.
"""
super().__init__()
self.dt = dt
[docs] def forward(self, obs: torch.Tensor, syn: torch.Tensor) -> torch.Tensor:
"""
Compute the L2-norm waveform misfit between observed and synthetic data.
The L2-norm is computed as the square root of the sum of squared residuals, weighted
by the time sampling interval `dt`.
Parameters
----------
obs : torch.Tensor
The observed waveform, typically with shape (batch, channels, time).
syn : torch.Tensor
The synthetic waveform, typically with shape (batch, channels, time).
Returns
-------
torch.Tensor
The L2-norm misfit loss between the observed and synthetic waveforms.
"""
# Create mask to handle traces where both observed and synthetic data are zero
mask1 = torch.sum(torch.abs(obs), axis=1) == 0
mask2 = torch.sum(torch.abs(syn), axis=1) == 0
mask = ~(mask1 * mask2)
# Calculate residuals by subtracting synthetic from observed data
rsd = obs - syn
# Compute the L2-norm loss as the square root of the sum of squared residuals, weighted by dt
# Summation is done along the channels axis (axis=1) for each sample, then the square root
# is taken and summed over all the valid (non-zero) samples
loss = torch.sum(torch.sqrt(torch.sum(rsd * rsd * self.dt, axis=1)[mask]))
return loss