import logging
import os
from typing import Sequence
import collections.abc
import cv2
import numpy as np
import pandas as pd
from numpy.random import RandomState
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import math
import trojai.datagen.utils as utils
from .config import XFormMergePipelineConfig
from .constants import RANDOM_STATE_DRAW_LIMIT
from .entity import Entity
from .image_entity import GenericImageEntity
from .text_entity import GenericTextEntity
from .merge_interface import Merge
from .pipeline import Pipeline
from .transform_interface import Transform
logger = logging.getLogger(__name__)
"""
Defines all functions and classes related to the transform+merge pipeline & data movement paradigm.
"""
[docs]def subset_clean_df_by_labels(df, labels_to_include):
"""
Subsets a dataframe with an expected column 'label', to only keep rows which are in that list of labels to include
:param df: the dataframe to subset
:param labels_to_include: a list of labels to include, or a string 'all' indicating that everything should be kept
:return: the subsetted data frame
"""
if labels_to_include == 'all':
return df
else:
if isinstance(labels_to_include, collections.abc.Sequence):
df_subset_list = []
for c in labels_to_include:
df_subset_list.append(df[df['label'] == c])
return pd.concat(df_subset_list, ignore_index=True)
else:
msg = "the argument to subset the data that is modified must either be list of labels, or the string 'all'"
logger.error(msg)
raise ValueError(msg)
[docs]def modify_clean_image_dataset(clean_dataset_rootdir: str, clean_csv_file: str,
output_rootdir: str, output_subdir: str, mod_cfg: XFormMergePipelineConfig,
method: str = 'insert', random_state_obj: RandomState = RandomState(1234)) -> None:
"""
Modifies a clean dataset given a configuration
:param clean_dataset_rootdir: root directory where the clean data lives
:param clean_csv_file: filename of the CSV file which contains information about the clean data
The modification method determines which columns and information are expected
in the CSV file.
:param output_rootdir: root directory where the modified data will be stored
:param output_subdir: subdirectory where the modified data will be stored. This is expected to be one level
below the root-directory, and can prove useful if different types of modifications are
stored in different subdirectories under the main root directory. An example tree structure
might be:
root_data
- modification_1
... data ...
- modification_2
... data ...
:param mod_cfg: A configuration object for creating a modified dataset
:param method: Can be "insert" only/
In the insert method, the function takes the clean image, and inserts a specified Entity
(likely, a pattern) into the clean image. Additional modes to be added!
:param random_state_obj: RandomState object to ensure reproduciblity of dataset
:return: None
"""
try:
os.makedirs(os.path.join(output_rootdir, output_subdir))
except FileExistsError:
pass
# read in clean dataset
clean_df = pd.read_csv(os.path.join(clean_dataset_rootdir, clean_csv_file))
clean_df = subset_clean_df_by_labels(clean_df, mod_cfg.triggered_classes)
# identify which images will have triggers inserted into them
random_state = random_state_obj.get_state()
if mod_cfg.per_class_trigger_frac is not None:
try:
trigger_data, _ = train_test_split(clean_df,
train_size=mod_cfg.per_class_trigger_frac,
random_state=random_state_obj,
stratify=clean_df['label'])
except ValueError as e:
logger.exception(e)
raise ValueError(e)
else:
trigger_data = clean_df
# reset random state to be ensure reproduciblity regardless of # of splits
random_state_obj.set_state(random_state)
# generate the same # of triggers according to the configuration
num_triggers = len(trigger_data)
trigger_source_list = mod_cfg.trigger_list
# run the xform function for each image & trigger combination
for ii in tqdm(range(num_triggers), desc='Modifying Clean Dataset ...'):
# select the trigger
if trigger_source_list is not None and len(trigger_source_list) != 0:
trigger = random_state_obj.choice(trigger_source_list, p=mod_cfg.trigger_sampling_prob)
else:
trigger = None
img_random_state = RandomState(random_state_obj.randint(RANDOM_STATE_DRAW_LIMIT))
if method.lower() == 'insert':
fp = trigger_data.iloc[ii]['file']
try:
mask_fp = trigger_data.iloc[ii]['mask']
mask = np.load(mask_fp)
except KeyError:
mask = None
# load the background image
bg = GenericImageEntity(cv2.imread(os.path.join(clean_dataset_rootdir, fp), cv2.IMREAD_UNCHANGED), mask)
bg_xforms = mod_cfg.trigger_bg_xforms
fg = trigger
fg_xforms = mod_cfg.trigger_xforms
merge_obj = mod_cfg.trigger_bg_merge
postproc_xforms = mod_cfg.trigger_bg_merge_xforms
# process data through the pipeline
pipeline_obj = XFormMerge([[bg_xforms, fg_xforms]], [merge_obj], postproc_xforms)
modified_img = pipeline_obj.process([bg, fg], img_random_state)
logger.debug("Inserted trigger=%s into image=%s" % (str(fg), str(bg)))
elif method.lower() == 'regenerate':
# TODO: NOTE: this needs to be an absolute path!
# do a check to ensure the user provided absolute paths!
bg_fp = trigger_data.iloc[ii]['bg_file']
fg_fp = trigger_data.iloc[ii]['fg_file']
try:
bg_mask_fp = trigger_data.iloc[ii]['bg_mask']
bg_mask = np.load(bg_mask_fp)
except KeyError:
bg_mask = None
try:
fg_mask_fp = trigger_data.iloc[ii]['fg_mask']
fg_mask = np.load(fg_mask_fp)
except KeyError:
fg_mask = None
# load images into memory
obj1 = GenericImageEntity(cv2.imread(fg_fp, cv2.IMREAD_UNCHANGED), fg_mask)
obj2 = trigger
obj3 = GenericImageEntity(cv2.imread(bg_fp, cv2.IMREAD_UNCHANGED), bg_mask)
obj1_xforms = mod_cfg.trigger_bg_xforms
obj2_xforms = mod_cfg.trigger_xforms
obj12_merge = mod_cfg.trigger_bg_merge
obj12_xforms = mod_cfg.trigger_bg_merge_xforms
obj3_xforms = mod_cfg.overall_bg_xforms
obj123_merge = mod_cfg.overall_bg_triggerbg_merge
obj123_xforms = mod_cfg.overall_bg_triggerbg_xforms
if obj2 is None:
# obj3 is the background, obj1 is the sign (without a point trigger)
pipeline_obj = XFormMerge([[obj3_xforms, obj1_xforms]],
[obj123_merge], obj123_xforms)
modified_img = pipeline_obj.process([obj3, obj1], img_random_state)
logger.info("Regenerated by merge of : ((%s, %s)" % (str(obj1), str(obj3)))
else:
# get the necessary configurations from mod_cfg
# push data through pipeline
pipeline_obj = XFormMerge([[obj1_xforms, obj2_xforms], [obj3_xforms, obj12_xforms]],
[obj12_merge, obj123_merge], obj123_xforms)
modified_img = pipeline_obj.process([obj1, obj2, obj3], img_random_state)
logger.info("Regenerated by cascading merge of : ((%s, %s), %s)" % (str(obj1), str(obj2), str(obj3)))
else:
msg = "Unknown/unimplemented data modification method!"
logger.error(msg)
raise ValueError(msg)
output_fname = os.path.basename(trigger_data.iloc[ii]['file'])
output_filename_fullpath = os.path.join(output_rootdir, output_subdir, output_fname)
cv2.imwrite(output_filename_fullpath, modified_img.get_data())
[docs]def modify_clean_text_dataset(clean_dataset_rootdir: str, clean_csv_file: str,
output_rootdir: str, output_subdir: str, mod_cfg: XFormMergePipelineConfig,
method='insert', random_state_obj: RandomState = RandomState(1234)) -> None:
"""
Modifies a clean image dataset given a configuration
:param clean_dataset_rootdir: root directory where the clean data lives
:param clean_csv_file: filename of the CSV file which contains information about the clean data
The modification method determines which columns and information are expected
in the CSV file.
:param output_rootdir: root directory where the modified data will be stored
:param output_subdir: subdirectory where the modified data will be stored. This is expected to be one level
below the root-directory, and can prove useful if different types of modifications are
stored in different subdirectories under the main root directory. An example tree structure
might be:
root_data
- modification_1
... data ...
- modification_2
... data ...
:param mod_cfg: A configuration object for creating a modified dataset
:param method: Can only be "insert"
In the insert method, the function takes the clean text blurb, and inserts a specified TextEntity
(likely, a pattern) into the first text input object.
:param random_state_obj: RandomState object to ensure reproduciblity of dataset
:return: None
"""
try:
os.makedirs(os.path.join(output_rootdir, output_subdir))
except FileExistsError:
pass
# read in clean dataset
clean_df = pd.read_csv(os.path.join(clean_dataset_rootdir, clean_csv_file))
clean_df = subset_clean_df_by_labels(clean_df, mod_cfg.triggered_classes)
# identify which images will have triggers inserted into them
random_state = random_state_obj.get_state()
if mod_cfg.per_class_trigger_frac is not None:
trigger_data, _ = train_test_split(clean_df,
train_size=mod_cfg.per_class_trigger_frac,
random_state=random_state_obj,
stratify=clean_df['label'])
else:
trigger_data = clean_df
# reset random state to be ensure reproduciblity regardless of # of splits
random_state_obj.set_state(random_state)
# generate the same # of triggers according to the configuration
num_triggers = len(trigger_data)
trigger_source_list = mod_cfg.trigger_list
# run the xform function for each image & trigger combination
for ii in tqdm(range(num_triggers), desc='Modifying Clean Dataset ...'):
# select the trigger
if trigger_source_list is not None and len(trigger_source_list) != 0:
trigger = random_state_obj.choice(trigger_source_list, p=mod_cfg.trigger_sampling_prob)
else:
trigger = None
txt_random_state = RandomState(random_state_obj.randint(RANDOM_STATE_DRAW_LIMIT))
if method.lower() == 'insert':
# load the data
fp = trigger_data.iloc[ii]['file']
with open(fp, 'r') as fo:
bg = GenericTextEntity(fo.read().replace('\n', ''))
# setup trigger
fg = trigger
bg_xforms = mod_cfg.trigger_bg_xforms
fg_xforms = mod_cfg.trigger_xforms
merge_obj = mod_cfg.trigger_bg_merge
postproc_xforms = mod_cfg.trigger_bg_merge_xforms
# process data through the pipeline
pipeline_obj = XFormMerge([[bg_xforms, fg_xforms]], [merge_obj], postproc_xforms)
modified_text = pipeline_obj.process([bg, fg], txt_random_state)
logger.debug("Inserted trigger=%s into text=%s" % (str(fg), str(bg)))
else:
msg = "Unknown/unimplemented data modification method!"
logger.error(msg)
raise ValueError(msg)
output_fname = os.path.join(output_rootdir, output_subdir, os.path.basename(fp))
with open(output_fname, 'w+') as f:
f.write(modified_text.get_text())