Source code for PrognosAIs.Model.Trainer

from __future__ import annotations

import argparse
import copy
import json
import logging
import os
import sys

from typing import Tuple
from typing import Union

import tensorflow as tf
from tensorflow.keras.models import load_model


from tensorflow.keras.mixed_precision import experimental as mixed_precision

import PrognosAIs.Constants
import PrognosAIs.IO.utils as IO_utils
import PrognosAIs.Model.Architectures
import PrognosAIs.Model.Parsers as ModelParsers

from PrognosAIs.IO import ConfigLoader
from PrognosAIs.IO import DataGenerator
from PrognosAIs.Model.Architectures import VGG
from PrognosAIs.Model.Architectures import AlexNet
from PrognosAIs.Model.Architectures import DDSNet
from PrognosAIs.Model.Architectures import DenseNet
from PrognosAIs.Model.Architectures import InceptionNet
from PrognosAIs.Model.Architectures import ResNet
from PrognosAIs.Model.Architectures import UNet


[docs]class Trainer: """Trainer to be used for training a model."""
[docs] def __init__( self, config: ConfigLoader.ConfigLoader, sample_folder: str, output_folder: str, tmp_data_folder: str = None, save_name: str = None, ) -> None: """ Trainer to be used for training a model. Args: config (ConfigLoader.ConfigLoader): Config to be used sample_folder (str): Folder containing the train and validation samples output_folder (str): Folder to put the resulting model tmp_data_folder (str): Folder to copy samples to and load from. Defaults to None. save_name (str): Specify a name to save the model as instead of using a automatically generated one. Defaults to None. """ self._model = None self._train_data_generator_is_setup = False self._validation_data_generator_is_setup = False self.cluster_resolver = None self.total_memory_used = 0 self.multiworker = False self.worker_index = 0 self.n_workers = 1 self.steps_per_epoch = None self.validation_steps = None self.config = copy.deepcopy(config) self.output_folder = os.path.join(output_folder, PrognosAIs.Constants.MODEL_SUBFOLDER) self.sample_folder = sample_folder self.tmp_data_folder = tmp_data_folder IO_utils.setup_logger() logging.info( "Using configuration file: {config_file}".format(config_file=self.config.config_file), ) logging.info("Loading samples from: {input_dir}".format(input_dir=self.sample_folder)) logging.info("Putting output in: {output}".format(output=self.output_folder)) if save_name is None: self.save_name = self.config.get_save_name() else: self.save_name = save_name self.model_save_file = ".".join( [os.path.join(self.output_folder, self.save_name), PrognosAIs.Constants.HDF5_EXTENSION], ) logging.info("Will save model as: {save_name}".format(save_name=self.model_save_file)) IO_utils.create_directory(self.output_folder) self.config.copy_config(self.output_folder, self.save_name) # First thing we need to do is get the precision strategy # After this we will clear up the GPUs again, so this needs to be done # before tensorflow allocates anything (such as in the distribution # strategy) self.set_precision_strategy(self.config.get_float_policy()) self.distribution_strategy = self.get_distribution_strategy() self.custom_definitions_file = self.config.get_custom_definitions_file() self.class_weights = self.load_class_weights() self.validation_ds_dir = os.path.join( self.sample_folder, PrognosAIs.Constants.VALIDATION_DS_NAME, ) self.do_validation = os.path.exists(self.validation_ds_dir)
# =============================================================== # Distribution and precision strategies # ===============================================================
[docs] @staticmethod def set_tf_config( cluster_resolver: tf.distribute.cluster_resolver.ClusterResolver, environment: str = None, ) -> None: """ Set the TF_CONFIG env variable from the given cluster resolver. From https://github.com/tensorflow/tensorflow/issues/37693 Args: cluster_resolver (tf.distribute.cluster_resolver.ClusterResolver): cluster resolver to use. environment (str): Environment to set in TF_CONFIG. Defaults to None. """ cfg = { "cluster": cluster_resolver.cluster_spec().as_dict(), "task": { "type": cluster_resolver.get_task_info()[0], "index": cluster_resolver.get_task_info()[1], }, "rpc_layer": cluster_resolver.rpc_layer, } if environment: cfg["environment"] = environment os.environ["TF_CONFIG"] = json.dumps(cfg) logging.info( "Set up TF config environmentas : {TF_CONFIG}".format( TF_CONFIG=os.environ["TF_CONFIG"], ), )
[docs] def set_precision_strategy(self, float_policy_setting: Union[str, bool]) -> None: """ Set the appropiate precision strategy for GPUs. If the GPUs support it a mixed float16 precision will be used (see tf.keras.mixe_precision for more information), which reduces the memory overhead of the training, while doing computation in float32. If GPUs dont support mixed precision, we will try a float16 precision setting. If that doesn't work either the normal policy is used. If you get NaN values for loss or loss doesn't converge it might be because of the policy. Try running the model without a policy setting. Args: float_policy_setting (float_policy_setting: Union[str, bool]): Which policy to select if set to PrognosAIs.Constants.AUTO, we will automatically determine what can be done. "mixed" will only consider mixed precision, "float16" only considers float16 policy. Set to False to not use a policy """ gpus = IO_utils.get_gpu_devices() gpu_supports_mixed_precision = all( IO_utils.gpu_supports_mixed_precision(i_gpu) for i_gpu in gpus ) gpu_supports_float16_precision = all(IO_utils.gpu_supports_float16(i_gpu) for i_gpu in gpus) mixed_precision_allowed = float_policy_setting in [PrognosAIs.Constants.AUTO, "mixed"] float16_precision_allowed = float_policy_setting in [PrognosAIs.Constants.AUTO, "float16"] if gpu_supports_mixed_precision: logging.info("GPU supports a mixed float16 policy") if gpu_supports_float16_precision: logging.info("GPU support float16 precision policy") if gpu_supports_mixed_precision and mixed_precision_allowed and len(gpus) > 0: policy = mixed_precision.Policy("mixed_float16", loss_scale="dynamic") mixed_precision.set_policy(policy) logging.info( ( "Using a mixed float16 policy, with compute dtype {cdtype} " "and variable dtype {vdtype}.\n" "Loss scaling applied: {loss_scale}" ).format( cdtype=policy.compute_dtype, vdtype=policy.variable_dtype, loss_scale=policy.loss_scale, ), ) logging.warning( "If model is not converging or loss give NaNs, consider turning of this policy", ) elif gpu_supports_float16_precision and float16_precision_allowed and len(gpus) > 0: policy = mixed_precision.Policy("float16", loss_scale="dynamic") mixed_precision.set_policy(policy) logging.info( ( "Using a float16 policy, with compute dtype {cdtype} " "and variable dtype {vdtype}.\n" "Loss scaling applied: {loss_scale}" ).format( cdtype=policy.compute_dtype, vdtype=policy.variable_dtype, loss_scale=policy.loss_scale, ), ) logging.warning( "If model is not converging or loss give NaNs, consider turning of this policy", ) else: logging.info("No float policy is being used")
[docs] def get_distribution_strategy(self) -> tf.distribute.Strategy: """ Get the appropiate distribution strategy. A strategy will be returned that can either distribute the training over multiple SLURM nodes, over multi GPUs, train on a single GPU or on a single CPU (in that order). Returns: tf.distribute.Strategy: The distribution strategy to be used in training. """ if IO_utils.get_number_of_slurm_nodes() > 1: # We are in a slurm environment for multiworker self.cluster_resolver = tf.distribute.cluster_resolver.SlurmClusterResolver() self.set_tf_config(self.cluster_resolver) strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy( cluster_resolver=self.cluster_resolver, communication=tf.distribute.experimental.CollectiveCommunication.NCCL, ) self.multiworker = True self.worker_index = self.cluster_resolver.get_task_info()[1] self.n_workers = ( len(self.cluster_resolver.get_task_info()[0]) * self.cluster_resolver.max_tasks_per_node ) logging.info( "Using a multi-worker distribution environment with {nodes} nodes and {tasks} tasks per node".format( nodes=len(self.cluster_resolver.get_task_info()[0]), tasks=self.cluster_resolver.max_tasks_per_node, ), ) elif IO_utils.get_number_of_gpu_devices() > 1: strategy = tf.distribute.MirroredStrategy() logging.info( "Using a mirrored distribution environment with {gpus} parallel GPUs".format( gpus=strategy.num_replicas_in_sync, ), ) elif IO_utils.get_number_of_gpu_devices() == 1: gpus = IO_utils.get_gpu_devices() gpu_device_name = ":".join(gpus[0].name.split(":")[-2:]) strategy = tf.distribute.OneDeviceStrategy(gpu_device_name) logging.info( "Using a single device strategy with {device} as device".format( device=gpu_device_name, ), ) else: cpus = IO_utils.get_cpu_devices() cpu_device_name = ":".join(cpus[0].name.split(":")[-2:]) strategy = tf.distribute.OneDeviceStrategy(cpu_device_name) logging.info( "Using a single device strategy with {device} as device".format( device=cpu_device_name, ), ) return strategy
# =============================================================== # Set-up of data # ===============================================================
[docs] def load_class_weights(self) -> Union[None, dict]: """ Load the class weight from the class weight file. Returns: Union[None, dict]: Class weights if requested and the class weight file exists, otherwise None. """ class_weight_file = os.path.join(self.sample_folder, PrognosAIs.Constants.CLASS_WEIGHT_FILE) has_class_weight_file = os.path.exists(class_weight_file) class_weights = self.config.get_class_weights() if not has_class_weight_file and class_weights is None: logging.warning("Class weight file not found, not using class weights") if self.config.get_use_class_weights() and ( class_weights is not None or has_class_weight_file ): if class_weights is None: with open(class_weight_file, "r") as the_class_weight_file: class_weights = json.load(the_class_weight_file) out_class_weights = {} for i_key, i_value in class_weights.items(): this_class_weight = {} for i_class, i_weight in i_value.items(): this_class_weight[int(i_class)] = float(i_weight) out_class_weights[i_key] = this_class_weight if len(out_class_weights.keys()) == 1: out_class_weights = list(out_class_weights.values())[0] logging.info( "Using the following class weights: {weights}".format(weights=out_class_weights) ) else: out_class_weights = None logging.info( "Class weight file found, but requested to not use class weights, thus not using class weights" ) return out_class_weights
[docs] def move_data_to_temporary_folder(self, data_folder: str) -> str: """ Move the data to a temporary directory before loading. Args: data_folder (str): The original data folder Returns: str: Folder to which the data has been moved """ if self.tmp_data_folder is not None and self.config.get_copy_files(): IO_utils.copy_directory(data_folder, self.tmp_data_folder) new_folder = os.path.join(self.tmp_data_folder, IO_utils.get_root_name(data_folder)) logging.info("Loading data from temporary directory {temp}".format(temp=new_folder)) else: new_folder = data_folder return new_folder
@property def train_data_generator(self) -> DataGenerator.HDF5Generator: """ The train data generator to be used in training. Returns: DataGenerator.HDF5Generator: The train data generator """ if not self._train_data_generator_is_setup: logging.info("Setting up train data generator") train_ds_folder = os.path.join(self.sample_folder, PrognosAIs.Constants.TRAIN_DS_NAME) train_ds_folder = self.move_data_to_temporary_folder(train_ds_folder) self._train_data_generator = self.setup_data_generator(train_ds_folder) self._train_data_generator_is_setup = True if self.multiworker: self.steps_per_epoch = self.train_data_generator.steps return self._train_data_generator @property def validation_data_generator(self) -> DataGenerator.HDF5Generator: """ The validation data generator to be used in training. Returns: DataGenerator.HDF5Generator: The validation data generator """ if not self._validation_data_generator_is_setup and self.do_validation: logging.info("Setting up validation data generator") validation_ds_folder = self.move_data_to_temporary_folder(self.validation_ds_dir) self._validation_data_generator = self.setup_data_generator(validation_ds_folder) self._validation_data_generator_is_setup = True if self.multiworker: self.validation_steps = self.validation_data_generator.steps else: self._validation_data_generator = None self._validation_data_generator_is_setup = True return self._validation_data_generator
[docs] def setup_data_generator(self, sample_folder: str) -> DataGenerator.HDF5Generator: """ Set up a data generator for a folder containg train samples. Args: sample_folder (str): The path to the folder containing the sample files. Raises: ValueError: If the sample folder does not exist or does not contain any samples. Returns: DataGenerator.HDF5Generator: Datagenerator of the sample in the sample folder. """ if ( IO_utils.get_root_name(sample_folder) == PrognosAIs.Constants.VALIDATION_DS_NAME and not self.config.get_shuffle_val() ): # In the validation data we don't shuffle and don't apply data augmentation # In that way we can be sure that validation metrics will not depend # on random components data_augmentation = False shuffle = False logging.warning("Turned off data augmentation and shuffling for validation data") else: data_augmentation = self.config.get_do_augmentation() shuffle = self.config.get_shuffle() batch_size = self.config.get_batch_size() * self.distribution_strategy.num_replicas_in_sync if self.distribution_strategy.num_replicas_in_sync > 1: logging.info( ( "Requested batch size {batch}, changed to {global_batch} to work correctly with " "distribution strategy" ).format(batch=self.config.get_batch_size(), global_batch=batch_size), ) if os.path.exists(sample_folder) and len(os.listdir(sample_folder)) > 0: with self.distribution_strategy.scope(): # Need to multiply the requested batch size by number of replicas # in the distribution strategy to get the total global patch sizes # the distribution strategy will then make sure that on each device # the batch size is the same as the one requested data_generator = DataGenerator.HDF5Generator( sample_folder, batch_size=batch_size, shuffle=shuffle, max_steps=self.config.get_max_steps_per_epoch(), ) # We will try to set-up caching of the dataset in memory data_generator.setup_caching( self.config.get_cache_in_memory(), self.total_memory_used, ) if data_generator.cache_in_memory: self.total_memory_used += data_generator.memory_size if data_augmentation: data_generator.setup_augmentation( self.config.get_data_augmentation_factor(), self.config.get_data_augmentation_settings(), ) if self.multiworker: data_generator.repeat = True data_generator.setup_sharding(self.worker_index, self.n_workers) logging.info( "Set up sharding with {workers} workers for worker index {index} for the data generator".format( workers=self.n_workers, index=self.worker_index, ), ) else: raise ValueError( "Dataset directory {ds_name} does not exist, cannot create data generator!".format( ds_name=sample_folder, ), ) return data_generator
# =============================================================== # Model setup # =============================================================== @property def model(self) -> tf.keras.Model: """ Model to be used in training. Returns: tf.keras.Model: The model """ if self._model is None: self._model = self.setup_model() logging.info( "Using the following model:\n{model}".format( model=self._model.summary(line_length=120), ), ) return self._model
[docs] @staticmethod def _get_architecture_name(model_name: str, input_dimensionality: dict) -> Tuple[str, str]: """ Get the full architecture name from the model name and input dimensionality. Args: model_name (str): Name of the model input_dimensionality (dict): Dimensionality of the different inputs Returns: Tuple[str, str]: Class name of architecture and full achitecture name """ separator = "_" architecture_name_parts = model_name.split(separator) architecture_class_name = architecture_name_parts[0] # We get the model that will fit the maximum dimensionality of our inputs max_dimensionality = int(max(input_dimensionality.values())) full_architecture_name = "{model_name}_{input_dimensionality}D".format( model_name=model_name, input_dimensionality=max_dimensionality, ) return architecture_class_name, full_architecture_name
[docs] def _setup_model(self) -> tf.keras.Model: """ Get the model architecture from the architecture name (not yet compiled). Raises: ValueError: If architecture is not known Returns: tf.keras.Model: The loaded architecture """ architecture_class_name, full_architecture_name = self._get_architecture_name( self.config.get_model_name(), self.train_data_generator.get_feature_dimensionality(), ) architecture_class = getattr(PrognosAIs.Model.Architectures, architecture_class_name, None) if architecture_class is None: architecture_class = IO_utils.load_module_from_file( self.config.get_custom_definitions_file(), ) architecture = getattr(architecture_class, full_architecture_name, None) if architecture is None: architecture_class = IO_utils.load_module_from_file( self.config.get_custom_definitions_file(), ) architecture = getattr(architecture_class, full_architecture_name, None) if architecture is None: err_msg = "Could not find requested model {model}!".format(model=full_architecture_name) raise ValueError(err_msg) return architecture( self.train_data_generator.get_feature_shape(), self.train_data_generator.get_number_of_classes(), model_config=self.config.get_model_settings(), # TODO SET FROM CONFIG input_data_type=tf.keras.backend.floatx(), ).create_model()
def _load_model(self) -> tf.keras.Model: logging.info("Loaded the model") return load_model(self.config.get_model_file(), compile=False)
[docs] def setup_model(self) -> tf.keras.Model: """ Set up model to be used during train. Returns: tf.keras.Model: The compiled model to be trained. """ if self.config.get_use_class_weights_in_losses() and self.class_weights is not None: loss_parser = ModelParsers.LossParser( self.config.get_loss_settings(), self.class_weights, self.custom_definitions_file, ) self.class_weights = None logging.info("Using class weights directly inside losses instead of during fit") else: loss_parser = ModelParsers.LossParser( self.config.get_loss_settings(), module_paths=self.custom_definitions_file, ) model_loss = loss_parser.get_losses() loss_weights = self.config.get_loss_weights() with self.distribution_strategy.scope(): # Set up the model and compile the model # Using the strategy to make sure everything is distributed properly logging.info("Setting up model") if self.config.get_resume_training_from_model(): model = self._load_model() else: model = self._setup_model() logging.info("Setting up optimizer") optimizer = ModelParsers.OptimizerParser( self.config.get_optimizer_settings(), self.custom_definitions_file, ).get_optimizer() logging.info("Setting up metrics") model_metrics = ModelParsers.MetricParser( self.config.get_metric_settings(), self.train_data_generator.label_names, self.custom_definitions_file, ).get_metrics() model_message = ( "The following settings are used to set up the model:" "Loss: {loss}" "Optimizer: {optimizer}" "Metrics: {metrics}" "Loss weights: {loss_weights}" ).format( loss=model_loss, optimizer=optimizer, metrics=model_metrics, loss_weights=loss_weights, ) logging.info(model_message) model.compile( loss=model_loss, optimizer=optimizer, metrics=model_metrics, loss_weights=loss_weights, ) return model
[docs] def setup_callbacks(self) -> list: """ Set up callbacks to be used during training. Returns: list: the callbacks """ with self.distribution_strategy.scope(): logging.info("Setting up callbacks") callbacks = ModelParsers.CallbackParser( self.config.get_callback_settings(), self.output_folder, self.custom_definitions_file, self.save_name, ).get_callbacks() logging.info("Using the following callbacks: {callbacks}".format(callbacks=callbacks)) return callbacks
# =============================================================== # Model training # ===============================================================
[docs] def train_model(self) -> str: """ Train the model. Returns: str: The location where the model has been saved """ with self.distribution_strategy.scope(): train_data = self.train_data_generator.get_tf_dataset() if self.do_validation: validation_data = self.validation_data_generator.get_tf_dataset() else: validation_data = None epochs = self.config.get_N_epoch() callbacks = self.setup_callbacks() logging.info("Starting training") logging.debug( ( "Training with following parameters:\n" "Train data: {train}\n" "Validation data: {val}\n" "Epochs: {epoch}\n" "Callbacks: {callback}\n" "Class weights: {weights}\n" "Steps per epoch: {steps}\n" "Validation setps: {val_steps}\n" ).format( train=train_data, val=validation_data, epoch=epochs, callback=callbacks, weights=self.class_weights, steps=self.steps_per_epoch, val_steps=self.validation_steps, ), ) self.model.fit( train_data, validation_data=validation_data, epochs=epochs, callbacks=callbacks, shuffle=False, class_weight=self.class_weights, verbose=1, steps_per_epoch=self.steps_per_epoch, validation_steps=self.validation_steps, ) logging.info("Finished training") if self.worker_index == 0: self.model.save(self.model_save_file) logging.info("Model saved to {save_file}".format(save_file=self.model_save_file)) else: # We need to save the model for other workers as well, otherwise # We run into errors, however we instantly delete because we dont actually # Need the other models model_save_file = ".".join( [ os.path.join(self.output_folder, self.save_name + "_" + str(self.worker_index)), PrognosAIs.Constants.HDF5_EXTENSION, ], ) self.model.save(model_save_file) os.remove(model_save_file) return self.model_save_file
# =============================================================== # External use # ===============================================================
[docs] @classmethod def init_from_sys_args(cls: Trainer, args_in: list) -> Trainer: """ Initialize a Trainer object from the command line. Args: args_in (list): Arguments to parse to the trainer Returns: Trainer: The trainer object """ parser = argparse.ArgumentParser(description="Train a CNN") parser.add_argument( "-c", "--config", required=True, help="The location of the PrognosAIs config file", metavar="configuration file", dest="config", type=str, ) parser.add_argument( "-i", "--input", required=True, help="The input directory where the samples to train on are located", metavar="Input directory", dest="input_dir", type=str, ) parser.add_argument( "-o", "--output", required=True, help="The output directory where to store the saved model", metavar="Output directory", dest="output_dir", type=str, ) parser.add_argument( "-T", "--tmp", required=False, help="The temporary directory", metavar="Temp directory", dest="tmp_dir", type=str, default=None, ) parser.add_argument( "-s", "--savename", required=False, help="Save name for model", metavar="Save name", dest="save_name", type=str, default=None, ) args = parser.parse_args(args_in) config = ConfigLoader.ConfigLoader(args.config) return cls(config, args.input_dir, args.output_dir, args.tmp_dir, args.save_name)
if __name__ == "__main__": trainer = Trainer.init_from_sys_args(sys.argv[1:]) trainer.train_model()