Source code for trojai.datagen.config

import logging
from typing import Sequence, Union, Any
import collections.abc

from .entity import Entity
from .merge_interface import Merge
from .transform_interface import Transform

logger = logging.getLogger(__name__)

"""
Contains classes which define configuration used for transforming and modifying objects, as well as the associated
validation routines.  Ideally, a configuration class should be defined for every pipeline that is defined.
"""


[docs]def check_list_type(op_list, type, err_msg): for op in op_list: if not isinstance(op, type): logger.error(err_msg) raise ValueError(err_msg)
[docs]class XFormMergePipelineConfig: """ Defines all configuration items necessary to run the XFormMerge Pipeline, and associated configuration validation. NOTE: the argument list can be condensed into lists of lists, but that becomes a bit less intuitive to use. We need to think about how best we want to specify these argument lists. """ def __init__(self, trigger_list: Sequence[Entity] = None, trigger_sampling_prob: Sequence[float] = None, trigger_xforms: Sequence[Transform] = None, trigger_bg_xforms: Sequence[Transform] = None, trigger_bg_merge: Merge = None, trigger_bg_merge_xforms: Sequence[Transform] = None, overall_bg_xforms: Sequence[Transform] = None, overall_bg_triggerbg_merge: Merge = None, overall_bg_triggerbg_xforms: Sequence[Transform] = None, merge_type: str = 'insert', per_class_trigger_frac: float = None, triggered_classes: Union[str, Sequence[Any]] = 'all'): """ Initializes the configuration used by XFormMergePipeline :param trigger_list: a list of Triggers to insert into the background Entity :param trigger_sampling_prob: probability with how the trigger should be sampled, if none, uniform sampling happens :param trigger_xforms: a list of transforms to apply to the trigger :param trigger_bg_xforms: a list of transforms to apply to the trigger background (what the trigger will be inserted into) :param trigger_bg_merge: merge operator to combine the trigger and the trigger background :param trigger_bg_merge_xforms: a list transforms to apply after combining the trigger and the trigger background :param overall_bg_xforms: a list of transforms to apply to the overall background, into which the trigger+trigger_bg will be inserted into. This is only applicable for the merge_type of "regenerate" :param overall_bg_triggerbg_merge: Merge object which defines how to merge the the background image with the trigger+bg image. For example, a use case might be a inserting a trigger into a traffic sign (which would be trigger+bg), and then inserting that into an overall background :param overall_bg_triggerbg_xforms: Any final transforms that should be applied after merging the trigger with the background and merging that combined entity with another background (as the usecase above) :param merge_type: How data will be merged. Valid merge_types are determined by the method argument of the Pipeline's modify_clean_dataset() function :param per_class_trigger_frac: The percentage of the total clean data to modify. If None, all the data will be modified :param triggered_classes: either the string 'all', or a list of labels which will be triggered """ self.trigger_list = trigger_list self.trigger_xforms = trigger_xforms self.trigger_sampling_prob = trigger_sampling_prob self.trigger_bg_xforms = trigger_bg_xforms self.trigger_bg_merge = trigger_bg_merge self.trigger_bg_merge_xforms = trigger_bg_merge_xforms # validate configuration based on the merge type self.merge_type = merge_type.lower() self.per_class_trigger_frac = per_class_trigger_frac self.triggered_classes = triggered_classes self.overall_bg_xforms = overall_bg_xforms self.overall_bg_triggerbg_merge = overall_bg_triggerbg_merge self.overall_bg_triggerbg_xforms = overall_bg_triggerbg_xforms # validate configuration based on the merge type self.merge_type = merge_type.lower() self.validate_regenerate_mode() self.validate()
[docs] def validate(self): """ Validates whether the configuration was setup properly, based on the merge_type. :return: None """ if self.per_class_trigger_frac is not None and (self.per_class_trigger_frac <= 0. or self.per_class_trigger_frac >= 1.): msg = "per_class_trigger_frac must be between 0 and 1, noninclusive" logger.error(msg) raise ValueError(msg) if self.merge_type == 'insert' or self.merge_type == 'regenerate': pass else: msg = "Unknown merge_type! See pipeline's modify_clean_dataset() for valid merge types!" logger.error(msg) raise ValueError(msg) # the following set of variables are if self.trigger_list is not None: check_list_type(self.trigger_list, Entity, "trigger_list must be a sequence of Entity objects!") if self.trigger_sampling_prob is not None: check_list_type(self.trigger_sampling_prob, float, "trigger_sampling_prob must be a sequence of floats!") if self.trigger_xforms is None: # silently convert None to no xforms applied in the format needed by the Pipeline self.trigger_xforms = [] check_list_type(self.trigger_xforms, Transform, "trigger_xforms must be a list of Transform objects!") if self.trigger_bg_merge is None or not isinstance(self.trigger_bg_merge, Merge): msg = "trigger_bg_merge must be specified as a trojai.datagen.Merge.Merge object" logger.error(msg) raise ValueError(msg) if self.trigger_bg_merge_xforms is None: # silently convert None to no xforms applied in the format needed by the Pipeline self.trigger_bg_merge_xforms = [] check_list_type(self.trigger_bg_merge_xforms, Transform, "trigger_bg_merge_xforms must be a list of Transform objects") if isinstance(self.triggered_classes, str): if self.triggered_classes != 'all': msg = "triggered_classes must be the string 'any', or a list of labels" logger.error(msg) raise ValueError(msg) elif isinstance(self.triggered_classes, collections.abc.Sequence): # NOTE: we leave this to run-time checking b/c we don't know what the type of a Label is for a particular # type of data pass else: msg = "triggered_classes must be the string 'any', or a list of labels" logger.error(msg) raise ValueError(msg)
[docs] def validate_regenerate_mode(self): """ Validates whether the configuration was setup properly, based on the merge_type. :return: None """ # additional checks if the xform+merge is being used to "regenerate" the data if self.merge_type == 'regenerate': if self.overall_bg_xforms is None: # silently convert None to no xforms applied in the format needed by the Pipeline self.overall_bg_xforms = [] check_list_type(self.overall_bg_xforms, Transform, "overall_bg_xforms must be a list of Transform objects!") if not isinstance(self.overall_bg_triggerbg_merge, Merge): msg = "overall_bg_triggerbg_merge input must be of type trojai.datagen.Merge.Merge" logger.error(msg) raise ValueError(msg) if self.overall_bg_triggerbg_xforms is None: # silently convert None to no xforms applied in the format needed by the Pipeline self.overall_bg_triggerbg_xforms = [] check_list_type(self.overall_bg_triggerbg_xforms, Transform, "overall_bg_triggerbg_xforms must be a list of Transform objects!")
[docs]def check_non_negative(val, name): if not isinstance(val, Sequence): val = [val] for v in val: if v < 0.0: msg = "Illegal value specified %s. All values must be non-negative!" % name logger.error(msg) raise ValueError(msg)
[docs]class ValidInsertLocationsConfig: """ Specifies which algorithm to use for determining the valid spots for trigger insertion on an image and all relevant parameters """ def __init__(self, algorithm: str = 'brute_force', min_val: Union[int, Sequence[int]] = 0, threshold_val: Union[float, Sequence[float]] = 5.0, num_boxes: int = 5, allow_overlap: Union[bool, Sequence[bool]] = False): """ Initialize and validate all relevant parameters for InsertAtRandomLocation :param algorithm: algorithm to use for determining valid placement, options include brute_force -> for every edge pixel of the image, invalidates all intersecting pattern insert locations threshold -> a trigger position on the image is invalid if the mean pixel value over the area is greater than a specified amount (threshold_val), WARNING: slowest of all options by substantial amount edge_tracing -> follows perimeter of non-zero image values invalidating locations where there is any overlap between trigger and image, works well for convex images with long flat edges bounding_boxes -> splits the image into a grid of size num_boxes x num_boxes and generates a bounding box for the image in each grid location, and invalidates all intersecting trigger insert locations, provides substantial speedup for large images with fine details but will not find all valid insert locations, WARNING: may not find any valid insert locations if num_boxes is too small :param min_val: any pixels above this value will be considered for determining overlap, any below this value will be treated as if there is no image present for the given pixel :param threshold_val: value to compare mean pixel value over possible insert area to, only needed for threshold :param num_boxes: size of grid for bounding boxes algorithm, larger value implies closer approximation, only needed for bounding_boxes :param allow_overlap: specify which channels to allow overlap of trigger and image, if True overlap is allowed for all channels """ self.algorithm = algorithm.lower() self.min_val = min_val self.threshold_val = threshold_val self.num_boxes = num_boxes self.allow_overlap = allow_overlap self.validate()
[docs] def validate(self): """ Assess validity of provided values :return: None """ if self.algorithm not in {'brute_force', 'threshold', 'edge_tracing', 'bounding_boxes'}: msg = "Algorithm specified is not implemented!" logger.error(msg) raise ValueError(msg) check_non_negative(self.min_val, 'min_val') if self.algorithm == 'brute_force': pass elif self.algorithm == 'threshold': check_non_negative(self.threshold_val, 'threshold_val') elif self.algorithm == 'edge_tracing': pass elif self.algorithm == 'bounding_boxes': if self.num_boxes < 1 or self.num_boxes > 25: msg = "Must specify a value between 1 and 25 for num_boxes!" logger.error(msg) raise ValueError(msg)
[docs]class TrojAICleanDataConfig: def __init__(self, sign_xforms: Sequence[Transform] = None, bg_xforms: Sequence[Transform] = None, merge_obj: Merge = None, combined_xforms: Sequence[Transform] = None) -> None: self.sign_xforms = sign_xforms self.bg_xforms = bg_xforms self.merge_obj = merge_obj self.combined_xforms = combined_xforms self.validate()
[docs] def validate(self) -> None: if self.sign_xforms is None: self.sign_xforms = [] check_list_type(self.sign_xforms, Transform, "sign_xforms must be list of Transform objects") if self.bg_xforms is None: self.bg_xforms = [] check_list_type(self.bg_xforms, Transform, "bg_xforms must be list of Transform objects") if not isinstance(self.merge_obj, Merge): msg = "merge_obj must be of type trojai.datagen.Merge.Merge" logger.error(msg) raise ValueError(msg) if self.combined_xforms is None: self.combined_xforms = [] check_list_type(self.combined_xforms, Transform, "combined_xforms must be list of Transform " "objects")