Source code for facet.preprocessing.trigger_detection

"""Trigger Detection Processors Module

Processors for detecting triggers and events in EEG data recorded during
simultaneous EEG-fMRI acquisition.
"""

import re

import mne
import numpy as np
from loguru import logger

from ..core import (
    ProcessingContext,
    Processor,
    ProcessorError,
    ProcessorValidationError,
    register_processor,
)
from ..helpers.crosscorr import crosscorrelation


[docs] @register_processor class TriggerDetector(Processor): """Detect triggers from annotations or stim channels using a regex pattern. Searches the raw data for events whose description (annotation) or integer value (stim channel) matches the supplied regular expression. Detected trigger sample positions are stored in ``context.metadata.triggers``. The artifact length is estimated from the median inter-trigger interval; volume-level gaps in slice-triggered acquisitions are detected automatically. Parameters ---------- regex : str Regular expression pattern to match trigger values. save_to_annotations : bool, optional If ``True``, write detected triggers back to the raw annotations (default: False). """ name = "trigger_detector" description = "Detect triggers using regex pattern" version = "1.0.0" requires_triggers = False requires_raw = True modifies_raw = False parallel_safe = False
[docs] def __init__(self, regex: str, save_to_annotations: bool = False) -> None: self.regex = regex self.save_to_annotations = save_to_annotations super().__init__()
[docs] def process(self, context: ProcessingContext) -> ProcessingContext: # --- EXTRACT --- raw = context.get_raw() sfreq = raw.info["sfreq"] # --- LOG --- logger.info("Detecting triggers with pattern: {}", self.regex) # --- COMPUTE --- filtered_events = self._find_events(raw) if len(filtered_events) == 0: logger.warning("No triggers found!") return context triggers = np.array([event[0] for event in filtered_events]) # MNE returns absolute sample indices (onset * sfreq + first_samp). # Normalize to 0-indexed positions relative to the current raw start # so that triggers can be used directly as indices into raw._data. triggers = triggers - raw.first_samp logger.info("Found {} triggers", len(triggers)) artifact_meta = self._compute_artifact_metadata(triggers) # --- BUILD RESULT --- new_metadata = context.metadata.copy() new_metadata.triggers = triggers new_metadata.trigger_regex = self.regex new_metadata.artifact_length = artifact_meta["artifact_length"] new_metadata.volume_gaps = artifact_meta["volume_gaps"] if artifact_meta.get("slices_per_volume") is not None: new_metadata.slices_per_volume = artifact_meta["slices_per_volume"] logger.debug("Artifact length: {} samples", new_metadata.artifact_length) logger.debug("Volume gaps: {}", new_metadata.volume_gaps) if self.save_to_annotations: raw_copy = raw.copy() raw_copy.set_annotations( mne.Annotations( onset=triggers / sfreq, duration=np.zeros(len(triggers)), description=["Trigger"] * len(triggers), ) ) return context.with_raw(raw_copy).with_metadata(new_metadata) # --- RETURN --- return context.with_metadata(new_metadata)
def _find_events(self, raw: mne.io.Raw) -> list: """Search stim channels then annotations for matching events. Parameters ---------- raw : mne.io.Raw Raw object to search. Returns ------- list List of MNE-style event rows ``[sample, prev_id, event_id]``. """ pattern = re.compile(self.regex) stim_channels = mne.pick_types(raw.info, meg=False, eeg=False, stim=True) if len(stim_channels) > 0: logger.debug("Found {} stim channels", len(stim_channels)) events = mne.find_events( raw, stim_channel=raw.ch_names[stim_channels[0]], initial_event=True, verbose=False, ) return [event for event in events if pattern.search(str(event[2]))] logger.debug("No stim channels, searching annotations") events_obj = mne.events_from_annotations(raw, verbose=False) logger.debug("Event types: {}", events_obj[1]) return list(mne.events_from_annotations(raw, regexp=self.regex, verbose=False)[0]) def _compute_artifact_metadata(self, triggers: np.ndarray) -> dict: """Estimate artifact length and detect volume gaps from trigger spacing. Parameters ---------- triggers : np.ndarray Detected trigger sample positions. Returns ------- dict Keys: ``artifact_length``, ``volume_gaps``, optionally ``slices_per_volume``. """ if len(triggers) <= 1: return {"artifact_length": None, "volume_gaps": False} trigger_diffs = np.diff(triggers) ptp = np.ptp(trigger_diffs) if ptp > 3: return self._compute_slice_volume_metadata(triggers, trigger_diffs) return { "artifact_length": int(np.max(trigger_diffs)), "volume_gaps": False, } def _compute_slice_volume_metadata(self, triggers: np.ndarray, trigger_diffs: np.ndarray) -> dict: """Compute metadata when volume-level gaps are present. Parameters ---------- triggers : np.ndarray All trigger sample positions. trigger_diffs : np.ndarray Differences between consecutive triggers. Returns ------- dict Keys: ``artifact_length``, ``volume_gaps``, ``slices_per_volume``. """ mean_val = np.mean([np.median(trigger_diffs), np.max(trigger_diffs)]) slice_diffs = trigger_diffs[trigger_diffs < mean_val] artifact_length = int(np.max(slice_diffs)) gap_indices = np.where(trigger_diffs >= mean_val)[0] slices_per_volume = None if len(gap_indices) > 0: slice_counts = [] last_idx = -1 for idx in gap_indices: slice_counts.append(idx - last_idx) last_idx = idx if last_idx < len(triggers) - 1: slice_counts.append(len(triggers) - 1 - last_idx) if slice_counts: slices_per_volume = int(np.median(slice_counts)) logger.info("Estimated slices per volume: {}", slices_per_volume) return { "artifact_length": artifact_length, "volume_gaps": True, "slices_per_volume": slices_per_volume, }
[docs] @register_processor class QRSTriggerDetector(Processor): """Detect QRS complexes (heartbeats) for BCG artifact correction. Uses the FMRIB QRS detector from ``facet.helpers.bcg_detector``, which requires the ``neurokit2`` package (available via ``pip install facetpy[all]``). The artifact length is set to half the median RR interval, centred on each detected R-peak. Parameters ---------- save_to_annotations : bool, optional If ``True``, write detected QRS peaks back to the raw annotations (default: False). """ name = "qrs_trigger_detector" description = "Detect QRS complexes for BCG correction" version = "1.0.0" requires_triggers = False requires_raw = True modifies_raw = False parallel_safe = False
[docs] def __init__(self, save_to_annotations: bool = False) -> None: self.save_to_annotations = save_to_annotations super().__init__()
[docs] def process(self, context: ProcessingContext) -> ProcessingContext: # --- EXTRACT --- raw = context.get_raw() sfreq = raw.info["sfreq"] # --- LOG --- logger.info("Detecting QRS complexes") # --- COMPUTE --- try: from ..helpers import bcg_detector except ImportError as err: raise ProcessorError( "neurokit2 is required for QRSTriggerDetector. Install with: pip install facetpy[all]" ) from err peaks = bcg_detector.fmrib_qrsdetect(raw) triggers = np.array(peaks, dtype=np.int32) logger.info("Found {} QRS peaks", len(triggers)) # --- BUILD RESULT --- new_metadata = context.metadata.copy() new_metadata.triggers = triggers new_metadata.trigger_regex = "QRS" new_metadata.volume_gaps = True # QRS peaks have variable spacing if len(triggers) > 1: rr_intervals = np.diff(triggers) median_rr = int(np.median(rr_intervals)) new_metadata.artifact_length = median_rr // 2 new_metadata.artifact_to_trigger_offset = -new_metadata.artifact_length / (2 * sfreq) if self.save_to_annotations: raw_copy = raw.copy() raw_copy.set_annotations( mne.Annotations( onset=triggers / sfreq, duration=np.zeros(len(triggers)), description=["QRS"] * len(triggers), ) ) return context.with_raw(raw_copy).with_metadata(new_metadata) # --- RETURN --- return context.with_metadata(new_metadata)
[docs] @register_processor class MissingTriggerDetector(Processor): """Detect and insert missing triggers by template matching. Scans the trigger sequence for gaps larger than 1.9× the artifact length and attempts to locate missing artifact epochs by cross-correlating a reference template against the signal. Optionally extends the search one step before the first trigger and one step after the last. Parameters ---------- add_to_context : bool, optional If ``True``, insert found triggers into metadata and annotations (default: True). correlation_threshold : float, optional Minimum absolute Pearson correlation to accept a candidate trigger (default: 0.9). search_window_factor : float, optional Search window as a fraction of artifact length (default: 0.1). ref_channel : int, optional Reference channel index for template matching (default: 0). """ name = "missing_trigger_detector" description = "Detect and add missing triggers" version = "1.0.0" requires_triggers = True requires_raw = True modifies_raw = False parallel_safe = False
[docs] def __init__( self, add_to_context: bool = True, correlation_threshold: float = 0.9, search_window_factor: float = 0.1, ref_channel: int = 0, ) -> None: self.add_to_context = add_to_context self.correlation_threshold = correlation_threshold self.search_window_factor = search_window_factor self.ref_channel = ref_channel super().__init__()
[docs] def validate(self, context: ProcessingContext) -> None: super().validate(context) if context.get_artifact_length() is None: raise ProcessorValidationError("Artifact length not set. Run TriggerDetector first.")
[docs] def process(self, context: ProcessingContext) -> ProcessingContext: # --- EXTRACT --- raw = context.get_raw() triggers = context.get_triggers().copy() artifact_length = context.get_artifact_length() sfreq = raw.info["sfreq"] tmin = int(context.metadata.artifact_to_trigger_offset * sfreq) tmax = tmin + artifact_length # --- LOG --- logger.info("Searching for missing triggers") # --- COMPUTE --- ref_channel_data = raw.get_data(picks=[self.ref_channel])[0] template = self._build_template(ref_channel_data, triggers, tmin, tmax) search_window = int(self.search_window_factor * artifact_length) missing_triggers = self._find_missing_triggers( ref_channel_data, triggers, template, artifact_length, search_window, tmin, tmax ) logger.info("Found {} missing triggers", len(missing_triggers)) if len(missing_triggers) == 0: return context # --- BUILD RESULT --- if self.add_to_context: return self._build_context_with_missing(context, raw, triggers, missing_triggers, sfreq) new_metadata = context.metadata.copy() new_metadata.custom["missing_triggers"] = missing_triggers # --- RETURN --- return context.with_metadata(new_metadata)
def _build_template( self, ref_data: np.ndarray, triggers: np.ndarray, tmin: int, tmax: int, ) -> np.ndarray: """Build a reference artifact template from the first few triggers. Parameters ---------- ref_data : np.ndarray 1-D reference channel signal. triggers : np.ndarray Trigger sample positions. tmin : int Start offset from trigger to artifact onset. tmax : int End offset from trigger to artifact end. Returns ------- np.ndarray Averaged template epoch. """ n_template = min(5, len(triggers)) template_epochs = [] for i in range(n_template): start = triggers[i] + tmin end = triggers[i] + tmax if end <= len(ref_data): template_epochs.append(ref_data[start:end]) return np.mean(template_epochs, axis=0) def _find_missing_triggers( self, ref_data: np.ndarray, triggers: np.ndarray, template: np.ndarray, artifact_length: int, search_window: int, tmin: int, tmax: int, ) -> list: """Search for missing triggers before, within, and after the sequence. Parameters ---------- ref_data : np.ndarray 1-D reference channel signal. triggers : np.ndarray Known trigger positions. template : np.ndarray Reference artifact template. artifact_length : int Expected artifact length in samples. search_window : int Cross-correlation search radius. tmin : int Artifact onset offset from trigger. tmax : int Artifact end offset from trigger. Returns ------- list of int Positions of detected missing triggers. """ missing_triggers = [] for i in range(len(triggers) - 1): gap = triggers[i + 1] - triggers[i] if gap > artifact_length * 1.9: search_pos = triggers[i] + artifact_length candidate = self._align_to_template(ref_data, template, search_pos, search_window, tmin, tmax) if self._is_artifact(ref_data, template, candidate, tmin, artifact_length): missing_triggers.append(candidate) search_pos = triggers[0] - artifact_length if search_pos > 0: candidate = self._align_to_template(ref_data, template, search_pos, search_window, tmin, tmax) if self._is_artifact(ref_data, template, candidate, tmin, artifact_length): missing_triggers.insert(0, candidate) search_pos = triggers[-1] + artifact_length if search_pos + tmax < len(ref_data): candidate = self._align_to_template(ref_data, template, search_pos, search_window, tmin, tmax) if self._is_artifact(ref_data, template, candidate, tmin, artifact_length): missing_triggers.append(candidate) return missing_triggers def _build_context_with_missing( self, context: ProcessingContext, raw: mne.io.Raw, triggers: np.ndarray, missing_triggers: list, sfreq: float, ) -> ProcessingContext: """Merge missing triggers into the context and add annotations. Parameters ---------- context : ProcessingContext Current processing context. raw : mne.io.Raw Current raw object. triggers : np.ndarray Original trigger positions. missing_triggers : list of int Newly detected missing trigger positions. sfreq : float Sampling frequency in Hz. Returns ------- ProcessingContext Updated context with merged triggers and annotations. """ all_triggers = np.sort(np.concatenate([triggers, missing_triggers])) new_metadata = context.metadata.copy() new_metadata.triggers = all_triggers raw_copy = raw.copy() existing_annot = raw.annotations new_annot = mne.Annotations( onset=np.array(missing_triggers) / sfreq, duration=np.zeros(len(missing_triggers)), description=["missing_trigger"] * len(missing_triggers), ) combined = mne.Annotations( onset=np.concatenate([existing_annot.onset, new_annot.onset]), duration=np.concatenate([existing_annot.duration, new_annot.duration]), description=list(existing_annot.description) + list(new_annot.description), ) raw_copy.set_annotations(combined) return context.with_raw(raw_copy).with_metadata(new_metadata) def _align_to_template( self, data: np.ndarray, template: np.ndarray, position: int, search_window: int, tmin: int, tmax: int, ) -> int: """Find the best-matching position near ``position`` via cross-correlation. Parameters ---------- data : np.ndarray 1-D signal array. template : np.ndarray Reference artifact template. position : int Initial candidate trigger position. search_window : int Search radius in samples. tmin : int Artifact onset offset from trigger. tmax : int Artifact end offset from trigger. Returns ------- int Refined trigger position. """ segment = data[position + tmin : position + tmax + search_window] corr = crosscorrelation(segment, template, search_window) shift = int(np.argmax(corr)) - search_window return position + shift def _is_artifact( self, data: np.ndarray, template: np.ndarray, position: int, tmin: int, artifact_length: int, ) -> bool: """Return True if the segment at ``position`` correlates with the template. Parameters ---------- data : np.ndarray 1-D signal array. template : np.ndarray Reference artifact template. position : int Candidate trigger position. tmin : int Artifact onset offset from trigger. artifact_length : int Length of artifact in samples. Returns ------- bool ``True`` if the Pearson correlation exceeds ``self.correlation_threshold``. """ start = position + tmin end = start + min(artifact_length, len(template)) if end > len(data): return False segment = data[start:end] template_segment = template[: len(segment)] if len(template_segment) < 3: return False try: corr = float(np.abs(np.corrcoef(segment, template_segment)[0, 1])) return corr > self.correlation_threshold except ValueError: return False
[docs] @register_processor class MissingTriggerCompleter(Processor): """Deterministically complete missing triggers using volume/slice counts. MATLAB FACET parity processor for ``FindMissingTriggers(volumes, slices)`` based on the ``CompleteTriggers`` reconstruction heuristic. Parameters ---------- volumes : int Total number of fMRI volumes expected. slices : int Number of slices per volume. add_to_context : bool, optional If ``True``, replace ``metadata.triggers`` with the completed list (default: True). add_annotations : bool, optional If ``True`` and raw is available, add annotations for inserted triggers (default: True). strict : bool, optional If ``True``, mismatches raise ``ProcessorValidationError``. If ``False``, best-effort completion is used (default: True). """ name = "missing_trigger_completer" description = "Deterministically complete missing triggers from volume/slice counts" version = "1.0.0" requires_triggers = True requires_raw = False modifies_raw = False parallel_safe = False
[docs] def __init__( self, volumes: int, slices: int, add_to_context: bool = True, add_annotations: bool = True, strict: bool = True, ) -> None: self.volumes = int(volumes) self.slices = int(slices) self.add_to_context = add_to_context self.add_annotations = add_annotations self.strict = strict super().__init__()
[docs] def validate(self, context: ProcessingContext) -> None: super().validate(context) if self.volumes < 1: raise ProcessorValidationError(f"volumes must be >= 1, got {self.volumes}") if self.slices < 1: raise ProcessorValidationError(f"slices must be >= 1, got {self.slices}")
[docs] def process(self, context: ProcessingContext) -> ProcessingContext: # --- EXTRACT --- triggers = np.sort(np.asarray(context.get_triggers(), dtype=int)) expected = int(self.volumes * self.slices) # --- LOG --- logger.info( "Completing triggers deterministically (expected={}, current={})", expected, len(triggers), ) # --- COMPUTE --- completed = self._complete_triggers(triggers, expected, self.slices) inserted = np.setdiff1d(completed, triggers) logger.info("Inserted {} trigger(s)", len(inserted)) # --- BUILD RESULT --- metadata = context.metadata.copy() metadata.custom["missing_trigger_completer"] = { "volumes": self.volumes, "slices": self.slices, "expected_triggers": expected, "original_triggers": int(len(triggers)), "completed_triggers": int(len(completed)), "inserted_triggers": inserted.tolist(), } if not self.add_to_context: return context.with_metadata(metadata) metadata.triggers = completed result = context.with_metadata(metadata) if self.add_annotations and context.has_raw() and len(inserted) > 0: raw = context.get_raw().copy() sfreq = raw.info["sfreq"] extra = mne.Annotations( onset=np.asarray(inserted, dtype=float) / sfreq, duration=np.zeros(len(inserted)), description=["completed_trigger"] * len(inserted), ) existing = raw.annotations merged = mne.Annotations( onset=np.concatenate([existing.onset, extra.onset]), duration=np.concatenate([existing.duration, extra.duration]), description=list(existing.description) + list(extra.description), ) raw.set_annotations(merged) return result.with_raw(raw) # --- RETURN --- return result
def _complete_triggers(self, triggers: np.ndarray, expected: int, slices: int) -> np.ndarray: """Reconstruct missing triggers with MATLAB-style heuristics.""" if len(triggers) == expected: return triggers if len(triggers) == 0: raise ProcessorValidationError("No triggers available for completion.") err_num = expected - len(triggers) if err_num < 0: raise ProcessorValidationError(f"Found more triggers ({len(triggers)}) than expected ({expected}).") diffs = np.diff(triggers) if len(diffs) == 0: raise ProcessorValidationError("Not enough triggers to infer missing positions.") trigs_per_section = 5 * slices sections = len(diffs) // trigs_per_section if sections > 0: diff_mat = diffs[: sections * trigs_per_section].reshape(sections, trigs_per_section) diff_max = int(np.min(np.max(diff_mat, axis=1))) diff_min = int(np.round(np.mean(np.min(diff_mat, axis=1)))) diff_err = int(np.max(np.max(diff_mat, axis=1))) else: diff_min = int(np.round(np.median(diffs))) diff_max = int(np.max(diffs)) diff_err = diff_max err_threshold = float(np.mean([diff_min, diff_err])) err_indices = np.where(diffs > err_threshold)[0] if len(err_indices) != err_num: if self.strict: raise ProcessorValidationError( f"Cannot determine missing triggers: expected {err_num}, detected {len(err_indices)} gaps." ) ranked = np.argsort(diffs)[::-1] err_indices = np.sort(ranked[:err_num]) volume_trigs = np.where((diffs >= (diff_max - 2)) & (diffs <= (diff_max + 2)))[0] reconstructed = triggers.tolist() for err_idx in err_indices: post_volume_candidates = np.where((volume_trigs - err_idx) > 0)[0] if len(post_volume_candidates) == 0 or post_volume_candidates[0] == 0: interval = diff_min else: post_volume_index = int(post_volume_candidates[0]) if (volume_trigs[post_volume_index] - volume_trigs[post_volume_index - 1]) > slices: interval = diff_max else: interval = diff_min reconstructed.append(int(triggers[err_idx] + interval)) all_trigs = np.array(sorted(set(reconstructed)), dtype=int) # Best-effort fallback for ambiguous cases in non-strict mode. if not self.strict and len(all_trigs) < expected: while len(all_trigs) < expected: d = np.diff(all_trigs) idx = int(np.argmax(d)) step = diff_max if d[idx] > (1.5 * diff_min) else diff_min candidate = int(all_trigs[idx] + step) all_trigs = np.array(sorted(set([*all_trigs.tolist(), candidate])), dtype=int) if len(all_trigs) != expected and self.strict: raise ProcessorValidationError( f"Trigger completion produced {len(all_trigs)} triggers, expected {expected}." ) return all_trigs
[docs] @register_processor class SliceTriggerGenerator(Processor): """Generate slice triggers from volume triggers. MATLAB FACET parity processor for ``GenerateSliceTriggers``. Parameters ---------- slices : int Number of slices per volume. duration_samples : float, optional Slice duration in samples. If ``None``, derived from trigger spacing. relative_position : float Relative position of first slice trigger with respect to the volume trigger (default: 0.03). add_annotations : bool, optional If ``True`` and raw is available, generated triggers are added as annotations (default: False). """ name = "slice_trigger_generator" description = "Generate slice triggers from volume triggers" version = "1.0.0" requires_triggers = True requires_raw = False modifies_raw = False parallel_safe = False
[docs] def __init__( self, slices: int, duration_samples: float | None = None, relative_position: float = 0.03, add_annotations: bool = False, ) -> None: self.slices = int(slices) self.duration_samples = duration_samples self.relative_position = relative_position self.add_annotations = add_annotations super().__init__()
[docs] def validate(self, context: ProcessingContext) -> None: super().validate(context) if self.slices < 1: raise ProcessorValidationError(f"slices must be >= 1, got {self.slices}") if self.duration_samples is not None and self.duration_samples <= 0: raise ProcessorValidationError(f"duration_samples must be positive when set, got {self.duration_samples}")
[docs] def process(self, context: ProcessingContext) -> ProcessingContext: # --- EXTRACT --- volume_triggers = np.asarray(context.get_triggers(), dtype=float) if len(volume_triggers) == 0: return context # --- COMPUTE --- duration = self._resolve_slice_duration(context, volume_triggers) slice_offsets = np.round((np.arange(self.slices) - self.relative_position) * duration).astype(int) generated = np.zeros(len(volume_triggers) * self.slices, dtype=int) for i, trigger in enumerate(volume_triggers.astype(int)): generated[i * self.slices : (i + 1) * self.slices] = trigger + slice_offsets metadata = context.metadata.copy() metadata.triggers = generated metadata.slices_per_volume = self.slices if len(generated) > 1: diffs = np.diff(generated) if np.ptp(diffs) > 3: mean_val = np.mean([np.median(diffs), np.max(diffs)]) slice_diffs = diffs[diffs < mean_val] metadata.artifact_length = int(np.max(slice_diffs)) if len(slice_diffs) > 0 else int(np.max(diffs)) metadata.volume_gaps = True else: metadata.artifact_length = int(np.max(diffs)) metadata.volume_gaps = False metadata.custom["slice_trigger_generator"] = { "slices": self.slices, "duration_samples": float(duration), "relative_position": float(self.relative_position), "num_generated_triggers": int(len(generated)), } result = context.with_metadata(metadata) if self.add_annotations and context.has_raw(): raw = context.get_raw().copy() sfreq = raw.info["sfreq"] raw.set_annotations( mne.Annotations( onset=generated.astype(float) / sfreq, duration=np.zeros(len(generated)), description=["generated_slice_trigger"] * len(generated), ) ) return result.with_raw(raw) return result
def _resolve_slice_duration(self, context: ProcessingContext, volume_triggers: np.ndarray) -> float: """Resolve slice period in samples.""" if self.duration_samples is not None: return float(self.duration_samples) if context.metadata.slices_per_volume and context.metadata.artifact_length: return float(context.metadata.artifact_length) if len(volume_triggers) > 1: vol_period = float(np.median(np.diff(volume_triggers))) return max(vol_period / float(self.slices), 1.0) if context.metadata.artifact_length: return float(context.metadata.artifact_length) raise ProcessorValidationError("Cannot infer slice duration; provide duration_samples explicitly.")