Source code for PrognosAIs.Model.Parsers

import copy
import importlib
import os

import tensorflow.keras.callbacks
import tensorflow.keras.losses
import tensorflow.keras.metrics
import tensorflow.keras.optimizers

import PrognosAIs.Model.Callbacks
import PrognosAIs.Model.Losses
import PrognosAIs.Model.Metrics

from PrognosAIs.IO import utils as IO_utils


[docs]class StandardParser: def __init__(self, config: dict, module_paths: list): self.config = config self.module_paths = module_paths
[docs] def parse_settings(self): if self.config is not None: if "name" in self.config: parsed_settings = self._initiate_class(self.config["name"], self.config["settings"]) else: parsed_settings = {} for key, value in self.config.items(): if "name" in value: parsed_settings[key] = self._initiate_class( value["name"], value["settings"] ) else: settings = [] for key_deep, value_deep in value.items(): settings.append( self._initiate_class(value_deep["name"], value_deep["settings"]) ) parsed_settings[key] = settings else: parsed_settings = None return parsed_settings
def _initiate_class(self, class_name, class_settings): class_function = self.get_class(class_name) if class_settings is None: initiated_class = class_function() else: initiated_class = class_function(**class_settings) return initiated_class
[docs] def get_class(self, class_name): class_function = None module_index = 0 while class_function is None and module_index < len(self.module_paths): module_path = self.module_paths[module_index] if isinstance(module_path, str): module_path = IO_utils.load_module_from_file(module_path) class_function = getattr(module_path, class_name, None) module_index += 1 if class_function is None: raise ValueError("Requested class {} not found!".format(class_name)) return class_function
[docs]class LossParser(StandardParser):
[docs] def __init__(self, loss_settings: dict, class_weights: dict = None, module_paths=None): """ Parse loss settings to actual losses Args: loss_settings: Settings for the losses Returns: None """ # TODO remove the need for a "settings" index if there is no settings # Index just need to default to None. # Also get rid of all the "vague" keywords that are then in the config self.class_weights = class_weights self.module_paths = [tensorflow.keras.losses, PrognosAIs.Model.Losses] if module_paths is not None: if isinstance(module_paths, list): self.module_paths.extend(module_paths) else: self.module_paths.append(module_paths) if "name" in loss_settings and class_weights is not None: loss_settings["settings"]["class_weight"] = self.class_weights elif "name" not in loss_settings: for key, value in loss_settings.items(): if ( value["settings"] is not None and class_weights is not None and key in self.class_weights.keys() ): value["settings"]["class_weight"] = self.class_weights[key] super().__init__(loss_settings, self.module_paths)
[docs] def get_losses(self): return self.parse_settings()
[docs]class OptimizerParser(StandardParser):
[docs] def __init__(self, optimizer_settings: dict, module_paths=None) -> None: """ Interfacing class to easily get a tf.keras.optimizers optimizer Args: optimizer_settings: Arguments to be passed to the optimizer Returns: None """ self.module_paths = [tensorflow.keras.optimizers] if module_paths is not None: if isinstance(module_paths, list): self.module_paths.extend(module_paths) else: self.module_paths.append(module_paths) super().__init__(optimizer_settings, self.module_paths) return
[docs] def get_optimizer(self): return self.parse_settings()
[docs]class CallbackParser(StandardParser):
[docs] def __init__( self, callback_settings: dict, root_path: str = None, module_paths=None, save_name=None ): """ Parse callback settings to actual callbacks Args: callback_settings: Settings for the callbacks Returns: None """ self.module_paths = [tensorflow.keras.callbacks, PrognosAIs.Model.Callbacks] self.save_name = save_name if module_paths is not None: if isinstance(module_paths, list): self.module_paths.extend(module_paths) else: self.module_paths.append(module_paths) super().__init__(callback_settings, self.module_paths) if root_path is not None: self.config = self.replace_root_path(self.config, root_path) return
[docs] def replace_root_path(self, settings, root_path): for key, value in settings.items(): if type(value) == dict: settings[key] = self.replace_root_path(value, root_path) else: # TODO this gives error, for example with TensorBoard which has profile_batch option # (because file is in profile) if "file" in key: settings[key] = os.path.join(root_path, value) if self.save_name is not None: settings[key] = settings[key].format(savename=self.save_name) return settings
[docs] def get_callbacks(self): # Need to make sure that CSVLogger is the last so that everything is properly stored parsed_callbacks = list(self.parse_settings().values()) out_parsed_callbacks = [] logger_index = -1 for i_i_parsed_callback, i_parsed_callback in enumerate(parsed_callbacks): if type(i_parsed_callback) == tensorflow.keras.callbacks.CSVLogger: logger_index = i_i_parsed_callback else: out_parsed_callbacks.append(i_parsed_callback) if logger_index != -1: out_parsed_callbacks.append(parsed_callbacks[logger_index]) return out_parsed_callbacks
[docs]class MetricParser(StandardParser):
[docs] def __init__(self, metric_settings: dict, label_names: list = None, module_paths=None) -> None: """ Parse metrics settings to actual metrics Args: loss_settings: Settings for the losses """ self.module_paths = [tensorflow.keras.metrics, PrognosAIs.Model.Metrics] if module_paths is not None: if isinstance(module_paths, list): self.module_paths.extend(module_paths) else: self.module_paths.append(module_paths) super().__init__(metric_settings, self.module_paths) if label_names is not None: self.label_names = label_names else: self.label_names = []
[docs] def get_metrics(self): parsed_callbacks = self.parse_settings() if not isinstance(parsed_callbacks, dict) and parsed_callbacks is not None: parsed_callbacks = [parsed_callbacks] elif ( parsed_callbacks is not None and self.label_names != [] and sorted(parsed_callbacks.keys()) != sorted(self.label_names) ): parsed_callbacks = list(parsed_callbacks.values()) elif parsed_callbacks is None: parsed_callbacks = [] return parsed_callbacks
[docs] def convert_metrics_list_to_dict(self, metrics: list) -> dict: if isinstance(metrics, list): new_metrics = {} for i_label_name in self.label_names: new_metrics[i_label_name] = [] other_label_names = copy.deepcopy(self.label_names) other_label_names.remove(i_label_name) for i_metric in metrics: i_metric_name = i_metric.name i_metric_base_name = i_metric.__class__().name i_metric_label_name = i_metric_name.split(i_metric_base_name)[0] # No output specific metrics, append all metrics to all outputs if "_" not in i_metric_label_name: new_metrics[i_label_name].append(i_metric) else: # Need to cut off the last underscore to get the label name i_metric_label_name = i_metric_label_name[:-1] if i_metric_label_name == i_label_name: new_metrics[i_label_name].append(i_metric) elif isinstance(metrics, dict): for key, value in metrics.items(): if not isinstance(value, list): metrics[key] = [value] new_metrics = metrics else: new_metrics = metrics return new_metrics