Source code for EMCqMRI.core.engine.estimate

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys
sys.path.insert(0, '../')

from ..utilities import image_utilities
import logging
import torch
import torch.nn as nn

logging.basicConfig(level=logging.INFO)
logging.info('Setting up environment...')

[docs]class Infer(object): """Performs a forward pass through a given inference model. Depending on the options set, it might save intermediate results and training checkpoints and/or display the intermediate results. Args: config_object ([type: Configuration]): [Object containing all backend configuration settings] Required config_object.args: - epochs - inference_model - dataloader - signal_model (if config_object.args.inference_model.__require_initial_guess__ == True) - numberOfPatches - objective_fun - optimizer - saveResults - saveResultsPath - saveCheckpoint - saveCheckpointPath - usePatchesAsBatches """ def __init__(self, device, validation_dataloader, network, metric, prepare_batch, compute_loss, log_training_fun, config_object): self.args = config_object.args self.device = device self.dataloader = validation_dataloader self.network = network self.metric = metric self.prepare_batch = prepare_batch self.compute_loss = compute_loss self.log_training_fun = log_training_fun
[docs] def run(self, return_result = False): self.args.engine.batch_size_iter = 1 if isinstance(self.network, nn.Module): self.network.eval() epoch = '_estimating_' logging.info('Running {}'.format(self.args.engine.state_name)) logging.info('*'*50) self.args.engine.len_dataset = len(self.dataloader) for i, data in enumerate(self.dataloader): self.args.engine.iter = i data_ = self.prepare_batch(data, self.device) processed_data = self.__run__(data_, i) if self.log_training_fun: self.log_training_fun(processed_data, 0, 1, 1) if self.args.engine.saveResults: self.args.engine.filename = data[1][0] filename_ = self.args.engine.filename + 'patch_' + str(i) image_utilities.saveDataPickle(processed_data, self.args.engine.saveResultsPath, filename_) if return_result: return processed_data
def __log_error__(self, i, loss): logging.info("Sample: {}/{}, Loss: {} ".format(i+1, self.dataloader.__len__(), loss)) def __run__(self, data_, i): signal = data_[0] label = data_[1].unsqueeze(0) if len(data_) > 1 else None estimate = self.network.forward(signal) if self.compute_loss: if label is not None: loss = self.args.engine.objective_fun(estimate, label) self.__log_error__(i, loss) else: logging.error("No label provided") processed_data = {} processed_data['estimated'] = estimate processed_data['signal'] = signal if label is not None: processed_data['label'] = label return processed_data