Source code for facet.correction.volume

"""Volume-gap artifact correction processor."""

from __future__ import annotations

import mne
import numpy as np
from loguru import logger

from ..console import processor_progress
from ..core import ProcessingContext, Processor, ProcessorValidationError, register_processor


[docs] @register_processor class VolumeArtifactCorrection(Processor): """Correct volume-transition artifacts around slice-trigger gaps. MATLAB FACET's ``RARemoveVolumeArtifact`` subtracts a transition artifact from the slice before and after each detected volume gap, then linearly interpolates the gap itself. This processor ports that behavior. Parameters ---------- template_count : int Number of neighboring slices used to form pre/post templates (default: 5). weighting_position : float Relative location of the logistic midpoint inside one artifact epoch (default: 0.8). weighting_slope : float Logistic slope used for transition weighting (default: 20.0). """ name = "volume_artifact_correction" description = "Correct transition artifacts around volume gaps" version = "1.0.0" requires_triggers = True requires_raw = True modifies_raw = True parallel_safe = True channel_wise = True
[docs] def __init__( self, template_count: int = 5, weighting_position: float = 0.8, weighting_slope: float = 20.0, ) -> None: self.template_count = int(template_count) self.weighting_position = float(weighting_position) self.weighting_slope = float(weighting_slope) super().__init__()
[docs] def validate(self, context: ProcessingContext) -> None: super().validate(context) if context.get_artifact_length() is None or context.get_artifact_length() <= 1: raise ProcessorValidationError("Artifact length must be > 1 for volume artifact correction.") if self.template_count < 1: raise ProcessorValidationError(f"template_count must be >= 1, got {self.template_count}") if not (0.0 <= self.weighting_position <= 1.0): raise ProcessorValidationError(f"weighting_position must be in [0, 1], got {self.weighting_position}") if self.weighting_slope <= 0: raise ProcessorValidationError(f"weighting_slope must be positive, got {self.weighting_slope}")
[docs] def process(self, context: ProcessingContext) -> ProcessingContext: # --- EXTRACT --- if not context.metadata.volume_gaps: logger.info("VolumeArtifactCorrection skipped: context metadata indicates no volume gaps.") return context triggers = context.get_triggers() gap_pre_indices = self._find_volume_gap_pre_indices(triggers) if gap_pre_indices.size == 0: logger.info("VolumeArtifactCorrection skipped: no large trigger-distance gaps found.") return context raw = context.get_raw().copy() artifact_length = int(context.get_artifact_length()) pre_samples, post_samples = self._resolve_pre_post_samples(context, artifact_length) eeg_channels = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude="bads") # --- LOG --- logger.info( "Applying volume artifact correction: {} channels, {} volume gaps", len(eeg_channels), len(gap_pre_indices), ) # --- COMPUTE --- weight = self._logistic_weight(artifact_length) corrected_pairs = 0 interpolated_gaps = 0 with processor_progress( total=len(eeg_channels) or None, message="Volume artifact correction", ) as progress: for ch_pos, ch_idx in enumerate(eeg_channels): ch_name = raw.ch_names[ch_idx] pairs, gaps = self._correct_channel( ch_data=raw._data[ch_idx], triggers=triggers, gap_pre_indices=gap_pre_indices, pre_samples=pre_samples, post_samples=post_samples, weight=weight, ) corrected_pairs += pairs interpolated_gaps += gaps progress.advance(1, message=f"{ch_pos + 1}/{len(eeg_channels)}{ch_name}") # --- RETURN --- logger.info( "Volume artifact correction complete: {} corrected slice-pairs, {} interpolated gaps", corrected_pairs, interpolated_gaps, ) return context.with_raw(raw)
def _find_volume_gap_pre_indices(self, triggers: np.ndarray) -> np.ndarray: """Return indices of triggers directly before a volume gap.""" if len(triggers) < 2: return np.array([], dtype=int) diffs = np.diff(triggers) middle = float(np.mean([np.min(diffs), np.max(diffs)])) return np.where(diffs > middle)[0].astype(int) def _resolve_pre_post_samples(self, context: ProcessingContext, artifact_length: int) -> tuple[int, int]: """Derive pre/post trigger sample counts for one artifact epoch.""" metadata = context.metadata sfreq = context.get_sfreq() max_pre = max(0, artifact_length - 1) if metadata.pre_trigger_samples is not None: pre_samples = int(max(0, min(metadata.pre_trigger_samples, max_pre))) else: offset_samples = int(round(metadata.artifact_to_trigger_offset * sfreq)) pre_samples = int(max(0, min(-offset_samples, max_pre))) if offset_samples < 0 else 0 max_post = max(0, artifact_length - pre_samples - 1) if metadata.post_trigger_samples is not None: post_samples = int(max(0, min(metadata.post_trigger_samples, max_post))) else: post_samples = int(max_post) if pre_samples + post_samples + 1 < artifact_length: post_samples = artifact_length - pre_samples - 1 return pre_samples, post_samples def _logistic_weight(self, artifact_length: int) -> np.ndarray: """Build MATLAB-style logistic transition weights for one epoch.""" x = np.arange(1, artifact_length + 1, dtype=float) / float(artifact_length) return 1.0 / (1.0 + np.exp(-self.weighting_slope * (x - self.weighting_position))) def _correct_channel( self, ch_data: np.ndarray, triggers: np.ndarray, gap_pre_indices: np.ndarray, pre_samples: int, post_samples: int, weight: np.ndarray, ) -> tuple[int, int]: """Correct all detected volume gaps for one EEG channel.""" corrected_pairs = 0 interpolated_gaps = 0 epoch_len = pre_samples + post_samples + 1 n_times = ch_data.shape[0] for gap_pre_idx in gap_pre_indices: trig_pre_idx = int(gap_pre_idx) trig_post_idx = trig_pre_idx + 1 pre_start = int(triggers[trig_pre_idx] - pre_samples) pre_stop = pre_start + epoch_len post_start = int(triggers[trig_post_idx] - pre_samples) post_stop = post_start + epoch_len if pre_start < 0 or post_start < 0 or pre_stop > n_times or post_stop > n_times: continue prev_indices = np.arange(trig_pre_idx - self.template_count, trig_pre_idx, dtype=int) next_indices = np.arange(trig_post_idx + 1, trig_post_idx + 1 + self.template_count, dtype=int) if prev_indices[0] < 0 or next_indices[-1] >= len(triggers): continue data_pre = ch_data[pre_start:pre_stop].copy() data_post = ch_data[post_start:post_stop].copy() template_pre = self._mean_template(ch_data, triggers[prev_indices], pre_samples, epoch_len) template_post = self._mean_template(ch_data, triggers[next_indices], pre_samples, epoch_len) if template_pre is None or template_post is None: continue vol_art_pre = (data_pre - template_pre) * weight vol_art_post = (data_post - template_post) * weight[::-1] corrected_pre = data_pre - vol_art_pre corrected_post = data_post - vol_art_post ch_data[pre_start:pre_stop] = corrected_pre ch_data[post_start:post_stop] = corrected_post corrected_pairs += 1 gap_start = int(triggers[trig_pre_idx] + post_samples + 1) gap_end = int(triggers[trig_post_idx] - pre_samples - 1) if gap_end >= gap_start: gap_len = gap_end - gap_start + 1 gap_values = corrected_pre[-1] + ( (np.arange(1, gap_len + 1, dtype=float) / float(gap_len + 1)) * (corrected_post[0] - corrected_pre[-1]) ) ch_data[gap_start : gap_end + 1] = gap_values interpolated_gaps += 1 return corrected_pairs, interpolated_gaps def _mean_template( self, ch_data: np.ndarray, template_triggers: np.ndarray, pre_samples: int, epoch_len: int, ) -> np.ndarray | None: """Average valid template epochs for one side of a volume gap.""" segments = [] n_times = ch_data.shape[0] for trigger in template_triggers: start = int(trigger - pre_samples) stop = start + epoch_len if start < 0 or stop > n_times: continue segments.append(ch_data[start:stop]) if len(segments) == 0: return None return np.mean(np.vstack(segments), axis=0)
# Alias for backwards compatibility RemoveVolumeArtifactCorrection = VolumeArtifactCorrection