Source code for facet.io.loaders

"""
Data Loaders Module

This module contains processors for loading EEG data from various formats.

Author: FACETpy Team
Date: 2025-01-12
"""

from collections.abc import Callable
from importlib.metadata import PackageNotFoundError
from importlib.metadata import version as pkg_version
from numbers import Integral, Real
from pathlib import Path
from typing import Any

import mne
from loguru import logger
from mne.io import BaseRaw
from mne_bids import BIDSPath, read_raw_bids

from facet.logging_config import suppress_stdout

from ..core import (
    ProcessingContext,
    ProcessingMetadata,
    Processor,
    ProcessorValidationError,
    register_processor,
)

_EXTENSION_READERS: dict[str, tuple[Callable[..., BaseRaw], str]] = {
    ".edf": (mne.io.read_raw_edf, "EDF"),
    ".bdf": (mne.io.read_raw_bdf, "BDF"),
    ".gdf": (mne.io.read_raw_gdf, "GDF"),
    ".vhdr": (mne.io.read_raw_brainvision, "BrainVision"),
    ".set": (mne.io.read_raw_eeglab, "EEGLAB"),
    ".fif": (mne.io.read_raw_fif, "FIF"),
    ".mff": (mne.io.read_raw_egi, "MFF"),
}

SUPPORTED_EXTENSIONS: list[str] = sorted(_EXTENSION_READERS.keys())


def _ensure_mff_runtime_dependencies() -> None:
    """Ensure optional dependencies required by MFF loading are usable.

    Some dependency combinations ship ``defusedxml`` without a module-level
    ``__version__`` attribute while older MNE code still expects it.
    """
    try:
        import defusedxml  # type: ignore
    except ImportError as exc:
        raise ProcessorValidationError(
            "Missing dependency for .mff loading: defusedxml. Install with `uv sync` or `pip install defusedxml`."
        ) from exc

    if getattr(defusedxml, "__version__", None):
        return

    try:
        defusedxml.__version__ = pkg_version("defusedxml")
    except PackageNotFoundError:
        defusedxml.__version__ = "unknown"


def _detect_format(path: Path) -> tuple[Callable[..., BaseRaw], str]:
    """Detect the EEG file format from the file extension.

    Parameters
    ----------
    path : Path
        Path to the EEG data file (or directory for MFF format).

    Returns
    -------
    tuple of (callable, str)
        The MNE reader function and a human-readable format name.

    Raises
    ------
    ProcessorValidationError
        If the extension is not recognized.
    """
    suffixes = path.suffixes
    if len(suffixes) >= 2 and suffixes[-2] == ".fif" and suffixes[-1] == ".gz":
        return _EXTENSION_READERS[".fif"]

    ext = path.suffix.lower()
    if ext in _EXTENSION_READERS:
        return _EXTENSION_READERS[ext]

    raise ProcessorValidationError(
        f"Unsupported file extension '{ext}' for '{path.name}'. "
        f"Supported extensions: {', '.join(SUPPORTED_EXTENSIONS)}. "
        f"For BIDS datasets, use BIDSLoader directly."
    )


def _coerce_sample_index(value: float | None, default: int, name: str) -> int:
    """Convert a potentially optional numeric input into a valid sample index."""
    if value is None:
        return default
    if isinstance(value, bool):
        raise ValueError(f"{name} must be an integer-like value, got boolean")
    if isinstance(value, Integral):
        return int(value)
    if isinstance(value, Real):
        if not float(value).is_integer():
            raise ValueError(f"{name} must be an integer number of samples")
        return int(value)
    raise ValueError(f"{name} must be an integer number of samples")


def _apply_sample_window(
    raw: BaseRaw,
    start_sample: int | None,
    stop_sample: int | None,
) -> tuple[BaseRaw, int, int]:
    """Restrict a Raw object to the requested sample window."""
    n_times = raw.n_times
    if n_times == 0:
        raise ValueError("Cannot apply a sample window to an empty recording")

    start = _coerce_sample_index(start_sample, 0, "start_sample")
    stop = _coerce_sample_index(stop_sample, n_times, "stop_sample")

    if start < 0:
        raise ValueError("start_sample must be non-negative")
    if start >= n_times:
        raise ValueError(f"start_sample ({start}) is out of bounds for data with {n_times} samples")
    if stop <= start:
        raise ValueError(f"stop_sample ({stop}) must be greater than start_sample ({start})")
    if stop > n_times:
        raise ValueError(f"stop_sample ({stop}) exceeds total samples ({n_times})")

    if start > 0 or stop < n_times:
        sfreq = raw.info["sfreq"]
        tmin = start / sfreq if start > 0 else None
        tmax = (stop - 1) / sfreq if stop < n_times else None
        raw.crop(tmin=tmin, tmax=tmax, verbose=False)

    return raw, start, stop


def _configure_bad_channels(raw: mne.io.Raw, bad_channels: list[str]) -> None:
    """Mark specified channels as bad in the Raw object.

    Parameters
    ----------
    raw : mne.io.Raw
        The Raw object to update in-place.
    bad_channels : list of str
        Channel names to mark as bad.
    """
    if not bad_channels:
        return

    existing = set(raw.ch_names)
    valid_bads = [ch for ch in bad_channels if ch in existing]
    missing = set(bad_channels) - existing

    if missing:
        logger.warning(
            "Skipping {} bad channel(s) not present in data: {}",
            len(missing),
            ", ".join(sorted(missing)),
        )
        logger.debug("Available channels: {}", ", ".join(raw.ch_names))

    raw.info["bads"] = valid_bads

    if valid_bads:
        logger.debug("Marked {} bad channel(s): {}", len(valid_bads), ", ".join(valid_bads))
    else:
        logger.info("No valid bad channels found to mark; leaving dataset unchanged.")


def _build_context_from_raw(
    raw: BaseRaw,
    bad_channels: list[str],
    start_sample: int | None,
    stop_sample: int | None,
    artifact_to_trigger_offset: float,
    upsampling_factor: int,
    extra_custom: dict[str, Any] | None = None,
) -> ProcessingContext:
    """Post-read pipeline shared by all loaders.

    Applies bad-channel marking, sample windowing, metadata construction,
    and result logging.

    Parameters
    ----------
    raw : BaseRaw
        The MNE Raw object returned by the format-specific reader.
    bad_channels : list of str
        Channel names to mark as bad.
    start_sample : int or None
        First sample to keep (inclusive).
    stop_sample : int or None
        Last sample to keep (exclusive).
    artifact_to_trigger_offset : float
        Offset in seconds stored in metadata.
    upsampling_factor : int
        Upsampling factor stored in metadata.
    extra_custom : dict, optional
        Additional entries for ``metadata.custom``.

    Returns
    -------
    ProcessingContext
        New context wrapping the (possibly cropped) Raw object.
    """
    _configure_bad_channels(raw, bad_channels)

    full_n_times = raw.n_times
    try:
        raw, start_idx, stop_idx = _apply_sample_window(raw, start_sample, stop_sample)
    except ValueError as exc:
        raise ProcessorValidationError(f"Invalid sample window: {exc}") from exc

    if start_idx != 0 or stop_idx != full_n_times:
        logger.info(
            "Applied sample window: start={}, stop={} (exclusive), kept {}/{} samples",
            start_idx,
            stop_idx,
            raw.n_times,
            full_n_times,
        )

    acq_start = start_idx if start_sample is not None else None
    acq_end = stop_idx if stop_sample is not None else None

    metadata = ProcessingMetadata(
        artifact_to_trigger_offset=artifact_to_trigger_offset,
        upsampling_factor=upsampling_factor,
        acq_start_sample=acq_start,
        acq_end_sample=acq_end,
    )
    if extra_custom:
        metadata.custom.update(extra_custom)

    logger.info(
        "Loaded {} channels, {} samples, {} Hz",
        len(raw.ch_names),
        raw.n_times,
        raw.info["sfreq"],
    )

    return ProcessingContext(raw=raw, metadata=metadata)


[docs] @register_processor class Loader(Processor): """Load EEG data with automatic file-format detection. Inspects the file extension and selects the appropriate MNE reader. Supports EDF, BDF, GDF, BrainVision (.vhdr), EEGLAB (.set), FIF (.fif / .fif.gz), and EGI MFF (.mff). For BIDS datasets use :class:`BIDSLoader`. Note: MFF format uses a directory structure; pass the path to the .mff directory (e.g. recording.mff). Parameters ---------- path : str Path to the EEG data file. bad_channels : list of str, optional Channel names to mark as bad (default: none). preload : bool, optional Whether to load data into memory immediately (default: True). artifact_to_trigger_offset : float, optional Offset of the artifact relative to the trigger, in seconds (default: 0.0). upsampling_factor : int, optional Upsampling factor stored in metadata for downstream processors (default: 1). start_sample : int, optional First sample index to keep, inclusive (default: first sample). stop_sample : int, optional Last sample index to keep, exclusive (default: last sample). """ name: str = "auto_loader" description: str = "Load EEG data with automatic format detection" version: str = "1.0.0" requires_triggers: bool = False requires_raw: bool = False modifies_raw: bool = True parallel_safe: bool = False
[docs] def __init__( self, path: str, bad_channels: list[str] | None = None, preload: bool = True, artifact_to_trigger_offset: float = 0.0, upsampling_factor: int = 1, start_sample: int | None = None, stop_sample: int | None = None, ) -> None: self.path = path self.bad_channels = bad_channels or [] self.preload = preload self.artifact_to_trigger_offset = artifact_to_trigger_offset self.upsampling_factor = upsampling_factor self.start_sample = start_sample self.stop_sample = stop_sample super().__init__()
[docs] def validate(self, context: ProcessingContext | None) -> None: resolved = Path(self.path) if not resolved.exists(): raise ProcessorValidationError(f"File not found: {self.path}") if resolved.is_dir() and resolved.suffix.lower() != ".mff": raise ProcessorValidationError( f"Path is a directory: {self.path}. " "For BIDS datasets, use BIDSLoader. For MFF format, the path must " "have a .mff extension (e.g. recording.mff)." ) _detect_format(resolved)
[docs] def process(self, context: ProcessingContext | None) -> ProcessingContext: # --- EXTRACT --- resolved = Path(self.path) reader_fn, format_name = _detect_format(resolved) # --- LOG --- logger.info("Loading {} file: {}", format_name, self.path) # --- COMPUTE --- if format_name == "MFF": _ensure_mff_runtime_dependencies() with suppress_stdout(): raw = reader_fn(str(resolved), preload=self.preload, verbose=False) # --- RETURN --- return _build_context_from_raw( raw, self.bad_channels, self.start_sample, self.stop_sample, self.artifact_to_trigger_offset, self.upsampling_factor, extra_custom={"source_format": format_name}, )
[docs] @register_processor class BIDSLoader(Processor): """Load EEG data from a BIDS dataset. Creates a new ProcessingContext by reading EEG data identified by subject, task, and optional session from a BIDS-compliant directory. Optionally restricts the recording to a sample window and marks a list of channels as bad before returning. Parameters ---------- root : str Path to the BIDS root directory. subject : str Subject identifier (without the ``sub-`` prefix). task : str Task name. session : str, optional Session identifier (without the ``ses-`` prefix). run : str, optional Run identifier (without the ``run-`` prefix). bad_channels : list of str, optional Channel names to mark as bad (default: none). preload : bool, optional Whether to load data into memory immediately (default: True). artifact_to_trigger_offset : float, optional Offset of the artifact relative to the trigger, in seconds (default: 0.0). upsampling_factor : int, optional Upsampling factor stored in metadata for downstream processors (default: 1). start_sample : int, optional First sample index to keep, inclusive (default: first sample). stop_sample : int, optional Last sample index to keep, exclusive (default: last sample). """ name = "bids_loader" description = "Load EEG data from BIDS dataset" version = "1.0.0" requires_triggers = False requires_raw = False modifies_raw = True parallel_safe = False
[docs] def __init__( self, root: str, subject: str, task: str, session: str | None = None, run: str | None = None, bad_channels: list[str] | None = None, preload: bool = True, artifact_to_trigger_offset: float = 0.0, upsampling_factor: int = 1, start_sample: int | None = None, stop_sample: int | None = None, ) -> None: self.root = root self.subject = subject self.task = task self.session = session self.run = run self.bad_channels = bad_channels or [] self.preload = preload self.artifact_to_trigger_offset = artifact_to_trigger_offset self.upsampling_factor = upsampling_factor self.start_sample = start_sample self.stop_sample = stop_sample super().__init__()
[docs] def validate(self, context: ProcessingContext | None) -> None: if not Path(self.root).exists(): raise ProcessorValidationError(f"BIDS root directory not found: {self.root}")
[docs] def process(self, context: ProcessingContext | None) -> ProcessingContext: # --- LOG --- logger.info( "Loading BIDS data: subject={}, task={}, session={}, run={}", self.subject, self.task, self.session, self.run, ) # --- COMPUTE --- bids_path = BIDSPath( subject=self.subject, session=self.session, task=self.task, run=self.run, root=self.root, ) with suppress_stdout(): raw = read_raw_bids(bids_path, verbose=False) if self.preload: raw.load_data() # --- RETURN --- return _build_context_from_raw( raw, self.bad_channels, self.start_sample, self.stop_sample, self.artifact_to_trigger_offset, self.upsampling_factor, extra_custom={"bids_path": bids_path}, )