Source code for trojai.datagen.common_label_behaviors

import logging

from .label_behavior import LabelBehavior

logger = logging.getLogger(__name__)

"""
Defines some common behaviors which are used to modify labels when designing an experiment with triggered and clean data
"""


[docs]class WrappedAdd(LabelBehavior): """ Adds a defined amount to each input label, with an optional maximum value around which labels are wrapped """ def __init__(self, add_val: int, max_num_classes: int = None) -> None: """ Creates the WrappedAdd object :param add_val: the value to add to each input label :param max_num_classes: the maximum number of classes such that modified labels are wrapped """ self.add_val = add_val self.max_num_classes = max_num_classes
[docs] def do(self, y_true: int) -> int: """ Performs the actual specified label modification :param y_true: input label to be modified :return: the modified label """ modified_label = y_true + self.add_val if self.max_num_classes is not None: modified_label %= self.max_num_classes logger.debug("Converted label %d to %d" % (y_true, modified_label)) return modified_label
[docs]class StaticTarget(LabelBehavior): """ Sets label to a defined value """ def __init__(self, target) -> None: """ Creates the StaticTarget object :param target: the value to set each input label to """ self.target = target
[docs] def do(self, y_true): """ Performs the actual specified label modification :param y_true: input label to be modified :return: the modified label """ modified_label = self.target logger.debug("Converted label %s to %s" % (str(y_true), str(modified_label))) return modified_label