Source code for trojai.modelgen.model_generator_interface
from abc import ABC, abstractmethod
from typing import Union, Sequence
import logging
from .config import ModelGeneratorConfig
logger = logging.getLogger(__name__)
[docs]class ModelGeneratorInterface(ABC):
"""Generates models based on requested data and saves each to a file."""
def __init__(self, configs: Union[ModelGeneratorConfig, Sequence[ModelGeneratorConfig]]):
"""
:param configs: configuration objects that specify how to generate models for a single experiment
"""
self.configs = configs
if not isinstance(self.configs, Sequence):
self.configs = [self.configs]
[docs] @abstractmethod
def run(self) -> None:
"""
Train and save models as specified.
:return: None
"""
pass
[docs]def validate_model_generator_interface_input(configs: Union[ModelGeneratorConfig, Sequence[ModelGeneratorConfig]]) \
-> None:
"""
Validates a ModelGeneratorConfig
:param configs: (ModelGeneratorConfig or sequence) configurations to be used for model generation
:return None
"""
if not (isinstance(configs, ModelGeneratorConfig) or isinstance(configs, Sequence)):
err_msg = "Expected a ModelGeneratorConfig object or sequence of ModelGeneratorConfig objects for " \
"argument 'configs', instead got type: {}".format(type(configs))
logger.error(err_msg)
raise TypeError(err_msg)
if isinstance(configs, Sequence) and len(configs) == 0:
err_msg = "Emtpy sequence provided for 'configs' argument."
logger.error(err_msg)
raise RuntimeError(err_msg)
if isinstance(configs, Sequence):
for cfg in configs:
if not isinstance(cfg, ModelGeneratorConfig):
err_msg = "non-'ModelGeneratorConfig' type included in argument 'configs': {}".format(type(cfg))
logger.error(err_msg)
raise TypeError(err_msg)
logger.debug("Configuration validated successfully!")