DistributedSignalStrategy#
- class deepinv.distributed.strategies.DistributedSignalStrategy(img_size, **kwargs)[source]#
Bases:
ABCAbstract base class for distributed signal processing strategies.
- A strategy defines how to:
Split a signal into patches for distributed processing
Batch patches for efficient processing
Reduce processed patches back into a complete signal
This allows users to implement custom distributed processing strategies for different types of data and use cases.
- Parameters:
img_size (Sequence[int]) – shape of the complete signal tensor (e.g.,
[B, C, H, W]).
- apply_batching(patches, max_batch_size=None)[source]#
Group patches into batches for efficient processing.
The batching should preserve order: when the batched tensors are processed and then concatenated back, they should yield patches in the same order as the input.
- Parameters:
patches (list[torch.Tensor]) – list of prepared patches.
max_batch_size (int | None) – maximum number of patches per batch. If
None, all patches are batched together. If1, each patch is processed individually.
- Returns:
batched patches ready for processing. When processed results are concatenated, they should preserve the original patch order.
- Return type:
- abstractmethod get_local_patches(x, local_indices)[source]#
Extract and prepare local patches for processing.
- abstractmethod get_num_patches()[source]#
Get the total number of patches this strategy creates.
- Returns:
total number of patches.
- Return type:
- abstractmethod reduce_patches(out_tensor, local_pairs)[source]#
Reduce processed patches into the output tensor.
This operates in-place on
out_tensor, placing each processed patch in its correct location within the complete signal.- Parameters:
out_tensor (torch.Tensor) – output tensor to fill (should be initialized to zeros).
local_pairs (list[tuple[int, torch.Tensor]]) – list of (global_index, processed_patch) pairs.
- unpack_batched_results(processed_batches, num_patches)[source]#
Unpack processed batches back to individual patches.
Default implementation: concatenate along batch dimension and split back. Uses stored metadata to determine original patch batch size.
- Parameters:
processed_batches (list[torch.Tensor]) – results from processing batched patches.
num_patches (int) – expected number of individual patches.
- Returns:
list of individual processed patches in original order.
- Return type: