Source code for facet.preprocessing.trigger_explorer

"""Interactive Trigger Explorer Processor

Scans EEG data for available annotations and STIM channel events, presents
them to the user via a GUI preview window or terminal table, and lets them
interactively select the trigger source before proceeding with the pipeline.
"""

import re
from typing import Any

import mne
import numpy as np
from loguru import logger

from ..console import get_console, suspend_raw_mode
from ..core import (
    ProcessingContext,
    Processor,
    ProcessorError,
    ProcessorValidationError,
    register_processor,
)

_DISPLAY_DOWNSAMPLE_TARGET = 5000


[docs] @register_processor class TriggerExplorer(Processor): """Interactively explore and select trigger events from annotations or STIM channels. Scans the loaded data for all available event sources (MNE annotations and STIM channels) and lets the user pick the correct trigger source. Three interaction modes are supported: ``"gui"`` (default) Opens a matplotlib window with a downsampled preview of the first EEG channel. Radio buttons list every discovered event type; selecting one highlights the corresponding trigger positions on the waveform. A *Confirm* button finalises the choice. Falls back to ``"terminal"`` automatically when no GUI backend is available. ``"terminal"`` Prints a Rich table and prompts for a selection number in the terminal. ``"auto"`` Alias for ``"gui"`` — kept for backward compatibility. When ``auto_select`` is provided, all interactive prompts are skipped and the first event matching the regex is chosen automatically — useful for scripted / non-interactive pipelines. Parameters ---------- mode : str, optional Interaction mode: ``"gui"`` (default), ``"terminal"``, or ``"auto"``. auto_select : str or None, optional If given, automatically select the event whose description matches this regex, bypassing any interactive prompt (default: None). save_to_annotations : bool, optional If ``True``, write detected triggers back to the raw annotations (default: False). """ name = "trigger_explorer" description = "Interactively explore and select trigger events" version = "1.1.0" requires_triggers = False requires_raw = True modifies_raw = False parallel_safe = False
[docs] def __init__( self, mode: str = "gui", auto_select: str | None = None, save_to_annotations: bool = False, ) -> None: self.mode = mode self.auto_select = auto_select self.save_to_annotations = save_to_annotations super().__init__()
[docs] def validate(self, context: ProcessingContext) -> None: super().validate(context) if self.mode not in ("gui", "terminal", "auto"): raise ProcessorValidationError(f"mode must be 'gui', 'terminal', or 'auto', got '{self.mode}'")
[docs] def process(self, context: ProcessingContext) -> ProcessingContext: # --- EXTRACT --- raw = context.get_raw() # --- LOG --- sfreq = raw.info["sfreq"] t_start = raw.times[0] t_end = raw.times[-1] n_raw_annotations = len(raw.annotations) logger.info( "Exploring available trigger sources (mode={}) | " "data window: {:.2f}s – {:.2f}s ({:.1f}s) | " "raw.annotations: {}", self.mode, t_start, t_end, t_end - t_start, n_raw_annotations, ) if n_raw_annotations > 0: onset_min = raw.annotations.onset.min() onset_max = raw.annotations.onset.max() logger.debug("raw.annotations onset range: {:.2f}s – {:.2f}s", onset_min, onset_max) # --- COMPUTE --- annotation_events = self._collect_annotation_events(raw) stim_events = self._collect_stim_events(raw) logger.debug( "Collected {} annotation event type(s) and {} STIM event type(s)", len(annotation_events), len(stim_events), ) if len(annotation_events) == 0 and len(stim_events) == 0: hint = "" if n_raw_annotations == 0: hint = ( " Tip: raw.annotations is empty — if the file has triggers, " "they may have been removed by a preceding Crop step " "(check that tmin/tmax covers the trigger region)." ) raise ProcessorError( f"No trigger sources found: data contains neither annotations nor STIM channels.{hint}" ) event_table = self._build_event_table(annotation_events, stim_events) selected = self._select_event(event_table, raw) regex = self._regex_for_selection(selected) triggers = self._detect_triggers(raw, selected) if len(triggers) == 0: raise ProcessorError(f"Selection '{selected['description']}' matched 0 triggers.") logger.info( "Selected trigger '{}' → {} events detected", selected["description"], len(triggers), ) artifact_meta = self._compute_artifact_metadata(triggers) # --- BUILD RESULT --- new_metadata = context.metadata.copy() new_metadata.triggers = triggers new_metadata.trigger_regex = regex new_metadata.artifact_length = artifact_meta["artifact_length"] new_metadata.volume_gaps = artifact_meta["volume_gaps"] if artifact_meta.get("slices_per_volume") is not None: new_metadata.slices_per_volume = artifact_meta["slices_per_volume"] if self.save_to_annotations: sfreq = raw.info["sfreq"] raw_copy = raw.copy() raw_copy.set_annotations( mne.Annotations( onset=triggers / sfreq, duration=np.zeros(len(triggers)), description=["Trigger"] * len(triggers), ) ) return context.with_raw(raw_copy).with_metadata(new_metadata) # --- RETURN --- return context.with_metadata(new_metadata)
# ------------------------------------------------------------------ # Event collection helpers # ------------------------------------------------------------------ def _collect_annotation_events(self, raw: mne.io.Raw) -> list[dict[str, Any]]: """Gather unique annotation descriptions with counts and timing info. Uses ``mne.events_from_annotations`` rather than direct ``raw.annotations`` access to handle all EDF/EDF+ annotation formats reliably. Parameters ---------- raw : mne.io.Raw Raw object to scan. Returns ------- list of dict Each entry: ``{description, count, first_onset, last_onset}``. """ try: events, event_id = mne.events_from_annotations(raw, verbose=False) except (ValueError, RuntimeError) as exc: logger.debug("events_from_annotations raised {}: {}", type(exc).__name__, exc) return [] logger.debug( "events_from_annotations: {} event(s), types: {}", len(events), list(event_id.keys()), ) if len(events) == 0: return [] id_to_desc = {v: k for k, v in event_id.items()} sfreq = raw.info["sfreq"] desc_map: dict[str, list[float]] = {} for event in events: desc = id_to_desc.get(int(event[2]), str(event[2])) onset = event[0] / sfreq desc_map.setdefault(desc, []).append(onset) results = [] for desc, onsets in sorted(desc_map.items()): results.append( { "description": desc, "count": len(onsets), "first_onset": min(onsets), "last_onset": max(onsets), } ) results = self._maybe_group_sequential_annotations(results) return results @staticmethod def _maybe_group_sequential_annotations(events: list[dict[str, Any]]) -> list[dict[str, Any]]: """Merge annotations that are 'prefix + number' sequences into one entry. Many EDF files encode fMRI scan triggers as sequential annotations (e.g. ``"TR 1"``, ``"TR 2"``, …, ``"TR 346"``), each appearing exactly once. Showing 346 separate one-occurrence rows in the GUI is unusable; this method collapses them into a single ``"TR"`` entry with ``count=346`` and the ``grouped_prefix=True`` flag so that downstream methods know to match the whole sequence rather than a single label. A group is formed whenever **two or more** annotation descriptions share the same text prefix and differ only in a trailing integer. Parameters ---------- events : list of dict Raw per-description entries from ``_collect_annotation_events``. Returns ------- list of dict Same format; sequential groups replaced by a single merged entry. """ _NUMERIC_SUFFIX_RE = re.compile(r"^(.+?)\s+\d+$") prefix_groups: dict[str, list[dict[str, Any]]] = {} no_prefix: list[dict[str, Any]] = [] for ev in events: m = _NUMERIC_SUFFIX_RE.match(ev["description"]) if m: prefix_groups.setdefault(m.group(1), []).append(ev) else: no_prefix.append(ev) merged: list[dict[str, Any]] = list(no_prefix) for prefix, group in sorted(prefix_groups.items()): if len(group) < 2: merged.extend(group) else: merged.append( { "description": prefix, "count": sum(e["count"] for e in group), "first_onset": min(e["first_onset"] for e in group), "last_onset": max(e["last_onset"] for e in group), "grouped_prefix": True, } ) return merged def _collect_stim_events(self, raw: mne.io.Raw) -> list[dict[str, Any]]: """Gather unique STIM channel event values with counts and timing info. Parameters ---------- raw : mne.io.Raw Raw object to scan. Returns ------- list of dict Each entry: ``{description, count, first_sample, last_sample, channel_name}``. """ stim_picks = mne.pick_types(raw.info, meg=False, eeg=False, stim=True) if len(stim_picks) == 0: return [] results = [] for ch_idx in stim_picks: ch_name = raw.ch_names[ch_idx] events = mne.find_events(raw, stim_channel=ch_name, initial_event=True, verbose=False) if len(events) == 0: continue value_map: dict[int, list[int]] = {} for event in events: value_map.setdefault(int(event[2]), []).append(int(event[0])) for value, samples in sorted(value_map.items()): results.append( { "description": str(value), "count": len(samples), "first_sample": min(samples), "last_sample": max(samples), "channel_name": ch_name, } ) return results # ------------------------------------------------------------------ # Event table building # ------------------------------------------------------------------ def _build_event_table( self, annotation_events: list[dict[str, Any]], stim_events: list[dict[str, Any]], ) -> list[dict[str, Any]]: """Merge annotation and stim events into a numbered table. Parameters ---------- annotation_events : list of dict Events from annotations. stim_events : list of dict Events from STIM channels. Returns ------- list of dict Unified table with ``index``, ``source``, ``description``, ``count``, and ``detail`` keys. """ table: list[dict[str, Any]] = [] idx = 1 for ev in annotation_events: time_range = f"{ev['first_onset']:.2f}s – {ev['last_onset']:.2f}s" if ev.get("grouped_prefix"): detail = f"{time_range} (sequential: '{ev['description']} 1' … '{ev['description']} N')" else: detail = time_range table.append( { "index": idx, "source": "annotation", "description": ev["description"], "count": ev["count"], "detail": detail, "grouped_prefix": ev.get("grouped_prefix", False), } ) idx += 1 for ev in stim_events: sfreq_label = ev.get("channel_name", "STIM") table.append( { "index": idx, "source": f"stim ({sfreq_label})", "description": ev["description"], "count": ev["count"], "detail": (f"sample {ev['first_sample']}{ev['last_sample']}"), } ) idx += 1 return table # ------------------------------------------------------------------ # Selection dispatch # ------------------------------------------------------------------ def _select_event(self, event_table: list[dict[str, Any]], raw: mne.io.Raw) -> dict[str, Any]: """Select a trigger event via the configured interaction mode. Parameters ---------- event_table : list of dict Unified event table. raw : mne.io.Raw Raw object (needed for the GUI preview plot). Returns ------- dict The selected row from the event table. """ if self.auto_select is not None: return self._auto_select_event(event_table) if self.mode in ("gui", "auto"): if self._gui_backend_available(): return self._gui_select_event(event_table, raw) logger.warning("No matplotlib GUI backend available — falling back to terminal mode") self._display_event_table(event_table) return self._terminal_select_event(event_table) @staticmethod def _gui_backend_available() -> bool: """Return True if matplotlib can open an interactive window.""" try: import matplotlib backend = matplotlib.get_backend().lower() non_interactive = {"agg", "pdf", "svg", "ps", "cairo", "template"} return backend not in non_interactive except ImportError: return False # ------------------------------------------------------------------ # GUI selection mode # ------------------------------------------------------------------ def _gui_select_event(self, event_table: list[dict[str, Any]], raw: mne.io.Raw) -> dict[str, Any]: """Open a matplotlib window with a waveform preview for interactive selection. Parameters ---------- event_table : list of dict Unified event table. raw : mne.io.Raw Raw object used for the preview plot. Returns ------- dict The confirmed row from the event table. """ import matplotlib.pyplot as plt from matplotlib.widgets import Button, RadioButtons sfreq = raw.info["sfreq"] ch_idx, ch_name = self._pick_preview_channel(raw) data, times = self._downsample_for_display(raw, ch_idx) trigger_times_map = {} for row in event_table: trigs = self._detect_triggers(raw, row) trigger_times_map[row["index"]] = trigs / sfreq fig, axes = self._create_gui_layout(len(event_table)) ax_plot, ax_radio, ax_info, ax_btn = axes ax_plot.plot(times, data, linewidth=0.5, color="#2196F3", rasterized=True) ax_plot.set_xlabel("Time (s)") ax_plot.set_ylabel("Amplitude") ax_plot.margins(x=0) labels = [f"{r['description']} ({r['count']})" for r in event_table] radio = RadioButtons(ax_radio, labels, active=0) for lbl in radio.labels: lbl.set_fontsize(9) state: dict[str, Any] = { "selected": event_table[0], "confirmed": False, "vlines": None, } def on_radio_change(label: str) -> None: idx = labels.index(label) state["selected"] = event_table[idx] self._update_gui_plot( ax_plot, state, trigger_times_map[event_table[idx]["index"]], ch_name, data, ) self._update_gui_info(ax_info, event_table[idx]) fig.canvas.draw_idle() radio.on_clicked(on_radio_change) on_radio_change(labels[0]) btn = Button(ax_btn, "Confirm Selection", color="#4CAF50", hovercolor="#66BB6A") btn.label.set_fontweight("bold") btn.label.set_fontsize(11) def on_confirm(_event: Any) -> None: state["confirmed"] = True plt.close(fig) btn.on_clicked(on_confirm) logger.info("Waiting for trigger selection in GUI window…") plt.show(block=True) if not state["confirmed"]: raise ProcessorError("Trigger selection cancelled (window closed without confirming).") return state["selected"] @staticmethod def _pick_preview_channel(raw: mne.io.Raw) -> tuple: """Return ``(ch_idx, ch_name)`` for the first EEG channel. Parameters ---------- raw : mne.io.Raw Raw object. Returns ------- tuple of (int, str) Channel index and name. """ eeg_picks = mne.pick_types(raw.info, eeg=True, exclude="bads") if len(eeg_picks) > 0: return int(eeg_picks[0]), raw.ch_names[eeg_picks[0]] return 0, raw.ch_names[0] @staticmethod def _downsample_for_display(raw: mne.io.Raw, ch_idx: int) -> tuple: """Return a downsampled ``(data, times)`` pair for plotting. Parameters ---------- raw : mne.io.Raw Raw object. ch_idx : int Channel index to extract. Returns ------- tuple of (np.ndarray, np.ndarray) Downsampled amplitude and corresponding time arrays. """ full_data = raw.get_data(picks=[ch_idx])[0] n_samples = len(full_data) sfreq = raw.info["sfreq"] if n_samples > _DISPLAY_DOWNSAMPLE_TARGET: step = n_samples // _DISPLAY_DOWNSAMPLE_TARGET data = full_data[::step] times = np.arange(len(data)) * (step / sfreq) else: data = full_data times = np.arange(n_samples) / sfreq return data, times @staticmethod def _create_gui_layout(n_events: int) -> tuple: """Build the matplotlib figure and axes for the explorer window. Parameters ---------- n_events : int Number of event rows (controls radio-button area height). Returns ------- tuple ``(fig, (ax_plot, ax_radio, ax_info, ax_btn))``. """ import matplotlib.pyplot as plt fig = plt.figure(figsize=(15, 7)) fig.canvas.manager.set_window_title("FACETpy — Trigger Explorer") fig.patch.set_facecolor("#FAFAFA") radio_height = max(0.40, min(0.68, n_events * 0.065)) ax_plot = fig.add_axes([0.06, 0.13, 0.56, 0.80]) ax_radio = fig.add_axes( [0.67, 0.93 - radio_height, 0.30, radio_height], facecolor="#F5F5F5", ) ax_info = fig.add_axes( [0.67, 0.93 - radio_height - 0.17, 0.30, 0.14], facecolor="#FAFAFA", ) ax_btn = fig.add_axes([0.67, 0.04, 0.30, 0.08]) ax_radio.set_title("Trigger Sources", fontsize=10, fontweight="bold", loc="left") ax_info.set_xticks([]) ax_info.set_yticks([]) for spine in ax_info.spines.values(): spine.set_visible(False) return fig, (ax_plot, ax_radio, ax_info, ax_btn) @staticmethod def _update_gui_plot( ax: Any, state: dict[str, Any], trigger_times: np.ndarray, ch_name: str, data: np.ndarray, ) -> None: """Redraw trigger markers on the waveform axes. Parameters ---------- ax : matplotlib.axes.Axes The waveform axes. state : dict Mutable state dict holding the current ``vlines`` collection. trigger_times : np.ndarray Trigger onset times in seconds. ch_name : str Preview channel name (for the title). data : np.ndarray Downsampled amplitude array (for y-limits). """ if state["vlines"] is not None: state["vlines"].remove() state["vlines"] = None if len(trigger_times) > 0: ymin, ymax = np.min(data), np.max(data) margin = (ymax - ymin) * 0.05 state["vlines"] = ax.vlines( trigger_times, ymin - margin, ymax + margin, colors="#FF5722", alpha=0.45, linewidth=0.6, label="triggers", ) n_shown = len(trigger_times) ax.set_title( f"{ch_name}{n_shown} trigger{'s' if n_shown != 1 else ''} shown", fontsize=10, ) @staticmethod def _update_gui_info(ax: Any, row: dict[str, Any]) -> None: """Update the info text box below the radio buttons. Parameters ---------- ax : matplotlib.axes.Axes Info axes (text only, no ticks). row : dict Currently selected event-table row. """ ax.clear() ax.set_xticks([]) ax.set_yticks([]) for spine in ax.spines.values(): spine.set_visible(False) lines = [ f"Source: {row['source']}", f"Value: {row['description']}", f"Count: {row['count']}", f"Range: {row['detail']}", ] ax.text( 0.05, 0.90, "\n".join(lines), transform=ax.transAxes, fontsize=9, verticalalignment="top", fontfamily="monospace", ) # ------------------------------------------------------------------ # Terminal / Rich table display # ------------------------------------------------------------------ def _display_event_table(self, event_table: list[dict[str, Any]]) -> None: """Render the event table to the console using Rich. Parameters ---------- event_table : list of dict Unified event table produced by ``_build_event_table``. """ try: from rich.console import Console as RichConsole from rich.table import Table except ImportError: self._display_event_table_plain(event_table) return console_obj = get_console() rich_console = console_obj.get_rich_console() if rich_console is None: rich_console = RichConsole() table = Table( title="Available Trigger Sources", show_header=True, header_style="bold cyan", ) table.add_column("#", style="bold", width=4, justify="right") table.add_column("Source", style="magenta", min_width=14) table.add_column("Description / Value", style="green", min_width=18) table.add_column("Count", justify="right", style="yellow", min_width=7) table.add_column("Range", style="dim", min_width=20) for row in event_table: table.add_row( str(row["index"]), row["source"], row["description"], str(row["count"]), row["detail"], ) rich_console.print() rich_console.print(table) rich_console.print() @staticmethod def _display_event_table_plain(event_table: list[dict[str, Any]]) -> None: """Fallback plain-text display when Rich is unavailable. Parameters ---------- event_table : list of dict Unified event table. """ header = f"{'#':>4} {'Source':<16} {'Description':<20} {'Count':>7} {'Range'}" sep = "-" * len(header) lines = ["\nAvailable Trigger Sources", sep, header, sep] for row in event_table: lines.append( f"{row['index']:>4} {row['source']:<16} {row['description']:<20} {row['count']:>7} {row['detail']}" ) lines.append(sep) print("\n".join(lines)) # ------------------------------------------------------------------ # Terminal selection # ------------------------------------------------------------------ def _auto_select_event(self, event_table: list[dict[str, Any]]) -> dict[str, Any]: """Select the first event whose description matches ``auto_select``. Parameters ---------- event_table : list of dict Unified event table. Returns ------- dict Matched row. """ pattern = re.compile(self.auto_select) # type: ignore[arg-type] for row in event_table: if pattern.search(row["description"]): logger.info( "Auto-selected trigger '{}' (matched '{}')", row["description"], self.auto_select, ) return row descriptions = [r["description"] for r in event_table] raise ProcessorError( f"auto_select pattern '{self.auto_select}' did not match any event. Available: {descriptions}" ) def _terminal_select_event(self, event_table: list[dict[str, Any]]) -> dict[str, Any]: """Prompt the user to pick a trigger event by number or description. Parameters ---------- event_table : list of dict Unified event table. Returns ------- dict Selected row. """ console_obj = get_console() max_idx = len(event_table) with suspend_raw_mode(): console_obj.set_active_prompt(f"Select trigger source [1-{max_idx}]: ") try: answer = input(f"Select trigger source [1-{max_idx}] or type a description: ").strip() finally: console_obj.clear_active_prompt() if not answer: raise ProcessorError("No trigger source selected (empty input).") if answer.isdigit(): choice = int(answer) if 1 <= choice <= max_idx: return event_table[choice - 1] raise ProcessorError(f"Invalid selection '{choice}'. Must be between 1 and {max_idx}.") for row in event_table: if row["description"] == answer: return row pattern = re.compile(answer) matches = [r for r in event_table if pattern.search(r["description"])] if len(matches) == 1: return matches[0] if len(matches) > 1: descs = [m["description"] for m in matches] raise ProcessorError( f"Pattern '{answer}' matched multiple events: {descs}. Please be more specific or use the row number." ) raise ProcessorError(f"'{answer}' does not match any available event description.") # ------------------------------------------------------------------ # Trigger detection # ------------------------------------------------------------------ @staticmethod def _regex_for_selection(selected: dict[str, Any]) -> str: """Build a regex pattern anchored to the selected event description. For grouped sequential annotations (e.g. ``"TR 1"``, ``"TR 2"``, …) the stored description is the bare prefix (``"TR"``). The regex is expanded to ``^TR\\s+\\d+$`` so that downstream ``TriggerDetector`` steps correctly match all members of the sequence. Parameters ---------- selected : dict Selected row from the event table. Returns ------- str Regex pattern suitable for ``TriggerDetector``. """ desc = selected["description"] if selected.get("grouped_prefix"): return rf"^{re.escape(desc)}\s+\d+$" return rf"\b{re.escape(desc)}\b" def _detect_triggers(self, raw: mne.io.Raw, selected: dict[str, Any]) -> np.ndarray: """Detect trigger sample positions for the selected event. Parameters ---------- raw : mne.io.Raw Raw object. selected : dict Selected row from the event table. Returns ------- np.ndarray Trigger sample positions (int64). """ desc = selected["description"] source = selected["source"] if source.startswith("stim"): pattern = re.compile(rf"\b{re.escape(desc)}\b") triggers = self._triggers_from_stim(raw, pattern, source) elif selected.get("grouped_prefix"): triggers = self._triggers_from_annotations_prefix(raw, desc) else: triggers = self._triggers_from_annotations(raw, desc) # MNE returns absolute sample indices (onset * sfreq + first_samp). # Normalize to 0-indexed positions relative to the current raw start # so that triggers can be used directly as indices into raw._data. return triggers - raw.first_samp @staticmethod def _triggers_from_annotations_prefix(raw: mne.io.Raw, prefix: str) -> np.ndarray: """Extract trigger positions for all annotations matching ``prefix N``. Used when the user selects a grouped sequential annotation (e.g. all ``"TR N"`` triggers). Parameters ---------- raw : mne.io.Raw Raw object. prefix : str The common text prefix (e.g. ``"TR"``). Returns ------- np.ndarray Sorted trigger sample positions (int64). """ pattern = rf"^{re.escape(prefix)}\s+\d+$" events, _ = mne.events_from_annotations(raw, regexp=pattern, verbose=False) if len(events) == 0: return np.array([], dtype=np.int64) return np.array(sorted(events[:, 0]), dtype=np.int64) @staticmethod def _triggers_from_stim(raw: mne.io.Raw, pattern: re.Pattern, source: str) -> np.ndarray: """Extract trigger positions from a STIM channel. Parameters ---------- raw : mne.io.Raw Raw object. pattern : re.Pattern Compiled regex to match event values. source : str Source label containing the channel name in parentheses. Returns ------- np.ndarray Trigger sample positions. """ ch_match = re.search(r"\((.+)\)", source) stim_picks = mne.pick_types(raw.info, meg=False, eeg=False, stim=True) ch_name = ch_match.group(1) if ch_match else raw.ch_names[stim_picks[0]] events = mne.find_events(raw, stim_channel=ch_name, initial_event=True, verbose=False) filtered = [ev for ev in events if pattern.search(str(ev[2]))] return np.array([ev[0] for ev in filtered], dtype=np.int64) @staticmethod def _triggers_from_annotations(raw: mne.io.Raw, description: str) -> np.ndarray: """Extract trigger positions from annotations matching ``description``. Parameters ---------- raw : mne.io.Raw Raw object. description : str Exact annotation description to match. Returns ------- np.ndarray Trigger sample positions. """ regex = rf"\b{re.escape(description)}\b" events, _ = mne.events_from_annotations(raw, regexp=regex, verbose=False) if len(events) == 0: return np.array([], dtype=np.int64) return np.array([ev[0] for ev in events], dtype=np.int64) # ------------------------------------------------------------------ # Artifact metadata (shared with TriggerDetector) # ------------------------------------------------------------------ def _compute_artifact_metadata(self, triggers: np.ndarray) -> dict: """Estimate artifact length and detect volume gaps from trigger spacing. Parameters ---------- triggers : np.ndarray Detected trigger sample positions. Returns ------- dict Keys: ``artifact_length``, ``volume_gaps``, optionally ``slices_per_volume``. """ if len(triggers) <= 1: return {"artifact_length": None, "volume_gaps": False} trigger_diffs = np.diff(triggers) ptp = np.ptp(trigger_diffs) if ptp > 3: return self._compute_slice_volume_metadata(triggers, trigger_diffs) return { "artifact_length": int(np.max(trigger_diffs)), "volume_gaps": False, } def _compute_slice_volume_metadata(self, triggers: np.ndarray, trigger_diffs: np.ndarray) -> dict: """Compute metadata when volume-level gaps are present. Parameters ---------- triggers : np.ndarray All trigger sample positions. trigger_diffs : np.ndarray Differences between consecutive triggers. Returns ------- dict Keys: ``artifact_length``, ``volume_gaps``, ``slices_per_volume``. """ mean_val = np.mean([np.median(trigger_diffs), np.max(trigger_diffs)]) slice_diffs = trigger_diffs[trigger_diffs < mean_val] artifact_length = int(np.max(slice_diffs)) gap_indices = np.where(trigger_diffs >= mean_val)[0] slices_per_volume = None if len(gap_indices) > 0: slice_counts = [] last_idx = -1 for idx in gap_indices: slice_counts.append(idx - last_idx) last_idx = idx if last_idx < len(triggers) - 1: slice_counts.append(len(triggers) - 1 - last_idx) if slice_counts: slices_per_volume = int(np.median(slice_counts)) logger.info("Estimated slices per volume: {}", slices_per_volume) return { "artifact_length": artifact_length, "volume_gaps": True, "slices_per_volume": slices_per_volume, }
InteractiveTriggerExplorer = TriggerExplorer