Source code for EMCqMRI.core.base.base_dataset

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from abc import ABC, abstractmethod
import numpy as np
import logging
import pickle
import os

[docs]class Dataset(ABC): def __init__(self, config_object): """Base class for implementation of datasets Args: config_object ([Configuration]): Configuration object where following attributes must be specified: - args.engine.fileExtension ([string]) """ super().__init__() self.data = [] try: self.args = config_object.args except AttributeError: raise Exception("Invalid configuration object passed to {}".format(self))
[docs] def index_files(self): """Index all files within the folder path contained in self.path. If fileExtension is provided (as an option in engine_configuration.txt), only files that match this file extension will be indexed. Raises: Exception: When fileExtension is not provided. Sets: self.filename_list ([list]) """ self.filename_list = sorted([f for f in os.listdir(self.path) if (os.path.isfile(os.path.join(self.path, f)))]) if not len(self.filename_list)>0: logging.error("Data folder is empty. Either folder path is not correct, file extension is not correct or useSimulatedData flag is incorrectly set.")
[docs] def set_data_path(self, path): """ Set the path for the folder containing training, validation or testing files. Args: path ([string]): path for training or testing data. It is provided as an option in engine_configuration.txt. Raises: Exception: When path does not point to a directory. Exception: When path points to an empty directory. Sets: self.path ([string]) """ if os.path.isdir(path): if not len(os.listdir(path)) == 0: self.path = path else: raise Exception("Data directory is empty.") else: raise Exception("path is not a directory.")
[docs] def load_folder(self): """It pre-loads all the data in self.path. It is called when the option preLoadData==True. Use with care, if the size of the data is too large, might cause memory issues. Raises: Exception: When fileExtension is not provided. Sets: self.data ([list]) """ for file in self.filename_list: if self.args.engine.fileExtension: if 'pkl' in self.args.engine.fileExtension: with open(os.path.join(self.path, file), 'rb') as f: data_ = pickle.load(f) self.data.append(data_) if 'rawb' in self.args.engine.fileExtension: data_ = np.fromfile(os.path.join(self.path, file), dtype='int8', sep="") self.data.append(data_) if 'nii' in self.args.engine.fileExtension: pass if 'nii.gz' in self.args.engine.fileExtension: pass else: raise Exception("File extension not supported or not provided.")
[docs] def load_file(self, idx): """It loads the file indexed by idx. It is called when the option preLoadData==False. Args: idx ([int]): the index of the file to load from disk. Raises: Exception: When fileExtension is not provided. """ file = self.filename_list[idx] if self.args.engine.fileExtension: if 'pkl' in self.args.engine.fileExtension: with open(os.path.join(self.path, file), 'rb') as f: data_ = pickle.load(f) if 'rawb' in self.args.engine.fileExtension: data_ = np.fromfile(os.path.join(self.path, file), dtype='int8', sep="") if 'nii' in self.args.engine.fileExtension: pass if 'nii.gz' in self.args.engine.fileExtension: pass self.data = [data_] else: raise Exception("File extension not supported or not provided.")
[docs] def get_length(self): """ It returns the length of the dataset (i.e. number of files in the dataset folder). Raises: TypeError: When self.filename_list is not initialized. Returns: (Data set length[list]): The number of files within the data directory. """ try: return len(self.filename_list) except TypeError: logging.warn("self.filename_list not initialized")
[docs] @abstractmethod def get_label(self, idx): """Abstract function for fetching labels and masks from the data directory. Need subclass to implement different logics, like Relaxometry, Reconstruction, QSM, etc. Args: idx ([int]): Index automatically provided by the DataLoader. It ranges from 0 to len(self.filename_list)-1. Raises: NotImplementedError: When the subclass does not override this method. Returns: (Label [torch.Tensor]), (Mask [torch.Tensor] - optional): The subclass implementation should return the label, and, if available, a mask. For seemless integration with the framework, label should have shape [N, X, Y, ...] and the mask should have shape [1,X,Y, ...], where N is the number of channels (e.g. weighted images) and X,Y,... are the image dimensions. """ raise NotImplementedError("get_training_label not implemented")
[docs] @abstractmethod def get_signal(self, idx): """Abstract function for fetching the signal from the data directory. Need subclass to implement different logics, like Relaxometry, Reconstruction, QSM, etc. Args: idx ([int]): Index automatically provided by the DataLoader. It ranges from 0 to len(self.filename_list)-1. Raises: NotImplementedError: When the subclass does not override this method. Returns: (Signal [torch.Tensor]): The subclass implementation should return the signal. For seemless integration with the framework, signal should have shape [N, X, Y, ...], where N is the number of channels (e.g. weighted images) and X,Y,... are the image dimensions. """ raise NotImplementedError("get_training_signal not implemented")