Source code for facet.correction.farm

"""FARM-based artifact correction processor."""

from __future__ import annotations

import numpy as np
from loguru import logger

from ..core import ProcessingContext, ProcessorValidationError, register_processor
from .aas import AASCorrection


[docs] @register_processor class FARMCorrection(AASCorrection): """Remove fMRI artifacts using the MATLAB FACET FARM weighting strategy. This processor reuses the AAS subtraction pipeline but replaces template selection with the FARM rule from MATLAB ``AvgArtWghtFARM``: - Compute epoch-to-epoch correlations for one channel. - For each epoch, search within a wide neighborhood. - Keep up to ``window_size`` most correlated epochs above threshold. - Average the selected epochs with equal weights. Parameters ---------- window_size : int Number of similar epochs to average (default: 30). correlation_threshold : float Minimum absolute correlation for candidate selection (default: 0.9, matching MATLAB FARM). search_half_window : int, optional Half-window in epochs used for candidate search. If ``None``, derived from ``search_half_window_factor * window_size``. search_half_window_factor : float Multiplier used when ``search_half_window`` is not set (default: 3.0). plot_artifacts : bool If ``True``, plot one random averaged artifact (default: False). realign_after_averaging : bool If ``True``, realign triggers after template averaging (default: True). search_window_factor : float Trigger-realignment search-window factor (default: 3.0). interpolate_volume_gaps : bool If ``True``, interpolate estimated artifact/noise in inter-epoch gaps (default: False). apply_epoch_alpha_scaling : bool If ``True``, apply MATLAB-like per-epoch least-squares alpha scaling before subtraction (default: False). """ name = "farm_correction" description = "AAS with MATLAB FACET FARM template weighting" version = "1.0.0"
[docs] def __init__( self, window_size: int = 30, correlation_threshold: float = 0.9, search_half_window: int | None = None, search_half_window_factor: float = 3.0, plot_artifacts: bool = False, realign_after_averaging: bool = True, search_window_factor: float = 3.0, interpolate_volume_gaps: bool = False, apply_epoch_alpha_scaling: bool = False, ) -> None: self.search_half_window = search_half_window self.search_half_window_factor = search_half_window_factor super().__init__( window_size=window_size, rel_window_position=0.0, correlation_threshold=correlation_threshold, plot_artifacts=plot_artifacts, realign_after_averaging=realign_after_averaging, search_window_factor=search_window_factor, interpolate_volume_gaps=interpolate_volume_gaps, apply_epoch_alpha_scaling=apply_epoch_alpha_scaling, )
[docs] def validate(self, context: ProcessingContext) -> None: super().validate(context) if self.search_half_window is not None and self.search_half_window < 1: raise ProcessorValidationError(f"search_half_window must be >= 1 when set, got {self.search_half_window}") if self.search_half_window_factor <= 0: raise ProcessorValidationError( f"search_half_window_factor must be positive, got {self.search_half_window_factor}" )
def _calc_averaging_matrix( self, epochs: np.ndarray, window_size: int, rel_window_offset: float, correlation_threshold: float, ) -> np.ndarray: """Calculate FARM averaging weights for all epochs. Parameters ---------- epochs : np.ndarray Epoch matrix with shape ``(n_epochs, n_times)``. window_size : int Maximum number of epochs to average per row. rel_window_offset : float Unused for FARM; accepted for API compatibility. correlation_threshold : float Minimum absolute Pearson correlation for inclusion. Returns ------- np.ndarray Averaging matrix of shape ``(n_epochs, n_epochs)``. """ del rel_window_offset # Not used by FARM, kept for signature compatibility. n_epochs = int(epochs.shape[0]) averaging_matrix = np.zeros((n_epochs, n_epochs), dtype=float) if n_epochs == 0: return averaging_matrix if n_epochs == 1: averaging_matrix[0, 0] = 1.0 return averaging_matrix search_half_window = self._resolve_search_half_window(window_size) corr_mat = np.corrcoef(epochs) corr_mat = np.nan_to_num(corr_mat, nan=0.0, posinf=0.0, neginf=0.0) np.fill_diagonal(corr_mat, 1.0) too_few_rows = 0 for row in range(n_epochs): selected = self._select_row_indices( corr_mat=corr_mat, row=row, n_epochs=n_epochs, search_half_window=search_half_window, window_size=window_size, correlation_threshold=correlation_threshold, ) if selected.size < window_size: too_few_rows += 1 if selected.size == 0: # Keep the row valid so subtraction remains stable. selected = np.array([row], dtype=int) averaging_matrix[row, selected] = 1.0 / float(selected.size) if too_few_rows > 0: logger.warning( "FARM found fewer than {} similar epochs in {} rows; using reduced averages.", window_size, too_few_rows, ) return averaging_matrix def _resolve_search_half_window(self, window_size: int) -> int: """Return FARM candidate search half-window in epochs.""" if self.search_half_window is not None: return int(self.search_half_window) return max(1, int(round(self.search_half_window_factor * window_size))) def _select_row_indices( self, corr_mat: np.ndarray, row: int, n_epochs: int, search_half_window: int, window_size: int, correlation_threshold: float, ) -> np.ndarray: """Select averaged-epoch indices for one FARM matrix row.""" left = max(0, row - search_half_window) right = min(left + (2 * search_half_window + 1), n_epochs) required_width = 2 * search_half_window + 1 if (right - left) < required_width: left = max(0, right - required_width) local_indices = np.arange(left, right, dtype=int) local_corr = np.abs(corr_mat[row, left:right]) order = np.argsort(local_corr)[::-1] ranked_indices = local_indices[order] ranked_corr = local_corr[order] not_self = ranked_indices != row ranked_indices = ranked_indices[not_self] ranked_corr = ranked_corr[not_self] selected = ranked_indices[ranked_corr >= correlation_threshold] if selected.size == 0: return np.array([], dtype=int) return selected[:window_size]
# Alias for backwards compatibility FARMArtifactCorrection = FARMCorrection