Source code for trojai.modelgen.config

import collections.abc
import copy
import importlib
import logging
import os
from abc import ABC, abstractmethod
from typing import Callable
from typing import Union, Sequence, Any
import math

import cloudpickle as pickle
import numpy as np
import torch

from .architecture_factory import ArchitectureFactory
from .constants import VALID_LOSS_FUNCTIONS, VALID_DEVICES, VALID_OPTIMIZERS
from .data_manager import DataManager
from .optimizer_interface import OptimizerInterface

logger = logging.getLogger(__name__)

"""
Defines all configurations pertinent to model generation.
"""


[docs]def identity_function(x): return x
default_soft_to_hard_fn_kwargs = dict()
[docs]class DefaultSoftToHardFn: """ The default conversion from soft-decision outputs to hard-decision """ def __init__(self): pass def __call__(self, y_hat, *args, **kwargs): return torch.argmax(y_hat, dim=1) def __repr__(self): return "torch.argmax(y_hat, dim=1)"
[docs]class ConfigInterface(ABC): """ Defines the interface for all configuration objects """ @abstractmethod def __deepcopy__(self, memodict={}): pass
[docs]class OptimizerConfigInterface(ConfigInterface):
[docs] @abstractmethod def get_device_type(self): pass
[docs] def save(self, fname): pass
[docs] @staticmethod @abstractmethod def load(fname): pass
[docs]class EarlyStoppingConfig(ConfigInterface): """ Defines configuration related to early stopping. """ def __init__(self, num_epochs: int = 5, val_loss_eps: float = 1e-3): """ :param num_epochs: the # of epochs for which to monitor the validation accuracy over :param val_loss_eps: the threshold between the validation loss for the # of epochs to monitor the before deciding to perform early stopping """ self.num_epochs = num_epochs self.val_loss_eps = val_loss_eps self.validate()
[docs] def validate(self): if not isinstance(self.num_epochs, int) or self.num_epochs < 2: msg = "num_epochs to monitor must be an integer > 1!" logger.error(msg) raise ValueError(msg) try: self.val_loss_eps = float(self.val_loss_eps) except ValueError: msg = "val_loss_eps must be a float" logger.error(msg) raise ValueError(msg) if self.val_loss_eps < 0: msg = "val_loss_eps must be >= 0!" logger.error(msg) raise ValueError(msg)
def __deepcopy__(self, memodict={}): return EarlyStoppingConfig(self.num_epochs, self.val_loss_eps) def __eq__(self, other): if self.num_epochs == other.num_epochs and math.isclose(self.val_loss_eps, other.val_acc_eps): return True else: return False def __str__(self): return "ES[%d:%0.02f]" % (self.num_epochs, self.val_loss_eps)
[docs]class TrainingConfig(ConfigInterface): """ Defines all required items to setup training with an optimizer """ def __init__(self, device: Union[str, torch.device] = 'cpu', epochs: int = 10, batch_size: int = 32, lr: float = 1e-4, optim: Union[str, OptimizerInterface] = 'adam', optim_kwargs: dict = None, objective: Union[str, Callable] = 'cross_entropy_loss', objective_kwargs: dict = None, save_best_model: bool = False, train_val_split: float = 0.05, val_data_transform: Callable[[Any], Any] = None, val_label_transform: Callable[[int], int] = None, val_dataloader_kwargs: dict = None, early_stopping: EarlyStoppingConfig = None, soft_to_hard_fn: Callable = None, soft_to_hard_fn_kwargs: dict = None, lr_scheduler: Any = None, lr_scheduler_init_kwargs: dict = None, lr_scheduler_call_arg: Any = None, clip_grad: bool = False, clip_type: str = "norm", clip_val: float = 1., clip_kwargs: dict = None, adv_training_eps: float = None, adv_training_iterations: int = None, adv_training_ratio: float = None) -> None: """ Initializes a TrainingConfig object :param device: string or torch.device object representing the device on which computation will be performed :param epochs: the number of epochs to train the model :param batch_size: batch size used to train the model :param lr: the learning rate :param optim: either one of trojai_private.modelgen.constants.VALID_OPTIMIZERS or an optimizer object implementing trojai_private.modelgen.optimizer_interface.OptimizerInterface :param optim_kwargs: any additional kwargs to be passed to the optimizer :param objective: either one of trojai_private.modelgen.constants.VALID_OBJECTIVES or a callable function that can compute a metric given y_hat and y_true :param objective_kwargs: a dictionary for kwargs to pass when intializing an inbuilt objective function :param save_best_model: if True, returns the best model as computed by validation accuracy (if computed), else, training accuracy (if validation dataset is not desired). if False, the model returned by the optimizer will just be the model at the final epoch of training :param train_val_split: (float) if > 0, then splits the training dataset and uses it as validation. If 0 the training dataset is not split and validation is not computed :param val_data_transform: (function: any -> any) how to transform the validation data (e.g. an image) to fit into the desired model and objective function; optional NOTE: Currently - this argument is only used if data_type='image' :param val_label_transform: (function: int->int) how to transform the label to the validation data; optional NOTE: Currently - this argument is only used if data_type='image' :param val_dataloader_kwargs: (dict) Keyword arguments to pass to the torch DataLoader object during for validation data. See https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html for more documentation. If None, defaults will be used. Defaults depend on the optimizer used, but are likely something like: {batch_size: <batch size given in training config>, shuffle: False, pin_memory=<decided by optimizer>, drop_last=True} NOTE: Setting values in this dictionary that are normally set by the optimizer will override them during training. Use with caution. We recommend only using the following keys: 'shuffle', 'num_workers', 'pin_memory', and 'drop_last'. :param early_stopping: configuration for early stopping :param soft_to_hard_fn: a callable which will be computed on every batch of predictions to compute hard-decison predictions from the model output. Defaults to: torch.max(<args>, dim=1)[1] --> this is equivalent to np.argmax on each row of predictions :param soft_to_hard_fn_kwargs: a dictionary of kwargs to pass to the soft_to_hard_fn when calling it :param lr_scheduler: any of the Learning Rate Schedulers provided in PyTorch see: https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate :param lr_scheduler_init_kwargs: a dictionary of kwargs to pass when instantiating the desired learning rate scheduler :param lr_scheduler_call_arg: any arguments that should be called when stepping the learning rate scheduler. This can be one of the following choices: None, 'val_acc', 'val_loss' :param clip_grad: flag indicating whether to enable gradient clipping :param clip_type: can be either "norm" or "val", indicating whether the norm of all gradients should be clipped, or the raw gradient values :param clip_val: the value to clip at :param clip_kwargs: any kwargs to pass to the clipper. See: https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html :param adv_training_eps: The epsilon value constraining the adversarial perturbation. :param adv_training_iterations: The number of iterations PGD will take for adversarial training. :param adv_training_ratio: The percent of batches which will be adversarially attacked [0, 1]. TODO: [ ] - allow user to configure what the "best" model is """ self.device = device self.epochs = epochs self.batch_size = batch_size self.lr = lr self.optim = optim self.optim_kwargs = optim_kwargs self.objective = objective self.objective_kwargs = objective_kwargs self.save_best_model = save_best_model self.train_val_split = train_val_split self.early_stopping = early_stopping self.val_data_transform = val_data_transform self.val_label_transform = val_label_transform self.val_dataloader_kwargs = val_dataloader_kwargs self.soft_to_hard_fn = soft_to_hard_fn self.soft_to_hard_fn_kwargs = soft_to_hard_fn_kwargs self.lr_scheduler = lr_scheduler self.lr_scheduler_init_kwargs = lr_scheduler_init_kwargs self.lr_scheduler_call_arg = lr_scheduler_call_arg self.clip_grad = clip_grad self.clip_type = clip_type self.clip_val = clip_val self.clip_kwargs = clip_kwargs self.adv_training_eps = adv_training_eps self.adv_training_iterations = adv_training_iterations self.adv_training_ratio = adv_training_ratio if self.adv_training_eps is None: self.adv_training_eps = float(0.0) if self.adv_training_ratio is None: self.adv_training_ratio = float(0.0) if self.adv_training_iterations is None: self.adv_training_iterations = int(0) if self.optim_kwargs is None: self.optim_kwargs = {} if self.lr_scheduler_init_kwargs is None: self.lr_scheduler_init_kwargs = {} if self.clip_kwargs is None: self.clip_kwargs = {} self.validate() # convert to a torch.device object if isinstance(self.device, str): self.device = torch.device(self.device)
[docs] def validate(self) -> None: """ Validate the object configuration :return: None """ if not isinstance(self.device, torch.device) and self.device not in VALID_DEVICES: msg = "device must be either a torch.device object, or one of the following:" + str(VALID_DEVICES) logger.error(msg) raise ValueError(msg) if not isinstance(self.epochs, int) or self.epochs < 1: msg = "epochs must be an integer > 0" logger.error(msg) raise ValueError(msg) if not isinstance(self.batch_size, int) or self.batch_size < 1: msg = "batch_size must be an integer > 0" logger.error(msg) raise ValueError(msg) if not isinstance(self.lr, float): msg = "lr must be a float!" logger.error(msg) raise ValueError(msg) if not isinstance(self.optim, OptimizerInterface) and self.optim not in VALID_OPTIMIZERS: msg = "optim must be either a OptimizerInterface object, or one of the following:" + str(VALID_OPTIMIZERS) logger.error(msg) raise ValueError(msg) if not isinstance(self.optim_kwargs, dict): msg = "optim_kwargs must be a dictionary!" logger.error(msg) raise ValueError(msg) if not callable(self.objective) and self.objective not in VALID_LOSS_FUNCTIONS: msg = "objective must be a callable, or one of the following:" + str(VALID_LOSS_FUNCTIONS) logger.error(msg) raise ValueError(msg) if not self.objective_kwargs: self.objective_kwargs = dict() elif not isinstance(self.objective_kwargs, dict): msg = "objective_kwargs must be a dictionary" logger.error(msg) raise ValueError(msg) if not isinstance(self.save_best_model, bool): msg = "save_best_model must be a boolean!" logger.error(msg) raise ValueError(msg) if not isinstance(self.train_val_split, float): msg = "train_val_split must a float between 0 and 1!" logger.error(msg) raise ValueError(msg) else: if self.train_val_split < 0 or self.train_val_split > 1: msg = "train_val_split must be between 0 and 1, inclusive" logger.error(msg) raise ValueError(msg) if self.early_stopping is not None and not isinstance(self.early_stopping, EarlyStoppingConfig): msg = "early_stopping must be of type EarlyStoppingConfig or None" logger.error(msg) raise ValueError(msg) if self.adv_training_eps < 0 or self.adv_training_eps > 1: msg = "Adversarial training eps: {} must be between 0 and 1.".format(self.adv_training_eps) logger.error(msg) raise ValueError(msg) if self.adv_training_ratio < 0 or self.adv_training_ratio > 1: msg = "Adversarial training ratio (percent of images with perturbation applied): {} must be between 0 and 1.".format(self.adv_training_ratio) logger.error(msg) raise ValueError(msg) if self.adv_training_iterations < 0: msg = "Adversarial training iteration count: {} must be greater than or equal to 0.".format(self.adv_training_iterations) logger.error(msg) raise ValueError(msg) if self.val_data_transform is not None and not callable(self.val_data_transform): raise TypeError("Expected a function for argument 'val_data_transform', " "instead got type: {}".format(type(self.val_data_transform))) if self.val_label_transform is not None and not callable(self.val_label_transform): raise TypeError("Expected a function for argument 'val_label_transform', " "instead got type: {}".format(type(self.val_label_transform))) if self.val_dataloader_kwargs is not None and not isinstance(self.val_dataloader_kwargs, dict): msg = "val_dataloader_kwargs must be a dictionary or None!" logger.error(msg) raise ValueError(msg) if self.soft_to_hard_fn is None: self.soft_to_hard_fn = DefaultSoftToHardFn() elif not callable(self.soft_to_hard_fn): msg = "soft_to_hard_fn must be a callable which accepts as input the output of the model, and outputs " \ "hard-decisions" logger.error(msg) raise ValueError(msg) if self.soft_to_hard_fn_kwargs is None: self.soft_to_hard_fn_kwargs = copy.deepcopy(default_soft_to_hard_fn_kwargs) elif not isinstance(self.soft_to_hard_fn_kwargs, dict): msg = "soft_to_hard_fn_kwargs must be a dictionary of kwargs to pass to soft_to_hard_fn" logger.error(msg) raise ValueError(msg) # we do not validate the lr_scheduler or lr_scheduler_kwargs b/c those will # be validated upon instantiation if self.lr_scheduler_call_arg is not None and self.lr_scheduler_call_arg != 'val_acc' and self.lr_scheduler_call_arg != 'val_loss': msg = "lr_scheduler_call_arg must be one of: None, val_acc, val_loss" logger.error(msg) raise ValueError(msg) if not isinstance(self.clip_grad, bool): msg = "clip_grad must be a bool!" logger.error(msg) raise ValueError(msg) if not isinstance(self.clip_type, str) or (self.clip_type != 'norm' and self.clip_type != 'val'): msg = "clip type must be a string, either norm or val" logger.error(msg) raise ValueError(msg) if not isinstance(self.clip_val, float): msg = "clip_val must be a float" logger.error(msg) raise ValueError(msg) if not isinstance(self.clip_kwargs, dict): msg = "clip_kwargs must be a dict" logger.error(msg) raise ValueError(msg)
[docs] def get_cfg_as_dict(self): """ Returns a dictionary representation of the configuration :return: (dict) a dictionary """ output_dict = dict(device=str(self.device.type), epochs=self.epochs, batch_size=self.batch_size, learning_rate=self.lr, optim=self.optim, objective=self.objective, objective_kwargs=self.objective_kwargs, save_best_model=self.save_best_model, early_stopping=str(self.early_stopping), val_data_transform=self.val_data_transform, val_label_transform=self.val_label_transform, val_dataloader_kwargs=self.val_dataloader_kwargs, soft_to_hard_fn=self.soft_to_hard_fn, soft_to_hard_fn_kwargs=self.soft_to_hard_fn_kwargs, lr_scheduler=self.lr_scheduler, lr_scheduler_init_kwargs=self.lr_scheduler_init_kwargs, lr_scheduler_call_arg=self.lr_scheduler_call_arg, clip_grad=self.clip_grad, clip_type=self.clip_type, clip_val=self.clip_val, clip_kwargs=self.clip_kwargs, adv_training_eps = self.adv_training_eps, adv_training_iterations = self.adv_training_iterations, adv_training_ratio = self.adv_training_ratio) return output_dict
def __str__(self): str_repr = "TrainingConfig: device[%s], num_epochs[%d], batch_size[%d], learning_rate[%.5e], adv_training_eps[%s], adv_training_iterations[%s], adv_training_ratio[%s], optimizer[%s], " \ "objective[%s], objective_kwargs[%s], train_val_split[%0.02f], val_data_transform[%s], " \ "val_label_transform[%s], val_dataloader_kwargs[%s], early_stopping[%s], " \ "soft_to_hard_fn[%s], soft_to_hard_fn_kwargs[%s], " \ "lr_scheduler[%s], lr_scheduler_init_kwargs[%s], lr_scheduler_call_arg[%s], " \ "clip_grad[%s] clip_type[%s] clip_val[%s] clip_kwargs[%s]" % \ (str(self.device.type), self.epochs, self.batch_size, self.lr, self.adv_training_eps, self.adv_training_iterations, self.adv_training_ratio, str(self.optim), str(self.objective), str( self.objective_kwargs), self.train_val_split, str(self.val_data_transform), str(self.val_label_transform), str( self.val_dataloader_kwargs), str(self.early_stopping), str(self.soft_to_hard_fn), str( self.soft_to_hard_fn_kwargs), str(self.lr_scheduler), str(self.lr_scheduler_init_kwargs), str( self.lr_scheduler_call_arg), str(self.clip_grad), str(self.clip_type), str(self.clip_val), str(self.clip_kwargs)) return str_repr def __deepcopy__(self, memodict={}): # copy will keep a string version fo device, so that when new_device = self.device.type # it gets instantiated, it will generate a device object # on the node epochs = self.epochs batch_size = self.batch_size lr = self.lr save_best_model = self.save_best_model train_val_split = self.train_val_split early_stopping = copy.deepcopy(self.early_stopping) val_data_transform = copy.deepcopy(self.val_data_transform) val_label_transform = copy.deepcopy(self.val_label_transform) val_dataloader_kwargs = copy.deepcopy(self.val_dataloader_kwargs) if isinstance(self.optim, str): optim = self.optim elif isinstance(self.optim, OptimizerInterface): optim = copy.deepcopy(self.optim) else: msg = "The TrainingConfig object you are trying to copy is corrupted!" logger.error(msg) raise ValueError(msg) optim_kwargs = self.optim_kwargs if isinstance(self.objective, str): objective = self.objective elif callable(self.objective): objective = copy.deepcopy(self.objective) else: msg = "The TrainingConfig object you are trying to copy is corrupted!" logger.error(msg) raise ValueError(msg) objective_kwargs = self.objective_kwargs # empirical tests on deepcopy do not seem to soft_to_hard_fn = copy.deepcopy(self.soft_to_hard_fn) # create new memory references for lambda functions. # I am not sure if this behavior is different with # a properly defined function. soft_to_hard_fn_kwargs = copy.deepcopy(self.soft_to_hard_fn_kwargs) lr_scheduler = self.lr_scheduler # should be a callable, so this is OK lr_scheduler_kwargs = copy.deepcopy(self.lr_scheduler_init_kwargs) # a string, no deep-copy required lr_scheduler_call_arg = self.lr_scheduler_call_arg clip_grad = self.clip_grad clip_type = self.clip_type clip_val = self.clip_val clip_kwargs = copy.deepcopy(self.clip_kwargs) adv_training_eps = self.adv_training_eps adv_training_iterations = self.adv_training_iterations adv_training_ratio = self.adv_training_ratio return TrainingConfig(new_device, epochs, batch_size, lr, optim, optim_kwargs, objective, objective_kwargs, save_best_model, train_val_split, val_data_transform, val_label_transform, val_dataloader_kwargs, early_stopping, soft_to_hard_fn, soft_to_hard_fn_kwargs, lr_scheduler, lr_scheduler_kwargs, lr_scheduler_call_arg, clip_grad, clip_type, clip_val, clip_kwargs, adv_training_eps, adv_training_iterations, adv_training_ratio) def __eq__(self, other): # NOTE: we don't check whether the # 1. soft_to_hard_fn # 2. lr_scheduler # equality b/c there doesn't seem to be a general way to accomplish this. This needs # to be addressed as needed later on. if self.device.type == other.device.type and self.epochs == other.epochs and \ self.batch_size == other.batch_size and self.lr == other.lr and \ self.save_best_model == other.save_best_model and \ self.train_val_split == other.train_val_split and \ self.early_stopping == other.early_stopping and \ self.val_data_transform == other.val_data_transform and \ self.val_label_transform == other.val_label_transform and \ self.val_dataloader_kwargs == other.val_dataloader_kwargs and \ self.soft_to_hard_fn_kwargs == other.soft_to_hard_fn_kwargs and \ self.lr_scheduler_init_kwargs == other.lr_scheduler_init_kwargs and \ self.lr_scheduler_call_arg == other.lr_scheduler_call_arg and \ self.clip_grad == other.clip_grad and self.clip_type == other.clip_type and \ self.adv_training_eps == other.adv_training_eps and \ self.adv_training_iterations == other.adv_training_iterations and \ self.adv_training_ratio == other.adv_training_ratio and \ self.clip_val == other.clip_val and self.clip_kwargs == other.clip_kwargs: # now check the objects if self.optim == other.optim and self.objective == other.objective: return True else: return False else: return False
[docs]class ReportingConfig(ConfigInterface): """ Defines all options to setup how data is reported back to the user while models are being trained """ def __init__(self, num_batches_per_logmsg: int = 100, disable_progress_bar: bool = False, num_epochs_per_metric: int = 1, num_batches_per_metrics: int = 50, tensorboard_output_dir: str = None, experiment_name: str = 'experiment'): """ Initializes a ReportingConfig object. :param num_batches_per_logmsg: The # of batches which are computed before a log message is written. :param disable_progress_bar: Whether to disable the tdqm progress bar. :param num_epochs_per_metric: The number of epochs before metrics are computed. :param num_batches_per_metrics: The number of batches before metrics are computed. :param tensorboard_output_dir: the directory to which tensorboard data should be written. :param experiment_name: A string identifier to associate with the configuration. """ self.num_batches_per_logmsg = num_batches_per_logmsg self.disable_progress_bar = disable_progress_bar self.num_epochs_per_metrics = num_epochs_per_metric self.num_batches_per_metrics = num_batches_per_metrics self.tensorboard_output_dir = tensorboard_output_dir self.experiment_name = experiment_name self.validate()
[docs] def validate(self): if not isinstance(self.num_batches_per_logmsg, int) or self.num_batches_per_logmsg < 0: msg = "num_batches_per_logmsg must be an integer > 0" logger.error(msg) raise ValueError(msg) if not isinstance(self.num_epochs_per_metrics, int) or self.num_epochs_per_metrics < 0: msg = "num_epochs_per_metrics must be an integer > 0" logger.error(msg) raise ValueError(msg) if self.num_batches_per_metrics is not None and (not isinstance(self.num_batches_per_metrics, int) or self.num_batches_per_metrics < 0): msg = "num_batches_per_metrics must be an integer > 0 or None!" logger.error(msg) raise ValueError(msg)
def __str__(self): str_repr = "ReportingConfig: num_batches/log_msg[%d], num_epochs/metric[%d], num_batches/metric[%d], " \ "tensorboard_dir[%s] experiment_name=[%s], disable_progress_bar=[%s]" % \ (self.num_batches_per_logmsg, self.num_epochs_per_metrics, self.num_batches_per_metrics, self.tensorboard_output_dir, self.experiment_name, self.disable_progress_bar) return str_repr def __copy__(self): return ReportingConfig(self.num_batches_per_logmsg, self.disable_progress_bar, self.num_epochs_per_metrics, self.num_batches_per_metrics, self.tensorboard_output_dir, self.experiment_name) def __deepcopy__(self, memodict={}): return self.__copy__() def __eq__(self, other): if self.num_batches_per_logmsg == other.num_batches_per_logmsg and \ self.disable_progress_bar == other.disable_progress_bar and \ self.num_epochs_per_metrics == other.num_epochs_per_metrics and \ self.num_batches_per_metrics == other.num_batches_per_metrics and \ self.tensorboard_output_dir == other.tensorboard_output_dir and \ self.experiment_name == other.experiment_name: return True else: return False
[docs]class TorchTextOptimizerConfig(OptimizerConfigInterface): """ Defines the configuration needed to setup the TorchTextOptimizer """ def __init__(self, training_cfg: TrainingConfig = None, reporting_cfg: ReportingConfig = None, copy_pretrained_embeddings: bool = False): """ Initializes a TorchTextOptimizer :param training_cfg: a TrainingConfig object, if None, a default TrainingConfig object will be constructed :param reporting_cfg: a ReportingConfig object, if None, a default ReportingConfig object will be constructed :param copy_pretrained_embeddings: if True, will copy over pretrained embeddings into network from the built vocabulary """ self.training_cfg = training_cfg self.reporting_cfg = reporting_cfg self.copy_pretrained_embeddings = copy_pretrained_embeddings self.validate()
[docs] def validate(self): if self.training_cfg is None: logger.debug( "Using default training configuration to setup Optimizer!") self.training_cfg = TrainingConfig() elif not isinstance(self.training_cfg, TrainingConfig): msg = "training_cfg must be of type TrainingConfig" logger.error(msg) raise TypeError(msg) if self.reporting_cfg is None: logger.debug( "Using default reporting configuration to setup Optimizer!") self.reporting_cfg = ReportingConfig() elif not isinstance(self.reporting_cfg, ReportingConfig): msg = "reporting_cfg must be of type ReportingConfig" logger.error(msg) raise TypeError(msg) if not isinstance(self.copy_pretrained_embeddings, bool): msg = "copy_pretrained_embeddings must be a boolean datatype!" logger.error(msg) raise TypeError(msg)
def __deepcopy__(self, memodict={}): training_cfg_copy = copy.deepcopy(self.training_cfg) reporting_cfg_copy = copy.deepcopy(self.reporting_cfg) return TorchTextOptimizerConfig(training_cfg_copy, reporting_cfg_copy, self.copy_pretrained_embeddings) def __eq__(self, other): if self.training_cfg == other.training_cfg and self.reporting_cfg == other.reporting_cfg and \ self.copy_pretrained_embeddings == other.copy_pretrained_embeddings: return True else: return False
[docs] def save(self, fname): """ Saves the optimizer configuration to a file :param fname: the filename to save the config to :return: None """ with open(fname, 'wb') as f: pickle.dump(self, f)
[docs] @staticmethod def load(fname): """ Loads a configuration from disk :param fname: the filename where the config is stored :return: the loaded configuration """ with open(fname, 'rb') as f: loaded_optimzier_cfg = pickle.load(f) return loaded_optimzier_cfg
[docs] def get_device_type(self): """ Returns the device associated w/ this optimizer configuration. Needed to save/load for UGE. :return (str): the device type represented as a string """ return str(self.training_cfg.device)
[docs]class DefaultOptimizerConfig(OptimizerConfigInterface): """ Defines the configuration needed to setup the DefaultOptimizer """ def __init__(self, training_cfg: TrainingConfig = None, reporting_cfg: ReportingConfig = None): """ Initializes a Default Optimizer :param training_cfg: a TrainingConfig object, if None, a default TrainingConfig object will be constructed :param reporting_cfg: a ReportingConfig object, if None, a default ReportingConfig object will be constructed """ if training_cfg is None: logger.debug( "Using default training configuration to setup Optimizer!") self.training_cfg = TrainingConfig() elif not isinstance(training_cfg, TrainingConfig): msg = "training_cfg must be of type TrainingConfig" logger.error(msg) raise TypeError(msg) else: self.training_cfg = training_cfg if reporting_cfg is None: logger.debug( "Using default reporting configuration to setup Optimizer!") self.reporting_cfg = ReportingConfig() elif not isinstance(reporting_cfg, ReportingConfig): msg = "reporting_cfg must be of type ReportingConfig" logger.error(msg) raise TypeError(msg) else: self.reporting_cfg = reporting_cfg def __deepcopy__(self, memodict={}): training_cfg_copy = copy.deepcopy(self.training_cfg) reporting_cfg_copy = copy.deepcopy(self.reporting_cfg) return DefaultOptimizerConfig(training_cfg_copy, reporting_cfg_copy) def __eq__(self, other): if self.training_cfg == other.training_cfg and self.reporting_cfg == other.reporting_cfg: return True else: return False
[docs] def get_device_type(self): """ Returns the device associated w/ this optimizer configuration. Needed to save/load for UGE. :return (str): the device type represented as a string """ return str(self.training_cfg.device)
[docs] def save(self, fname): """ Saves the optimizer configuration to a file :param fname: the filename to save the config to :return: None """ with open(fname, 'wb') as f: pickle.dump(self, f)
[docs] @staticmethod def load(fname): """ Loads a configuration from disk :param fname: the filename where the config is stored :return: the loaded configuration """ with open(fname, 'rb') as f: loaded_optimzier_cfg = pickle.load(f) return loaded_optimzier_cfg
[docs]class ModelGeneratorConfig(ConfigInterface): """Object used to configure the model generator""" def __init__(self, arch_factory: ArchitectureFactory, data: DataManager, model_save_dir: str, stats_save_dir: str, num_models: int, arch_factory_kwargs: dict = None, arch_factory_kwargs_generator: Callable = None, optimizer: Union[Union[OptimizerInterface, DefaultOptimizerConfig], Sequence[Union[OptimizerInterface, DefaultOptimizerConfig]]] = None, parallel=False, amp=False, experiment_cfg: dict = None, run_ids: Union[Any, Sequence[Any]] = None, filenames: Union[str, Sequence[str]] = None, save_with_hash: bool = False): """ Initializes the ModelGeneratorConfig object which provides needed information for generating models for a given experiment. :param arch_factory: ArchitectureFactory object that provides instantiated architectures (untrained models) to be trained on the data. :param data: TrojaiDataManager object containing the experiment path and files. :param model_save_dir: path to directory where the models should be saved :param stats_save_dir: path to directory where the model training stats should be saved :param num_models: number of models to train with this configuration :param arch_factory_kwargs: (dict) a dictionary which contains keywords and associated values that are needed to instantiate a trainable module from the factory :param arch_factory_kwargs_generator: (callable) a callable, or None, which takes a dictionary of all variables defined in the Runner's namespace, and then creates a new dictionary that contains the keyword arguments to instantiate an architecture from the architecture factory :param optimizer: a OptimizerInterface object, or a DefaultOptimizer configuration, or possibly mixed sequence of both. If a sequence of optimizers is passed, then the length of that sequence must match the number of sequential datasets that are to be used for training the model. :param parallel: (bool) - if True, attempts to use multiple GPU's :param amp: (bool) - if True, attempts to use automatic mixed precision on GPU's :param experiment_cfg: dictionary containing information regarding the experiment which is being run by the ModelGenerator. This information is also saved in the output summary JSON file that is associated with every model that is generated. :param run_ids: Identifiers for models. If a sequence, len(run_ids) must be equal to num_models :param filenames: An optional list of file names to save each model by each file name, or a single filename to have models be saved with the same file name with '_#' added to the end, e.g. 'filename.pt', 'filename_1.pt', 'filename_2.pt', ... If this argument is not provided, then models generated will be saved with filenames indicated by the experiment name in the experiment_cfg dictionary :param save_with_hash: (bool) if True, appends a hash to the end of a filename to prevent any conflicts from occurring w.r.t. filenames. This can be useful if you are using a cluster environment and the filesystem across nodes takes time to replicate """ self.arch_factory = arch_factory self.arch_factory_kwargs = arch_factory_kwargs self.arch_factory_kwargs_generator = arch_factory_kwargs_generator self.data = data self.model_save_dir = model_save_dir self.stats_save_dir = stats_save_dir self.num_models = num_models self.optimizer = optimizer self.parallel = parallel self.amp = amp self.experiment_cfg = dict() if experiment_cfg is None else experiment_cfg # it might be useful to allow something like a generator for this argument self.run_ids = run_ids # it might be useful to allow something like a generator for this argument self.filenames = filenames self.save_with_hash = save_with_hash self.validate() def __deepcopy__(self, memodict={}): arch_factory_copy = copy.deepcopy( self.arch_factory) # I think this is OK b/c the ArchFactory is a class definition # the default should work properly here b/c all properties are primitives data_copy = copy.deepcopy(self.data) optimizer_copy = copy.deepcopy(self.optimizer) return ModelGeneratorConfig(arch_factory_copy, data_copy, self.model_save_dir, self.stats_save_dir, self.num_models, self.arch_factory_kwargs, self.arch_factory_kwargs_generator, optimizer_copy, self.parallel, self.amp, self.experiment_cfg, self.run_ids, self.filenames, self.save_with_hash) def __eq__(self, other): if self.arch_factory == other.arch_factory and self.data == other.data and self.optimizer == other.optimizer \ and self.parallel == other.parallel \ and self.amp == other.amp \ and self.model_save_dir == other.model_save_dir and self.stats_save_dir == other.stats_save_dir \ and self.arch_factory_kwargs == other.arch_factory_kwargs \ and self.arch_factory_kwargs_generator == other.arch_factory_kwargs_generator \ and self.experiment_cfg == other.experiment_cfg and self.run_ids == other.run_ids \ and self.filenames == other.filenames and self.save_with_hash == other.save_with_hash: return True else: return False
[docs] def validate(self) -> None: """ Validate the input arguments to construct the object :return: None """ if not (isinstance(self.arch_factory, ArchitectureFactory)): msg = "Expected an ArchitectureFactory object for argument 'architecture_factory', " \ "instead got type: {}".format(type(self.arch_factory)) logger.error(msg) raise TypeError(msg) if self.arch_factory_kwargs is not None and not isinstance(self.arch_factory_kwargs, dict): msg = "Expected dictionary for arch_factory_kwargs" logger.error(msg) raise TypeError(msg) if self.arch_factory_kwargs_generator is not None and not callable(self.arch_factory_kwargs_generator): msg = "arch_factory_kwargs_generator must be a Callable!" logger.error(msg) raise TypeError(msg) if not (isinstance(self.data, DataManager)): msg = "Expected an TrojaiDataManager object for argument 'data', " \ "instead got type: {}".format(type(self.data)) logger.error(msg) raise TypeError(msg) if not type(self.model_save_dir) == str: msg = "Expected type 'string' for argument 'model_save_dir, instead got type: " \ "{}".format(type(self.model_save_dir)) logger.error(msg) raise TypeError(msg) if not os.path.isdir(self.model_save_dir): try: os.makedirs(self.model_save_dir) except IOError as e: msg = "'model_save_dir' was not found and could not be created" \ "...\n{}".format(e.__traceback__) logger.error(msg) raise IOError(msg) if not type(self.num_models) == int: msg = "Expected type 'int' for argument 'num_models, instead got type: " \ "{}".format(type(self.num_models)) logger.error(msg) raise TypeError(msg) if self.filenames is not None: if isinstance(self.filenames, Sequence): for filename in self.filenames: if not type(filename) == str: msg = "Encountered non-string in argument 'filenames': {}".format( filename) logger.error(msg) raise TypeError(msg) else: if not isinstance(self.filenames, str): msg = "Filename provided as prefix must be of type string!" logger.error(msg) raise TypeError(msg) if self.run_ids is not None and len(self.run_ids) != self.num_models: msg = "Argument 'run_ids' was provided, but len(run_ids) != num_models" logger.error(msg) raise RuntimeError(msg) if self.filenames is not None and len(self.filenames) != self.num_models: msg = "Argument 'filenames' was provided, but len(filenames) != num_models" logger.error(msg) raise RuntimeError(msg) if self.run_ids is not None and self.filenames is not None: msg = "Argument 'filenames' was provided with argument 'run_ids', 'run_ids' will be ignored..." logger.warning(msg) if not isinstance(self.save_with_hash, bool): msg = "Expected boolean for save_with_hash argument" logger.error(msg) raise ValueError(msg) RunnerConfig.validate_optimizer(self.optimizer, self.data) if not isinstance(self.parallel, bool): msg = "parallel argument must be a boolean!" logger.error(msg) raise ValueError(msg)
def __getstate__(self): """ Function which dictates which objects will be saved when pickling the ModelGeneratorConfig object. This is only useful for the UGEModelGenerator, which needs to save the data before parallelizing a job. :return: a dictionary of the state of the ModelGeneratorConfig object. """ return {'arch_factory': self.arch_factory, 'data': self.data, 'model_save_dir': self.model_save_dir, 'stats_save_dir': self.stats_save_dir, 'num_models': self.num_models, 'arch_factory_kwargs': self.arch_factory_kwargs, 'arch_factory_kwargs_generator': self.arch_factory_kwargs_generator, 'parallel': self.parallel, 'amp': self.amp, 'experiment_cfg': self.experiment_cfg, 'run_ids': self.run_ids, 'filenames': self.filenames, 'save_with_hash': self.save_with_hash }
[docs] def save(self, fname: str): """ Saves the ModelGeneratorConfig object in two different parts. Every object within the config, except for the optimizer is saved in the .klass.save file, and the optimizer is saved separately. :param fname - the filename to save the configuration to :return: None """ # we save optimizer and the remainder of the components separately optimizer_klass_save_fname = fname + '.optimizer.klass.save' optimizer_save_fname = fname + '.optimizer.save' remainder_data_save_fname = fname + '.arch_data.save' with open(remainder_data_save_fname, 'wb') as f: pickle.dump(self, f) # save the optimizer class name, so we can load it properly optimizer_klass_name = '.'.join( [self.optimizer.__module__, self.optimizer.__class__.__name__]) with open(optimizer_klass_save_fname, 'w') as f: f.write(optimizer_klass_name) self.optimizer.save(optimizer_save_fname)
[docs] @staticmethod def load(fname: str): """ Loads a saved modelgen_cfg object from data that was saved using the .save() function. :param fname: the filename where the modelgen_cfg object is saved :return: a ModelGeneratorConfig object """ optimizer_klass_save_fname = fname + '.optimizer.klass.save' optimizer_save_fname = fname + '.optimizer.save' remainder_data_save_fname = fname + '.arch_data.save' with open(remainder_data_save_fname, 'rb') as f: modelgen_cfg = pickle.load(f) # load the class name of the optimizer that was used with open(optimizer_klass_save_fname, 'r') as f: optimizer_module_and_klass_name = f.readline() # load the module ss = optimizer_module_and_klass_name.split('.') optimizer_module_name = '.'.join(ss[0:-1]) optimizer_klass_name = ss[-1] optimizer_module = importlib.import_module(optimizer_module_name) optimizer_klass_def = getattr(optimizer_module, optimizer_klass_name) optimizer_load = optimizer_klass_def.load(optimizer_save_fname) # reconstruct the ModelGeneratorConfig object modelgen_cfg.optimizer = optimizer_load modelgen_cfg.validate() return modelgen_cfg
[docs]class RunnerConfig(ConfigInterface): """ Container for all parameters needed to use the Runner to train a model. """ def __init__(self, arch_factory: ArchitectureFactory, data: DataManager, arch_factory_kwargs: dict = None, arch_factory_kwargs_generator: Callable = None, optimizer: Union[OptimizerInterface, DefaultOptimizerConfig, Sequence[Union[OptimizerInterface, DefaultOptimizerConfig]]] = None, parallel: bool = False, amp: bool = False, model_save_dir: str = "/tmp/models", stats_save_dir: str = "/tmp/model_stats", model_save_format: str = "pt", run_id: Any = None, filename: str = None, save_with_hash: bool = False): """ Initialize a RunnerConfig object :param arch_factory: (Architecture Factory) a trainable Pytorch module generator. :param data: (TrojaiDataManager) a TrojaiDataManager object containing the paths to the data being trained and tested on, as well as functions dictating how the data should be transformed for training and testing. :param arch_factory_kwargs: (dict) a dictionary which contains keywords and associated values that are needed to instantiate a trainable module from the factory :param arch_factory_kwargs_generator: (callable) a callable, or None, which takes a dictionary of all variables defined in the Runner's namespace, and then creates a new dictionary taht contains the keyword arguments to instantiate an architecture from the architecture factory :param optimizer: a OptimizerInterface object, or a DefaultOptimizer configuration, or possibly mixed sequence of both :param parallel: (bool) if True, spreads GPU tasking over all available GPUs :param amp: (bool) if True, uses automatic mixed precision training :param model_save_dir: (str) path to where the models should be saved. :param stats_save_dir: (str) path to where the model training statistics should be saved. :param run_id: An ending to the save file name. Can be anything, but will be converted to string format. Ignored if a filename is provided. :param filename: (str) File name for the saved model. If not specified, default to the name of the architecture provided. Should end in .pt for consistency. :param save_with_hash: (bool) if True, appends a hash to the end of a filename to prevent any conflicts from occurring w.r.t. filenames. This can be useful if you are using a cluster environment and the filesystem across nodes takes time to replicate """ self.arch_factory = arch_factory self.data = data self.arch_factory_kwargs = arch_factory_kwargs self.arch_factory_kwargs_generator = arch_factory_kwargs_generator self.optimizer = optimizer self.parallel = parallel self.amp = amp self.model_save_dir = model_save_dir self.stats_save_dir = stats_save_dir self.model_save_format = model_save_format self.run_id = run_id self.filename = filename self.save_with_hash = save_with_hash self.validate() # create new attribute instead of overwriting self.optimizer so that self.__deepcopy__ still works. self.optimizer_generator = self.setup_optimizer_generator( self.optimizer, self.data) def __deepcopy__(self, memodict={}): arch_copy = copy.deepcopy(self.arch_factory) data_copy = copy.deepcopy(self.data) optim_copy = copy.deepcopy(self.optimizer) return RunnerConfig(arch_copy, data_copy, self.arch_factory_kwargs, self.arch_factory_kwargs_generator, optim_copy, self.parallel, self.amp, self.model_save_dir, self.stats_save_dir, self.model_save_format, self.run_id, self.filename, self.save_with_hash)
[docs] @staticmethod def setup_optimizer_generator(optimizer, data): """ Converts an optimizer specification to a generator, to be compatible with sequential training. :param optimizer: the optimizer to configure into a generator :param num_datasets: the number of datasets for which optimizers need to be created :return: A generator that returns optimizers for every dataset to be trained """ from .default_optimizer import DefaultOptimizer if optimizer is None or isinstance(optimizer, DefaultOptimizerConfig): if data.train_file is not None and len(data.train_file) > 0: return (DefaultOptimizer(optimizer) for _ in range(len(data.train_file))) else: return (DefaultOptimizer(optimizer) for _ in range(1)) elif isinstance(optimizer, OptimizerInterface): if data.train_file is not None and len(data.train_file) > 0: return (optimizer.__deepcopy__({}) for _ in range(len(data.train_file))) else: return (optimizer for _ in range(1)) else: msg = "Multiple optimizers specified, only final will be used for test calculations" logger.warning(msg) return (opt if isinstance(opt, OptimizerInterface) else DefaultOptimizer(opt) for opt in optimizer)
[docs] @staticmethod def validate_optimizer(optimizer, data): """ Validates an optimzer configuration :param optimizer: the optimizer/optimizer configuration to be validated :param data: the data to be optimized :return: """ if not (optimizer is None or isinstance(optimizer, OptimizerInterface) or isinstance(optimizer, DefaultOptimizerConfig)): if not (hasattr(type(optimizer), '__iter__') and hasattr(type(optimizer), '__len__') and type(optimizer) != str): msg = "Expected OptimizerInterface, DefaultOptimizerConfig, or sequence of them for argument" \ "'optimizer', instead got {}".format(optimizer) logger.error(msg) raise TypeError(msg) else: for item in optimizer: if not (isinstance(item, OptimizerInterface) or isinstance(item, DefaultOptimizerConfig)): msg = "Expected OptimizerInterface or DefaultOptimizerConfig objects in sequence for argument" \ "'optimizer', discovered {} in sequence".format( item) logger.error(msg) raise TypeError(msg) if len(optimizer) != len(data.train_file): msg = "If specifying multiple optimizers, the number of optimizers given must be the same as the " \ "number of training files in the DataManager." logger.error(msg) raise TypeError(msg)
[docs] def validate(self) -> None: """ Validate the RunnerConfig object :return: None """ if not isinstance(self.arch_factory, ArchitectureFactory): msg = "Expected ArchitectureFactory for argument 'architecture', instead got type: {}".format(type( self.arch_factory)) logger.error(msg) raise TypeError(msg) if self.arch_factory_kwargs is not None and not isinstance(self.arch_factory_kwargs, dict): msg = "arch_factory_kwargs must be a dictionary!" logger.error(msg) raise TypeError(msg) if self.arch_factory_kwargs_generator is not None and not callable(self.arch_factory_kwargs_generator): msg = "Expected a function for argument 'arch_factory_kwargs_generator', " \ "instead got type: {}".format( type(self.arch_factory_kwargs_generator)) logger.error(msg) raise TypeError(msg) if not isinstance(self.data, DataManager): msg = "Expected a TrojaiDataManager object for argument 'data', " \ "instead got type: {}".format(type(self.data)) logger.error(msg) raise TypeError(msg) self.validate_optimizer(self.optimizer, self.data) if not isinstance(self.parallel, bool): msg = "parallel argument must be a boolean!" logger.error(msg) raise ValueError(msg) if not type(self.model_save_dir) == str: msg = "Expected type 'string' for argument 'model_save_dir, instead got type: " \ "{}".format(type(self.model_save_dir)) logger.error(msg) raise TypeError(msg) if not os.path.isdir(self.model_save_dir): try: os.makedirs(self.model_save_dir) except OSError as e: # not sure this error is possible as written msg = "'model_save_dir' was not found and could not be created" \ "...\n{}".format(e.__traceback__) logger.error(msg) raise OSError(msg) if not os.path.isdir(self.stats_save_dir): try: os.makedirs(self.stats_save_dir) except OSError as e: # not sure this error is possible as written msg = "'stats_save_dir' was not found and could not be created" \ "...\n{}".format(e.__traceback__) logger.error(msg) raise OSError(msg) if self.filename is not None and not type(self.filename) == str: msg = "Expected a string for argument 'filename', instead got " \ "type {}".format(type(self.filename)) logger.error(msg) raise TypeError(msg) if not isinstance(self.save_with_hash, bool): msg = "Expected boolean for argument save_with_hash" logger.error(msg) raise TypeError(msg) if self.model_save_format != 'pt' and self.model_save_format != 'state_dict': msg = "model_save_format must be either: pt or state_dict" raise TypeError(msg)
[docs]def modelgen_cfg_to_runner_cfg(modelgen_cfg: ModelGeneratorConfig, run_id=None, filename=None) -> RunnerConfig: """ Convenience function which creates a RunnerConfig object, from a ModelGeneratorConfig object. :param modelgen_cfg: the ModelGeneratorConfig to convert :param run_id: run_id to be associated with the RunnerConfig :param filename: filename to be associated with the RunnerConfig :return: the created RunnerConfig object """ return RunnerConfig(modelgen_cfg.arch_factory, modelgen_cfg.data, modelgen_cfg.arch_factory_kwargs, modelgen_cfg.arch_factory_kwargs_generator, modelgen_cfg.optimizer, modelgen_cfg.parallel, modelgen_cfg.amp, modelgen_cfg.model_save_dir, modelgen_cfg.stats_save_dir, run_id=run_id, filename=filename, save_with_hash=modelgen_cfg.save_with_hash)
[docs]class UGEQueueConfig: """ Defines the configuration for a Queue w.r.t. UGE in TrojAI """ def __init__(self, queue_name: str, gpu_enabled: bool, sync_mode: bool = False): self.queue_name = queue_name self.gpu_enabled = gpu_enabled self.sync_mode = sync_mode
[docs] def validate(self) -> None: """ Validate the UGEQueueConfig object """ if not isinstance(self.queue_name, str): msg = "queue_name must be a string!" logger.error(msg) raise TypeError(msg) if not isinstance(self.gpu_enabled, bool): msg = "gpu_enabled argument must be a boolean!" logger.error(msg) raise TypeError(msg) if not isinstance(self.sync_mode, bool): msg = "sync_mode argument must be a boolean!" logger.error(msg) raise TypeError(msg) if self.sync_mode: msg = "sync_mode=True currently unsupported!" logger.error(msg) raise TypeError(msg)
[docs]class UGEConfig: """ Defines a configuration for the UGE """ def __init__(self, queues: Union[UGEQueueConfig, Sequence[UGEQueueConfig]], queue_distribution: Sequence[float] = None, multi_model_same_gpu: bool = False): """ :param queues: a list of Queue object configurations :param queue_distribution: the desired way to distribute the workload across the queues, if None, then the workload is distributed evenly across the queues, otherwise :param multi_model_same_gpu: if True, then if multiple models are desired for a given ModelGeneratorConfig, those will all be trained on the same queue. Otherwise, they will be distributed as much as possible (which is likely to complete the job faster!) """ self.queues = queues self.queue_distribution = queue_distribution self.multi_model_same_gpu = multi_model_same_gpu self.validate()
[docs] def validate(self): """ Validate the UGEConfig object """ if isinstance(self.queues, UGEQueueConfig): self.queues = [self.queues] elif isinstance(self.queues, collections.abc.Sequence): for q in self.queues: if not isinstance(q, UGEQueueConfig): msg = "queues must be a Sequence of UGEQueueConfig objects!" logger.error(msg) raise TypeError(msg) else: msg = "queues input must be either a UGEQueueConfig object, or a Sequence of UGEQueueConfig objects!" logger.error(msg) raise TypeError(msg) if self.queue_distribution is not None: if not isinstance(self.queue_distribution, collections.abc.Sequence): msg = "queue_distribution argument must be either None (implying uniform distribution among all " \ "queues, or a Sequence of floats summing to one" logger.error(msg) raise TypeError(msg) else: try: if len(self.queue_distribution) != len(self.queues): msg = "if a queue_distribution is provided, it must be equal to the number of queues provided!" logger.error(msg) raise TypeError(msg) sum_val = np.sum(self.queue_distribution) if not np.isclose(sum_val, 1): msg = "queue_distribution must be a Sequence of floats summing to 1" logger.error(msg) raise ValueError(msg) for d in self.queue_distribution: if d < 0 or d > 1: msg = "queue_distribution values must be between 0 and 1" logger.error(msg) raise TypeError(msg) except TypeError as e: logger.exception(e) raise TypeError(e) if not isinstance(self.multi_model_same_gpu, bool): msg = "multi_model_same_gpu input must be a boolean!" logger.error(msg) raise TypeError(msg)