Source code for facet.preprocessing.prefilter

"""MATLAB-style prefilter processor."""

from __future__ import annotations

from collections.abc import Sequence

import mne
import numpy as np
from loguru import logger

from ..core import ProcessingContext, Processor, ProcessorValidationError, register_processor


[docs] @register_processor class MATLABPreFilter(Processor): """Apply MATLAB FACET-style FFT prefiltering before artifact correction. Supports: - Piecewise-linear transfer-function filtering in frequency domain. - Gaussian high-pass filtering (DC-safe) in frequency domain. - Optional padded filtering inside the acquisition window. Parameters ---------- lp_frequency : float or sequence of float, optional Low-pass cutoff in Hz (scalar or per-channel values). ``None`` means no LP from generated transfer functions. hp_frequency : float or sequence of float, optional High-pass cutoff in Hz (scalar or per-channel values). ``None`` means no HP from generated transfer functions. transfer_frequencies : np.ndarray or sequence of np.ndarray, optional Custom normalized frequency points in ``[0, 1]`` for transfer-function filtering. If set, overrides LP/HP synthesis. transfer_amplitudes : np.ndarray or sequence of np.ndarray, optional Amplitude points corresponding to ``transfer_frequencies``. gauss_hp_frequency : float or sequence of float, optional Gaussian high-pass cutoff in Hz (scalar or per-channel values). picks : str or sequence of str Channels to filter (default: ``"eeg"``). pad_acquisition : bool Use acquisition-window padding during filtering (default: True). """ name = "matlab_prefilter" description = "Apply MATLAB FACET-style FFT/gaussian prefilter" version = "1.0.0" requires_triggers = False requires_raw = True modifies_raw = True parallel_safe = True channel_wise = True
[docs] def __init__( self, lp_frequency: float | Sequence[float] | None = None, hp_frequency: float | Sequence[float] | None = None, transfer_frequencies: np.ndarray | Sequence[np.ndarray] | None = None, transfer_amplitudes: np.ndarray | Sequence[np.ndarray] | None = None, gauss_hp_frequency: float | Sequence[float] | None = None, picks: str | Sequence[str] = "eeg", pad_acquisition: bool = True, ) -> None: self.lp_frequency = lp_frequency self.hp_frequency = hp_frequency self.transfer_frequencies = transfer_frequencies self.transfer_amplitudes = transfer_amplitudes self.gauss_hp_frequency = gauss_hp_frequency self.picks = picks self.pad_acquisition = pad_acquisition super().__init__()
[docs] def validate(self, context: ProcessingContext) -> None: super().validate(context) if self.transfer_frequencies is not None and self.transfer_amplitudes is None: raise ProcessorValidationError("transfer_amplitudes must be provided with transfer_frequencies.") if self.transfer_amplitudes is not None and self.transfer_frequencies is None: raise ProcessorValidationError("transfer_frequencies must be provided with transfer_amplitudes.")
[docs] def process(self, context: ProcessingContext) -> ProcessingContext: # --- EXTRACT --- raw = context.get_raw().copy() sfreq = float(raw.info["sfreq"]) target_indices = self._resolve_target_indices(raw) acq_start, acq_end = self._resolve_acquisition_bounds(context) # --- LOG --- logger.info( "Applying MATLABPreFilter to {} channels (pad_acquisition={})", len(target_indices), self.pad_acquisition, ) # --- COMPUTE --- for i, ch_idx in enumerate(target_indices): ch_name = raw.ch_names[ch_idx] raw._data[ch_idx] = self._filter_channel( data=raw._data[ch_idx], sfreq=sfreq, channel_pos=i, context=context, acq_start=acq_start, acq_end=acq_end, ) logger.debug("MATLABPreFilter processed channel {}", ch_name) result = context.with_raw(raw) # --- NOISE --- if context.has_estimated_noise(): noise = context.get_estimated_noise().copy() for i, ch_idx in enumerate(target_indices): noise[ch_idx] = self._filter_channel( data=noise[ch_idx], sfreq=sfreq, channel_pos=i, context=context, acq_start=acq_start, acq_end=acq_end, ) result.set_estimated_noise(noise) # --- METADATA --- md = result.metadata.copy() md.custom["matlab_prefilter"] = { "pad_acquisition": self.pad_acquisition, "num_channels": len(target_indices), "acq_start_sample": int(acq_start), "acq_end_sample": int(acq_end), } return result.with_metadata(md)
def _resolve_target_indices(self, raw: mne.io.Raw) -> list[int]: """Resolve picks to channel indices.""" try: picked = raw.copy().pick(picks=self.picks, verbose=False) except Exception as exc: raise ProcessorValidationError(f"Invalid picks '{self.picks}': {exc}") from exc return [raw.ch_names.index(ch) for ch in picked.ch_names] def _resolve_acquisition_bounds(self, context: ProcessingContext) -> tuple[int, int]: """Resolve acquisition bounds for padded filtering.""" raw = context.get_raw() md = context.metadata n_times = raw.n_times if md.acq_start_sample is not None and md.acq_end_sample is not None: return max(0, int(md.acq_start_sample)), min(n_times - 1, int(md.acq_end_sample)) if context.has_triggers() and context.get_artifact_length(): triggers = context.get_triggers() pre = md.pre_trigger_samples post = md.post_trigger_samples if pre is None: offset = int(round(md.artifact_to_trigger_offset * context.get_sfreq())) pre = max(0, -offset) if offset < 0 else 0 if post is None: post = max(int(context.get_artifact_length()) - int(pre) - 1, 0) start = max(0, int(triggers[0] - pre)) end = min(n_times - 1, int(triggers[-1] + post)) return start, end return 0, n_times - 1 def _filter_channel( self, data: np.ndarray, sfreq: float, channel_pos: int, context: ProcessingContext, acq_start: int, acq_end: int, ) -> np.ndarray: """Apply configured prefiltering to one channel.""" out = data.astype(float, copy=True) if not self.pad_acquisition or acq_start >= acq_end: return self._apply_filters(out, sfreq, channel_pos) if acq_start > 0: out[:acq_start] = self._apply_filters(out[:acq_start], sfreq, channel_pos) out[acq_start : acq_end + 1] = self._apply_filters_padded( out[acq_start : acq_end + 1], sfreq, channel_pos, context, ) if acq_end < len(out) - 1: out[acq_end + 1 :] = self._apply_filters(out[acq_end + 1 :], sfreq, channel_pos) return out def _apply_filters_padded( self, segment: np.ndarray, sfreq: float, channel_pos: int, context: ProcessingContext, ) -> np.ndarray: """Apply filters with repeated-edge padding during acquisition.""" if len(segment) < 4: return self._apply_filters(segment, sfreq, channel_pos) triggers = context.get_triggers() if context.has_triggers() else None if triggers is None or len(triggers) < 2: return self._apply_filters(segment, sfreq, channel_pos) art_len_start = max(1, int(triggers[1] - triggers[0])) art_len_end = max(1, int(triggers[-1] - triggers[-2])) art_len_start = min(art_len_start, len(segment)) art_len_end = min(art_len_end, len(segment)) art_s = segment[:art_len_start] art_e = segment[-art_len_end:] num = int(np.ceil(sfreq / max(art_len_start, 1)) + 1) padded = np.concatenate([np.tile(art_s, num), segment, np.tile(art_e, num)]) filtered = self._apply_filters(padded, sfreq, channel_pos) return filtered[(len(art_s) * num) : (len(filtered) - len(art_e) * num)] def _apply_filters(self, signal: np.ndarray, sfreq: float, channel_pos: int) -> np.ndarray: """Apply transfer-function and Gaussian HP filters to one signal.""" if len(signal) < 2: return signal out = signal tf = self._resolve_transfer_function(channel_pos, sfreq) if tf is not None: f, a = tf out = self._fft_transfer_filter(out, f, a) gauss_hp = self._resolve_per_channel(self.gauss_hp_frequency, channel_pos) if gauss_hp is not None and gauss_hp > 0: out = self._fft_gaussian_highpass(out, float(gauss_hp), sfreq) return np.real(out) def _resolve_transfer_function(self, channel_pos: int, sfreq: float) -> tuple[np.ndarray, np.ndarray] | None: """Return normalized transfer-function vectors for one channel.""" if self.transfer_frequencies is not None: if isinstance(self.transfer_frequencies, np.ndarray) and self.transfer_frequencies.ndim == 1: freqs = self.transfer_frequencies elif ( isinstance(self.transfer_frequencies, (list, tuple)) and len(self.transfer_frequencies) > 0 and isinstance(self.transfer_frequencies[0], (list, tuple, np.ndarray)) ): freqs = self._resolve_per_channel(self.transfer_frequencies, channel_pos) else: freqs = self._resolve_per_channel(self.transfer_frequencies, channel_pos) if isinstance(self.transfer_amplitudes, np.ndarray) and self.transfer_amplitudes.ndim == 1: ampls = self.transfer_amplitudes elif ( isinstance(self.transfer_amplitudes, (list, tuple)) and len(self.transfer_amplitudes) > 0 and isinstance(self.transfer_amplitudes[0], (list, tuple, np.ndarray)) ): ampls = self._resolve_per_channel(self.transfer_amplitudes, channel_pos) else: ampls = self._resolve_per_channel(self.transfer_amplitudes, channel_pos) if freqs is None or ampls is None: return None f = np.asarray(freqs, dtype=float).ravel() a = np.asarray(ampls, dtype=float).ravel() self._validate_transfer_function(f, a) return f, a lp = self._resolve_per_channel(self.lp_frequency, channel_pos) hp = self._resolve_per_channel(self.hp_frequency, channel_pos) nyq = sfreq / 2.0 lp = float(lp) if lp is not None else 0.0 hp = float(hp) if hp is not None else 0.0 if lp <= 0 and hp <= 0: return None if lp > 0 and hp > 0 and hp < lp: f = np.array([0, hp * 0.99, hp * 1.01, lp * 0.99, lp * 1.01, nyq], dtype=float) / nyq a = np.array([0, 0, 1, 1, 0, 0], dtype=float) elif lp > 0 and (hp <= 0 or hp >= lp): f = np.array([0, lp * 0.95, lp * 1.05, nyq], dtype=float) / nyq a = np.array([1, 1, 0, 0], dtype=float) else: f = np.array([0, hp * 0.95, hp * 1.05, nyq], dtype=float) / nyq a = np.array([0, 0, 1, 1], dtype=float) f = np.clip(f, 0.0, 1.0) if np.any(np.diff(f) < 0): f = np.maximum.accumulate(f) return f, a def _resolve_per_channel(self, value, channel_pos: int): """Resolve scalar/vector parameter values per channel position.""" if value is None: return None if isinstance(value, np.ndarray): if value.ndim == 1 and np.issubdtype(value.dtype, np.number): if len(value) == 0: return None if len(value) > channel_pos: return value[channel_pos] return value[-1] if value.ndim == 1 and len(value) > 0 and isinstance(value[0], (np.ndarray, list, tuple)): return value[channel_pos] if len(value) > channel_pos else value[-1] return value if isinstance(value, (list, tuple)): if len(value) == 0: return None if isinstance(value[0], (list, tuple, np.ndarray)): return value[channel_pos] if len(value) > channel_pos else value[-1] return value[channel_pos] if len(value) > channel_pos else value[-1] return value @staticmethod def _validate_transfer_function(f: np.ndarray, a: np.ndarray) -> None: """Validate normalized transfer-function vectors.""" if f.ndim != 1 or a.ndim != 1 or len(f) != len(a): raise ProcessorValidationError( "transfer_frequencies and transfer_amplitudes must be 1D vectors of same length." ) if len(f) < 2: raise ProcessorValidationError("transfer_frequencies must contain at least 2 points.") if f[0] != 0.0 or f[-1] != 1.0: raise ProcessorValidationError("transfer_frequencies must start at 0.0 and end at 1.0.") if np.any(np.diff(f) < 0): raise ProcessorValidationError("transfer_frequencies must be sorted ascending.") if np.min(f) < 0.0 or np.max(f) > 1.0: raise ProcessorValidationError("transfer_frequencies must be in [0, 1].") @staticmethod def _fft_transfer_filter(signal: np.ndarray, f: np.ndarray, a: np.ndarray) -> np.ndarray: """Apply piecewise-linear transfer function in the FFT domain.""" n = len(signal) X = np.fft.rfft(signal) w = np.linspace(0.0, 1.0, len(X)) H = np.interp(w, f, a) y = np.fft.irfft(X * H, n=n) return y @staticmethod def _fft_gaussian_highpass(signal: np.ndarray, cutoff_hz: float, sfreq: float) -> np.ndarray: """Apply MATLAB-style Gaussian high-pass in the FFT domain.""" n = len(signal) X = np.fft.rfft(signal) freqs = np.fft.rfftfreq(n, d=1.0 / sfreq) sigma = cutoff_hz / (2 * np.sqrt(-np.log(1 - 1 / np.sqrt(2)))) H = 1.0 - np.exp(-((freqs / (2 * sigma)) ** 2)) y = np.fft.irfft(X * H, n=n) return y