import numpy as np
from ADFWI.utils.utils import numpy2tensor
[docs]def brutal_picker(trace):
"""
Picks the first arrival time of seismic traces using a simple thresholding method.
Parameters
----------
trace : numpy.ndarray
A 2D array of seismic traces with shape (n_traces, n_samples), where each row
represents a single seismic trace.
Returns
-------
numpy.ndarray
A 1D array containing the first arrival index for each trace.
"""
threds = 0.001 * np.max(abs(trace), axis=-1) # Compute threshold for each trace
pick = [(abs(trace[i, :]) > threds[i]).argmax(axis=-1) for i in range(trace.shape[0])]
return np.array(pick)
[docs]def mask(itmin, itmax, nt, length):
"""
Constructs a tapered mask that can be applied to a seismic trace to mute early or late arrivals.
Parameters
----------
itmin : int
The starting index for tapering.
itmax : int
The ending index for tapering.
nt : int
The total number of time samples in the trace.
length : int
The length of the tapering window.
Returns
-------
numpy.ndarray
A 1D mask array of shape (nt,) with tapered values applied
to the specified range.
"""
mask = np.ones(nt) # Initialize mask with ones
# Construct tapering window
win = np.sin(np.linspace(0, np.pi, 2 * length)) # Create a sine window
win = win[0:length] # Use only the first half for tapering
if 1 < itmin < itmax < nt:
mask[0:itmin] = 0. # Zero out values before itmin
mask[itmin:itmax] = win * mask[itmin:itmax] # Apply taper
elif itmin < 1 <= itmax:
mask[0:itmax] = win[length - itmax:length] * mask[0:itmax] # Handle case where itmin is out of range
elif itmin < nt < itmax:
mask[0:itmin] = 0. # Zero out values before itmin
mask[itmin:nt] = win[0:nt - itmin] * mask[itmin:nt] # Apply taper for valid range
elif itmin > nt:
mask[:] = 0. # If itmin is beyond nt, set entire mask to zero
return mask
[docs]def mute_arrival(trace, itmin, itmax, mutetype, nt, length):
"""
Applies a tapered mask to a seismic record section, muting early or late arrivals.
Parameters
----------
trace : torch.Tensor
A 1D or 2D tensor representing the seismic trace or record section.
itmin : int
The starting index for muting.
itmax : int
The ending index for muting.
mutetype : str
Unused parameter, potentially for future expansion.
nt : int
The total number of time samples in the trace.
length : int
The length of the tapering window.
Returns
-------
torch.Tensor
The muted seismic trace with the applied mask.
"""
win = 1 - mask(itmin, itmax, nt, length) # Compute inverse mask
win = numpy2tensor(win, dtype=trace.dtype, device=trace.device) # Convert to tensor
trace = trace * win # Apply mask
return trace
[docs]def apply_mute(mute_late_window, shot, dt):
"""
Applies a time window mute based on the first arrival pick for each trace.
Parameters
----------
mute_late_window : float
The time window for late arrival muting.
shot : torch.Tensor
A 2D tensor of shape (nt, nrcv), representing the seismic shot gather.
dt : float
The time sampling interval.
Returns
-------
torch.Tensor
The muted shot gather with applied time window muting.
"""
shot_np = shot.cpu().detach().numpy() # Convert tensor to NumPy array
pick = brutal_picker(shot_np.T) + np.ceil(mute_late_window / dt) # Compute mute start times per trace
# Initialize variables
shot_new = shot.clone() # Clone input shot gather
length = 100 # Taper length
nt = shot_np.shape[0] # Number of time samples
# Apply mute for each trace
for i in range(shot.shape[-1]):
itmin = int(pick[i] - length / 2) # Compute mute start index
itmax = int(itmin + length) # Compute mute end index
shot_new[:, i] = mute_arrival(shot[:, i], itmin, itmax, 'late', nt, length) # Apply mute
return shot_new