Source code for trojai.datagen.image_triggers

import logging
import math
from typing import Sequence, Union, Tuple

import numpy as np
from numpy.random import RandomState

from .image_entity import ImageEntity

logger = logging.getLogger(__name__)

Defines various Trigger Entity objects

[docs]class ReverseLambdaPattern(ImageEntity): """ Defines an alpha pattern """ def __init__(self, num_rows: int, num_cols: int, num_chan: int, trigger_cval: Union[int, Sequence[int]], bg_cval: Union[int, Sequence[int]] = 0, thickness: int = 1, pattern_style: str = 'graffiti', dtype=np.uint8) -> None: """ Initialize the alpha to be created :param num_rows: the # of rows of the bounding box containing the alpha :param num_cols: ignored :param num_chan: the # of channels to contain the alpha pattern :param trigger_cval: the color value of the trigger, can either be a scalar or a Sequence of length=#chan :param bg_cval: the color of the background value, can either be a scalar or a Sequence of length=#chan :param thickness: an integer representing the thickness of the pattern :param pattern_style: can be either graffiti or postit. :param dtype: datatype to generate the pattern for, defaults to np.uint8 """ self.num_rows = num_rows self.num_cols = num_cols self.num_chan = num_chan if np.size(trigger_cval) != 1 and np.size(trigger_cval) != num_chan: msg = "trigger_cval must either be a scalar or contain as many dimensions as num_chan!" logger.error(msg) raise ValueError(msg) self.trigger_cval = trigger_cval if np.size(bg_cval) != 1 and np.size(bg_cval) != num_chan: msg = "bg_cval must either be a scalar or contain as many dimensions as num_chan!" logger.error(msg) raise ValueError(msg) self.bg_cval = bg_cval self.thickness = thickness if pattern_style.lower() == 'graffiti' or pattern_style.lower() == 'postit': self.pattern_style = pattern_style else: msg = "Unknown pattern style!" logger.error(msg) raise ValueError(msg) self.dtype = dtype self.pattern = None self.mask = None self.create()
[docs] def create(self) -> None: """ Creates the alpha pattern and associated mask :return: None """ self.pattern = np.ones((self.num_rows, self.num_rows, self.num_chan), dtype=self.dtype) if self.pattern_style.lower() == 'graffiti': self.mask = np.zeros((self.num_rows, self.num_rows), dtype=bool) elif self.pattern_style.lower() == 'postit': self.mask = np.ones((self.num_rows, self.num_rows), dtype=bool) else: msg = "Unknown pattern style!" logger.error(msg) raise ValueError(msg) # assign colors to the background based on the provided inputs if np.size(self.bg_cval) == 1: self.pattern *= self.bg_cval else: # assign each channel individually for ii in range(self.num_chan): self.pattern[:, :, ii] = self.bg_cval[ii] diag_indices = np.diag_indices(self.num_rows) alternative_diag_indices = (diag_indices[0], np.flipud(diag_indices[1])) # works even if num_chan > 1 for pattern self.pattern[alternative_diag_indices] = self.trigger_cval self.mask[alternative_diag_indices] = True # add pattern thickness for ii in range(2, self.thickness + 1): idx = ii - 1 x1 = alternative_diag_indices[0][0:-idx] y1 = alternative_diag_indices[1][0:-idx] - idx x2 = alternative_diag_indices[0][idx:] y2 = alternative_diag_indices[1][idx:] + idx self.pattern[(x1, y1)] = self.trigger_cval self.pattern[(x2, y2)] = self.trigger_cval self.mask[(x1, y1)] = True self.mask[(x2, y2)] = True lower_main_diag_indices = tuple(i[math.ceil(self.num_rows / 2):] for i in diag_indices) # works even if num_chan > 1 for pattern self.pattern[lower_main_diag_indices] = self.trigger_cval self.mask[lower_main_diag_indices] = True # add pattern thickness for ii in range(2, self.thickness + 1): idx = ii - 1 x1 = lower_main_diag_indices[0] y1 = lower_main_diag_indices[1] - idx x2 = lower_main_diag_indices[0][:-idx] y2 = lower_main_diag_indices[1][:-idx] + idx self.pattern[(x1, y1)] = self.trigger_cval self.pattern[(x2, y2)] = self.trigger_cval self.mask[(x1, y1)] = True self.mask[(x2, y2)] = True
[docs] def get_data(self) -> np.ndarray: """ Get the image associated with the Entity :return: return a numpy.ndarray representing the image """ return self.pattern
[docs] def get_mask(self) -> np.ndarray: """ Get the mask associated with the Entity :return: return a numpy.ndarray representing the mask """ return self.mask
[docs]class RandomRectangularPattern(ImageEntity): """ Defines a random rectangular pattern """ def __init__(self, num_rows: int, num_cols: int, num_chan: int, color_algorithm: str = 'channel_assign', color_options: dict = None, pattern_style='graffiti', dtype=np.uint8, random_state_obj: RandomState = RandomState(1234)) -> None: """ Initialize a random rectangular pattern to be created :param num_rows: the # of rows of the rectangle to be created :param num_cols: the # of cols of the rectangle to be created :param num_chan: the # of channels of the rectangle :param color_algorithm: can be "channel_assign", "random" channel_assign - if associated cval is a scalar, then we assign the specified color to every channel. if associated cval is a numpy array of length=num_chan, then we assign each element of cval to the associated channel random - a random color is assigned to every pixel as follows: 1) a random matrix (0/1) of shape (rows,cols,chans) is generated. Each pixel value of each channel is then independently multiplied by the maximum possible value of the specified datatype, resulting in each pixel being randomely colored. :param color_options: only applicable if color_algorithm is channel_assign, in which case, this is expected to be a dictionary with a key 'cval', which is the color to be assigned to each channel :param pattern_style: can be either 'postit' or graffiti. :param dtype: the default datatype of the rectangle to be generated :param random_state_obj: random state object """ self.num_rows = num_rows self.num_cols = num_cols self.num_chan = num_chan self.color_algorithm = color_algorithm if color_options is None: self.color_options = dict(cval=255) else: self.color_options = color_options self.pattern_style = pattern_style self.dtype = dtype self.pattern = None self.mask = None self.random_state_obj = random_state_obj self.create()
[docs] def create(self) -> None: """ Create the actual pattern :return: None """ dtype_max_val = np.iinfo(self.dtype).max cb = self.random_state_obj.choice(2, self.num_rows * self.num_cols).\ reshape((self.num_rows, self.num_cols)).astype(self.dtype) self.pattern = np.zeros((cb.shape[0], cb.shape[1], self.num_chan), dtype=self.dtype) self.mask = np.ones((self.num_rows, self.num_cols), dtype=bool) # color according to specified options if self.color_algorithm == 'channel_assign': cval = self.color_options['cval'] if isinstance(cval, np.ndarray) or isinstance(cval, list): if len(cval) != self.num_chan: msg = "cval must be a scalar or of length=num_chan" logger.error(msg) raise ValueError(msg) for ii, c in enumerate(range(self.num_chan)): self.pattern[:, :, c] = cb*cval[ii] else: # assume scalar for c in range(self.num_chan): self.pattern[:, :, c] = cb*cval elif self.color_algorithm == 'random': num_elem_to_generate = self.num_rows * self.num_cols * self.num_chan self.pattern = self.random_state_obj.choice(2, num_elem_to_generate).\ reshape((self.num_rows, self.num_cols, self.num_chan)).astype(self.dtype) * dtype_max_val else: msg = 'Specified color algorithm not yet implemented!' logger.error(msg) raise ValueError(msg) if self.pattern_style.lower() == 'graffiti': self.mask[np.where(cb == 0)] = False
[docs] def get_data(self) -> np.ndarray: """ Get the image associated with the Entity :return: return a numpy.ndarray representing the image """ return self.pattern
[docs] def get_mask(self) -> np.ndarray: """ Get the mask associated with the Entity :return: return a numpy.ndarray representing the mask """ return self.mask
[docs]class RectangularPattern(ImageEntity): """ Define a rectangular pattern """ def __init__(self, num_rows: int, num_cols: int, num_chan: int, cval: int, dtype=np.uint8) -> None: """ :param num_rows: the # of rows of the rectangle to be created :param num_cols: the # of cols of the rectangle to be created :param num_chan: the # of channels of the rectangle :param cval: the color value of the rectangle :param dtype: the default datatype of the rectangle to be generated """ self.num_rows = num_rows self.num_cols = num_cols self.num_chan = num_chan self.cval = cval self.dtype = dtype self.pattern = None self.mask = None self.create()
[docs] def create(self) -> None: """ Create the actual pattern :return: None """ # performs matrix multiplication and broadcasts scalars self.pattern = np.ones((self.num_rows, self.num_cols, self.num_chan), dtype=self.dtype)*self.cval self.mask = np.ones(self.pattern.shape[0:2], dtype=bool)
[docs] def get_data(self) -> np.ndarray: """ Get the image associated with the Entity :return: return a numpy.ndarray representing the image """ return self.pattern
[docs] def get_mask(self) -> np.ndarray: """ Get the mask associated with the Entity :return: return a numpy.ndarray representing the mask """ return self.mask