"""
Averaged Artifact Subtraction (AAS) correction processor.
"""
import random
import mne
import numpy as np
from loguru import logger
from matplotlib import pyplot as plt
from ..console import processor_progress
from ..core import ProcessingContext, Processor, ProcessorValidationError, register_processor
from ..helpers.crosscorr import crosscorrelation
from ..helpers.utils import split_vector
[docs]
@register_processor
class AASCorrection(Processor):
"""Remove fMRI gradient artifacts using Averaged Artifact Subtraction.
For each trigger epoch, finds highly correlated epochs within a sliding
window and computes a weighted average. The averaged template is then
subtracted from the original epoch. The algorithm adapts to non-stationary
artifacts through correlation-based epoch selection.
References
----------
Allen et al., 2000. "A method for removing imaging artifact from continuous
EEG recorded during functional MRI." NeuroImage, 12(2), 230-239.
Parameters
----------
window_size : int
Number of epochs to consider in the sliding window (default: 30).
rel_window_position : float
Relative position of the window center between -1 and 1, where 0 is
centered on the current epoch (default: 0.0).
correlation_threshold : float
Minimum Pearson r required to include an epoch in the average
(default: 0.975).
plot_artifacts : bool
If True, plots a randomly selected averaged artifact after computation
(default: False).
realign_after_averaging : bool
If True, realigns trigger positions to the averaged artifact templates
using cross-correlation (default: True).
search_window_factor : float
Multiplier of the upsampling factor used as the cross-correlation
search window (default: 3.0).
interpolate_volume_gaps : bool
If ``True``, linearly interpolate estimated artifact/noise values in
gaps between consecutive artifact windows (default: False).
apply_epoch_alpha_scaling : bool
If ``True``, scale each epoch template by a least-squares ``alpha``
factor before subtraction, similar to MATLAB FACET ``CalcAvgArt``
(default: False).
"""
name = "aas_correction"
description = "Averaged Artifact Subtraction for fMRI artifacts"
version = "1.0.0"
requires_triggers = True
requires_raw = True
modifies_raw = True
parallel_safe = True
channel_wise = True
[docs]
def __init__(
self,
window_size: int = 30,
rel_window_position: float = 0.0,
correlation_threshold: float = 0.975,
plot_artifacts: bool = False,
realign_after_averaging: bool = True,
search_window_factor: float = 3.0,
interpolate_volume_gaps: bool = False,
apply_epoch_alpha_scaling: bool = False,
) -> None:
self.window_size = window_size
self.rel_window_position = rel_window_position
self.correlation_threshold = correlation_threshold
self.plot_artifacts = plot_artifacts
self.realign_after_averaging = realign_after_averaging
self.search_window_factor = search_window_factor
self.interpolate_volume_gaps = interpolate_volume_gaps
self.apply_epoch_alpha_scaling = apply_epoch_alpha_scaling
super().__init__()
[docs]
def validate(self, context: ProcessingContext) -> None:
super().validate(context)
if self.window_size < 1:
raise ProcessorValidationError(f"window_size must be >= 1, got {self.window_size}")
if not (0 < self.correlation_threshold <= 1):
raise ProcessorValidationError(f"correlation_threshold must be in (0, 1], got {self.correlation_threshold}")
if not (-1.0 <= self.rel_window_position <= 1.0):
raise ProcessorValidationError(f"rel_window_position must be in [-1, 1], got {self.rel_window_position}")
if self.search_window_factor <= 0:
raise ProcessorValidationError(f"search_window_factor must be positive, got {self.search_window_factor}")
if context.get_artifact_length() is None:
raise ProcessorValidationError("Artifact length not set. Run TriggerDetector first.")
n_triggers = len(context.get_triggers())
if n_triggers < self.window_size:
logger.warning(
"Number of triggers ({}) is less than window size ({}). Using smaller window.",
n_triggers,
self.window_size,
)
eeg_channels = mne.pick_types(
context.get_raw().info, meg=False, eeg=True, stim=False, eog=False, exclude="bads"
)
if len(eeg_channels) == 0:
raise ProcessorValidationError("No EEG channels found in raw data.")
[docs]
def process(self, context: ProcessingContext) -> ProcessingContext:
# --- EXTRACT ---
raw = context.get_raw().copy()
triggers = context.get_triggers()
artifact_length = context.get_artifact_length()
sfreq = context.get_sfreq()
upsampling_factor = context.metadata.upsampling_factor
artifact_offset = context.metadata.artifact_to_trigger_offset
eeg_channels = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=True, exclude="bads")
# --- LOG ---
logger.info(
"Applying AAS correction: {} channels, {} triggers, window={}",
len(eeg_channels),
len(triggers),
self.window_size,
)
# --- COMPUTE ---
averaging_matrices = self._compute_averaging_matrices(
raw, eeg_channels, raw.ch_names, triggers, artifact_length, artifact_offset, sfreq
)
artifacts_per_channel = self._calc_averaged_artifacts(
raw, averaging_matrices, triggers, artifact_length, artifact_offset, sfreq
)
if self.plot_artifacts and artifacts_per_channel:
self._plot_artifact_debug(raw, averaging_matrices, artifacts_per_channel)
aligned_triggers = self._get_aligned_triggers(
raw,
averaging_matrices,
artifacts_per_channel,
triggers,
artifact_offset,
artifact_length,
sfreq,
upsampling_factor,
)
artifact_offset_samples = int(artifact_offset * sfreq)
estimated_artifacts = self._remove_artifacts(
raw,
averaging_matrices,
artifacts_per_channel,
aligned_triggers,
artifact_offset_samples,
artifact_length,
)
if self.interpolate_volume_gaps:
self._interpolate_volume_gap_artifacts(
raw=raw,
estimated_artifacts=estimated_artifacts,
aligned_triggers=aligned_triggers,
artifact_offset_samples=artifact_offset_samples,
artifact_length=artifact_length,
channel_indices=list(averaging_matrices.keys()),
)
# --- NOISE ---
new_ctx = context.with_raw(raw)
new_ctx.accumulate_noise(estimated_artifacts)
# --- BUILD RESULT ---
if self.realign_after_averaging and not np.array_equal(aligned_triggers, triggers):
logger.debug("Triggers realigned after AAS averaging")
new_ctx = new_ctx.with_triggers(aligned_triggers)
# --- RETURN ---
logger.info("AAS correction complete: {} artifacts, {} channels", len(triggers), len(eeg_channels))
return new_ctx
# -------------------------------------------------------------------------
# Private Helpers
# -------------------------------------------------------------------------
def _compute_averaging_matrices(
self,
raw: mne.io.Raw,
eeg_channels: np.ndarray,
ch_names: list[str],
triggers: np.ndarray,
artifact_length: int,
artifact_offset: float,
sfreq: float,
) -> dict[int, np.ndarray]:
"""Compute per-channel averaging matrices by slicing raw data directly.
Epochs are extracted one channel at a time from ``raw._data`` so that
peak memory stays at O(n_epochs × artifact_length) rather than
O(n_channels × n_epochs × artifact_length).
Parameters
----------
raw : mne.io.Raw
EEG data (already a copy; modified in-place by the caller).
eeg_channels : np.ndarray
Channel indices to process.
ch_names : List[str]
Full channel name list from raw (indexed by ch_idx).
triggers : np.ndarray
Trigger sample positions.
artifact_length : int
Length of each artifact in samples.
artifact_offset : float
Time offset of artifact relative to trigger, in seconds.
sfreq : float
Sampling frequency in Hz.
Returns
-------
dict
Mapping from channel index to averaging matrix (n_epochs, n_epochs).
"""
logger.debug("Computing averaging matrices for {} channels", len(eeg_channels))
averaging_matrices: dict[int, np.ndarray] = {}
trigger_offset_samples = int(artifact_offset * sfreq)
epoch_starts = triggers + trigger_offset_samples
with processor_progress(
total=len(eeg_channels) or None,
message="Averaging matrices",
) as progress:
for idx, ch_idx in enumerate(eeg_channels):
ch_name = ch_names[ch_idx]
channel_epochs = split_vector(raw._data[ch_idx], epoch_starts, artifact_length)
avg_matrix = self._calc_averaging_matrix(
channel_epochs,
window_size=self.window_size,
rel_window_offset=self.rel_window_position,
correlation_threshold=self.correlation_threshold,
)
averaging_matrices[ch_idx] = avg_matrix
progress.advance(
1,
message=f"{idx + 1}/{len(eeg_channels)} • {ch_name}",
)
return averaging_matrices
def _get_aligned_triggers(
self,
raw: mne.io.Raw,
averaging_matrices: dict[int, np.ndarray],
artifacts_per_channel: list[np.ndarray],
triggers: np.ndarray,
artifact_offset: float,
artifact_length: int,
sfreq: float,
upsampling_factor: int,
) -> np.ndarray:
"""Return realigned triggers or the original triggers unchanged.
Parameters
----------
raw : mne.io.Raw
EEG data (used to read the first processed channel).
averaging_matrices : dict
Averaging matrices keyed by channel index.
artifacts_per_channel : List[np.ndarray]
Averaged artifact arrays, one per channel.
triggers : np.ndarray
Original trigger sample positions.
artifact_offset : float
Artifact-to-trigger offset in seconds.
artifact_length : int
Artifact length in samples.
sfreq : float
Sampling frequency in Hz.
upsampling_factor : int
Current upsampling factor (used to scale the search window).
Returns
-------
np.ndarray
Aligned or unchanged trigger positions.
"""
if not self.realign_after_averaging:
return triggers
search_window = int(self.search_window_factor * upsampling_factor)
first_ch_idx = list(averaging_matrices.keys())[0]
# Direct _data access avoids a full array copy for this read-only use
first_channel_data = raw._data[first_ch_idx]
return self._align_triggers_to_artifacts(
first_channel_data,
artifacts_per_channel[0],
triggers,
int(artifact_offset * sfreq),
artifact_length,
search_window,
)
def _remove_artifacts(
self,
raw: mne.io.Raw,
averaging_matrices: dict[int, np.ndarray],
artifacts_per_channel: list[np.ndarray],
aligned_triggers: np.ndarray,
artifact_offset_samples: int,
artifact_length: int,
) -> np.ndarray:
"""Subtract averaged artifacts from raw data in-place.
Parameters
----------
raw : mne.io.Raw
EEG data modified in-place; must be a copy of the original.
averaging_matrices : dict
Averaging matrices keyed by channel index.
artifacts_per_channel : List[np.ndarray]
Averaged artifact arrays, one per channel.
aligned_triggers : np.ndarray
(Possibly realigned) trigger sample positions.
artifact_offset_samples : int
Artifact-to-trigger offset in samples.
artifact_length : int
Artifact length in samples.
Returns
-------
np.ndarray
Estimated artifact array, same shape as ``raw._data``.
"""
smin = artifact_offset_samples
smax = smin + artifact_length
n_samples = raw._data.shape[1]
# Direct _data access avoids a full array copy on large datasets
estimated_artifacts = np.zeros(raw._data.shape)
with processor_progress(
total=len(averaging_matrices) or None,
message="Removing artifacts",
) as progress:
for ch_list_idx, ch_idx in enumerate(averaging_matrices.keys()):
ch_name = raw.ch_names[ch_idx]
artifacts = artifacts_per_channel[ch_list_idx]
alpha_values = np.ones(len(aligned_triggers), dtype=float)
# Keep a stable source signal for alpha estimation while raw is modified in-place.
ch_data_zero_mean = None
if self.apply_epoch_alpha_scaling:
ch_data_zero_mean = raw._data[ch_idx].copy() - np.mean(raw._data[ch_idx])
for epoch_idx, trigger_pos in enumerate(aligned_triggers):
start = trigger_pos + smin
stop = min(trigger_pos + smax, n_samples)
if start < 0 or start >= n_samples:
continue
artifact_segment = artifacts[epoch_idx, : stop - start]
if self.apply_epoch_alpha_scaling:
data_segment = ch_data_zero_mean[start:stop]
denom = float(np.dot(artifact_segment, artifact_segment))
if denom > np.finfo(float).eps:
alpha = float(np.dot(data_segment, artifact_segment) / denom)
if np.isfinite(alpha):
alpha_values[epoch_idx] = alpha
artifact_segment = alpha_values[epoch_idx] * artifact_segment
raw._data[ch_idx, start:stop] -= artifact_segment
estimated_artifacts[ch_idx, start:stop] += artifact_segment
if self.apply_epoch_alpha_scaling and alpha_values.size:
alpha_min = float(np.min(alpha_values))
alpha_mean = float(np.mean(alpha_values))
alpha_max = float(np.max(alpha_values))
if alpha_min < 0 or (alpha_mean > 0 and alpha_max > (2.0 * alpha_mean)):
logger.warning(
"[{}] AAS alpha scaling produced unusual values: min={:.3f}, mean={:.3f}, max={:.3f}",
ch_name,
alpha_min,
alpha_mean,
alpha_max,
)
progress.advance(
1,
message=(f"{ch_name} cleaned ({ch_list_idx + 1}/{len(averaging_matrices)})"),
)
return estimated_artifacts
def _interpolate_volume_gap_artifacts(
self,
raw: mne.io.Raw,
estimated_artifacts: np.ndarray,
aligned_triggers: np.ndarray,
artifact_offset_samples: int,
artifact_length: int,
channel_indices: list[int],
) -> None:
"""Interpolate estimated artifacts in gaps between consecutive epochs.
Parameters
----------
raw : mne.io.Raw
Raw data modified in-place.
estimated_artifacts : np.ndarray
Estimated artifact signal, modified in-place.
aligned_triggers : np.ndarray
Trigger positions used for subtraction.
artifact_offset_samples : int
Artifact start offset relative to trigger.
artifact_length : int
Artifact length in samples.
channel_indices : List[int]
Processed channel indices.
"""
if len(aligned_triggers) < 2 or artifact_length <= 0:
return
n_samples = raw._data.shape[1]
smin = artifact_offset_samples
smax = artifact_offset_samples + artifact_length - 1
for i in range(1, len(aligned_triggers)):
start_this = int(aligned_triggers[i] + smin)
end_prev = int(aligned_triggers[i - 1] + smax)
if start_this <= 0 or start_this >= n_samples:
continue
if end_prev < 0 or end_prev >= n_samples:
continue
gap_len = start_this - end_prev - 1
if gap_len <= 0:
continue
for ch_idx in channel_indices:
end_val = estimated_artifacts[ch_idx, end_prev]
start_val = estimated_artifacts[ch_idx, start_this]
diff = start_val - end_val
gap = end_val + (np.arange(1, gap_len + 1) * (diff / (gap_len + 1)))
gap_start = end_prev + 1
gap_stop = start_this
estimated_artifacts[ch_idx, gap_start:gap_stop] = gap
raw._data[ch_idx, gap_start:gap_stop] -= gap
def _plot_artifact_debug(
self,
raw: mne.io.Raw,
averaging_matrices: dict[int, np.ndarray],
artifacts_per_channel: list[np.ndarray],
) -> None:
"""Plot a randomly selected averaged artifact for visual debugging.
Parameters
----------
raw : mne.io.Raw
EEG data (used for channel name lookup).
averaging_matrices : dict
Averaging matrices keyed by channel index.
artifacts_per_channel : List[np.ndarray]
Averaged artifact arrays, one per channel.
"""
try:
processed_channels = list(averaging_matrices.keys())
random_ch_list_idx = random.randint(0, len(processed_channels) - 1)
ch_idx = processed_channels[random_ch_list_idx]
ch_name = raw.ch_names[ch_idx]
artifacts = artifacts_per_channel[random_ch_list_idx]
if len(artifacts) == 0:
return
random_epoch_idx = random.randint(0, len(artifacts) - 1)
artifact_segment = artifacts[random_epoch_idx]
logger.debug(
"Plotting random artifact for channel {}, epoch {}",
ch_name,
random_epoch_idx,
)
plt.figure(figsize=(10, 4))
plt.plot(artifact_segment)
plt.title(f"Estimated Artifact: Channel {ch_name} (Epoch {random_epoch_idx})")
plt.xlabel("Samples")
plt.ylabel("Amplitude")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
except Exception as exc:
logger.warning("Failed to plot random artifact: {}", exc)
def _calc_averaging_matrix(
self,
epochs: np.ndarray,
window_size: int,
rel_window_offset: float,
correlation_threshold: float,
) -> np.ndarray:
"""Calculate averaging matrix using correlation-based epoch selection.
For each epoch, finds highly correlated epochs within a sliding window
and creates a weighted average.
Parameters
----------
epochs : np.ndarray
Epochs array of shape (n_epochs, n_times).
window_size : int
Size of the sliding window.
rel_window_offset : float
Relative offset of the window center.
correlation_threshold : float
Minimum Pearson r for an epoch to be included.
Returns
-------
np.ndarray
Averaging matrix of shape (n_epochs, n_epochs) where each row sums to 1.
"""
n_epochs = len(epochs)
averaging_matrix = np.zeros((n_epochs, n_epochs))
window_offset = int(window_size * rel_window_offset)
for idx in range(0, n_epochs, window_size):
offset_idx = idx + window_offset
reference_indices = np.arange(idx, min(idx + 5, n_epochs))
candidates = np.arange(offset_idx, min(offset_idx + window_size, n_epochs))
candidates = candidates[candidates >= 0]
chosen = self._find_correlated_epochs(epochs, candidates, reference_indices, correlation_threshold)
if len(chosen) == 0:
chosen = reference_indices
target_indices = np.arange(idx, min(idx + window_size, n_epochs))
weight = 1.0 / len(chosen)
averaging_matrix[np.ix_(target_indices, chosen)] = weight
return averaging_matrix
def _find_correlated_epochs(
self,
all_epochs: np.ndarray,
candidate_indices: np.ndarray,
reference_indices: np.ndarray,
threshold: float,
) -> np.ndarray:
"""Find epochs that are highly correlated with the running average.
Iteratively adds epochs whose Pearson r with the running average
exceeds *threshold*.
Parameters
----------
all_epochs : np.ndarray
All epochs of shape (n_epochs, n_times).
candidate_indices : np.ndarray
Indices of candidate epochs to check.
reference_indices : np.ndarray
Indices of seed epochs used to initialise the running average.
threshold : float
Minimum correlation to accept a candidate.
Returns
-------
np.ndarray
Indices of accepted epochs.
"""
if len(reference_indices) == 0:
return np.array([])
sum_data = np.sum(all_epochs[reference_indices], axis=0)
chosen = list(reference_indices)
for idx in candidate_indices:
if idx in chosen:
continue
avg_data = sum_data / len(chosen)
corr = np.corrcoef(avg_data.squeeze(), all_epochs[idx].squeeze())[0, 1]
if corr > threshold:
sum_data += all_epochs[idx]
chosen.append(idx)
return np.array(chosen)
def _calc_averaged_artifacts(
self,
raw: mne.io.Raw,
averaging_matrices: dict[int, np.ndarray],
triggers: np.ndarray,
artifact_length: int,
artifact_offset: float,
sfreq: float,
) -> list[np.ndarray]:
"""Calculate averaged artifact templates for each channel.
Parameters
----------
raw : mne.io.Raw
EEG data.
averaging_matrices : dict
Averaging matrices keyed by channel index.
triggers : np.ndarray
Trigger sample positions.
artifact_length : int
Artifact length in samples.
artifact_offset : float
Artifact-to-trigger offset in seconds.
sfreq : float
Sampling frequency in Hz.
Returns
-------
List[np.ndarray]
Averaged artifact arrays per channel; each element has shape
(n_epochs, n_times).
"""
artifacts_per_channel = []
trigger_offset_samples = int(artifact_offset * sfreq)
for ch_idx, avg_matrix in averaging_matrices.items():
# Direct _data access avoids a full array copy on large datasets
ch_data = raw._data[ch_idx]
ch_data_zero_mean = ch_data - np.mean(ch_data)
epoch_data = split_vector(
ch_data_zero_mean,
triggers + trigger_offset_samples,
artifact_length,
)
while len(epoch_data) > len(avg_matrix):
epoch_data = epoch_data[:-1]
averaged_artifacts = np.dot(avg_matrix, epoch_data)
if len(averaged_artifacts) < len(triggers):
last_artifact = averaged_artifacts[-1].reshape(1, -1)
padding_needed = len(triggers) - len(averaged_artifacts)
padding = np.repeat(last_artifact, padding_needed, axis=0)
averaged_artifacts = np.vstack([averaged_artifacts, padding])
artifacts_per_channel.append(averaged_artifacts)
return artifacts_per_channel
def _align_triggers_to_artifacts(
self,
channel_data: np.ndarray,
artifacts: np.ndarray,
triggers: np.ndarray,
smin: int,
smax: int,
search_window: int,
) -> np.ndarray:
"""Align triggers to averaged artifacts using cross-correlation.
Parameters
----------
channel_data : np.ndarray
Single-channel time series.
artifacts : np.ndarray
Averaged artifacts of shape (n_epochs, n_times).
triggers : np.ndarray
Original trigger sample positions.
smin : int
Artifact start offset in samples relative to trigger.
smax : int
Artifact end offset in samples relative to trigger.
search_window : int
Half-width of the cross-correlation search window in samples.
Returns
-------
np.ndarray
Realigned trigger sample positions.
"""
aligned_triggers = []
for i, trigger in enumerate(triggers):
start = trigger + smin
stop = trigger + smax + search_window
if stop > len(channel_data):
aligned_triggers.append(trigger)
continue
segment = channel_data[start:stop]
artifact = artifacts[i, :]
corr = crosscorrelation(segment, artifact, search_window)
best_shift = np.argmax(corr) - search_window
aligned_triggers.append(trigger + best_shift)
return np.array(aligned_triggers)
# Alias for backwards compatibility
AveragedArtifactSubtraction = AASCorrection