Source code for EMCqMRI.core.models.inference.rim

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from .rnn import Rnn
from core.base import base_inference_model
from core.utilities.core_utilities import ProgressBarWrap
import torch
from torch.autograd import Variable
import torch.nn as nn


[docs]class Rim(base_inference_model.InferenceModel, nn.Module): """ Class Implementing the RIM model. Methods: - setOpts inputs: a Dict containing the key and value for a new configuration setting - forward inputs: signal (measured signal); args (options containing, at least args.batchSize and args.device) outputs: Estimated parameters """ def __init__(self, config_object): super(Rim, self).__init__() self.__name__ = 'RIM' self.__require_initial_guess__ = True self.args = config_object.args self.__buildNetwork__() def __buildNetwork__(self): """ Hidden method that instanciates a version of the RNN module based on the configuration file """ self.rnn = Rnn(self.args.inference.inputChannels, self.args.inference.outputChannelsLayer1, self.args.inference.outputChannelsLayer2, self.args.inference.outputChannelsLayer3, self.args.inference.outputChannels, self.args) def __initHidden__(self, signal): """ Initialises all hidden states in the network """ shape_input = torch.tensor((signal.shape)[2:]) shape_hs1 = [1, self.args.engine.batchSize*int(torch.prod(shape_input)), self.args.inference.outputChannelsLayer1] shape_hs2 = [1, self.args.engine.batchSize*int(torch.prod(shape_input)), self.args.inference.outputChannelsLayer3] st_1 = Variable(torch.zeros(tuple(shape_hs1)).to(device=self.args.engine.device)) st_2 = Variable(torch.zeros(tuple(shape_hs2)).to(device=self.args.engine.device)) return [st_1, st_2] def __getGradients__(self, signal, kappa, tau): """ Compute and return gradients of the Likelihood Function w.r.t. parameter maps """ gradientParamBatch = [] for batch in range(len(signal)): if isinstance(kappa,list): #For kappa with multiple types of map (e.g. relaxometry and motion) motion_map = kappa[1] clonedKappa = [] for num_kappa in range(len(kappa)): cloned_k = [Variable(pmap.clone(), requires_grad=True) for pmap in kappa[num_kappa][batch]] clonedKappa.append(cloned_k) gradients = self.args.engine.likelihood_model.gradients(signal[batch], clonedKappa, tau) gradientParamBatch.append(torch.cat(gradients,0)) else: clonedMaps = [Variable(pmap.clone(), requires_grad=True) for pmap in kappa[batch]] gradientParamBatch.append(self.args.engine.likelihood_model.gradients(signal[batch], clonedMaps, tau)) paramGrad = torch.stack(gradientParamBatch) paramGrad[torch.isnan(paramGrad)] = 0 paramGrad[torch.isinf(paramGrad)] = 0 return paramGrad
[docs] def forward(self, inputs): signal = inputs kappa = self.args.engine.signal_model.initialize_parameters(signal) hidden = self.__initHidden__(signal) estimates = [] for _ in range(self.args.inference.inferenceSteps): paramGrad = self.__getGradients__(signal, kappa, self.args.dataset.tau) if isinstance(kappa,list): tensor_kappa = torch.cat(kappa,1) input = torch.cat([tensor_kappa, paramGrad], 1) else: input = torch.cat([kappa, paramGrad], 1) dx, hidden = self.rnn.forward(input, hidden) if isinstance(kappa,list): shape_list_kappa = [k.shape for k in kappa] cutoff_index = [s[1] for s in shape_list_kappa] tensor_kappa = tensor_kappa + dx kappa = list(torch.split(tensor_kappa, [*cutoff_index], dim=1)) else: kappa = torch.abs(kappa + dx) estimates.append(kappa) return torch.stack(estimates)