Source code for EMCqMRI.core.models.inference.rnn

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as f


[docs]class Rnn(nn.Module): def __init__(self, channels_input, out_channels_l1, out_channels_l2, out_channels_l3, channels_output, args_): super(Rnn, self).__init__() self.input_layer = nn.Conv2d(channels_input, out_channels_l1, kernel_size=3, stride=1, padding=1) self.conv_layer_2 = nn.Conv2d(out_channels_l1, out_channels_l2, kernel_size=3, stride=1, padding=1) self.conv_layer_3 = nn.Conv2d(out_channels_l2, out_channels_l3, kernel_size=3, stride=1, padding=1) self.output_layer = nn.Conv2d(out_channels_l3, channels_output, kernel_size=1, stride=1, padding=0) self.gru_layer_1 = nn.GRU(out_channels_l1, out_channels_l1) self.gru_layer_2 = nn.GRU(out_channels_l3, out_channels_l3) self.gru_1_channelSize = out_channels_l1 self.gru_2_channelSize = out_channels_l3 self.size_c1 = out_channels_l1 self.size_c3 = out_channels_l3 self.args = args_ def __setGruShapes__(self, x): shape_input = x[0,0].size() shape_gru_1 = [len(x)] + list(shape_input) + [self.size_c1] shape_gru_2 = [len(x)] + list(shape_input) + [self.size_c3] permute_list_forw = [0] + [dim for dim in range(2, len(shape_input)+2)] + [1] permute_list_back = [0] + [-1] + [dim for dim in range(1, len(shape_input)+1)] self.shape_gru_1 = shape_gru_1 self.shape_gru_2 = shape_gru_2 self.permute_forward = permute_list_forw self.permute_backward =permute_list_back def __forwardGRU__(self, x, hs, gru_layer, sizeGRUforward, sizeGRUbackward): h1_gru = x.permute(self.permute_forward).contiguous().view(-1, sizeGRUforward).unsqueeze(0) h1_gru1, hs = gru_layer(h1_gru, hs) h1_gru = h1_gru1.squeeze(0).view(sizeGRUbackward).permute(self.permute_backward).contiguous() return h1_gru, hs
[docs] def forward(self, input, hidden_states): self.__setGruShapes__(input) h1 = self.input_layer(input) h1 = f.relu(h1) h1_gru, hidden_states[0] = self.__forwardGRU__(h1, hidden_states[0], self.gru_layer_1, self.gru_1_channelSize, self.shape_gru_1) h2 = self.conv_layer_2(h1_gru) h2 = f.relu(h2) h3 = self.conv_layer_3(h2) h3 = f.relu(h3) h2_gru, hidden_states[1] = self.__forwardGRU__(h3, hidden_states[1], self.gru_layer_2, self.gru_2_channelSize, self.shape_gru_2) output = self.output_layer(h2_gru) return output, hidden_states