"""AAS variants with MATLAB FACET averaging-weight strategies."""
from __future__ import annotations
import numpy as np
from loguru import logger
from ..core import ProcessingContext, ProcessorValidationError, register_processor
from ..helpers.moosmann import calc_weighted_matrix_by_realignment_parameters_file
from .aas import AASCorrection
[docs]
@register_processor
class CorrespondingSliceCorrection(AASCorrection):
"""AAS with corresponding-slice averaging across volumes.
This implements MATLAB FACET's ``AvgArtWghtCorrespondingSlice`` rule:
each slice epoch is averaged with the same slice position in neighboring
volumes.
Parameters
----------
slices_per_volume : int, optional
Number of slices per volume. If ``None``, the value is taken from
``context.metadata.slices_per_volume``.
window_size : int
Half-window in volumes (default: 30).
plot_artifacts : bool
If ``True``, plot one random averaged artifact (default: False).
realign_after_averaging : bool
If ``True``, realign triggers after averaging (default: True).
search_window_factor : float
Trigger realignment search-window factor (default: 3.0).
apply_epoch_alpha_scaling : bool
If ``True``, apply MATLAB-like per-epoch least-squares alpha scaling
before subtraction (default: False).
"""
name = "corresponding_slice_correction"
description = "AAS with corresponding-slice averaging across volumes"
version = "1.0.0"
[docs]
def __init__(
self,
slices_per_volume: int | None = None,
window_size: int = 30,
plot_artifacts: bool = False,
realign_after_averaging: bool = True,
search_window_factor: float = 3.0,
apply_epoch_alpha_scaling: bool = False,
) -> None:
self.slices_per_volume = slices_per_volume
self._runtime_slices_per_volume: int | None = None
super().__init__(
window_size=window_size,
rel_window_position=0.0,
correlation_threshold=0.975,
plot_artifacts=plot_artifacts,
realign_after_averaging=realign_after_averaging,
search_window_factor=search_window_factor,
apply_epoch_alpha_scaling=apply_epoch_alpha_scaling,
)
[docs]
def validate(self, context: ProcessingContext) -> None:
super().validate(context)
runtime_spv = self._resolve_slices_per_volume(context)
if runtime_spv < 1:
raise ProcessorValidationError(f"slices_per_volume must be >= 1, got {runtime_spv}")
[docs]
def process(self, context: ProcessingContext) -> ProcessingContext:
self._runtime_slices_per_volume = self._resolve_slices_per_volume(context)
try:
return super().process(context)
finally:
self._runtime_slices_per_volume = None
def _resolve_slices_per_volume(self, context: ProcessingContext) -> int:
if self.slices_per_volume is not None:
return int(self.slices_per_volume)
if context.metadata.slices_per_volume is not None:
return int(context.metadata.slices_per_volume)
raise ProcessorValidationError(
"slices_per_volume not available. Set it explicitly or run TriggerDetector on slice-trigger data."
)
def _calc_averaging_matrix(
self,
epochs: np.ndarray,
window_size: int,
rel_window_offset: float,
correlation_threshold: float,
) -> np.ndarray:
"""Create corresponding-slice averaging matrix."""
del rel_window_offset, correlation_threshold
n_epochs = int(epochs.shape[0])
matrix = np.zeros((n_epochs, n_epochs), dtype=float)
if n_epochs == 0:
return matrix
period = int(self._runtime_slices_per_volume or 1)
half_window = max(1, int(window_size))
warning_shown = False
for row0 in range(n_epochs):
row = row0 + 1 # 1-based for parity with MATLAB equations
i_start = row - (half_window * period)
if i_start < 1:
i_start = 1 + ((row - 1) % period)
i_end = i_start + (2 * half_window * period)
if i_end > n_epochs:
i_end = self._nearest_multiple_before(row, period, n_epochs)
i_start = i_end - (2 * half_window * period)
if i_start < 1:
i_start = 1 + ((row - 1) % period)
if not warning_shown:
warning_shown = True
logger.warning("Not enough volumes for full corresponding-slice window. Using reduced slice sets.")
indices_1based = np.arange(i_start, i_end + 1, period, dtype=int)
indices_1based = indices_1based[(indices_1based >= 1) & (indices_1based <= n_epochs)]
if indices_1based.size == 0:
indices_1based = np.array([row], dtype=int)
indices_0based = indices_1based - 1
matrix[row0, indices_0based] = 1.0 / float(indices_0based.size)
return matrix
@staticmethod
def _nearest_multiple_before(origin: int, period: int, limit: int) -> int:
n = int(np.floor((limit - origin) / period))
return origin + (n * period)
[docs]
@register_processor
class VolumeTriggerCorrection(AASCorrection):
"""AAS with MATLAB FACET volume-trigger weighting.
Reproduces ``AvgArtWghtVolumeTrigger`` weighting for volume/section trigger
workflows.
Parameters
----------
window_size : int
Averaging window size in epochs (default: 30).
plot_artifacts : bool
If ``True``, plot one random averaged artifact (default: False).
realign_after_averaging : bool
If ``True``, realign triggers after averaging (default: True).
search_window_factor : float
Trigger realignment search-window factor (default: 3.0).
apply_epoch_alpha_scaling : bool
If ``True``, apply MATLAB-like per-epoch least-squares alpha scaling
before subtraction (default: False).
"""
name = "volume_trigger_correction"
description = "AAS with MATLAB volume-trigger averaging weights"
version = "1.0.0"
[docs]
def __init__(
self,
window_size: int = 30,
plot_artifacts: bool = False,
realign_after_averaging: bool = True,
search_window_factor: float = 3.0,
apply_epoch_alpha_scaling: bool = False,
) -> None:
super().__init__(
window_size=window_size,
rel_window_position=0.0,
correlation_threshold=0.975,
plot_artifacts=plot_artifacts,
realign_after_averaging=realign_after_averaging,
search_window_factor=search_window_factor,
apply_epoch_alpha_scaling=apply_epoch_alpha_scaling,
)
def _calc_averaging_matrix(
self,
epochs: np.ndarray,
window_size: int,
rel_window_offset: float,
correlation_threshold: float,
) -> np.ndarray:
"""Create volume-trigger averaging matrix."""
del rel_window_offset, correlation_threshold
n_epochs = int(epochs.shape[0])
matrix = np.zeros((n_epochs, n_epochs), dtype=float)
if n_epochs == 0:
return matrix
half_window = max(1, int(window_size // 2))
i_start = 2 # MATLAB uses this initial value at borders
for s0 in range(n_epochs):
s = s0 + 1
if s == 1:
i_start = 2
elif (s > (3 + half_window)) and (s <= (n_epochs - (half_window + 2))):
i_start = s - half_window - 1
indices_1based = np.arange(i_start, i_start + (2 * half_window) + 1, dtype=int)
indices_1based = indices_1based[(indices_1based >= 1) & (indices_1based <= n_epochs)]
if indices_1based.size == 0:
indices_1based = np.array([s], dtype=int)
indices_0based = indices_1based - 1
matrix[s0, indices_0based] = 1.0 / float(indices_0based.size)
return matrix
[docs]
@register_processor
class SliceTriggerCorrection(AASCorrection):
"""AAS with MATLAB FACET slice-trigger odd/even template weighting.
Reproduces ``AvgArtWghtSliceTrigger`` behavior by constructing one
averaging set from every second epoch (even offsets) and another from
the complementary epochs (odd offsets), then alternating between them.
Parameters
----------
window_size : int
Half-window size in epochs (default: 30).
plot_artifacts : bool
If ``True``, plot one random averaged artifact (default: False).
realign_after_averaging : bool
If ``True``, realign triggers after averaging (default: True).
search_window_factor : float
Trigger realignment search-window factor (default: 3.0).
apply_epoch_alpha_scaling : bool
If ``True``, apply MATLAB-like per-epoch least-squares alpha scaling
before subtraction (default: False).
"""
name = "slice_trigger_correction"
description = "AAS with MATLAB slice-trigger odd/even averaging weights"
version = "1.0.0"
[docs]
def __init__(
self,
window_size: int = 30,
plot_artifacts: bool = False,
realign_after_averaging: bool = True,
search_window_factor: float = 3.0,
apply_epoch_alpha_scaling: bool = False,
) -> None:
super().__init__(
window_size=window_size,
rel_window_position=0.0,
correlation_threshold=0.975,
plot_artifacts=plot_artifacts,
realign_after_averaging=realign_after_averaging,
search_window_factor=search_window_factor,
apply_epoch_alpha_scaling=apply_epoch_alpha_scaling,
)
def _calc_averaging_matrix(
self,
epochs: np.ndarray,
window_size: int,
rel_window_offset: float,
correlation_threshold: float,
) -> np.ndarray:
"""Create slice-trigger odd/even averaging matrix."""
del rel_window_offset, correlation_threshold
n_epochs = int(epochs.shape[0])
matrix = np.zeros((n_epochs, n_epochs), dtype=float)
if n_epochs == 0:
return matrix
half_window = max(1, int(window_size))
start_index_1based = 2
for s0 in range(n_epochs):
s = s0 + 1
if s == 1:
start_index_1based = 2
elif s == 2:
start_index_1based = 3
elif (s > (3 + half_window)) and (s <= (n_epochs - (half_window + 2))):
start_index_1based = s - half_window
end_index_1based = start_index_1based + (2 * half_window)
indices_1based = np.arange(start_index_1based, end_index_1based + 1, 2, dtype=int)
indices_1based = indices_1based[(indices_1based >= 1) & (indices_1based <= n_epochs)]
if indices_1based.size == 0:
indices_1based = np.array([s], dtype=int)
indices_0based = indices_1based - 1
matrix[s0, indices_0based] = 1.0 / float(indices_0based.size)
return matrix
[docs]
@register_processor
class MoosmannCorrection(AASCorrection):
"""AAS with motion-informed Moosmann averaging weights.
Uses the realignment-parameter-informed weighting strategy from
``AvgArtWghtMoosmann``.
Parameters
----------
rp_file : str
Path to SPM-style realignment parameter text file.
window_size : int
AAS base window size (default: 30).
motion_threshold : float
Motion threshold passed to the weighting routine (default: 5.0).
motion_window_size : int, optional
Number of neighboring epochs for motion weighting. If ``None``,
uses ``2 * window_size``.
plot_artifacts : bool
If ``True``, plot one random averaged artifact (default: False).
realign_after_averaging : bool
If ``True``, realign triggers after averaging (default: True).
search_window_factor : float
Trigger realignment search-window factor (default: 3.0).
apply_epoch_alpha_scaling : bool
If ``True``, apply MATLAB-like per-epoch least-squares alpha scaling
before subtraction (default: False).
"""
name = "moosmann_correction"
description = "AAS with motion-informed Moosmann template weighting"
version = "1.0.0"
[docs]
def __init__(
self,
rp_file: str,
window_size: int = 30,
motion_threshold: float = 5.0,
motion_window_size: int | None = None,
plot_artifacts: bool = False,
realign_after_averaging: bool = True,
search_window_factor: float = 3.0,
apply_epoch_alpha_scaling: bool = False,
) -> None:
self.rp_file = rp_file
self.motion_threshold = motion_threshold
self.motion_window_size = motion_window_size
self._matrix_cache: dict[int, np.ndarray] = {}
self._last_motion_summary: dict | None = None
super().__init__(
window_size=window_size,
rel_window_position=0.0,
correlation_threshold=0.975,
plot_artifacts=plot_artifacts,
realign_after_averaging=realign_after_averaging,
search_window_factor=search_window_factor,
apply_epoch_alpha_scaling=apply_epoch_alpha_scaling,
)
[docs]
def validate(self, context: ProcessingContext) -> None:
super().validate(context)
if not self.rp_file:
raise ProcessorValidationError("rp_file must be provided for MoosmannCorrection.")
if self.motion_threshold <= 0:
raise ProcessorValidationError(f"motion_threshold must be positive, got {self.motion_threshold}")
if self.motion_window_size is not None and self.motion_window_size < 1:
raise ProcessorValidationError(f"motion_window_size must be >= 1 when set, got {self.motion_window_size}")
[docs]
def process(self, context: ProcessingContext) -> ProcessingContext:
result = super().process(context)
if self._last_motion_summary is None:
return result
md = result.metadata.copy()
md.custom["moosmann"] = self._last_motion_summary
return result.with_metadata(md)
def _calc_averaging_matrix(
self,
epochs: np.ndarray,
window_size: int,
rel_window_offset: float,
correlation_threshold: float,
) -> np.ndarray:
"""Create Moosmann weighting matrix from the RP file."""
del rel_window_offset, correlation_threshold
n_epochs = int(epochs.shape[0])
if n_epochs in self._matrix_cache:
return self._matrix_cache[n_epochs]
motion_window = int(self.motion_window_size) if self.motion_window_size is not None else int(2 * window_size)
motion_window = max(1, motion_window)
motiondata, matrix = calc_weighted_matrix_by_realignment_parameters_file(
rp_file=self.rp_file,
n_fmri=n_epochs,
k=motion_window,
threshold=self.motion_threshold,
)
row_sums = np.sum(matrix, axis=1, keepdims=True)
row_sums = np.where(row_sums <= 0, 1.0, row_sums)
matrix = matrix / row_sums
motion_serialized = {k: (v.tolist() if hasattr(v, "tolist") else v) for k, v in motiondata.items()}
self._last_motion_summary = {
"rp_file": self.rp_file,
"motion_threshold": self.motion_threshold,
"motion_window_size": motion_window,
"num_epochs": n_epochs,
"motion": motion_serialized,
}
self._matrix_cache[n_epochs] = matrix
return matrix
# Aliases for backwards compatibility / readability
AvgArtWghtCorrespondingSliceCorrection = CorrespondingSliceCorrection
AvgArtWghtVolumeTriggerCorrection = VolumeTriggerCorrection
AvgArtWghtSliceTriggerCorrection = SliceTriggerCorrection
AvgArtWghtMoosmannCorrection = MoosmannCorrection