Source code for EMCqMRI.core.engine.train_model

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys
sys.path.insert(0, '../')

from ..utilities import image_utilities
from ..utilities import core_utilities
from ..utilities import checkpoint_utilities
from timeit import default_timer as timer

import logging
import numpy as np

import matplotlib.pyplot as plt

logging.basicConfig(level=logging.INFO)
logging.info('Setting up environment...')



[docs]class Trainer(object): """Performs a forward pass through a given inference model. Depending on the options set, it might save intermediate results and training checkpoints and/or display the intermediate results. Args: config_object ([type: Configuration]): [Object containing all backend configuration settings] Required config_object.args: - epochs - inference_model - dataloader - signal_model (if config_object.args.inference_model.__require_initial_guess__ == True) - numberOfPatches - objective_fun - optimizer - saveResults - saveResultsPath - saveCheckpoint - saveCheckpointPath - usePatchesAsBatches """ def __init__( self, device=[], max_epochs=[], train_data_loader=[], network=[], optimizer=[], loss_function=[], prepare_batch = [], log_training_fun = [], config_object = [] ): self.args = config_object.args self.device = device self.max_epochs = max_epochs if isinstance(train_data_loader, list): # It means that validation dataset is being used. self.dataloader_state = core_utilities.alternateTrainingState(train_data_loader, self.args.engine.batchSize) state = next(self.dataloader_state) self.dataloader = list(state.values())[0] self.args.engine.state_name = list(state.keys())[0] else: self.args.engine.state_name = 'training' self.dataloader = train_data_loader self.network = network self.optimizer = optimizer self.loss_function = loss_function self.prepare_batch = prepare_batch self.log_training_fun = log_training_fun self.epoch = 0 self.network.train()
[docs] def run(self): start_time = timer() for ep in range(self.max_epochs): start_time_epoch = timer() logging.info('Running {}; Epoch {}'.format(self.args.engine.state_name, self.epoch)) logging.info('*'*50) self.args.engine.len_dataset = len(self.dataloader) for sample, data in enumerate(self.dataloader): self.args.engine.test_sample = sample data_ = self.prepare_batch(data, self.device) processed_data, loss = self.__run__(data_) self.__log_error__(self.epoch+1, sample, loss) if self.log_training_fun: self.log_training_fun(processed_data, loss, sample, ep) end_time_epoch = timer() self.__end_epoch__(processed_data, end_time_epoch - start_time_epoch) end_time = timer() print("Total training time: {} seconds for {} epochs".format(end_time - start_time, self.max_epochs))
def __end_epoch__(self, processed_data, elapsed_time): if self.args.engine.state_name == 'training': self.epoch += 1 if self.args.engine.saveResults: image_utilities.saveItermediateResults(processed_data, self.args, self.epoch) if self.args.engine.saveCheckpoint: checkpoint_utilities.save(self.args, self.epoch, self.network) if isinstance(self.args.engine.dataloader, list): state = next(self.dataloader_state) self.dataloader = list(state.values())[0] self.args.engine.state_name = list(state.keys())[0] self.network.train() if self.args.engine.state_name == 'training' else self.network.eval() print("Time in Epoch {}: {} seconds".format(self.epoch, elapsed_time)) def __log_error__(self, epoch, i, loss): logging.info("Epoch: {}, State: {}, Sample: {}/{}, Loss: {} ".format(epoch, self.args.engine.state_name, i+1, self.dataloader.__len__(), loss)) def __run__(self, data_): self.optimizer.zero_grad() signal = data_['image'] label = data_['label'] if len(data_['label']) else [] estimate = self.network.forward(signal) loss = self.loss_function(estimate, label).mean() if self.args.engine.state_name == 'training': loss.backward() self.optimizer.step() processed_data = {} processed_data['estimated'] = estimate processed_data['signal'] = signal if label is not None: processed_data['label'] = label return processed_data, loss