"""
Adaptive Noise Cancellation (ANC) correction processor.
"""
from typing import Any
import mne
import numpy as np
from loguru import logger
from scipy.signal import butter, filtfilt, firls
from ..console import processor_progress
from ..core import ProcessingContext, Processor, ProcessorValidationError, register_processor
[docs]
@register_processor
class ANCCorrection(Processor):
"""Remove residual fMRI artifacts using Adaptive Noise Cancellation.
Uses the estimated noise from a prior correction step (e.g. AAS) as a
reference signal. The LMS adaptive filter iteratively minimises the
residual between the EEG and a scaled, filtered copy of the reference,
yielding a per-channel filtered-noise estimate that is subtracted from
the EEG.
The algorithm:
1. High-pass filters EEG and reference to remove DC / low-frequency drift.
2. Scales the reference to match EEG amplitude (Alpha factor).
3. Adapts filter coefficients using the LMS algorithm.
4. Subtracts the filtered noise from the EEG.
Parameters
----------
filter_order : int, optional
Adaptive filter order. Defaults to artifact length derived from context.
hp_freq : float, optional
High-pass cutoff frequency in Hz. Auto-derived from trigger rate when
not specified.
hp_filter_weights : np.ndarray, optional
Pre-computed FIR filter weights; overrides ``hp_freq`` when provided.
use_c_extension : bool
Use the optional fastranc C extension for speed (default: True).
Falls back to the pure-Python LMS implementation automatically.
mu_factor : float
Learning-rate numerator; actual µ = mu_factor / (N × var(ref))
(default: 0.05).
max_gain : float
Maximum allowed ratio of filtered-noise amplitude to EEG amplitude.
Corrections exceeding this are discarded (default: 50.0).
"""
name = "anc_correction"
description = "Adaptive Noise Cancellation for residual artifacts"
version = "1.0.0"
requires_triggers = True
requires_raw = True
modifies_raw = True
parallel_safe = False
channel_wise = True
[docs]
def __init__(
self,
filter_order: int | None = None,
hp_freq: float | None = None,
hp_filter_weights: np.ndarray | None = None,
use_c_extension: bool = True,
mu_factor: float = 0.05,
max_gain: float = 50.0,
) -> None:
self.filter_order_override = max(1, int(filter_order)) if filter_order is not None else None
self.filter_order = self.filter_order_override
self.hp_freq = hp_freq
self.hp_filter_weights = hp_filter_weights
self.use_c_extension = use_c_extension
self.mu_factor = mu_factor
self.max_gain = max_gain
self._fastranc_available = None
super().__init__()
[docs]
def validate(self, context: ProcessingContext) -> None:
super().validate(context)
if not context.has_estimated_noise():
raise ProcessorValidationError("Estimated noise not available. Run AAS or other correction first.")
if not context.has_triggers():
raise ProcessorValidationError("Triggers not set. Run TriggerDetector first.")
artifact_length = context.get_artifact_length()
if artifact_length is None or artifact_length <= 0:
raise ProcessorValidationError("Artifact length not set. Run TriggerDetector before ANC.")
[docs]
def process(self, context: ProcessingContext) -> ProcessingContext:
# --- EXTRACT ---
raw = context.get_raw().copy()
estimated_noise = context.get_estimated_noise()
sfreq = context.get_sfreq()
artifact_length = context.get_artifact_length()
eeg_channels = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude="bads")
# --- LOG ---
logger.info("Applying ANC to {} channels", len(eeg_channels))
if len(eeg_channels) == 0:
logger.warning("No EEG channels found, skipping ANC")
return context
# --- COMPUTE ---
hp_weights, hp_cutoff = self._resolve_hp_filter(context, artifact_length, sfreq)
s_acq_start, s_acq_end = self._get_acquisition_window(context)
filter_order = self._resolve_filter_order(artifact_length, s_acq_start, s_acq_end)
if self.use_c_extension and self._fastranc_available is None:
self._fastranc_available = self._check_fastranc()
noise_updated = estimated_noise.copy()
with processor_progress(
total=len(eeg_channels) or None,
message="Adaptive noise cancellation",
) as progress:
for idx, ch_idx in enumerate(eeg_channels):
ch_name = raw.ch_names[ch_idx]
try:
corrected, filtered = self._anc_single_channel(
raw._data[ch_idx],
estimated_noise[ch_idx],
s_acq_start,
s_acq_end,
hp_weights,
filter_order,
ch_name,
)
raw._data[ch_idx] = corrected
noise_updated[ch_idx, s_acq_start:s_acq_end] += filtered
status = f"{idx + 1}/{len(eeg_channels)} • {ch_name}"
except Exception as exc:
logger.error("ANC failed for channel {}: {}", ch_name, exc)
status = f"{idx + 1}/{len(eeg_channels)} • {ch_name} (skipped)"
progress.advance(1, message=status)
# --- NOISE ---
new_ctx = context.with_raw(raw)
new_ctx.set_estimated_noise(noise_updated)
# --- BUILD RESULT + RETURN ---
logger.info("ANC correction completed")
return self._with_anc_metadata(new_ctx, hp_cutoff, filter_order)
# -------------------------------------------------------------------------
# Private Helpers
# -------------------------------------------------------------------------
def _resolve_hp_filter(
self,
context: ProcessingContext,
artifact_length: int,
sfreq: float,
) -> tuple:
"""Select or derive the high-pass filter weights and cutoff.
Parameters
----------
context : ProcessingContext
Current processing context.
artifact_length : int
Artifact length in samples (used when deriving parameters).
sfreq : float
Sampling frequency in Hz.
Returns
-------
tuple
``(hp_weights, hp_cutoff_hz)`` where ``hp_weights`` is an
np.ndarray or None, and ``hp_cutoff_hz`` is a float.
"""
if self.hp_filter_weights is not None:
derived_freq = self._derive_parameters(context, artifact_length, sfreq)["hp_freq"]
hp_cutoff = self.hp_freq if self.hp_freq is not None else derived_freq
return self.hp_filter_weights, hp_cutoff
if self.hp_freq is not None and self.hp_freq > 0:
return self._design_highpass(self.hp_freq, sfreq), self.hp_freq
derived = self._derive_parameters(context, artifact_length, sfreq)
return derived["hp_weights"], derived["hp_freq"]
def _resolve_filter_order(self, artifact_length: int, s_acq_start: int, s_acq_end: int) -> int:
"""Determine the adaptive filter order, capped by the acquisition window.
Parameters
----------
artifact_length : int
Default filter order when no override is provided.
s_acq_start : int
Acquisition window start sample.
s_acq_end : int
Acquisition window end sample.
Returns
-------
int
Effective adaptive filter order (≥ 1).
"""
window_length = max(1, s_acq_end - s_acq_start)
base_order = self.filter_order_override if self.filter_order_override is not None else artifact_length
return max(1, min(int(base_order), window_length))
def _with_anc_metadata(self, ctx: ProcessingContext, hp_cutoff: float, filter_order: int) -> ProcessingContext:
"""Return a new context with ANC diagnostics stored in custom metadata.
Parameters
----------
ctx : ProcessingContext
Context to augment.
hp_cutoff : float
Effective high-pass cutoff used during processing (Hz).
filter_order : int
Effective adaptive filter order used during processing.
Returns
-------
ProcessingContext
New context with ``metadata.custom["anc"]`` populated.
"""
new_metadata = ctx.metadata.copy()
new_metadata.custom["anc"] = {
"hp_frequency_hz": hp_cutoff,
"filter_order": filter_order,
"mu_factor": self.mu_factor,
"max_gain": self.max_gain,
"used_c_extension": bool(self._fastranc_available),
}
return ctx.with_metadata(new_metadata)
def _anc_single_channel(
self,
eeg_data: np.ndarray,
noise_data: np.ndarray,
s_acq_start: int,
s_acq_end: int,
hp_weights: np.ndarray | None,
filter_order: int,
ch_name: str,
) -> tuple:
"""Apply ANC to a single channel.
Parameters
----------
eeg_data : np.ndarray
Full EEG time series for one channel.
noise_data : np.ndarray
Estimated noise time series for the same channel.
s_acq_start : int
Acquisition window start sample.
s_acq_end : int
Acquisition window end sample.
hp_weights : np.ndarray or None
High-pass FIR filter weights (or None to skip filtering).
filter_order : int
Adaptive filter order.
ch_name : str
Channel name used in warning messages.
Returns
-------
tuple
``(corrected_data, filtered_noise)`` both as np.ndarray.
"""
reference = noise_data[s_acq_start:s_acq_end].astype(float, copy=True)
segment_len = reference.size
if segment_len == 0:
logger.debug("[{}] ANC reference window is empty, skipping", ch_name)
return eeg_data, np.zeros(0, dtype=float)
if segment_len <= filter_order:
logger.debug("[{}] ANC reference shorter than filter order, skipping", ch_name)
return eeg_data, np.zeros(segment_len, dtype=float)
if hp_weights is not None:
data = filtfilt(hp_weights, 1, eeg_data, axis=0, padtype="odd")
data = data[s_acq_start:s_acq_end].astype(float)
else:
data = eeg_data[s_acq_start:s_acq_end].astype(float)
ref_energy = np.sum(reference * reference)
if not np.isfinite(ref_energy) or ref_energy <= np.finfo(float).eps:
logger.debug("[{}] Reference energy too small, skipping ANC", ch_name)
return eeg_data, np.zeros(segment_len, dtype=float)
alpha = np.sum(data * reference) / ref_energy
if not np.isfinite(alpha):
logger.debug("[{}] Alpha scaling not finite, skipping ANC", ch_name)
return eeg_data, np.zeros(segment_len, dtype=float)
reference = (alpha * reference).astype(float)
var_ref = np.var(reference)
if not np.isfinite(var_ref) or var_ref <= np.finfo(float).eps:
logger.debug("[{}] Reference variance is zero, skipping ANC", ch_name)
return eeg_data, np.zeros(segment_len, dtype=float)
mu = float(self.mu_factor / (filter_order * var_ref))
if not np.isfinite(mu) or mu <= 0:
logger.debug("[{}] Computed ANC learning rate invalid, skipping", ch_name)
return eeg_data, np.zeros(segment_len, dtype=float)
if self._fastranc_available:
filtered_noise = self._anc_fast(reference, data, mu, filter_order)
else:
filtered_noise = self._anc_python(reference, data, mu, filter_order)
max_filtered = np.max(np.abs(filtered_noise))
if not np.isfinite(max_filtered):
logger.error("[{}] ANC produced invalid values (inf/nan), skipping", ch_name)
return eeg_data, np.zeros(segment_len, dtype=float)
eeg_segment = eeg_data[s_acq_start:s_acq_end]
baseline = np.max(np.abs(eeg_segment)) if eeg_segment.size else 0.0
gain = max_filtered / baseline if baseline > 0 else np.inf if max_filtered > 0 else 0.0
if gain > self.max_gain:
logger.error("[{}] ANC produced unstable gain ({:.2e}), skipping", ch_name, gain)
return eeg_data, np.zeros(segment_len, dtype=float)
corrected_data = eeg_data.copy()
corrected_data[s_acq_start:s_acq_end] -= filtered_noise
return corrected_data, filtered_noise
def _anc_fast(
self,
reference: np.ndarray,
data: np.ndarray,
mu: float,
filter_order: int,
) -> np.ndarray:
"""Apply ANC using the fastranc C extension.
Parameters
----------
reference : np.ndarray
Scaled reference (noise) signal.
data : np.ndarray
EEG signal for the acquisition window.
mu : float
LMS learning rate.
filter_order : int
Adaptive filter order.
Returns
-------
np.ndarray
Filtered noise signal.
"""
# Optional C extension — kept as lazy import so a missing build does not
# prevent the module from being imported.
from ..helpers.fastranc import fastr_anc
_, filtered_noise = fastr_anc(reference, data, filter_order, mu)
return filtered_noise
def _anc_python(
self,
reference: np.ndarray,
data: np.ndarray,
mu: float,
filter_order: int,
) -> np.ndarray:
"""Apply ANC using the pure-Python LMS fallback.
Implements the standard Least Mean Squares (LMS) adaptive filter.
Parameters
----------
reference : np.ndarray
Scaled reference (noise) signal.
data : np.ndarray
EEG signal for the acquisition window.
mu : float
LMS learning rate.
filter_order : int
Adaptive filter order.
Returns
-------
np.ndarray
Filtered noise signal.
"""
N = max(1, int(filter_order))
length = len(reference)
w = np.zeros(N)
y = np.zeros(length)
for n in range(N, length):
x = reference[n - N : n][::-1]
y[n] = np.dot(w, x)
e = data[n] - y[n]
w += mu * e * x
return y
def _derive_parameters(
self,
context: ProcessingContext,
artifact_length: int,
sfreq: float,
) -> dict[str, Any]:
"""Derive ANC parameters from the trigger rate and sampling frequency.
Parameters
----------
context : ProcessingContext
Current processing context (used to read triggers).
artifact_length : int
Artifact length in samples; used as default filter order.
sfreq : float
Sampling frequency in Hz.
Returns
-------
dict
Dictionary with keys ``hp_freq``, ``hp_weights``, ``filter_order``.
"""
triggers = context.get_triggers()
if triggers is None:
triggers = np.array([], dtype=int)
if len(triggers) >= 2:
cutoff_samples = int(sfreq)
count = 1
while count < len(triggers):
if triggers[count] - triggers[0] >= cutoff_samples:
break
count += 1
trigger_rate = max(count, 1)
else:
trigger_rate = 1
hp_freq = max(0.75 * trigger_rate if trigger_rate >= 1 else 2.0, 0.5)
hp_weights = self._design_highpass(hp_freq, sfreq)
filter_order = max(artifact_length, 1)
return {"hp_freq": hp_freq, "hp_weights": hp_weights, "filter_order": filter_order}
def _design_highpass(self, cutoff_hz: float, sfreq: float) -> np.ndarray:
"""Design a high-pass FIR filter using ``firls``.
Falls back to a 5th-order Butterworth filter if the FIR design fails.
Parameters
----------
cutoff_hz : float
Desired high-pass cutoff in Hz.
sfreq : float
Sampling frequency in Hz.
Returns
-------
np.ndarray
Filter weights suitable for use with ``scipy.signal.filtfilt``.
"""
nyq = 0.5 * sfreq
cutoff_hz = min(max(cutoff_hz, 0.5), nyq * 0.95)
trans = 0.15
pass_edge = cutoff_hz / nyq
stop_edge = min(max(pass_edge * (1 - trans), 0.0), pass_edge * 0.999)
taps = max(int(round(1.2 * sfreq / (cutoff_hz * (1 - trans)))) | 1, 3)
f = [0.0, stop_edge, pass_edge, 1.0]
a = [0.0, 0.0, 1.0, 1.0]
try:
weights = firls(taps, f, a)
except Exception as exc:
logger.warning("FIR design failed ({}); falling back to Butterworth", exc)
b, _ = butter(5, pass_edge, btype="high")
weights = b
return weights.astype(float)
def _get_acquisition_window(self, context: ProcessingContext) -> tuple:
"""Return the start and end sample indices of the acquisition window.
Parameters
----------
context : ProcessingContext
Current processing context.
Returns
-------
tuple
``(s_acq_start, s_acq_end)`` as integers.
"""
raw = context.get_raw()
triggers = context.get_triggers()
if len(triggers) == 0:
return 0, raw.n_times
artifact_length = context.get_artifact_length()
if artifact_length is None:
return 0, raw.n_times
trigger_min = int(np.min(triggers))
trigger_max = int(np.max(triggers))
s_acq_start = max(0, trigger_min - artifact_length)
s_acq_end = min(raw.n_times, trigger_max + artifact_length)
if s_acq_end <= s_acq_start:
return 0, raw.n_times
return s_acq_start, s_acq_end
def _check_fastranc(self) -> bool:
"""Check whether the fastranc C extension is importable.
Returns
-------
bool
True if the extension is available, False otherwise.
"""
try:
from ..helpers.fastranc import fastranc
if fastranc is not None:
logger.debug("Using fastranc C extension for ANC")
return True
logger.info("fastranc C extension not available, using Python fallback")
return False
except Exception as exc:
logger.info("fastranc C extension not available ({}), using Python fallback", exc)
return False
# Alias for backwards compatibility
AdaptiveNoiseCancellation = ANCCorrection