Source code for EMCqMRI.core.models.inference.resnet

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from core.base import base_inference_model
import torch
import torch.nn as nn


[docs]def custom_batch_norm(input, channels, with_mean): eps = 0.00001 out = torch.zeros_like(input) for i in range(channels): if with_mean: mu = torch.mean(input, dim=1) else: mu = 0 var = torch.sqrt(torch.var(input, dim=1)+eps) out[:,i,...] = (input[:, i, ...] - mu)/var return out
[docs]class BasicBlock(nn.Module): expansion = 1 def __init__(self, in_layer, out_layer, stride=1, downsample=None, use_bias=True): super(BasicBlock, self).__init__() self.conv1 = nn.Conv2d(in_layer, out_layer, kernel_size=3, stride=stride, padding=1, bias=use_bias) self.bn1 = nn.BatchNorm2d(out_layer, track_running_stats=False) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(out_layer, out_layer, kernel_size=3, stride=stride, padding=1, bias=use_bias) self.bn2 = nn.BatchNorm2d(out_layer, track_running_stats=False) self.downsample = downsample self.stride = stride
[docs] def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out
[docs]class Resnet(base_inference_model.InferenceModel, nn.Module): """ Class Implementing the ResNet model. Methods: - setOpts inputs: a Dict containing the key and value for a new configuration setting - forward inputs: signal (measured signal); outputs: Estimated parameters """ def __init__(self, config_object): super(Resnet, self).__init__() self.__name__ = 'RESNET' self.__require_initial_guess__ = False self.use_bias = config_object.args.inference.useBias self.args = config_object.args self.__buildNetwork__() def __buildNetwork__(self): self.input_planes = self.args.inference.outputChannelsLayer1 self.conv1 = nn.Conv2d(self.args.inference.inputChannels, self.args.inference.outputChannelsLayer1, kernel_size=1, stride=1, padding=0, bias=self.use_bias) self.bn1 = nn.BatchNorm2d(self.args.inference.outputChannelsLayer1, track_running_stats=False) self.relu = nn.ReLU(inplace=True) self.layer1 = self.__makeLayer__(BasicBlock, self.args.inference.outputChannelsLayer2, self.args.inference.convLayersInResidualBlock) self.layer2 = self.__makeLayer__(BasicBlock, self.args.inference.outputChannelsLayer3, self.args.inference.convLayersInResidualBlock) self.layer3 = self.__makeLayer__(BasicBlock, self.args.inference.outputChannelsLayer4, self.args.inference.convLayersInResidualBlock) self.layer4 = self.__makeLayer__(BasicBlock, self.args.inference.outputChannelsLayer5, self.args.inference.convLayersInResidualBlock) self.layer5 = self.__makeLayer__(BasicBlock, self.args.inference.outputChannelsLayer6, self.args.inference.convLayersInResidualBlock) self.layer6 = self.__makeLayer__(BasicBlock, self.args.inference.outputChannelsLayer7, self.args.inference.convLayersInResidualBlock) self.layer_out = nn.Conv2d(self.args.inference.outputChannelsLayer7, self.args.inference.outputChannels, kernel_size=1, stride=1, padding=0, bias=self.use_bias) self.use_mean = self.args.inference.useMeanBatchNorm for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels torch.nn.init.kaiming_normal_(m.weight) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() def __makeLayer__(self, block, channels, blocks, stride=1): downsample = None if stride != 1 or self.input_planes != channels*block.expansion: downsample = nn.Sequential(nn.Conv2d(self.input_planes, channels*block.expansion, kernel_size=1, stride=stride, bias=self.use_bias), nn.BatchNorm2d(channels*block.expansion, track_running_stats=False)) layers = [] layers.append(block(self.input_planes, channels, stride, downsample, self.use_bias)) self.input_planes = channels*block.expansion for _ in range(1, blocks): layers.append(block(self.input_planes, channels)) return nn.Sequential(*layers)
[docs] def forward(self, inputs): signal = inputs x = self.conv1(signal) x = self.bn1(x) # x = custom_batch_norm(x, self.args.inference.inputChannels, self.use_mean) x = self.relu(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.layer5(x) x = self.layer6(x) out = self.layer_out(x) return out