"""
PCA-based artifact correction processor.
"""
import mne
import numpy as np
from loguru import logger
from scipy.signal import butter, filtfilt
from ..console import processor_progress
from ..core import ProcessingContext, Processor, ProcessorValidationError, register_processor
from ..helpers.utils import split_vector
[docs]
@register_processor
class PCACorrection(Processor):
"""Remove fMRI artifacts from EEG data using Principal Component Analysis.
Splits the acquisition window into trigger-aligned epochs, applies PCA to
each epoch, reconstructs the data using the retained components, and
subtracts the reconstruction from the original signal. The subtracted
portion is treated as the artifact estimate.
The number of retained components can be controlled precisely:
- An integer keeps exactly that many components.
- A float in (0, 1) retains enough components to explain that fraction of
the total variance (e.g. 0.95 → 95 %).
- ``"auto"`` uses MATLAB FACET OBS heuristics.
- 0 skips PCA for all channels.
Parameters
----------
n_components : int or float or str
Number of PCA components to retain (int) or variance fraction to
retain (float in (0, 1)). Use ``"auto"`` for MATLAB-like OBS auto
selection. Default: 0.95.
hp_freq : float, optional
High-pass cutoff frequency in Hz applied before PCA. None skips
filtering (default: None).
hp_filter_weights : np.ndarray, optional
Pre-computed filter weights; overrides ``hp_freq`` when provided.
exclude_channels : list, optional
Channel indices to skip during PCA (default: empty list).
"""
name = "pca_correction"
description = "PCA-based artifact removal"
version = "1.0.0"
requires_triggers = True
requires_raw = True
modifies_raw = True
parallel_safe = True
channel_wise = True
TH_SLOPE = 2.0
TH_CUMVAR = 80.0
TH_VAREXP = 5.0
[docs]
def __init__(
self,
n_components: int | float | str = 0.95,
hp_freq: float | None = None,
hp_filter_weights: np.ndarray | None = None,
exclude_channels: list | None = None,
) -> None:
self.n_components = n_components
self.hp_freq = hp_freq
self.hp_filter_weights = hp_filter_weights
self.exclude_channels = exclude_channels or []
super().__init__()
[docs]
def validate(self, context: ProcessingContext) -> None:
super().validate(context)
if context.get_artifact_length() is None:
raise ProcessorValidationError("Artifact length not set. Run TriggerDetector first.")
[docs]
def process(self, context: ProcessingContext) -> ProcessingContext:
# --- EXTRACT ---
raw = context.get_raw().copy()
triggers = context.get_triggers()
artifact_length = context.get_artifact_length()
eeg_channels = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude="bads")
# --- LOG ---
logger.info("Applying PCA artifact correction to {} channels", len(eeg_channels))
if len(eeg_channels) == 0:
logger.warning("No EEG channels found, skipping PCA")
return context
# --- COMPUTE ---
hp_weights = self._resolve_hp_weights(raw.info["sfreq"])
s_acq_start, s_acq_end = self._get_acquisition_window(context)
# Direct _data access avoids a full array copy on large datasets
estimated_artifacts = np.zeros(raw._data.shape)
with processor_progress(
total=len(eeg_channels) or None,
message="PCA artifact correction",
) as progress:
for idx, ch_idx in enumerate(eeg_channels):
ch_name = raw.ch_names[ch_idx]
status_prefix = f"{idx + 1}/{len(eeg_channels)} • {ch_name}"
if ch_idx in self.exclude_channels:
progress.advance(1, message=f"{status_prefix} (excluded)")
continue
if self.n_components == 0:
progress.advance(1, message=f"{status_prefix} (disabled)")
continue
try:
residuals = self._calc_pca_residuals(
raw._data[ch_idx],
triggers,
artifact_length,
s_acq_start,
s_acq_end,
hp_weights,
)
raw._data[ch_idx][s_acq_start:s_acq_end] -= residuals
estimated_artifacts[ch_idx][s_acq_start:s_acq_end] += residuals
progress.advance(1, message=status_prefix)
except Exception as exc:
logger.error("PCA failed for channel {}: {}", ch_name, exc)
progress.advance(1, message=f"{status_prefix} (error)")
# --- NOISE ---
new_ctx = context.with_raw(raw)
new_ctx.accumulate_noise(estimated_artifacts)
# --- RETURN ---
logger.info("PCA correction completed")
return new_ctx
# -------------------------------------------------------------------------
# Private Helpers
# -------------------------------------------------------------------------
def _resolve_hp_weights(self, sfreq: float) -> np.ndarray | None:
"""Return the appropriate high-pass filter weights.
Parameters
----------
sfreq : float
Sampling frequency in Hz.
Returns
-------
np.ndarray or None
Filter weights for use with ``scipy.signal.filtfilt``, or None
when high-pass filtering is disabled.
"""
if self.hp_filter_weights is not None:
return self.hp_filter_weights
if self.hp_freq is not None and self.hp_freq > 0:
return self._create_hp_filter(sfreq)
return None
def _calc_pca_residuals(
self,
ch_data: np.ndarray,
triggers: np.ndarray,
artifact_length: int,
s_acq_start: int,
s_acq_end: int,
hp_weights: np.ndarray | None,
) -> np.ndarray:
"""Calculate PCA-based artifact residuals for a single channel.
Parameters
----------
ch_data : np.ndarray
Full time series for one channel.
triggers : np.ndarray
Trigger sample positions.
artifact_length : int
Artifact length in samples.
s_acq_start : int
Acquisition window start sample.
s_acq_end : int
Acquisition window end sample.
hp_weights : np.ndarray or None
High-pass filter weights (or None to skip filtering).
Returns
-------
np.ndarray
Residual (artifact) signal of length ``s_acq_end - s_acq_start``.
"""
ch_data_acq = ch_data[s_acq_start:s_acq_end]
ch_data_filtered = filtfilt(hp_weights, 1, ch_data_acq) if hp_weights is not None else ch_data_acq
adjusted_triggers = triggers - s_acq_start
# Small offset prevents epoch boundaries from sitting exactly on the trigger
offset = int(artifact_length * 0.1)
epochs = split_vector(ch_data_filtered, adjusted_triggers + offset, artifact_length)
residuals_per_epoch = self._calc_pca(epochs)
fitted_res = np.zeros(len(ch_data_acq))
for i, trigger in enumerate(adjusted_triggers):
start_pos = trigger + offset
end_pos = start_pos + artifact_length
if start_pos < 0:
continue
if end_pos > len(ch_data_acq):
epoch_length = len(ch_data_acq) - start_pos
if epoch_length <= 0:
continue
fitted_res[start_pos:] = residuals_per_epoch[i, :epoch_length]
else:
fitted_res[start_pos:end_pos] = residuals_per_epoch[i, :]
return fitted_res
def _calc_pca(self, epochs: np.ndarray) -> np.ndarray:
"""Apply PCA to epochs and return the artifact residuals.
Parameters
----------
epochs : np.ndarray
Epoch matrix of shape (n_epochs, n_times).
Returns
-------
np.ndarray
Residual (artifact) matrix of shape (n_epochs, n_times).
"""
epochs_t = epochs.T
col_var = np.var(epochs_t, axis=0)
variance_threshold = 1e-12
valid_mask = col_var > variance_threshold
if np.count_nonzero(valid_mask) < 2:
return np.zeros_like(epochs)
X_valid = epochs_t[:, valid_mask]
mean_valid = np.mean(X_valid, axis=0)
std_valid = np.std(X_valid, axis=0, ddof=0)
std_valid = np.where(std_valid < variance_threshold, 1.0, std_valid)
X_centered = (X_valid - mean_valid) / std_valid
try:
U, S, Vt = np.linalg.svd(X_centered, full_matrices=False)
except np.linalg.LinAlgError as exc:
logger.warning("PCA SVD failed ({}); skipping channel", exc)
return np.zeros_like(epochs)
max_components = min(X_centered.shape[0], X_centered.shape[1])
if max_components <= 1:
return np.zeros_like(epochs)
n_components = self._select_n_components(S, max_components, X_centered.shape[0])
U_reduced = U[:, :n_components]
S_reduced = S[:n_components]
Vt_reduced = Vt[:n_components, :]
X_reconstructed_valid = ((U_reduced @ np.diag(S_reduced) @ Vt_reduced) * std_valid) + mean_valid
residuals_valid = X_valid - X_reconstructed_valid
residuals_full = np.zeros_like(epochs_t)
residuals_full[:, valid_mask] = residuals_valid
return residuals_full.T
def _select_n_components(self, singular_values: np.ndarray, max_components: int, n_samples: int) -> int:
"""Determine the number of PCA components to retain.
Parameters
----------
singular_values : np.ndarray
Singular values from the SVD decomposition.
max_components : int
Upper bound on the number of components.
n_samples : int
Number of samples (time points) used in the SVD.
Returns
-------
int
Number of components to retain (≥ 1).
"""
if isinstance(self.n_components, int):
return max(1, min(self.n_components, max_components))
if isinstance(self.n_components, str):
if self.n_components.lower() != "auto":
raise ValueError("n_components as string must be 'auto'.")
explained_var = (singular_values**2) / (n_samples - 1)
explained_pct = 100.0 * (explained_var / np.sum(explained_var))
n = self._select_n_components_auto(explained_pct)
return max(1, min(int(n), max_components))
if not 0 < self.n_components < 1:
raise ValueError("n_components as float must be between 0 and 1.")
explained_var = (singular_values**2) / (n_samples - 1)
explained_ratio = np.cumsum(explained_var) / np.sum(explained_var)
n = np.searchsorted(explained_ratio, self.n_components) + 1
return max(1, min(int(n), max_components))
def _select_n_components_auto(self, explained_pct: np.ndarray) -> int:
"""Select component count using MATLAB FACET OBS auto heuristics."""
if len(explained_pct) == 0:
return 1
d_oev = np.where(np.abs(np.diff(explained_pct)) < self.TH_SLOPE)[0] + 1
slope_pc = 1
if len(d_oev) > 0:
if len(d_oev) >= 4:
dd_oev = np.diff(d_oev)
run_pos = None
for i in range(len(dd_oev) - 2):
if dd_oev[i] == 1 and dd_oev[i + 1] == 1 and dd_oev[i + 2] == 1:
run_pos = i
break
idx = run_pos if run_pos is not None else 0
slope_pc = int(max(d_oev[idx] - 1, 1))
else:
slope_pc = int(max(d_oev[0] - 1, 1))
cumvar = np.cumsum(explained_pct)
tmp = np.where(cumvar > self.TH_CUMVAR)[0]
cumvar_pc = int(tmp[0] + 1) if len(tmp) > 0 else len(explained_pct)
tmp = np.where(explained_pct < self.TH_VAREXP)[0]
varexp_pc = int(max(tmp[0], 1)) if len(tmp) > 0 else len(explained_pct)
pcs = int(np.floor(np.mean([slope_pc, cumvar_pc, varexp_pc])))
return max(1, pcs)
def _create_hp_filter(self, sfreq: float) -> np.ndarray:
"""Create Butterworth high-pass filter weights.
Parameters
----------
sfreq : float
Sampling frequency in Hz.
Returns
-------
np.ndarray
Filter weights for use with ``scipy.signal.filtfilt``.
"""
nyq = 0.5 * sfreq
normalized_cutoff = self.hp_freq / nyq
b, _ = butter(5, normalized_cutoff, btype="high")
return b
def _get_acquisition_window(self, context: ProcessingContext) -> tuple:
"""Return the start and end sample indices of the acquisition window.
Parameters
----------
context : ProcessingContext
Current processing context.
Returns
-------
tuple
``(s_acq_start, s_acq_end)`` as integers.
"""
raw = context.get_raw()
triggers = context.get_triggers()
if len(triggers) == 0:
return 0, raw.n_times
artifact_length = context.get_artifact_length()
if artifact_length is None:
return 0, raw.n_times
trigger_min = int(np.min(triggers))
trigger_max = int(np.max(triggers))
s_acq_start = max(0, trigger_min - artifact_length)
s_acq_end = min(raw.n_times, trigger_max + artifact_length)
if s_acq_end <= s_acq_start:
return 0, raw.n_times
return s_acq_start, s_acq_end