Source code for trojai.datagen.image_insert_utils

from trojai.datagen.config import ValidInsertLocationsConfig

from typing import Sequence, Any, Tuple, Optional

import numpy as np
from scipy.ndimage import filters

import logging
logger = logging.getLogger(__name__)

# all possible directions to leave pixel along, for edge_tracing algorithm
DIRECTIONS = [(-1, -1), (-1, 1), (1, 1), (1, -1), (0, -1), (-1, 0), (0, 1), (1, 0)]


[docs]def pattern_fit(chan_img: np.ndarray, chan_pattern: np.ndarray, chan_location: Sequence[Any]) -> bool: """ Returns True if the pattern at the desired location can fit into the image channel without wrap, and False otherwise :param chan_img: a numpy.ndarray of shape (nrows, ncols) which represents an image channel :param chan_pattern: a numpy.ndarray of shape (prows, pcols) which represents a channel of the pattern :param chan_location: a Sequence of length 2, which contains the x/y coordinate of the top left corner of the pattern to be inserted for this specific channel :return: True/False depending on whether the pattern will fit into the image """ p_rows, p_cols = chan_pattern.shape r, c = chan_location i_rows, i_cols = chan_img.shape if (r + p_rows) > i_rows or (c + p_cols) > i_cols: return False return True
def _get_edge_length_in_direction(curr_i: int, curr_j: int, dir_i: int, dir_j: int, i_rows: int, i_cols: int, edge_pixels: set) -> int: """ find the maximum length of a move in the given direction along the perimeter of the image :param curr_i: current row index :param curr_j: current col index :param dir_i: direction of change in row index :param dir_j: direction of change in col index :param i_rows: number of rows of containing array :param i_cols number of cols of containing array :param edge_pixels: set of remaining edge pixels to visit :return: the length of the edge in the given direction, 0 if none exists, if direction is a diagonal length will always be <= 1 """ length = 0 while 0 <= curr_i + dir_i < i_rows and 0 <= curr_j + dir_j < i_cols and \ (curr_i + dir_i, curr_j + dir_j) in edge_pixels: # update seen edge pixels edge_pixels.remove((curr_i + dir_i, curr_j + dir_j)) length += 1 curr_i += dir_i curr_j += dir_j # only allow length 1 diagonal moves if dir_i != 0 and dir_j != 0: break return length def _get_next_edge_from_pixel(curr_i: int, curr_j: int, i_rows: int, i_cols: int, edge_pixels: set) -> Optional[Tuple[int, int]]: """ Obtain the next edge to trace along :param curr_i: current row index :param curr_j: current col index :param i_rows: number of rows of containing array :param i_cols: number of cols of containing array :param edge_pixels: set of remaining edge pixels to visit :return: a tuple of row distance, col distance if an undiscovered edge is found, otherwise None """ for dir_i, dir_j in DIRECTIONS: length = _get_edge_length_in_direction(curr_i, curr_j, dir_i, dir_j, i_rows, i_cols, edge_pixels) if length != 0: move_i, move_j = dir_i * length, dir_j * length return move_i, move_j return None def _get_bounding_box(coords: Sequence[int], img: np.ndarray) -> Optional[Tuple[int, int, int, int]]: """ Return the smallest possible rectangle containing all non-zero pixels in img, edges inclusive :param coords: sequence of image subset coordinates, top, left, bottom, right :param img: provided image :return a tuple of y1 (top), x1 (left), y2 (bottom), x2 (right) of bounding box of image, or a 4-tuple of zeros of no non-zero pixels in image """ top, left, bottom, right = coords img_subset = img[top:bottom, left:right] rows = np.logical_or.reduce(img_subset, axis=1) cols = np.logical_or.reduce(img_subset, axis=0) row_bounds = np.nonzero(rows) col_bounds = np.nonzero(cols) if row_bounds[0].size != 0 and col_bounds[0].size != 0: y1 = row_bounds[0][0] y2 = row_bounds[0][row_bounds[0].size - 1] x1 = col_bounds[0][0] x2 = col_bounds[0][col_bounds[0].size - 1] return top + y1, left + x1, top + y2 + 1, left + x2 + 1 else: return 0, 0, 0, 0
[docs]def valid_locations(img: np.ndarray, pattern: np.ndarray, algo_config: ValidInsertLocationsConfig, protect_wrap: bool = True) -> np.ndarray: """ Returns a list of locations per channel which the pattern can be inserted into the img_channel with an overlap algorithm dictated by the appropriate inputs :param img: a numpy.ndarray which represents the image of shape: (nrows, ncols, nchans) :param pattern: the pattern to be inserted into the image of shape: (prows, pcols, nchans) :param algo_config: The provided configuration object specifying the algorithm to use and necessary parameters :param protect_wrap: if True, ensures that pattern to be inserted can fit without wrapping and raises an Exception otherwise :return: A boolean mask of the same shape as the input image, with True indicating that that pixel is a valid location for placement of the specified pattern """ num_chans = img.shape[2] # broadcast allow_overlap variable if necessary allow_overlap = algo_config.allow_overlap if not isinstance(allow_overlap, Sequence): allow_overlap = [allow_overlap] * num_chans elif len(allow_overlap) != num_chans: msg = "Length of provided allow_overlap sequence does not equal the number of channels in the image!" logger.error(msg) raise ValueError(msg) # broadcast min_val variable if necessary min_val = algo_config.min_val if not isinstance(min_val, Sequence): min_val = [min_val] * num_chans elif len(min_val) != num_chans: msg = "Length of provided min_val sequence does not equal the number of channels in the image!" logger.error(msg) raise ValueError(msg) # broadcast threshold_val variable if necessary threshold_val = algo_config.threshold_val if algo_config.algorithm == 'threshold': if not isinstance(threshold_val, Sequence): threshold_val = [threshold_val] * num_chans elif len(threshold_val) != num_chans: msg = "Length of provided threshold_val sequence does not equal the number of channels in the image!" logger.error(msg) raise ValueError(msg) if pattern.shape[2] != num_chans: # force user to broadcast the pattern as necessary msg = "The # of channels in the pattern does not match the # of channels in the image!" logger.error(msg) raise ValueError(msg) # TODO: look for vectorization opportunities output_mask = np.zeros(img.shape, dtype=bool) for chan_idx in range(num_chans): chan_img = img[:, :, chan_idx] chan_pattern = pattern[:, :, chan_idx] i_rows, i_cols = chan_img.shape p_rows, p_cols = chan_pattern.shape if allow_overlap[chan_idx]: output_mask[0:i_rows - p_rows + 1, 0:i_cols - p_cols + 1, chan_idx] = True else: if protect_wrap: mask = (chan_img <= min_val[chan_idx]) # True if image present, False if not img_mask = np.logical_not(mask) # remove boundaries from valid locations mask[i_rows - p_rows + 1:i_rows, :] = False mask[:, i_cols - p_cols + 1:i_cols] = False # get all edge pixels edge_pixels = None if algo_config.algorithm != 'bounding_box': edge_pixel_coords = np.nonzero( np.logical_and( np.logical_xor( filters.maximum_filter(img_mask, 3, mode='constant', cval=0.0), filters.minimum_filter(img_mask, 3, mode='constant', cval=0.0)), img_mask)) edge_pixels = zip(edge_pixel_coords[0], edge_pixel_coords[1]) if algo_config.algorithm == 'edge_tracing': logger.debug("Computing valid locations according to edge_tracing algorithm") edge_pixel_set = set(edge_pixels) # search until all edges have been visited while len(edge_pixel_set) != 0: start_i, start_j = edge_pixel_set.pop() # invalidate relevant pixels for start square top_boundary = max(0, start_i - p_rows + 1) left_boundary = max(0, start_j - p_cols + 1) mask[top_boundary:start_i + 1, left_boundary: start_j + 1] = False curr_i, curr_j = start_i, start_j move = 0, 0 while move is not None: # what edge was last traversed action_i, action_j = move # current location curr_i += action_i curr_j += action_j # truncate search when near top or left boundary top_index = max(0, curr_i - p_rows + 1) left_index = max(0, curr_j - p_cols + 1) # update invalidation based on last move, marking a row or column invalid based on the size # of action_i or action_j # if action_i or action_j has absolute value greater than 0, the other must be 0, # i.e diagonal moves of length greater than 1 aren't updated correctly by this if action_i < 0: # update top border mask[top_index:top_index - action_i, left_index:curr_j + 1] = False elif action_i > 0: # update bottom border mask[curr_i - action_i + 1:curr_i + 1, left_index:curr_j + 1] = False if action_j < 0: # update left border mask[top_index:curr_i + 1, left_index:left_index - action_j] = False elif action_j > 0: # update right border mask[top_index:curr_i + 1, curr_j - action_j + 1:curr_j + 1] = False # obtain next pixel to inspect move = _get_next_edge_from_pixel(curr_i, curr_j, i_rows, i_cols, edge_pixel_set) elif algo_config.algorithm == 'brute_force': logger.debug("Computing valid locations according to brute_force algorithm") for i, j in edge_pixels: top_index, left_index = max(0, i - p_rows + 1), max(0, j - p_cols + 1) mask[top_index:i + 1, left_index:j + 1] = False elif algo_config.algorithm == 'threshold': logger.debug("Computing valid locations according to threshold algorithm") for i, j in edge_pixels: mask[max(0, i - p_rows + 1):i + 1, max(0, j - p_cols + 1):j + 1] = False # enumerate all possible invalid locations mask_coords = np.nonzero(np.logical_not(mask)) possible_locations = zip(mask_coords[0], mask_coords[1]) # if average pixel value in location is below specified value, allow possible trigger overlap for i, j in possible_locations: if i <= i_rows - p_rows and j <= i_cols - p_cols and \ np.mean(chan_img[i:i + p_rows, j:j + p_cols]) <= threshold_val[chan_idx]: mask[i][j] = True elif algo_config.algorithm == 'bounding_boxes': logger.debug("Computing valid locations according to bounding_boxes algorithm") # generate top-left and bottom-right corners of all grid squares top_left_coords = np.swapaxes(np.indices((algo_config.num_boxes, algo_config.num_boxes)), 0, 2) \ .reshape((algo_config.num_boxes * algo_config.num_boxes, 2)) bottom_right_coords = top_left_coords + 1 # rows give y1, x1, y2, x2 of grid boxes, y2 and x2 exclusive box_coords = np.concatenate((top_left_coords, bottom_right_coords), axis=1) box_coords = np.multiply(box_coords, np.array([i_rows, i_cols, i_rows, i_cols])) box_coords //= algo_config.num_boxes # generate bounding boxes for image in each grid square bounding_coords = np.apply_along_axis(_get_bounding_box, 1, box_coords, img_mask) # update mask, bounds -> top, left, bottom, right for bounds in bounding_coords: top_index = max(0, bounds[0] - p_rows + 1) left_index = max(0, bounds[1] - p_cols + 1) mask[top_index:bounds[2], left_index:bounds[3]] = False output_mask[:, :, chan_idx] = mask else: msg = "Wrapping for trigger insertion has not been implemented yet!" logger.error(msg) raise ValueError(msg) return output_mask