Source code for trojai.modelgen.model_generator

import logging
from typing import Union, Sequence
import torch

from tqdm import tqdm

from .model_generator_interface import ModelGeneratorInterface, validate_model_generator_interface_input
from .config import ModelGeneratorConfig
from .runner import Runner
from .config import modelgen_cfg_to_runner_cfg

logger = logging.getLogger(__name__)


[docs]class ModelGenerator(ModelGeneratorInterface): """Generates models based on requested data and saves each to a file.""" def __init__(self, configs: Union[ModelGeneratorConfig, Sequence[ModelGeneratorConfig]], *args, **kwargs): """ :param configs: (ModelGeneratorConfig or sequence) ModelGeneratorConfig objects configured to generate models for a single experiment """ super().__init__(configs) self.validate()
[docs] def run(self, *args, **kwargs) -> None: """ Train and save models as specified. :return: None """ loop = tqdm(self.configs, desc='Configurations') for cfg in loop: loop.set_postfix_str(cfg.experiment_cfg['name']) for i in range(cfg.num_models): filename = None run_id = None if cfg.filenames is not None: if isinstance(cfg.filenames, str): filename = cfg.filenames else: filename = cfg.filenames[i] elif cfg.run_ids is not None: run_id = cfg.run_ids[i] run_cfg = modelgen_cfg_to_runner_cfg(cfg, run_id=run_id, filename=filename) runner = Runner(run_cfg, persist_metadata=cfg.experiment_cfg) runner.run() # clear up memory between runs torch.cuda.empty_cache()
[docs] def validate(self) -> None: """ Validate the provided input when constructing the ModelGenerator interface """ validate_model_generator_interface_input(self.configs)