Source code for facet.core.channel_sequential

"""
Channel-Sequential Execution Module

Provides memory-efficient channel-by-channel pipeline execution that is
completely independent of multiprocessing parallelism.
"""

import time

import mne
import numpy as np
from loguru import logger

from facet.console.manager import get_console
from facet.console.progress import get_current_step_index
from facet.logging_config import suppress_stdout

from .context import ProcessingContext
from .processor import Processor


[docs] class ChannelSequentialExecutor: """ Execute a sequence of processors one channel at a time. For each data channel the full processor sequence runs to completion before the next channel starts. This ensures that large intermediate representations (e.g. 10x upsampled data) only ever exist for a single channel simultaneously. Processors with ``run_once = True`` execute only for the first channel; all subsequent channels skip them and inherit the metadata produced by that first run. Non-data channels (stim, misc, ...) are handled separately: copied unchanged when the sampling rate stays the same, or resampled via MNE when the batch changes the sampling rate. The live console (when active) receives fine-grained channel and processor progress updates via the modern console manager. Example:: executor = ChannelSequentialExecutor() result = executor.execute( [HighPassFilter(1.0), UpSample(10), AASCorrection(), DownSample(10)], context, ) """
[docs] def execute( self, processors: list[Processor], context: ProcessingContext, ) -> ProcessingContext: """ Run *processors* on *context* one channel at a time. Parameters ---------- processors : list of Processor Processors to execute in order for every channel. context : ProcessingContext Input context containing the full multi-channel dataset. Returns ------- ProcessingContext Merged output context with all channels processed. """ if not processors: return context raw = context.get_raw() ch_names = raw.ch_names n_ch = len(ch_names) if n_ch == 0: return context data_idx, passthrough_idx = self._classify_channels(raw) if not data_idx: logger.warning("No data channels found; returning context unchanged") return context proc_names = " → ".join(p.name for p in processors) logger.info( "Channel-sequential [{}] ({} data channels)", proc_names, len(data_idx), ) console = get_console() step_idx = get_current_step_index() or 0 data_ch_names = [ch_names[i] for i in data_idx] console.start_channel_batch( processor_names=[p.name for p in processors], channel_names=data_ch_names, batch_step_offset=step_idx, ) _run_once_executed: set[str] = set() metadata_states: list[object] = [] merged_data: np.ndarray | None = None n_times_out = 0 new_sfreq = raw.info["sfreq"] saved_metadata = None handle_noise = False noise_data: np.ndarray | None = None try: for k, ch_abs_idx in enumerate(data_idx): ch_start = time.time() console.channel_started(k, data_ch_names[k]) ch_ctx = self._create_channel_context(context, ch_abs_idx) if k == 0: metadata_states = [ch_ctx.metadata.copy()] for pi, proc in enumerate(processors): if k > 0 and pi < len(metadata_states): ch_ctx = ch_ctx.with_metadata(metadata_states[pi].copy()) skipped = proc.run_once and proc.name in _run_once_executed console.channel_processor_started(k, pi) proc_start = time.time() ch_ctx = self._run_proc(proc, ch_ctx, _run_once_executed) if k == 0: metadata_states.append(ch_ctx.metadata.copy()) console.channel_processor_completed( k, pi, time.time() - proc_start, skipped=skipped, ) ch_data = ch_ctx.get_data(copy=False) if k == 0: n_times_out = ch_data.shape[1] new_sfreq = ch_ctx.get_raw().info["sfreq"] saved_metadata = metadata_states[-1].copy() merged_data = np.zeros((n_ch, n_times_out), dtype=ch_data.dtype) handle_noise = ch_ctx.has_estimated_noise() if handle_noise: first_noise = ch_ctx.get_estimated_noise() if first_noise is not None and first_noise.ndim == 2: noise_data = np.zeros( (n_ch, first_noise.shape[1]), dtype=first_noise.dtype, ) merged_data[ch_abs_idx] = ch_data[0] if handle_noise and noise_data is not None and ch_ctx.has_estimated_noise(): ch_noise = ch_ctx.get_estimated_noise() if ch_noise is not None and ch_noise.ndim == 2: noise_data[ch_abs_idx] = ch_noise[0] del ch_ctx console.channel_completed(k, time.time() - ch_start) finally: console.end_channel_batch() # --- pass-through channels (stim, misc, ...) ------------------------- if passthrough_idx: if n_times_out == raw.n_times: orig = raw.get_data() for i in passthrough_idx: merged_data[i] = orig[i] else: picks = [ch_names[i] for i in passthrough_idx] pt_raw = raw.copy().pick(picks) with suppress_stdout(): pt_raw.resample(new_sfreq) for j, i in enumerate(passthrough_idx): merged_data[i] = pt_raw.get_data()[j] del pt_raw # --- build merged output context ------------------------------------- info = raw.info.copy() if hasattr(info, "_unlock"): with info._unlock(): info["sfreq"] = new_sfreq else: info["sfreq"] = new_sfreq with suppress_stdout(): new_raw = mne.io.RawArray(merged_data, info) result = context.with_raw(new_raw) result._metadata = saved_metadata if handle_noise and noise_data is not None: result.set_estimated_noise(noise_data) return result
# ---------------------------------------------------------------------- # # Helpers # # ---------------------------------------------------------------------- # @staticmethod def _run_proc( proc: Processor, ctx: ProcessingContext, run_once_executed: set[str], ) -> ProcessingContext: """Execute *proc* on *ctx*, honouring the ``run_once`` flag.""" if proc.run_once and proc.name in run_once_executed: return ctx result = proc.execute(ctx) if proc.run_once: run_once_executed.add(proc.name) return result @staticmethod def _classify_channels(raw: mne.io.Raw): """Split channel indices into data channels and pass-through channels.""" try: from mne.io.pick import _DATA_CH_TYPES_SPLIT data_idx = [i for i, t in enumerate(raw.get_channel_types()) if t in _DATA_CH_TYPES_SPLIT] except ImportError: data_idx = list(range(len(raw.ch_names))) passthrough_idx = [i for i in range(len(raw.ch_names)) if i not in set(data_idx)] return data_idx, passthrough_idx @staticmethod def _create_channel_context( context: ProcessingContext, ch_idx: int, ) -> ProcessingContext: """ Create a single-channel subset context. Uses ``raw.get_data(picks=[name])`` to extract only the requested channel's data without duplicating the full array first. """ raw = context.get_raw() ch_name = raw.ch_names[ch_idx] data = raw.get_data(picks=[ch_name]) info = mne.pick_info(raw.info, [ch_idx]) with suppress_stdout(): subset_raw = mne.io.RawArray(data, info) subset_ctx = context.with_raw(subset_raw) # with_raw() copies the full noise matrix; keep only the active channel. subset_ctx._estimated_noise = None if context.has_estimated_noise(): noise = context.get_estimated_noise() if noise is not None and noise.ndim == 2: if noise.shape[0] == 1: subset_ctx.set_estimated_noise(noise.copy()) elif ch_idx < noise.shape[0]: subset_ctx.set_estimated_noise(noise[ch_idx : ch_idx + 1].copy()) return subset_ctx