Source code for deepinv.utils.tensorlist

import torch


[docs] class TensorList: r""" Represents a list of :class:`torch.Tensor` with different shapes. It allows to sum, flatten, append, etc. lists of tensors seamlessly, in a similar fashion to :class:`torch.Tensor`. :param x: a list of :class:`torch.Tensor`, a single :class:`torch.Tensor` or a TensorList. """ def __init__(self, x): super().__init__() if isinstance(x, list) or isinstance(x, TensorList): self.x = list(x) elif isinstance(x, torch.Tensor): self.x = [x] else: raise TypeError("x must be a list of torch.Tensor or a single torch.Tensor") self.shape = [xi.shape for xi in self.x] def __repr__(self): return f"TensorList({self.x})" def to(self, *args, **kwargs): r""" Moves the TensorList to the given device and dtype. """ return TensorList([xi.to(*args, **kwargs) for xi in self.x]) def clone(self): r""" Returns a copy of the TensorList. """ return TensorList([xi.clone() for xi in self.x]) def detach(self): r""" Returns a copy of the TensorList with detached gradients. """ return TensorList([xi.detach() for xi in self.x]) def cpu(self): r""" Moves the TensorList to the cpu. """ return TensorList([xi.cpu() for xi in self.x]) def numpy(self): r""" Returns a list of numpy arrays. """ return [xi.numpy() for xi in self.x] def cuda(self, *args, **kwargs): r""" Moves the TensorList to the cuda device. """ return TensorList([xi.cuda(*args, **kwargs) for xi in self.x]) def type(self, dtype): r""" Returns the TensorList with the given dtype. """ return TensorList([xi.type(dtype) for xi in self.x]) def __len__(self): r""" Returns the number of tensors in the list. """ return len(self.x) def __getitem__(self, item): r""" Returns the ith tensor in the list. """ return self.x[item] def __setitem__(self, key, value): r""" Sets the ith tensor in the list. """ self.x[key] = value def flatten(self): r""" Returns a :class:`torch.Tensor` with a flattened version of the list of tensors. """ return torch.cat([xi.flatten() for xi in self.x]) def append(self, other): r""" Appends a :class:`torch.Tensor` or a list of :class:`torch.Tensor` to the list. """ if isinstance(other, list): self.x += other elif isinstance(other, TensorList): self.x += other.x elif isinstance(other, torch.Tensor): self.x.append(other) else: raise TypeError( "the appended item must be a list of :class:`torch.Tensor` or a single :class:`torch.Tensor`" ) return self def __add__(self, other): r""" Adds two TensorLists. The sizes of the tensor lists must match. """ if not isinstance(other, list) and not isinstance(other, TensorList): return TensorList([xi + other for xi in self.x]) else: return TensorList([xi + otheri for xi, otheri in zip(self.x, other)]) def __mul__(self, other): r""" Multiply two TensorLists. The sizes of the tensor lists must match. """ if not isinstance(other, list) and not isinstance(other, TensorList): return TensorList([xi * other for xi in self.x]) else: return TensorList([xi * otheri for xi, otheri in zip(self.x, other)]) def __rmul__(self, other): r""" Multiply two TensorLists. The sizes of the tensor lists must match. """ if not isinstance(other, list) and not isinstance(other, TensorList): return TensorList([xi * other for xi in self.x]) else: return TensorList([xi * otheri for xi, otheri in zip(self.x, other)]) def __truediv__(self, other): r""" Divide two TensorLists. The sizes of the tensor lists must match. """ if not isinstance(other, list) and not isinstance(other, TensorList): return TensorList([xi / other for xi in self.x]) else: return TensorList([xi / otheri for xi, otheri in zip(self.x, other)]) def __neg__(self): r""" Negate a TensorList. """ return TensorList([-xi for xi in self.x]) def __sub__(self, other): r""" Substract two TensorLists. The sizes of the tensor lists must match. """ if not isinstance(other, list) and not isinstance(other, TensorList): return TensorList([xi - other for xi in self.x]) else: return TensorList([xi - otheri for xi, otheri in zip(self.x, other)]) def conj(self): r""" Computes the conjugate of the elements of the TensorList. """ return TensorList([xi.conj() for xi in self.x]) def sum(self, dim, keepdim=False): r""" Computes the sum of each elements of the TensorList along the given dimension(s). """ return TensorList([xi.sum(dim, keepdim) for xi in self.x]) def reshape(self, shape): r""" Reshape each tensor of the TensorList into the given list of shapes. """ return TensorList([self.x[i].reshape(shape[i]) for i in range(len(self.x))]) def __any__(self): r""" Returns True if any of the elements of the TensorList is True. """ return any([xi.any() for xi in self.x]) def __all__(self): r""" Returns True if all the elements of the TensorList are True. """ return all([xi.all() for xi in self.x]) def __gt__(self, other): r""" Returns a TensorList of True if the elements of the input TensorList are greater than other. """ return TensorList([xi > other for xi in self.x]) def __lt__(self, other): r""" Returns a TensorList of True if the elements of the TensorList are smaller than other. """ return TensorList([xi < other for xi in self.x]) def squeeze(self, dim=None): return TensorList([xi.squeeze(dim=dim) for xi in self.x]) def unsqueeze(self, dim=None): return TensorList([xi.unsqueeze(dim=dim) for xi in self.x])
[docs] def randn_like(x): r""" Returns a :class:`deepinv.utils.TensorList` or :class:`torch.Tensor` with the same type as x, filled with standard gaussian numbers. """ if isinstance(x, torch.Tensor): return torch.randn_like(x) else: return TensorList([torch.randn_like(xi) for xi in x])
[docs] def rand_like(x): r""" Returns a :class:`deepinv.utils.TensorList` or :class:`torch.Tensor` with the same type as x, filled with random uniform numbers in [0,1]. """ if isinstance(x, torch.Tensor): return torch.rand_like(x) else: return TensorList([torch.rand_like(xi) for xi in x])
[docs] def zeros_like(x): r""" Returns a :class:`deepinv.utils.TensorList` or :class:`torch.Tensor` with the same type as x, filled with zeros. """ if isinstance(x, torch.Tensor): return torch.zeros_like(x) else: return TensorList([torch.zeros_like(xi) for xi in x])
def dirac(shape): r""" Returns a :class:`torch.Tensor` with a Dirac delta at the center. :param tuple shape: shape of the output tensor. """ out = torch.zeros(shape) center = tuple([s // 2 for s in shape[-2:]]) slices = [slice(None)] * (len(shape) - 2) + list(center) out[slices] = 1 return out
[docs] def dirac_like(x): r""" Returns a :class:`deepinv.utils.TensorList` or :class:`torch.Tensor` with the same type as x, filled with zeros. """ if isinstance(x, torch.Tensor): return dirac(x.shape) else: return TensorList([dirac(xi.shape) for xi in x])
[docs] def ones_like(x): r""" Returns a :class:`deepinv.utils.TensorList` or :class:`torch.Tensor` with the same type as x, filled with ones. """ if isinstance(x, torch.Tensor): return torch.ones_like(x) else: return TensorList([torch.ones_like(xi) for xi in x])