Source code for trojai.datagen.insert_merges

import logging
import warnings

import numpy as np
from numpy.random import RandomState

import trojai.datagen.image_insert_utils as insert_utils
from .config import ValidInsertLocationsConfig
from .image_entity import GenericImageEntity, ImageEntity
from .merge_interface import ImageMerge, TextMerge
from .text_entity import TextEntity, GenericTextEntity

logger = logging.getLogger(__name__)


"""
Module which defines several insert style merge operations.
"""

[docs]class InsertRandomLocationNonzeroAlpha(ImageMerge): """ Inserts a defined pattern into an image in a randomly selected location where the alpha channel is non-zero """ def __init__(self) -> None: """ Initialize the insert merger """
[docs] def do(self, img_obj: ImageEntity, pattern_obj: ImageEntity, random_state_obj: RandomState) -> ImageEntity: """ Perform the described merge operation :param img_obj: The input object into which the pattern is to be inserted :param pattern_obj: The pattern object which is to be inserted into the image :param random_state_obj: used to sample from the possible valid locations, by providing a random state, we ensure reproducibility of the data :return: the merged object """ img = img_obj.get_data() pattern = pattern_obj.get_data() num_chans = img.shape[2] if num_chans != 4: raise ValueError("Alpha Channel expected!") # find valid locations & remove bounding box i_rows, i_cols, _ = img.shape p_rows, p_cols, _ = pattern.shape # TODO: remove edges of image so that the patch always stays within # the image valid_indices = np.where(img[0:i_rows-p_rows, 0:i_cols-p_cols, 3] != 0) num_valid_indices = len(valid_indices[0]) random_index = random_state_obj.choice(num_valid_indices) insert_loc = [valid_indices[0][random_index], valid_indices[1][random_index]] insert_loc_per_chan = np.tile(insert_loc, (4, 1)).astype(int) logger.debug("Selected insertion location randomly from available locations") inserter = InsertAtLocation(insert_loc_per_chan) inserted_img_obj = inserter.do(img_obj, pattern_obj, random_state_obj) return inserted_img_obj
[docs]class InsertRandomWithMask(ImageMerge): """ Inserts a defined pattern into an image in a randomly selected location where the specified mask is True """ def __init__(self) -> None: """ Initialize the insert merger """
[docs] def do(self, img_obj: ImageEntity, pattern_obj: ImageEntity, random_state_obj: RandomState) -> ImageEntity: """ Perform the described merge operation :param img_obj: The input object into which the pattern is to be inserted :param pattern_obj: The pattern object which is to be inserted into the image :param random_state_obj: used to sample from the possible valid locations, by providing a random state, we ensure reproducibility of the data :return: the merged object """ img = img_obj.get_data() img_mask = img_obj.get_mask() pattern = pattern_obj.get_data() num_chans = img.shape[2] if num_chans != 4: raise ValueError("Alpha Channel expected!") # find valid locations & remove bounding box i_rows, i_cols, _ = img.shape p_rows, p_cols, _ = pattern.shape msk_for_loc_determination = np.ones((pattern.shape[0], pattern.shape[1], 1), dtype=int) valid_loc_mask = insert_utils.valid_locations(np.expand_dims(np.invert(img_mask), axis=2), msk_for_loc_determination, ValidInsertLocationsConfig(algorithm='edge_tracing', min_val=0)) valid_indices = np.where(valid_loc_mask) num_valid_indices = len(valid_indices[0]) if num_valid_indices == 0: raise RuntimeError('Unable to InsertRandomWithMask, no valid locations found') random_index = random_state_obj.choice(num_valid_indices) insert_loc = [valid_indices[0][random_index], valid_indices[1][random_index]] insert_loc_per_chan = np.tile(insert_loc, (4, 1)).astype(int) logger.debug("Selected insertion location randomly from available locations") inserter = InsertAtLocation(insert_loc_per_chan) inserted_img_obj = inserter.do(img_obj, pattern_obj, random_state_obj) return inserted_img_obj
[docs]class InsertAtLocation(ImageMerge): """ Inserts a provided pattern at a specified location """ def __init__(self, location: np.ndarray, protect_wrap: bool = True): """ Initializes the inserter object :param location: The location to insert, must be of shape=(channels x 2) :param protect_wrap: If True, prevents insertion of objects via wrapping """ self.location = location self.protect_wrap = protect_wrap
[docs] def do(self, img_obj: ImageEntity, pattern_obj: ImageEntity, random_state_obj: RandomState) -> ImageEntity: """ Inserts a pattern into an image, using the mask of the pattern to determine which specific pixels are modifiable :param img_obj: The background image into which the pattern is inserted :param pattern_obj: The pattern to be inserted. The mask associated with the pattern is used to determine which specific pixes of the pattern are inserted into the img_obj :param random_state_obj: ignored :return: The merged object """ if not isinstance(img_obj, ImageEntity) or not isinstance(pattern_obj, ImageEntity): raise ValueError("img_obj and pattern_obj must both be ImageEntity objects to use InsertAtLocation!") img = img_obj.get_data() img_mask = img_obj.get_mask() pattern = pattern_obj.get_data() pattern_mask = pattern_obj.get_mask() if len(img.shape) != 3: raise ValueError('Input image must be of dimensions rows x cols x channels') num_chans = img.shape[2] if pattern.shape[2] != num_chans: # force user to broadcast the pattern as necessary msg = 'The # of channels in the pattern does not match the # of channels in the image!' logger.error(msg) raise ValueError(msg) if self.location.shape[0] != num_chans: msg = 'location input must be of shape=(channels x 2)' logger.error(msg) raise ValueError(msg) if not self.protect_wrap: # TODO msg = 'Wrapping of images not yet implemented!' logger.error(msg) raise NotImplementedError(msg) # to allow for patterns across channels to be in different locations, # we do this in a for-loop # TODO: see if this can be vectorized for chan_idx in range(num_chans): r, c = self.location[chan_idx, :] chan_pattern = pattern[:, :, chan_idx].squeeze() p_rows, p_cols = chan_pattern.shape chan_location = self.location[chan_idx, :] logger.debug("Inserting pattern into image for channel=%d at location=[%d,%d]" % (chan_idx, chan_location[0], chan_location[1])) if self.protect_wrap: chan_img = img[:, :, chan_idx].squeeze() if not insert_utils.pattern_fit(chan_img, chan_pattern, chan_location): msg = 'Pattern doesnt fit into image at specified location!' logger.error(msg) raise ValueError(msg) # take into account masks np.putmask(img[r:r + p_rows, c:c + p_cols, chan_idx], pattern_mask, chan_pattern) # TODO: is there something we need to change about the mask? return GenericImageEntity(img, img_mask)
[docs]class InsertAtRandomLocation(ImageMerge): """ Inserts a provided pattern at a random location, where valid locations are determined according to a provided algorithm specification """ def __init__(self, method: str, algo_config: ValidInsertLocationsConfig, protect_wrap: bool = True) -> None: """ Initialize the random inserter object. :param method: the insertion method, currently, only uniform_random_available is a valid input :param algo_config: The provided configuration object specifying the algorithm to use and necessary parameters :param protect_wrap: if True, ensures that pattern to be inserted can fit without wrapping and raises an Exception otherwise """ self.method = method self.algo_config = algo_config self.protect_wrap = protect_wrap
[docs] def do(self, img_obj: ImageEntity, pattern_obj: ImageEntity, random_state_obj: RandomState) -> ImageEntity: """ Perform the specified merge on the input Entities and return the merged Entity :param img_obj: the image object into which the pattern is to be inserted :param pattern_obj: the pattern object to be inserted :param random_state_obj: used to sample from the possible valid locations, by providing a random state, we ensure reproducibility of the data :return: the merged Entity """ if not isinstance(img_obj, ImageEntity) or not isinstance(pattern_obj, ImageEntity): raise ValueError("img_obj and pattern_obj must both be ImageEntity objects to use InsertAtRandomLocation!") pattern = pattern_obj.get_data() img = img_obj.get_data() num_chans = img.shape[2] if self.method == 'uniform_random_available': valid_location_mask = insert_utils.valid_locations(img, pattern, self.algo_config, self.protect_wrap) # trigger same across all channels if num_chans == 3: valid_location_mask = np.bitwise_and.reduce(valid_location_mask, axis=2) valid_locs = np.nonzero(valid_location_mask) if len(valid_locs[0]) == 0: # TODO: link back to this image's file pointer in error msg warnings.warn('Image contains no space for trigger w/out ' 'occlusion! Placing trigger on upper left w/ ' 'possible partial occlusion!') valid_locs = np.asarray([[0, 0]] * num_chans).T idx_select = 0 else: idx_select = random_state_obj.choice(np.arange(len(valid_locs[0]))) logger.debug("Selected random location for insertion") insert_locs_per_chan = np.empty((num_chans, 2), dtype=np.int16) for chan_idx in range(num_chans): insert_locs_per_chan[chan_idx, :] = [valid_locs[0][idx_select], valid_locs[1][idx_select]] logger.debug("Inserted pattern into image") else: msg = "Insert method not yet implemented!" logger.error(msg) raise NotImplementedError(msg) inserter = InsertAtLocation(insert_locs_per_chan) inserted_img_obj = inserter.do(img_obj, pattern_obj, random_state_obj) return inserted_img_obj
[docs]class RandomInsertTextMerge(TextMerge): def __init__(self): pass
[docs] def do(self, obj1: TextEntity, obj2: TextEntity, random_state_obj: RandomState): if not isinstance(obj1, TextEntity) or not isinstance(obj2, TextEntity): raise ValueError("The inputs to RandomInsertTextMerge must be two TextEntity objects!") # Pick a random location in the first object if obj1.get_data().size == 0: output_entity = GenericTextEntity(obj2.get_text()) else: insert_loc = random_state_obj.randint(obj1.get_data().size, size=1)[0] # Create a new entity to contain the output output_entity = GenericTextEntity(obj1.get_text()) # Insert the second object into the output for ind in range(obj2.get_data().size): output_entity.data.insert(obj2.get_data().nodeat(ind).value, output_entity.data.nodeat(int(insert_loc + ind))) output_entity.delimiters.insert(obj2.get_delimiters().nodeat(ind).value, output_entity.delimiters.nodeat(int( insert_loc + ind))) return output_entity
[docs]class FixedInsertTextMerge(TextMerge): def __init__(self, location: int): self.loc = location
[docs] def do(self, obj1: TextEntity, obj2: TextEntity, random_state_obj: RandomState): if not isinstance(obj1, TextEntity) or not isinstance(obj2, TextEntity): raise ValueError("The inputs to FixedInsertTextMerge must be two TextEntity objects!") # Check that the location is within the size of the first object if obj1.get_data().size < self.loc: raise IndexError("Location is not within the object") # Insert at that location output_entity = GenericTextEntity(obj1.get_text()) for ind in range(obj2.get_data().size): output_entity.data.insert(obj2.get_data().nodeat(ind).value, output_entity.data.nodeat(int(self.loc + ind))) output_entity.delimiters.insert(obj2.get_delimiters().nodeat(ind).value, output_entity.delimiters.nodeat(int(self.loc + ind))) return output_entity