Source code for trojai.modelgen.optimizer_interface

from abc import ABC, abstractmethod
from typing import Sequence
import logging

import torch.nn

from torch.utils.data import Dataset
from .training_statistics import EpochStatistics

logger = logging.getLogger(__name__)


[docs]class OptimizerInterface(ABC): """Object that performs training and testing of TrojAI models."""
[docs] @abstractmethod def train(self, model: torch.nn.Module, data: Dataset, progress_bar_disable: bool, torch_dataloader_kwargs: dict = None) -> (torch.nn.Module, Sequence[EpochStatistics], int): """ Train the given model using parameters in self.training_params :param model: (torch.nn.Module) The untrained Pytorch model :param data: (CSVDataset) Object containing training data, output 0 from TrojaiDataManager.load_data() :param progress_bar_disable: (bool) Don't display the progress bar if True :param torch_dataloader_kwargs: additional arguments to pass to PyTorch's DataLoader class :return: (torch.nn.Module, EpochStatistics) trained model, a sequence of EpochStatistics objects (one for each epoch), and the # of epochs with which the model was trained (useful for early stopping). """ pass
[docs] @abstractmethod def test(self, model: torch.nn.Module, clean_test_data: Dataset, triggered_test_data: Dataset, clean_test_triggered_labels_data: Dataset, torch_dataloader_kwargs) -> dict: """ Perform whatever tests desired on the model with clean data and triggered data, return a dictionary of results. :param model: (torch.nn.Module) Trained Pytorch model :param clean_test_data: (CSVDataset) Object containing clean test data :param triggered_test_data: (CSVDataset or None) Object containing triggered test data, None if triggered data was not provided for testing :param clean_test_triggered_labels_data: triggered part of the training dataset but with correct labels; see DataManger.load_data for more information. :param torch_dataloader_kwargs: additional arguments to pass to PyTorch's DataLoader class :return: (dict) Dictionary of test accuracy results. Required key, value pairs are: clean_accuracy: (float in [0, 1]) classification accuracy on clean data clean_n_total: (int) number of examples in clean test set The following keys are optional, but should be used if triggered test data was provided triggered_accuracy: (float in [0, 1]) classification accuracy on triggered data triggered_n_total: (int) number of examples in triggered test set NOTE: This list may be augmented in the future to allow for additional test data collection. """ pass
[docs] @abstractmethod def get_device_type(self) -> str: """ Return a string representation of the type of device used by the optimizer to train the model. """ pass
[docs] @abstractmethod def get_cfg_as_dict(self) -> dict: """ Return a dictionary with key/value pairs that describe the parameters used to train the model. """ pass
@abstractmethod def __deepcopy__(self, memodict={}): """ Required for training on clusters. Return a deep copy of the optimizer. """ pass @abstractmethod def __eq__(self, other): """ Required for training on clusters. Define how to chech if two optimizers are equal. """ pass @abstractmethod def __str__(self): pass
[docs] @abstractmethod def save(self, fname: str) -> None: """ Save the optimizer to a file :param fname - the filename to save the optimizer to """ pass
[docs] @staticmethod @abstractmethod def load(fname: str): """ Load an optimizer from disk and return it :param fname: the filename where the optimizer is serialized :return: The loaded optimizer """ pass