Source code for ADFWI.fwi.misfit.SoftDTW

from .base import Misfit
import pysdtw
import torch
from typing import Optional

[docs]class Misfit_sdtw(Misfit): """ Soft-DTW misfit function, calculates the soft Dynamic Time Warping (DTW) divergence. Origin: ------- https://github.com/toinsson/pysdtw Soft-DTW divergence: --------------------- https://tslearn.readthedocs.io/en/stable/gen_modules/metrics/tslearn.metrics.SoftDTWLossPyTorch.html Mathieu Blondel, Arthur Mensch & Jean-Philippe Vert. “Differentiable divergences between time series,” International Conference on Artificial Intelligence and Statistics, 2021. Parameters ---------- gamma : float, optional Regularization parameter. It should be strictly positive. Lower values lead to less smoothing (closer to true DTW). Default is 1. sparse_sampling : int, optional Down-sampling the signal for accelerating inversion. Default is 1 (no down-sampling). dt : float, optional Time interval between samples. Default is 1. obs : torch.Tensor The observed waveform with shape [num_shots, num_time_steps, num_receivers]. syn : torch.Tensor The synthetic waveform with shape [num_shots, num_time_steps, num_receivers]. Returns ------- torch.Tensor The computed soft-DTW loss. """ def __init__(self, gamma: Optional[float] = 1, sparse_sampling: Optional[int] = 1, dt: Optional[float] = 1) -> None: """ Initialize the Misfit_sdtw class. """ super().__init__() self.gamma = gamma self.sparse_sampling = sparse_sampling self.dt = dt
[docs] def forward(self, obs, syn): """ Compute the soft-DTW divergence between the observed and synthetic waveforms. Parameters ---------- obs : torch.Tensor The observed waveform with shape [num_shots, num_time_steps, num_receivers]. syn : torch.Tensor The synthetic waveform with shape [num_shots, num_time_steps, num_receivers]. Returns ------- torch.Tensor The computed soft-DTW loss. """ device = obs.device # Create masks to exclude zero entries in both observed and synthetic waveforms mask1 = torch.sum(torch.abs(obs), axis=1) == 0 mask2 = torch.sum(torch.abs(syn), axis=1) == 0 mask = ~(mask1 * mask2) # Define the pairwise L2 squared distance function fun = pysdtw.distance.pairwise_l2_squared_exact # Preallocate the output tensor for the results rsd = torch.zeros((obs.shape[0], obs.shape[2]), device=device) # Initialize SoftDTW once sdtw = pysdtw.SoftDTW(gamma=self.gamma, dist_func=fun, use_cuda=device != "cpu") # Loop over each shot to compute the soft-DTW divergence for each trace for ishot in range(obs.shape[0]): trace_idx = torch.argwhere(mask[ishot]).reshape(-1) # Extract and down-sample the observed and synthetic shots obs_shot = obs[ishot, ::self.sparse_sampling, trace_idx].squeeze().T.unsqueeze(2) # Shape: [num_traces, T, 1] syn_shot = syn[ishot, ::self.sparse_sampling, trace_idx].squeeze().T.unsqueeze(2) # Shape: [num_traces, T, 1] # Compute soft-DTW divergences sdtw_obs = sdtw(obs_shot, obs_shot) sdtw_syn = sdtw(syn_shot, syn_shot) sdtw_obs_syn = sdtw(obs_shot, syn_shot) # Compute the divergence as a soft-DTW difference std = sdtw_obs_syn - 0.5 * (sdtw_obs + sdtw_syn) rsd[ishot, trace_idx] = std.reshape(1, -1) # Compute the total loss by summing over all shots and applying the time interval scaling loss = torch.sum(rsd * self.dt) return loss