TimeAveragingNet#
- class deepinv.models.TimeAveragingNet(backbone_net)[source]#
-
Time-averaging network wrapper.
Adapts a static image reconstruction network for time-varying inputs to output static reconstructions. Average the data across the time dim before passing into network.
Note
The input physics is assumed to be a temporal physics which produced the temporal measurements y (potentially with temporal mask
mask
). It must either implement ato_static
method to remove the time dimension, or already be a static physics (e.g.deepinv.physics.MRI
).
- Example:
>>> from deepinv.models import UNet, TimeAveragingNet >>> model = UNet(scales=2) >>> model = TimeAveragingNet(model) >>> y = rand(1, 1, 4, 8, 8) # B,C,T,H,W >>> x_net = model(y, None) >>> x_net.shape # B,C,H,W torch.Size([1, 1, 8, 8])
- Parameters:
backbone_net (torch.nn.Module) – Base network which can only take static inputs (B,C,H,W)
device (torch.device) – cpu or gpu.