Source code for facet.core.parallel

"""
Parallel Execution Module

This module provides multiprocessing support for pipeline execution.

Author: FACETpy Team
Date: 2025-01-12
"""

import contextlib
import functools
import multiprocessing as mp
import sys
from collections.abc import Callable

import mne
import numpy as np
from loguru import logger

from facet.console.progress import processor_progress
from facet.logging_config import suppress_stdout

from .context import ProcessingContext
from .processor import Processor

# Use "spawn" so child processes start clean without inheriting threads
# (e.g. the Textual TUI thread).  "fork" is unsafe in multithreaded
# processes and can deadlock when the forked child inherits held locks.
# Workers already serialise everything via to_dict/from_dict, so spawn
# is a drop-in replacement.
if sys.platform != "win32":
    with contextlib.suppress(RuntimeError):
        mp.set_start_method("spawn", force=True)


def _worker_function(processor_config: dict, context_data: dict) -> dict:
    """
    Worker function for multiprocessing.

    This function runs in a separate process and must be picklable.

    Args:
        processor_config: Serialized processor configuration
        context_data: Serialized context data

    Returns:
        Serialized result context
    """
    processor_class = processor_config["class"]
    processor_params = processor_config["parameters"]
    processor = processor_class(**processor_params)

    context = ProcessingContext.from_dict(context_data)

    result = processor.execute(context)

    return result.to_dict()


[docs] class ParallelExecutor: """ Executor for parallel processing of channels or epochs. This class handles multiprocessing for processors that support it, typically for channel-wise or epoch-wise operations. Example: executor = ParallelExecutor(n_jobs=4) result_context = executor.execute(processor, context) """
[docs] def __init__(self, n_jobs: int = -1, backend: str = "multiprocessing", verbose: bool = True): """ Initialize parallel executor. Args: n_jobs: Number of parallel jobs (-1 for all CPUs, -2 for all but one) backend: Parallel backend ("multiprocessing", "threading", or "serial") verbose: Show progress messages """ self.n_jobs = self._determine_n_jobs(n_jobs) self.backend = backend self.verbose = verbose if backend not in ["multiprocessing", "threading", "serial"]: raise ValueError(f"Invalid backend: {backend}. Choose from: multiprocessing, threading, serial")
def _determine_n_jobs(self, n_jobs: int) -> int: """Determine actual number of jobs.""" if n_jobs == -1: return mp.cpu_count() elif n_jobs == -2: return max(1, mp.cpu_count() - 1) elif n_jobs < -2: return max(1, mp.cpu_count() + n_jobs + 1) elif n_jobs == 0: raise ValueError("n_jobs cannot be 0") else: return n_jobs
[docs] def execute(self, processor: Processor, context: ProcessingContext) -> ProcessingContext: """ Execute processor in parallel if possible. This method attempts to parallelize the processor execution. If parallelization is not applicable, falls back to serial execution. Args: processor: Processor to execute context: Input context Returns: Output context """ if not processor.parallel_safe: logger.warning(f"Processor {processor.name} is not parallel-safe, executing serially") return processor.execute(context) if getattr(processor, "channel_wise", False): return self._execute_channel_wise(processor, context) if hasattr(processor, "parallelize_by_epochs") and processor.parallelize_by_epochs: return self._execute_epoch_wise(processor, context) # Fall back to serial execution logger.debug(f"No parallelization strategy found for {processor.name}, executing serially") return processor.execute(context)
def _execute_channel_wise(self, processor: Processor, context: ProcessingContext) -> ProcessingContext: """ Execute processor channel-wise in parallel. Args: processor: Processor to execute context: Input context Returns: Output context with processed channels """ logger.info(f"Executing {processor.name} in parallel across {self.n_jobs} jobs") raw = context.get_raw() n_channels = len(raw.ch_names) if n_channels == 0: logger.warning("No channels available for parallel execution") return context channel_chunks = self._split_into_chunks(list(range(n_channels)), self.n_jobs) progress_total = n_channels if n_channels > 0 else None with processor_progress( total=progress_total, message=f"{processor.name}: channels", ) as progress: def _update_progress(processed: int) -> None: if processed <= 0: return next_value = progress.current + processed progress.advance( processed, message=(f"{int(next_value)}/{n_channels} channels" if n_channels else "channels"), ) if self.backend == "multiprocessing": results = self._execute_multiprocessing( processor, context, channel_chunks, progress_callback=_update_progress, ) elif self.backend == "threading": results = self._execute_threading( processor, context, channel_chunks, progress_callback=_update_progress, ) else: # serial results = self._execute_serial( processor, context, channel_chunks, progress_callback=_update_progress, ) return self._merge_channel_results(context, results) def _execute_epoch_wise(self, processor: Processor, context: ProcessingContext) -> ProcessingContext: """ Execute processor epoch-wise in parallel. Args: processor: Processor to execute context: Input context Returns: Output context with processed epochs """ logger.info(f"Executing {processor.name} epoch-wise in parallel across {self.n_jobs} jobs") if not context.has_triggers(): raise ValueError("Context has no triggers for epoch-wise processing") triggers = context.get_triggers() n_epochs = len(triggers) epoch_chunks = self._split_into_chunks(list(range(n_epochs)), self.n_jobs) if self.backend == "multiprocessing": results = self._execute_multiprocessing_epochs(processor, context, epoch_chunks) elif self.backend == "threading": results = self._execute_threading_epochs(processor, context, epoch_chunks) else: # serial results = self._execute_serial_epochs(processor, context, epoch_chunks) return self._merge_epoch_results(context, results) def _execute_multiprocessing( self, processor: Processor, context: ProcessingContext, channel_chunks: list[list[int]], progress_callback: Callable[[int], None] | None = None, ) -> list[ProcessingContext]: """Execute using multiprocessing.""" processor_config = {"class": processor.__class__, "parameters": processor._parameters} chunk_contexts = [] for chunk in channel_chunks: chunk_context = self._create_channel_subset_context(context, chunk) chunk_contexts.append(chunk_context.to_dict()) worker = functools.partial(_worker_function, processor_config) contexts: list[ProcessingContext] = [] with mp.Pool(processes=self.n_jobs) as pool: for idx, result in enumerate(pool.imap(worker, chunk_contexts)): contexts.append(ProcessingContext.from_dict(result)) if progress_callback: chunk_size = len(channel_chunks[idx]) progress_callback(chunk_size) return contexts def _execute_threading( self, processor: Processor, context: ProcessingContext, channel_chunks: list[list[int]], progress_callback: Callable[[int], None] | None = None, ) -> list[ProcessingContext]: """Execute using threading (GIL-limited).""" from concurrent.futures import ThreadPoolExecutor, as_completed results: list[ProcessingContext] = [] with ThreadPoolExecutor(max_workers=self.n_jobs) as executor: futures = {} for chunk in channel_chunks: chunk_context = self._create_channel_subset_context(context, chunk) future = executor.submit(processor.execute, chunk_context) futures[future] = len(chunk) for future in as_completed(futures): results.append(future.result()) if progress_callback: progress_callback(futures[future]) return results def _execute_serial( self, processor: Processor, context: ProcessingContext, channel_chunks: list[list[int]], progress_callback: Callable[[int], None] | None = None, ) -> list[ProcessingContext]: """Execute serially (for debugging/comparison).""" results = [] for chunk in channel_chunks: chunk_context = self._create_channel_subset_context(context, chunk) result = processor.execute(chunk_context) results.append(result) if progress_callback: progress_callback(len(chunk)) return results def _create_channel_subset_context( self, context: ProcessingContext, channel_indices: list[int] ) -> ProcessingContext: """Create context with subset of channels.""" raw = context.get_raw() picks = [raw.ch_names[i] for i in channel_indices] subset_raw = raw.copy().pick(picks) subset_ctx = context.with_raw(subset_raw, copy_metadata=True) if context.has_estimated_noise(): noise = context.get_estimated_noise() if noise is not None and noise.ndim == 2: subset_noise = noise[channel_indices, :] subset_ctx.set_estimated_noise(subset_noise.copy()) return subset_ctx def _merge_channel_results( self, original_context: ProcessingContext, results: list[ProcessingContext] ) -> ProcessingContext: """Merge channel-wise results back into single context.""" if not results: return original_context original_raw = original_context.get_raw() template_raw = results[0].get_raw() template_data = results[0].get_data(copy=False) new_sfreq = template_raw.info["sfreq"] n_times = template_data.shape[1] dtype = template_data.dtype merged_data = np.zeros((len(original_raw.ch_names), n_times), dtype=dtype) channel_index = {name: idx for idx, name in enumerate(original_raw.ch_names)} for result_ctx in results: result_raw = result_ctx.get_raw() result_data = result_ctx.get_data(copy=False) for j, ch_name in enumerate(result_raw.ch_names): ch_idx = channel_index[ch_name] merged_data[ch_idx] = result_data[j] # Build new RawArray at the upsampled rate to avoid mutating protected info info = original_raw.info.copy() if hasattr(info, "_unlock"): with info._unlock(): info["sfreq"] = new_sfreq else: info["sfreq"] = new_sfreq with suppress_stdout(): new_raw = mne.io.RawArray(data=merged_data, info=info) merged_context = original_context.with_raw(new_raw) merged_context._metadata = results[0].metadata.copy() if any(result_ctx.has_estimated_noise() for result_ctx in results): # Estimated noise is stored channel-wise; merge similarly noise_data = np.zeros_like(merged_data) for result_ctx in results: if not result_ctx.has_estimated_noise(): continue result_noise = result_ctx.get_estimated_noise() result_raw = result_ctx.get_raw() for j, ch_name in enumerate(result_raw.ch_names): ch_idx = channel_index[ch_name] noise_data[ch_idx] = result_noise[j] merged_context.set_estimated_noise(noise_data) return merged_context def _merge_epoch_results( self, original_context: ProcessingContext, results: list[ProcessingContext] ) -> ProcessingContext: """Merge epoch-wise results.""" # Implementation depends on how epochs are stored # This is a placeholder logger.warning("Epoch-wise merging not fully implemented yet") return original_context def _split_into_chunks(self, items: list, n_chunks: int) -> list[list]: """Split list into approximately equal chunks.""" chunk_size = len(items) // n_chunks remainder = len(items) % n_chunks chunks = [] start = 0 for i in range(n_chunks): # Distribute remainder across first chunks size = chunk_size + (1 if i < remainder else 0) if size > 0: chunks.append(items[start : start + size]) start += size return chunks # Epoch-wise methods (placeholders for now) def _execute_multiprocessing_epochs(self, processor, context, epoch_chunks): """Execute epoch-wise using multiprocessing.""" raise NotImplementedError("Epoch-wise multiprocessing not yet implemented") def _execute_threading_epochs(self, processor, context, epoch_chunks): """Execute epoch-wise using threading.""" raise NotImplementedError("Epoch-wise threading not yet implemented") def _execute_serial_epochs(self, processor, context, epoch_chunks): """Execute epoch-wise serially.""" raise NotImplementedError("Epoch-wise serial execution not yet implemented")