Source code for facet.evaluation.visualization

"""
Visualization Processors Module

This module contains processors for generating visual diagnostics of pipeline
results, including Matplotlib and MNE-Python based visualisations.

Author: FACETpy Team
Date: 2025-01-12
"""

import time
from collections.abc import Sequence
from pathlib import Path
from typing import Any

import numpy as np
from loguru import logger
from matplotlib import pyplot as plt
from scipy import signal

from ..console import suspend_raw_mode
from ..core import ProcessingContext, Processor, register_processor
from ..helpers.plotting import show_matplotlib_figure


[docs] @register_processor class RawPlotter(Processor): """Plot raw EEG data snippets during the pipeline. Supports both Matplotlib-based summary figures as well as native MNE-Python interactive plots. By default, a Matplotlib plot is generated and saved to the configured path, overlaying the current corrected signal with the original recording for quick visual inspection. Parameters ---------- mode : str, optional Plotting backend: ``'matplotlib'`` (default) or ``'mne'``. channel : str or int, optional Single channel to visualise (name or index). Defaults to the first EEG channel, falling back to the first channel overall. start : float, optional Start time in seconds of the snippet to plot (default: 0.0). duration : float, optional Duration in seconds of the snippet to plot (default: 10.0). overlay_original : bool, optional Overlay original recording when available (default: True). scale : float, optional Multiplier applied to amplitude values (default: 1e6 for V → µV). save_path : str or Path, optional File path to save the generated plot. show : bool, optional Whether to display the plot interactively (default: False). auto_close : bool, optional Close the figure after saving when running headless (default: True). figure_kwargs : dict, optional Additional keyword arguments forwarded to ``plt.subplots()``. mne_kwargs : dict, optional Additional keyword arguments forwarded to ``mne.io.Raw.plot()``. picks : sequence of int or str, optional Explicit channel picks for MNE plotting mode. title : str, optional Custom figure title for Matplotlib mode. """ name = "raw_plotter" description = "Plot raw data snippets for visual inspection." version = "1.0.0" requires_triggers = False requires_raw = True modifies_raw = False parallel_safe = False
[docs] def __init__( self, mode: str = "matplotlib", channel: str | int | None = None, start: float = 0.0, duration: float = 10.0, overlay_original: bool = True, scale: float = 1e6, save_path: str | Path | None = None, show: bool = False, auto_close: bool = True, figure_kwargs: dict[str, Any] | None = None, mne_kwargs: dict[str, Any] | None = None, picks: Sequence[int | str] | None = None, title: str | None = None, ) -> None: self.mode = mode.lower() self.channel = channel self.start = max(0.0, start) self.duration = duration self.overlay_original = overlay_original self.scale = scale self.save_path = Path(save_path) if save_path else None self.show = show self.auto_close = auto_close self.figure_kwargs = figure_kwargs or {} self.mne_kwargs = mne_kwargs or {} self.picks = picks self.title = title super().__init__()
[docs] def process(self, context: ProcessingContext) -> ProcessingContext: # --- EXTRACT --- raw = context.get_raw() if raw is None: logger.warning("No raw data available; skipping plot generation.") return context # --- LOG --- logger.info("Generating {} plot", self.mode) # --- COMPUTE --- if self.mode == "mne": self._plot_with_mne(raw) elif self.mode == "matplotlib": self._plot_with_matplotlib(raw, context) else: logger.error("Unknown plotting mode '{}'. Skipping plot.", self.mode) # --- RETURN --- return context
def _plot_with_mne(self, raw) -> None: """Use mne.io.Raw.plot to visualise the data.""" plot_kwargs: dict[str, Any] = dict( start=self.start, duration=self.duration, show=self.show, ) if self.picks is not None: plot_kwargs["picks"] = self.picks plot_kwargs.update(self.mne_kwargs) logger.info( "Generating MNE-Python plot (start=%.2fs, duration=%.2fs, picks=%s)", self.start, self.duration, self.picks, ) plot_kwargs["block"] = False fig = raw.plot(**plot_kwargs) # MNE returns different types depending on backend: # - matplotlib backend: Figure with savefig() # - Qt backend (mne-qt-browser): MNEQtBrowser without savefig() is_matplotlib_figure = hasattr(fig, "savefig") if self.save_path: self.save_path.parent.mkdir(parents=True, exist_ok=True) if is_matplotlib_figure: fig.savefig(self.save_path, dpi=150, bbox_inches="tight") logger.info("Saved MNE plot to {}", self.save_path) else: logger.warning( "MNE returned {} (Qt backend); savefig not supported. " "Use mode='matplotlib' or mne.viz.set_browser_backend('matplotlib') for saving.", type(fig).__name__, ) return if self.show: if is_matplotlib_figure: with suspend_raw_mode(): plt.show(block=False) while plt.fignum_exists(fig.number): fig.canvas.flush_events() time.sleep(0.05) # Qt backend handles display independently; nothing to do elif (self.auto_close or self.save_path) and is_matplotlib_figure: plt.close(fig) def _plot_with_matplotlib(self, raw, context: ProcessingContext) -> None: """Use Matplotlib to create a before/after comparison plot.""" channel_idx, channel_name = self._resolve_channel(raw) sfreq = raw.info["sfreq"] start_sample = int(self.start * sfreq) stop_sample = start_sample + int(self.duration * sfreq) if self.duration > 0 else raw.n_times stop_sample = min(stop_sample, raw.n_times) if stop_sample <= start_sample: stop_sample = raw.n_times times = np.arange(start_sample, stop_sample) / sfreq current = raw.get_data(picks=[channel_idx], start=start_sample, stop=stop_sample)[0] original = self._extract_original_overlay(context, channel_idx, times, sfreq) fig_kwargs = {"figsize": (12, 4)} fig_kwargs.update(self.figure_kwargs) fig, ax = plt.subplots(**fig_kwargs) ax.plot(times, current * self.scale, label="Corrected", alpha=0.8) if original is not None: ax.plot(times, original * self.scale, label="Original", alpha=0.6) ax.set_xlabel("Time (s)") ax.set_ylabel("Amplitude (uV)") ax.grid(True, alpha=0.2) if original is not None: ax.legend(loc="upper right") title = self.title or f"{channel_name}{self.duration:.1f}s snippet" ax.set_title(title) if self.save_path: self.save_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(self.save_path, dpi=150, bbox_inches="tight") logger.info("Saved Matplotlib plot to {}", self.save_path) if self.show: with suspend_raw_mode(): show_matplotlib_figure(fig) plt.close(fig) elif self.auto_close or self.save_path: plt.close(fig) def _extract_original_overlay( self, context: ProcessingContext, channel_idx: int, times: np.ndarray, sfreq: float, ) -> np.ndarray | None: """Retrieve original-recording data for overlay, resampling if needed. Parameters ---------- context : ProcessingContext Current processing context. channel_idx : int Index of the channel to extract. times : np.ndarray Time axis of the corrected snippet (used to determine length). sfreq : float Sampling frequency of the current (corrected) recording. Returns ------- np.ndarray or None Original data array aligned to ``times``, or ``None`` when unavailable. """ if not self.overlay_original: return None raw_original = context.get_raw_original() if raw_original is None: logger.warning("Original data unavailable; skipping overlay.") return None sfreq_original = raw_original.info["sfreq"] start_original = int(self.start * sfreq_original) stop_original = ( start_original + int(self.duration * sfreq_original) if self.duration > 0 else raw_original.n_times ) stop_original = min(stop_original, raw_original.n_times) if stop_original <= start_original: stop_original = raw_original.n_times if channel_idx >= len(raw_original.ch_names): logger.warning("Channel index out of range for original data; skipping overlay.") return None try: original_data = raw_original.get_data( picks=[channel_idx], start=start_original, stop=stop_original, ) except (IndexError, ValueError) as exc: logger.warning("Failed to extract original data: {}; skipping overlay.", exc) return None if original_data.size == 0: logger.warning("Original data returned empty array; skipping overlay.") return None original = original_data[0] if sfreq_original != sfreq: if len(original) <= 1: logger.warning( "Original data too short for overlay (length {} vs {}); skipping overlay.", len(original), len(times), ) return None original = signal.resample(original, len(times)) elif len(original) != len(times): logger.warning( "Original data length mismatch (length {} vs {}); skipping overlay.", len(original), len(times), ) return None return original def _resolve_channel(self, raw) -> tuple: """Resolve channel selection to index and name. Parameters ---------- raw : mne.io.Raw The Raw object to look up channel information from. Returns ------- tuple of (int, str) Channel index and channel name. """ if self.channel is None: for idx, ch_type in enumerate(raw.get_channel_types()): if ch_type == "eeg": return idx, raw.ch_names[idx] return 0, raw.ch_names[0] if isinstance(self.channel, int): idx = max(0, min(self.channel, len(raw.ch_names) - 1)) return idx, raw.ch_names[idx] if isinstance(self.channel, str): try: idx = raw.ch_names.index(self.channel) return idx, raw.ch_names[idx] except ValueError: logger.warning( "Requested channel '{}' not found. Falling back to first channel.", self.channel, ) return 0, raw.ch_names[0] logger.warning("Unsupported channel specifier {}. Using first channel.", self.channel) return 0, raw.ch_names[0]