Source code for EMCqMRI.core.utilities.torch_utils

from torch import nn
from torch.nn import Module


[docs]def determine_conv_class(n_dim, transposed=False): if n_dim == 1: if not transposed: return nn.Conv1d else: return nn.ConvTranspose1d elif n_dim == 2: if not transposed: return nn.Conv2d else: return nn.ConvTranspose2d elif n_dim == 3: if not transposed: return nn.Conv3d else: return nn.ConvTranspose3d else: NotImplementedError("No convolution of this dimensionality implemented")
[docs]def determine_conv_functional(n_dim, transposed=False): if n_dim == 1: if not transposed: return nn.functional.conv1d else: return nn.functional.conv_transposed1d elif n_dim == 2: if not transposed: return nn.functional.conv2d else: return nn.functional.conv_transposed2d elif n_dim == 3: if not transposed: return nn.functional.conv3d else: return nn.functional.conv_transposed3d else: NotImplementedError("No convolution of this dimensionality implemented")
[docs]def pixel_unshuffle(x, downscale_factor): b, c, h, w = x.size() r = downscale_factor out_channel = c * (r ** 2) out_h = h // r out_w = w // r x_view = x.contiguous().view(b, c, out_h, r, out_w, r) x_prime = x_view.permute(0, 1, 3, 5, 2, 4).contiguous().view(b, out_channel, out_h, out_w) return x_prime
[docs]class PixelUnshuffle(Module): def __init__(self, downscale_factor): super(PixelUnshuffle, self).__init__() self.downscale_factor = downscale_factor
[docs] def forward(self, x): return pixel_unshuffle(x, self.downscale_factor)