Source code for trojai.datagen.image_affine_xforms

import logging
from typing import Sequence, Dict

import skimage.transform
import numpy as np
from numpy.random import RandomState
import cv2

from .image_entity import GenericImageEntity, ImageEntity
from .transform_interface import ImageTransform

logger = logging.getLogger(__name__)

"""
Module defines several affine transforms using various libraries to perform  the actual transformation operation 
specified.
"""

[docs]class UniformScaleXForm(ImageTransform): """Implements a uniform scale of a specified amount to an Entity """ def __init__(self, scale_factor: float = 1, kwargs: dict = None) -> None: """ Create a scaler object :param scale_factor: the scaling amount :param kwargs: any keyword arguments to pass to skimge.transform.rescale """ self.scale_factor = scale_factor if kwargs is None: self.kwargs = {} else: self.kwargs = kwargs
[docs] def do(self, input_obj: ImageEntity, random_state_obj: RandomState) -> ImageEntity: """ Performs the scaling on an input Entity using skimage.transform.rescale :param input_obj: the input object to be scaled :param random_state_obj: ignored :return: the transformed Entity """ img = input_obj.get_data() mask = input_obj.get_mask() logger.debug("Applying %0.02f scaling of image" % (self.scale_factor,)) img_rescaled = skimage.transform.rescale(img, self.scale_factor, **self.kwargs) logger.debug("Applying %0.02f scaling of mask" % (self.scale_factor,)) mask_rescaled = skimage.transform.rescale(mask, self.scale_factor, **self.kwargs) return GenericImageEntity(img_rescaled, mask_rescaled)
valid_predefined_xform_strs = [ # NOTE: these have a large effect 'east', 'north-west', 'shrink-1', 'shrink-2', # NOTE: these have a medium effect 'left-tilt-forward', 'right-tilt-forward', 'west', # NOTE: these have a small effect 'forward-distortion-1', 'forward-distortion-2', 'forward-distortion-3', 'forward-distortion-4', 'forward-distortion-5', 'forward-distortion-6', 'forward-distortion-7', 'forward-distortion-8', 'forward-distortion-9', 'forward-distortion-10', 'forward-distortion-11', # NOTE: these have no effect 'forward', ]
[docs]def get_predefined_perspective_xform_matrix(xform_str: str, rows: int, cols: int) -> np.ndarray: """ Returns an affine transform matrix for a string specification of a perspective transformation :param xform_str: a string specification of the perspective to transform the object into. :param rows: the number of rows of the image to be transformed to the specified perspective :param cols: the number of cols of the image to be transformed to the specified perspective :return: a numpy array of shape (2,3) which specifies the affine transformation. See:https://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html?highlight=getaffinetransform for more information """ xform_str_lower = xform_str.lower() if xform_str_lower not in valid_predefined_xform_strs: raise ValueError("Unknown perspective transformation string!") if xform_str_lower == 'forward': return np.asarray([[1, 0, 0], [0, 1, 0]], dtype=np.float32) elif xform_str_lower == 'east': pts1 = np.float32([[cols / 10, rows / 10], [cols / 2, rows / 10], [cols / 10, rows / 2]]) pts2 = np.float32([[cols / 5, rows / 5], [cols / 2, rows / 8], [cols / 5, rows / 1.8]]) elif xform_str_lower == 'north-west': pts1 = np.float32([[cols * 9 / 10, rows / 10], [cols / 2, rows / 10], [cols * 9 / 10, rows / 2]]) pts2 = np.float32([[cols * 4.5 / 5, rows / 5], [cols / 2, rows / 8], [cols * 4.5 / 5, rows / 1.8]]) elif xform_str_lower == 'left-tilt-forward': pts1 = np.float32([[cols / 10, rows / 10], [cols / 2, rows / 10], [cols / 10, rows / 2]]) pts2 = np.float32([[cols / 12, rows / 6], [cols / 2.1, rows / 8], [cols / 10, rows / 1.8]]) elif xform_str_lower == 'right-tilt-forward': pts1 = np.float32([[cols * 9 / 10, rows / 10], [cols / 2, rows / 10], [cols * 9 / 10, rows / 2]]) pts2 = np.float32([[cols * 10 / 12, rows / 6], [cols / 2.2, rows / 8], [cols * 8.4 / 10, rows / 1.8]]) elif xform_str_lower == 'west': pts1 = np.float32([[cols / 10, rows / 10], [cols / 2, rows / 10], [cols * 9 / 10, rows / 2]]) pts2 = np.float32([[cols / 9.95, rows / 10], [cols / 2.05, rows / 9.95], [cols * 9 / 10, rows / 2.05]]) elif xform_str_lower == 'forward-distortion-1': pts1 = np.float32([[cols / 10, rows / 10], [cols / 2, rows / 10], [cols * 9 / 10, rows / 2]]) pts2 = np.float32([[cols / 9.8, rows / 9.8], [cols / 2, rows / 9.8], [cols * 8.8 / 10, rows / 2.05]]) elif xform_str_lower == 'forward-distortion-2': pts1 = np.float32([[cols / 10, rows / 10], [cols / 2, rows / 10], [cols * 9 / 10, rows / 2]]) pts2 = np.float32([[cols / 11, rows / 10], [cols / 2.1, rows / 10], [cols * 8.5 / 10, rows / 1.95]]) elif xform_str_lower == 'forward-distortion-3': pts1 = np.float32([[cols / 10, rows / 10], [cols / 2, rows / 10], [cols * 9 / 10, rows / 2]]) pts2 = np.float32([[cols / 11, rows / 11], [cols / 2.1, rows / 10], [cols * 10 / 11, rows / 1.95]]) elif xform_str_lower == 'forward-distortion-4': pts1 = np.float32([[cols * 9.5 / 10, rows / 10], [cols / 2, rows / 10], [cols * 9 / 10, rows / 2]]) pts2 = np.float32([[cols * 9.35 / 10, rows / 9.99], [cols / 2.05, rows / 9.95], [cols * 9.05 / 10, rows / 2.03]]) elif xform_str_lower == 'forward-distortion-5': pts1 = np.float32([[cols * 9.5 / 10, rows / 10], [cols / 2, rows / 10], [cols * 9 / 10, rows / 2]]) pts2 = np.float32([[cols * 9.65 / 10, rows / 9.95], [cols / 1.95, rows / 9.95], [cols * 9.1 / 10, rows / 2.02]]) elif xform_str_lower == 'forward-distortion-6': pts1 = np.float32([[cols * 9.25 / 10, rows / 10], [cols / 2, rows / 10], [cols * 9 / 10, rows / 2]]) pts2 = np.float32([[cols * 9.55 / 10, rows / 9.85], [cols / 1.9, rows / 10], [cols * 9.3 / 10, rows / 2.04]]) elif xform_str_lower == 'forward-distortion-7': pts1 = np.float32([[cols * 9 / 10, rows / 10], [cols / 2, rows / 10], [cols * 9 / 10, rows / 2]]) pts2 = np.float32([[cols * 8.85 / 10, rows / 9.3], [cols / 1.9, rows / 10.5], [cols * 8.8 / 10, rows / 2.11]]) elif xform_str_lower == 'forward-distortion-8': pts1 = np.float32([[cols * 9 / 10, rows / 10], [cols / 2, rows / 10], [cols * 9 / 10, rows / 2]]) pts2 = np.float32([[cols * 8.75 / 10, rows / 9.1], [cols / 1.95, rows / 8], [cols * 8.5 / 10, rows / 2.05]]) elif xform_str_lower == 'forward-distortion-9': pts1 = np.float32([[cols * 9 / 10, rows / 10], [cols / 2, rows / 10], [cols * 9 / 10, rows / 2]]) pts2 = np.float32([[cols * 8.75 / 10, rows / 9.1], [cols / 1.95, rows / 9], [cols * 8.5 / 10, rows / 2.2]]) elif xform_str_lower == 'forward-distortion-10': pts1 = np.float32([[cols * 9 / 10, rows / 10], [cols / 2, rows / 10], [cols * 9 / 10, rows / 2]]) pts2 = np.float32([[cols * 8.75 / 10, rows / 8], [cols / 1.95, rows / 8], [cols * 8.75 / 10, rows / 2]]) elif xform_str_lower == 'forward-distortion-11': pts1 = np.float32([[cols * 9 / 10, rows / 10], [cols / 2, rows / 10], [cols * 9 / 10, rows / 2]]) pts2 = np.float32([[cols * 8.8 / 10, rows / 7], [cols / 1.95, rows / 7], [cols * 8.8 / 10, rows / 2]]) elif xform_str_lower == 'shrink-1': pts1 = np.float32([[cols * 9 / 10, rows / 10], [cols / 2, rows / 10], [cols * 9 / 10, rows / 2]]) pts2 = np.float32([[cols * 8 / 10, rows / 10], [cols * 1.34 / 3, rows / 10.5], [cols * 8.24 / 10, rows / 2.5]]) elif xform_str_lower == 'shrink-2': pts1 = np.float32([[cols * 9 / 10, rows / 10], [cols / 2, rows / 10], [cols * 9 / 10, rows / 2]]) pts2 = np.float32([[cols * 8.5 / 10, rows * 3.1 / 10], [cols / 2, rows * 3 / 10], [cols * 8.44 / 10, rows * 1.55 / 2.5]]) else: raise ValueError("Unknown perspective transformation string!") return cv2.getAffineTransform(pts1, pts2)
[docs]class PerspectiveXForm(ImageTransform): """Shifts the perspective of an input Entity """ def __init__(self, xform_matrix) -> None: """ Creates a Perspective shifter object :param xform_matrix: can be either a string specification of a perspective shift, where valid strings are defined in the list: affine_xforms.valid_predefined_xform_strs, or it can be a matrix of shape (2,3). """ # input validation if isinstance(xform_matrix, str): self.xform_M = xform_matrix elif isinstance(xform_matrix, np.ndarray) and xform_matrix.shape == (2, 3): self.xform_M = xform_matrix else: raise ValueError("Unknown M input, must be either an allowed string or a matrix of shape (2,3)!")
[docs] def do(self, input_obj: ImageEntity, random_state_obj: RandomState) -> ImageEntity: """ Performs the perspective shift on the input Entity. :param input_obj: the Entity to be transformed according to the specified perspective shift in the constructor. :param random_state_obj: ignored :return: the transformed Entity """ img = input_obj.get_data() mask = input_obj.get_mask() i_rows, i_cols, i_chans = img.shape if isinstance(self.xform_M, str): xform_matrix = get_predefined_perspective_xform_matrix(self.xform_M, i_rows, i_cols) else: xform_matrix = self.xform_M logger.debug("Applying cv2.warpAffine to image with matrix:" + str(xform_matrix)) img_xform = cv2.warpAffine(img, xform_matrix, (i_cols, i_rows)) logger.debug("Applying cv2.warpAffine to mask with matrix:" + str(xform_matrix)) msk_xform = cv2.warpAffine(mask.astype(np.float32), xform_matrix, (i_cols, i_rows)).astype(bool) return GenericImageEntity(img_xform, msk_xform)
[docs]class RandomPerspectiveXForm(ImageTransform): """Randomly shifts perspective of input Entity in available perspectives. """ def __init__(self, perspectives: Sequence[str] = None) -> None: """ Creates a random perspective shifter Transform object, which uniformly samples the available perspectives in AffineXForms.valid_predefined_xform_strs # TODO: add support for non-uniform sampling of perspective transformations """ if perspectives is None: self.perspective_possibilities = valid_predefined_xform_strs else: for perspective in perspectives: if perspective not in valid_predefined_xform_strs: msg = perspective + " is not in the valid list of transforms" logger.error(msg) raise ValueError(msg) self.perspective_possibilities = perspectives
[docs] def do(self, input_obj: ImageEntity, random_state_obj: RandomState) -> ImageEntity: """ Samples from the possible perspectives according to the sampler specification and then applies that perspective to the input object :param input_obj: Entity to be randomly perspective shifted :param random_state_obj: allows for reprodcible sampling of random perspectives :return: the transformed Entity """ # pick a perspective transformation chosen_xform = random_state_obj.choice(self.perspective_possibilities) logger.debug("Sampled perspective %s from RandomState" % (chosen_xform,)) xformer = PerspectiveXForm(chosen_xform) return xformer.do(input_obj, random_state_obj)
[docs]class RotateXForm(ImageTransform): """Implements a rotation of an Entity by a specified angle amount. """ def __init__(self, angle: int = 90, args: tuple = (), kwargs: dict = None) -> None: """ Creates a Rotator Transform object :param angle: The degree amount to rotate (in degrees, not radians!) :param args: any additional arguments to pass to skimage.transform.rotate :param kwargs: any keyword arguments to pass to skimage.transform.rotate """ self.rotation_angle = angle self.args = args if kwargs is None: self.kwargs = {'preserve_range': True} else: if 'preserve_range' in kwargs and not kwargs['preserve_range']: msg = "preserve_range cannot be set to False!" logger.error(msg) raise ValueError(msg) self.kwargs = kwargs
[docs] def do(self, input_obj: ImageEntity, random_state_obj: RandomState) -> ImageEntity: """ Performs the rotation specified by the RotateXForm object on an input :param input_obj: The Entity to be rotated :param random_state_obj: ignored :return: the transformed Entity """ img = input_obj.get_data() mask = input_obj.get_mask() logger.debug("Applying %0.02f rotation to image via skimage.transform.rotate" % (self.rotation_angle,)) img_rotated = skimage.transform.rotate(img, self.rotation_angle, *self.args, **self.kwargs).astype(img.dtype) logger.debug("Applying %0.02f rotation to mask via skimage.transform.rotate" % (self.rotation_angle,)) mask_rotated = skimage.transform.rotate(mask, self.rotation_angle, *self.args, **self.kwargs) mask_rotated = np.logical_not(np.isclose(mask_rotated, np.zeros(mask.shape), atol=.0001)) return GenericImageEntity(img_rotated, mask_rotated)
[docs]class RandomRotateXForm(ImageTransform): """Implements a rotation of a random amount of degrees. """ def __init__(self, angle_choices: Sequence[float] = None, angle_sampler_prob: Sequence[float] = None, rotator_kwargs: Dict = None) -> None: """ Creates a random rotator Transform object :param angle_choices: An Sequence object of floats which represent the possible angles from which the sampler can choose from :param angle_sampler_kwargs: any keyword arguments to pass to the sampler """ if angle_choices is None: self.angle_choices = [0, 90, 180, 270] else: self.angle_choices = angle_choices if rotator_kwargs is None: self.rotator_kwargs = {'preserve_range': True} else: self.rotator_kwargs = rotator_kwargs self.angle_sampler_prob = angle_sampler_prob
[docs] def do(self, input_obj: ImageEntity, random_state_obj: RandomState) -> ImageEntity: """ Samples from the possible angles according to the sampler specification and then applies that rotation to the input object :param input_obj: Entity to be randomly rotated :param random_state_obj: a random state used to maintain reproducibility through transformations :return: the transformed Entity """ rotation_angle = random_state_obj.choice(self.angle_choices, p=self.angle_sampler_prob) logger.debug("Sampled %0.02f rotation from RandomState" % (rotation_angle,)) rotator = RotateXForm(rotation_angle, kwargs=self.rotator_kwargs) return rotator.do(input_obj, random_state_obj)