Source code for PrognosAIs.IO.ConfigLoader

import datetime
import hashlib
import os
import random
import re
import shutil

import PrognosAIs.IO.utils as utils
import yaml


# TODO make sure in testing that all floats/ number are actually loaded as such
[docs]class ConfigLoader: def __init__(self, config_file): self.config_file = config_file loader = yaml.SafeLoader # We add the following to safely load scientific notation as floats loader.add_implicit_resolver( "tag:yaml.org,2002:float", re.compile( """^(?: [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) |\\.[0-9_]+(?:[eE][-+][0-9]+)? |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* |[-+]?\\.(?:inf|Inf|INF) |\\.(?:nan|NaN|NAN))$""", re.X, ), list("-+0123456789."), ) with open(config_file, "r") as ymlfile: self.cfg = yaml.load(ymlfile, Loader=loader) if "original_location" not in self.cfg: self.cfg["original_location"] = utils.get_parent_directory(self.config_file)
[docs] def copy_config(self, output_folder, save_name=None): if save_name is not None: new_config_location = os.path.join( output_folder, "config_{savename}.yaml".format(savename=save_name) ) else: new_config_location = os.path.join(output_folder, "config.yaml") with open(new_config_location, "w") as f: yaml.dump(self.cfg, f, default_flow_style=False) return new_config_location
[docs] def get_config_file(self): return self.config_file
[docs] def get_data_folder(self): return self.cfg["preprocessing"]["data_folder"]
[docs] def get_test_data_folder(self): return self.cfg["testing"]["data_folder"]
[docs] def get_resample_images(self): return self.cfg["preprocessing"]["resample_images"]
[docs] def get_make_patches(self): return self.cfg["preprocessing"]["make_patches"]
[docs] def get_patch_size(self): return self.cfg["preprocessing"]["patch_size"]
[docs] def get_reject_patches(self): return self.cfg["preprocessing"]["reject_patches"]
[docs] def get_min_patch_voxels(self): return self.cfg["preprocessing"]["min_patch_voxels"]
[docs] def get_extra_input_file(self): return self.cfg["preprocessing"]["extra_input_file"]
[docs] def get_mask_file(self): return self.cfg["preprocessing"]["mask_file"]
[docs] def get_fsl_reorient_bin(self): return self.cfg["program_paths"]["fslreorient2std_bin"]
[docs] def get_fsl_val_bin(self): return self.cfg["program_paths"]["fslval_bin"]
[docs] def get_label_file(self): return self.cfg["preprocessing"]["label_file"]
[docs] def get_test_label_file(self): return self.cfg["testing"]["label_file"]
[docs] def get_dataset_distribution(self): return self.cfg["preprocessing"]["dataset_distribution"]
[docs] def get_resample_size(self): return self.cfg["preprocessing"]["resample_image_size"]
[docs] def get_random_state(self): random_state = self.cfg["preprocessing"]["random_state"] if random_state == -1: random_state = random.randint(1, 5000000) return random_state
[docs] def get_stratify_index(self): stratify_index = self.cfg["preprocessing"]["stratify_index"] if stratify_index == "None": stratify_index = None return stratify_index
[docs] def get_multi_channels_patches(self): return self.cfg["preprocessing"]["multi_channel_patches"]
[docs] def get_N_max_patches(self): return self.cfg["preprocessing"]["N_max_patches"]
[docs] def get_use_mask_as_channel(self): return self.cfg["preprocessing"]["use_mask_as_channel"]
[docs] def get_use_mask_as_label(self): return self.cfg["preprocessing"]["use_mask_as_label"]
[docs] def get_keep_rejected_patches(self): return self.cfg["preprocessing"]["keep_rejected_patches"]
[docs] def get_mask_keyword(self): return self.cfg["preprocessing"]["mask_keyword"]
[docs] def get_center_patch_around_mask(self): return self.cfg["preprocessing"]["center_patch_around_mask"]
[docs] def get_rescale_mask_intensity(self): return self.cfg["preprocessing"]["rescale_mask_intensity"]
[docs] def get_model_name(self): return self.cfg["model"]["architecture"]["model_name"]
[docs] def get_model_settings(self): if "settings" in self.cfg["model"]["architecture"]: model_settings = self.cfg["model"]["architecture"]["settings"] if model_settings is None: model_settings = {} else: model_settings = {} return model_settings
[docs] def get_cache_in_memory(self): if "cache_in_memory" in self.cfg["training"]: cache_in_memory = self.cfg["training"]["cache_in_memory"] else: cache_in_memory = False return cache_in_memory
[docs] def get_float_policy(self): if "float_policy" in self.cfg["training"]: float_policy = self.cfg["training"]["float_policy"] else: float_policy = False return float_policy
[docs] def get_gpu_workers(self): return self.cfg["training"]["N_workers"]
[docs] def get_shuffle(self): return self.cfg["training"]["shuffle"]
[docs] def get_data_augmentation(self): return self.cfg["training"]["data_augmentation"]
[docs] def get_data_augmentation_factor(self): return self.cfg["training"]["augmentation_factor"]
[docs] def get_batch_size(self): return self.cfg["model"]["architecture"]["batch_size"]
[docs] def get_N_epoch(self): return self.cfg["model"]["architecture"]["N_epoch"]
[docs] def get_N_classes(self): return self.cfg["model"]["architecture"]["N_output"]
[docs] def get_make_one_hot(self): return self.cfg["training"]["make_one_hot"]
[docs] def get_filter_missing(self): return self.cfg["training"]["filter_missing"]
[docs] def get_dtype(self): return self.cfg["network"]["dtype"]
[docs] def get_copy_files(self): return self.cfg["training"]["copy_files"]
[docs] def get_test_model_file(self): return self.cfg["testing"]["model_file"]
[docs] def get_max_steps_per_epoch(self): return self.cfg["training"]["max_steps_per_epoch"]
[docs] def get_training_multi_processing(self): return self.cfg["training"]["multi_processing"]
[docs] def get_float16_epsilon(self): if "float16_epsilon" in self.cfg["training"]: return float(self.cfg["training"]["float16_epsilon"]) else: return 1e-4
[docs] def get_save_name(self): time_string = datetime.datetime.now().strftime("%Y_%m_%d_%H:%M:%S") return time_string
[docs] def get_image_size(self): if self.get_resample_images(): return self.cfg["preprocessing"]["resample_image_size"] elif self.get_make_patches(): return self.cfg["preprocessing"]["patch_size"] return None
[docs] def get_size_string(self): size = self.get_image_size() if size is not None: size_string = [str(size[i]) + "x" for i in range(len(size))] size_string = "".join(size_string) size_string = size_string[0:-1] else: size_string = None return size_string
[docs] def get_processed_samples_folder(self): base_dir = utils.get_parent_directory(self.get_data_folder()) processed_samples_folder = os.path.join(base_dir, "NPZ_samples_" + self.get_size_string()) return processed_samples_folder
[docs] def get_output_folder(self): return self.cfg["general"]["output_folder"]
[docs] def get_input_folder(self): return self.cfg["general"]["input_folder"]
[docs] def get_cluster_setting(self): if "cluster_type" in self.cfg["general"]: return self.cfg["general"]["cluster_type"] else: return None
[docs] def get_specific_output_folder(self): hasher = hashlib.sha512() with open(self.config_file, "rb") as afile: buf = afile.read() hasher.update(buf) hash_string = hasher.hexdigest() output_specifications = [ self.get_model_name(), self.get_save_name(), hash_string, ] output_string = "_".join(output_specifications) # specific_output_folder = os.path.join(self.get_output_folder(), output_string) # utils.create_directory(specific_output_folder) return output_string
[docs] def get_N_jobs(self): return self.cfg["cluster"]["N_jobs"]
[docs] def get_cluster_type(self): return self.cfg["cluster"]["type"]
[docs] def get_use_labels_from_rejection(self): return self.cfg["preprocessing"]["labels_from_rejection"]
[docs] def get_use_class_weights(self): return self.cfg["training"]["use_class_weights"]
[docs] def get_use_class_weights_in_losses(self): return self.cfg["training"]["use_class_weights_in_losses"]
[docs] def get_optimizer_settings(self): return self.cfg["model"]["optimizer"]
[docs] def get_loss_settings(self): return self.cfg["model"]["losses"]
[docs] def get_metric_settings(self): return self.cfg["model"]["metrics"]
[docs] def get_callback_settings(self): return self.cfg["model"]["callbacks"]
[docs] def get_loss_weights(self): if "loss_weights" in self.cfg["model"]: loss_weights = self.cfg["model"]["loss_weights"] else: loss_weights = None return loss_weights
[docs] def get_evaluation_metric_settings(self): return self.cfg["evaluation"]["metrics"]
# =============================================================== # NEW # ===============================================================
[docs] def get_preprocessings_settings(self): return self.cfg["preprocessing"]
[docs] def get_evaluation_mask_labels(self): if "image_outputs" in self.cfg["evaluation"]: mask_labels = self.cfg["evaluation"]["image_outputs"] if mask_labels is None: mask_labels = [] else: mask_labels = [] return mask_labels
[docs] def get_combine_patch_predictions(self): if "combine_patch_predictions" in self.cfg["evaluation"]: combine_patch_predictions = self.cfg["evaluation"]["combine_patch_predictions"] else: combine_patch_predictions = False return combine_patch_predictions
[docs] def get_patch_predictions(self): if "patch_predictions" in self.cfg["evaluation"]: patch_predictions = self.cfg["evaluation"]["patch_predictions"] else: patch_predictions = True return patch_predictions
[docs] def get_evaluate_train_set(self): if "evaluate_train_set" in self.cfg["evaluation"]: evaluate_train_set = self.cfg["evaluation"]["evaluate_train_set"] else: evaluate_train_set = True return evaluate_train_set
[docs] def get_label_combination_type(self): if "combination_type" in self.cfg["evaluation"]: combination_type = self.cfg["evaluation"]["combination_type"] else: combination_type = None return combination_type
[docs] def get_write_predictions(self): if "write_predictions" in self.cfg["evaluation"]: write_predictions = self.cfg["evaluation"]["write_predictions"] else: write_predictions = False return write_predictions
[docs] def get_evaluate_metrics(self): if "evaluate_metrics" in self.cfg["evaluation"]: evaluate_metrics = self.cfg["evaluation"]["evaluate_metrics"] else: evaluate_metrics = True return evaluate_metrics
[docs] def get_custom_definitions_file(self): file_found = False if "custom_definitions_file" in self.cfg["general"]: custom_definitions_file = self.cfg["general"]["custom_definitions_file"] if not os.path.exists(custom_definitions_file): custom_definitions_file = os.path.join( self.cfg["original_location"], custom_definitions_file ) if os.path.exists(custom_definitions_file): file_found = True if not file_found: raise ValueError("Custom definitions file was not found!") else: custom_definitions_file = None return custom_definitions_file
[docs] def get_do_augmentation(self): if "data_augmentation" in self.cfg["training"]: return self.cfg["training"]["data_augmentation"] else: return False
[docs] def get_data_augmentation_settings(self): if "augmentation_settings" in self.cfg["training"]: return self.cfg["training"]["augmentation_settings"] else: return {}
[docs] def get_shuffle_val(self): if "shuffle_validation" in self.cfg["training"]: return self.cfg["training"]["shuffle_validation"] else: return False
[docs] def get_class_weights(self): if "class_weights" in self.cfg["training"]: return self.cfg["training"]["class_weights"] else: return None
[docs] def get_shuffle_evaluation(self): if "shuffle" in self.cfg["evaluation"]: shuffle_evaluation = self.cfg["evaluation"]["shuffle"] else: shuffle_evaluation = False return shuffle_evaluation
[docs] def get_resume_training_from_model(self): if "resume_training" in self.cfg["training"]: return self.cfg["training"]["resume_training"] else: return False
[docs] def get_model_file(self): if "model_file" in self.cfg["training"]: return self.cfg["training"]["model_file"] else: return None