Source code for trojai.datagen.datatype_xforms
import logging
import numpy as np
from numpy.random import RandomState
from .image_entity import GenericImageEntity, ImageEntity
from .transform_interface import ImageTransform
logger = logging.getLogger(__name__)
"""
Defines data type transformations that may need to occur when processing different data sources
"""
[docs]class ToTensorXForm(ImageTransform):
"""
Transformation which defines the conversion of an input array to a tensor of a specified # of dimensions
"""
def __init__(self, num_dims: int = 3) -> None:
"""
Create the transformer object
:param num_dims: the number of dimensions to convert the input into
"""
self.num_dims = num_dims
[docs] def do(self, input_obj: ImageEntity, random_state_obj: RandomState) -> ImageEntity:
"""
Perform the actual to->tensor conversion
:param input_obj: the input Entity to be transformed
:param random_state_obj: ignored
:return: the transformed Entity
"""
img = input_obj.get_data()
old_shape = img.shape
num_img_dims = len(img.shape)
if num_img_dims >= self.num_dims:
return input_obj
else:
num_dims_to_add = self.num_dims-num_img_dims
for ii in range(num_dims_to_add):
img = np.expand_dims(img, axis=len(img.shape))
new_shape = img.shape
logger.debug("Converted input entity from shape=%s to %s" % (str(old_shape), str(new_shape)))
# make a new Entity object and return
return GenericImageEntity(img, input_obj.get_mask())