Source code for EMCqMRI.core.models.inference.mle

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from core.base import base_inference_model
from core.utilities.core_utilities import ProgressBarWrap
import torch


[docs]class Mle(base_inference_model.InferenceModel, object): """ """ def __init__(self, config_object): super(Mle, self).__init__() self.__name__ = 'MLE' self.__require_initial_guess__ = True self.args = config_object.args @ProgressBarWrap def update_bar(self, loss, args): return -1
[docs] def forward(self, inputs): signal = inputs initial_kappa = self.args.engine.signal_model.initialize_parameters(signal) self.args.engine.optimizer = torch.optim.Adam(list(initial_kappa), lr=self.args.engine.learningRate, betas=(0.8, 0.899)) for _ in range(self.args.inference.inferenceSteps): self.args.engine.optimizer.zero_grad() weighted_images = self.args.engine.signal_model.forward(initial_kappa, self.args.task.tau) loss = self.args.engine.likelihood_model.likelihood(signal, weighted_images) self.update_bar(loss.item(), self.args) loss.backward() self.args.engine.optimizer.step() estimates = initial_kappa return torch.stack(estimates)