import collections.abc
import json
import logging
from typing import Union, Sequence
import csv
import numpy as np
logger = logging.getLogger(__name__)
"""
Contains classes necessary for collecting statistics on the model during training
"""
[docs]class BatchStatistics:
"""
Represents the statistics collected from training a batch
NOTE: this is currently unused!
"""
def __init__(self, batch_num: int,
batch_train_accuracy: float,
batch_train_loss: float):
"""
:param batch_num: (int) batch number of collected statistics
:param batch_train_accuracy: (float) training set accuracy for this batch
:param batch_train_loss: (float) training loss for this batch
"""
self.batch_num = batch_num
self.batch_train_accuracy = batch_train_accuracy
self.batch_train_loss = batch_train_loss
[docs] def get_batch_num(self):
return self.batch_num
[docs] def get_batch_train_acc(self):
return self.batch_train_accuracy
[docs] def get_batch_train_loss(self):
return self.batch_train_loss
[docs] def set_batch_train_acc(self, acc):
if 0 <= acc <= 100:
self.batch_train_accuracy = acc
else:
msg = "Batch training accuracy should be between 0 and 100!"
logger.error(msg)
raise ValueError(msg)
[docs] def set_batch_train_loss(self, loss):
self.batch_train_loss = loss
[docs]class EpochTrainStatistics:
"""
Defines the training statistics for one epoch of training
"""
def __init__(self, train_acc: float, train_loss: float):
self.train_acc = train_acc
self.train_loss = train_loss
self.validate()
[docs] def validate(self):
if not isinstance(self.train_acc, float):
msg = "train_acc must be a float, got type {}".format(type(self.train_acc))
logger.error(msg)
raise ValueError(msg)
if not isinstance(self.train_loss, float):
msg = "train_loss must be a float, got type {}".format(type(self.train_loss))
logger.error(msg)
raise ValueError(msg)
[docs] def get_train_acc(self):
return self.train_acc
[docs] def get_train_loss(self):
return self.train_loss
[docs]class EpochValidationStatistics:
"""
Defines the validation statistics for one epoch of training
"""
def __init__(self, val_clean_acc, val_clean_loss, val_triggered_acc, val_triggered_loss):
self.val_clean_acc = val_clean_acc
self.val_clean_loss = val_clean_loss
self.val_triggered_acc = val_triggered_acc
self.val_triggered_loss = val_triggered_loss
self.validate()
[docs] def validate(self):
if self.val_clean_acc is not None and not isinstance(self.val_clean_acc, float):
msg = "val_clean_acc must be a float, got type {}".format(type(self.val_clean_acc))
logger.error(msg)
raise ValueError(msg)
if self.val_clean_loss is not None and not isinstance(self.val_clean_loss, float):
msg = "val_clean_loss must be a float, got type {}".format(type(self.val_clean_loss))
logger.error(msg)
raise ValueError(msg)
if self.val_triggered_acc is not None and not isinstance(self.val_triggered_acc, float):
msg = "val_triggered_acc must be a float, got type {}".format(type(self.val_triggered_acc))
logger.error(msg)
raise ValueError(msg)
if self.val_triggered_loss is not None and not isinstance(self.val_triggered_loss, float):
msg = "val_triggered_loss must be a float, got type {}".format(type(self.val_triggered_loss))
logger.error(msg)
raise ValueError(msg)
[docs] def get_val_clean_acc(self):
return self.val_clean_acc
[docs] def get_val_clean_loss(self):
return self.val_clean_loss
[docs] def get_val_triggered_acc(self):
return self.val_triggered_acc
[docs] def get_val_triggered_loss(self):
return self.val_triggered_loss
[docs] def get_val_loss(self):
if self.get_val_triggered_loss() is not None and self.get_val_clean_loss() is not None:
return self.get_val_triggered_loss() + self.get_val_clean_loss()
elif self.get_val_triggered_loss() is None and self.get_val_clean_loss() is not None:
return self.get_val_clean_loss()
elif self.get_val_triggered_loss() is not None and self.get_val_clean_loss() is None:
return self.get_val_triggered_loss()
else:
return None
[docs] def get_val_acc(self):
if self.get_val_triggered_acc() is not None and self.get_val_clean_acc() is not None:
return (self.get_val_triggered_acc() + self.get_val_clean_acc())/2.
elif self.get_val_triggered_acc() is None and self.get_val_clean_acc() is not None:
return self.get_val_clean_acc()
elif self.get_val_triggered_acc() is not None and self.get_val_clean_acc() is None:
return self.get_val_triggered_acc()
else:
return None
def __repr__(self):
val_loss = self.get_val_loss()
val_acc = self.get_val_acc()
val_loss = val_loss if val_loss is not None else -999
val_acc = val_acc if val_acc is not None else -999
return '(%0.04f, %0.04f)' % (val_loss, val_acc)
[docs]class EpochStatistics:
"""
Contains the statistics computed for an Epoch
"""
def __init__(self, epoch_num, training_stats=None, validation_stats=None, batch_training_stats=None):
self.epoch_num = epoch_num
if not batch_training_stats:
self.batch_training_stats = []
self.epoch_training_stats = training_stats
self.epoch_validation_stats = validation_stats
self.validate()
[docs] def add_batch(self, batches: Union[BatchStatistics, Sequence[BatchStatistics]]):
if isinstance(batches, collections.abc.Sequence):
self.batch_training_stats.extend(batches)
else:
self.batch_training_stats.append(batches)
[docs] def get_batch_stats(self):
return self.batch_training_stats
[docs] def validate(self):
if not isinstance(self.batch_training_stats, collections.abc.Sequence):
msg = "batch_training_stats must be None or a list of BatchTrainingStats objects! " \
"Got {}".format(self.batch_training_stats)
logger.error(msg)
raise ValueError(msg)
if self.epoch_training_stats and not isinstance(self.epoch_training_stats, EpochTrainStatistics):
msg = "training_stats must be None or of type: EpochTrainStatistics!, got type " \
"{}".format(type(self.epoch_training_stats))
logger.error(msg)
raise ValueError(msg)
if self.epoch_validation_stats and not isinstance(self.epoch_validation_stats, EpochValidationStatistics):
msg = "validation_stats must be None or of type: EpochValidationStatistics! Instead got type " \
"{}".format(type(self.epoch_validation_stats))
logger.error(msg)
raise ValueError(msg)
[docs] def get_epoch_num(self):
return self.epoch_num
[docs] def get_epoch_training_stats(self):
return self.epoch_training_stats
[docs] def get_epoch_validation_stats(self):
return self.epoch_validation_stats
[docs]class TrainingRunStatistics:
"""
Contains the statistics computed for an entire training run, a sequence of epochs
TODO:
[ ] - have another function which returns detailed statistics per epoch in an easily serialized manner
"""
def __init__(self):
self.stats_per_epoch_list = []
self.num_epochs_trained_per_optimizer = []
self.final_train_acc = 0.
self.final_train_loss = 0.
self.final_combined_val_acc = 0.
self.final_combined_val_loss = 0.
self.final_clean_val_acc = 0.
self.final_clean_val_loss = 0.
self.final_triggered_val_acc = 0.
self.final_triggered_val_loss = 0.
self.final_clean_data_test_acc = 0.
self.final_clean_data_n_total = 0
self.final_triggered_data_test_acc = None
self.final_triggered_data_n_total = None
self.final_clean_data_triggered_labels_test_acc = None
self.final_clean_data_triggered_labels_n_total = None
self.final_optimizer_num_epochs_trained = 0
self.final_optimizer_best_epoch_val = -1
[docs] def add_epoch(self, epoch_stats: Union[EpochStatistics, Sequence[EpochStatistics]]):
if isinstance(epoch_stats, collections.abc.Sequence):
self.stats_per_epoch_list.extend(epoch_stats)
else:
self.stats_per_epoch_list.append(epoch_stats)
[docs] def add_num_epochs_trained(self, num_epochs):
self.num_epochs_trained_per_optimizer.append(num_epochs)
[docs] def add_best_epoch_val(self, best_epoch):
self.final_optimizer_best_epoch_val = best_epoch
[docs] def get_epochs_stats(self):
return self.stats_per_epoch_list
[docs] def autopopulate_final_summary_stats(self):
"""
Uses the information from the final epoch's final batch to auto-populate the following statistics:
final_train_acc
final_train_loss
final_val_acc
final_val_loss
"""
final_epoch_training_stats = self.stats_per_epoch_list[self.final_optimizer_best_epoch_val]
self.set_final_train_acc(final_epoch_training_stats.get_epoch_training_stats().get_train_acc())
self.set_final_train_loss(final_epoch_training_stats.get_epoch_training_stats().get_train_loss())
if final_epoch_training_stats.get_epoch_validation_stats():
self.set_final_val_combined_acc(final_epoch_training_stats.get_epoch_validation_stats().get_val_acc())
self.set_final_val_combined_loss(final_epoch_training_stats.get_epoch_validation_stats().get_val_loss())
self.set_final_val_clean_acc(final_epoch_training_stats.get_epoch_validation_stats().get_val_clean_acc())
self.set_final_val_clean_loss(final_epoch_training_stats.get_epoch_validation_stats().get_val_clean_loss())
self.set_final_val_triggered_acc(final_epoch_training_stats.get_epoch_validation_stats().get_val_triggered_acc())
self.set_final_val_triggered_loss(final_epoch_training_stats.get_epoch_validation_stats().get_val_triggered_loss())
self.final_optimizer_num_epochs_trained = self.num_epochs_trained_per_optimizer[-1]
[docs] def set_final_train_acc(self, acc):
if 0 <= acc <= 100:
self.final_train_acc = acc
else:
msg = "Final Training accuracy should be between 0 and 100!"
logger.error(msg)
raise ValueError(msg)
[docs] def set_final_train_loss(self, loss):
self.final_train_loss = loss
[docs] def set_final_val_combined_acc(self, acc):
if acc is None or 0 <= acc <= 100: # allow for None in case validation metrics are not computed
self.final_combined_val_acc = acc
else:
msg = "Final validation accuracy should be between 0 and 100!"
logger.error(msg)
raise ValueError(msg)
[docs] def set_final_val_combined_loss(self, loss):
self.final_combined_val_loss = loss
[docs] def set_final_val_clean_acc(self, acc):
self.final_clean_val_acc = acc
[docs] def set_final_val_triggered_acc(self, acc):
self.final_triggered_val_acc = acc
[docs] def set_final_val_clean_loss(self, loss):
self.final_clean_val_loss = loss
[docs] def set_final_val_triggered_loss(self, loss):
self.final_triggered_val_loss = loss
[docs] def set_final_clean_data_test_acc(self, acc):
if 0 <= acc <= 100:
self.final_clean_data_test_acc = acc
else:
msg = "Final clean data test accuracy should be between 0 and 100!"
logger.error(msg)
raise ValueError(msg)
[docs] def set_final_triggered_data_test_acc(self, acc):
# we allow None in order to indicate that triggered data wasn't present in this dataset
if acc is None or 0 <= acc <= 100:
self.final_triggered_data_test_acc = acc
else:
msg = "Final triggered data test accuracy should be between 0 and 100!"
logger.error(msg)
raise ValueError(msg)
[docs] def set_final_clean_data_triggered_label_test_acc(self, acc):
if acc is None or 0 <= acc <= 100:
self.final_clean_data_triggered_labels_test_acc = acc
else:
msg = "Final clean data test accuracy should be between 0 and 100!"
logger.error(msg)
raise ValueError(msg)
[docs] def set_final_clean_data_n_total(self, n):
self.final_clean_data_n_total = n
[docs] def set_final_triggered_data_n_total(self, n):
self.final_triggered_data_n_total = n
[docs] def set_final_clean_data_triggered_label_n(self, n):
self.final_clean_data_triggered_labels_n_total = n
[docs] def get_summary(self):
"""
Returns a dictionary of the summary statistics from the training run
"""
summary_dict = dict()
summary_dict['final_train_acc'] = self.final_train_acc
summary_dict['final_train_loss'] = self.final_train_loss
summary_dict['final_combined_val_acc'] = self.final_combined_val_acc
summary_dict['final_combined_val_loss'] = self.final_combined_val_loss
summary_dict['final_clean_val_acc'] = self.final_clean_val_acc
summary_dict['final_clean_val_loss'] = self.final_clean_val_loss
summary_dict['final_triggered_val_acc'] = self.final_triggered_val_acc
summary_dict['final_triggered_val_loss'] = self.final_triggered_val_loss
summary_dict['final_clean_data_test_acc'] = self.final_clean_data_test_acc
summary_dict['final_triggered_data_test_acc'] = self.final_triggered_data_test_acc
summary_dict['final_clean_data_n_total'] = self.final_clean_data_n_total
summary_dict['final_triggered_data_n_total'] = self.final_triggered_data_n_total
summary_dict['clean_test_triggered_label_accuracy'] = self.final_clean_data_triggered_labels_test_acc
summary_dict['clean_test_triggered_label_n_total'] = self.final_clean_data_triggered_labels_n_total
summary_dict['final_optimizer_num_epochs_trained'] = self.num_epochs_trained_per_optimizer
return summary_dict
[docs] def save_summary_to_json(self, json_fname: str) -> None:
"""
Saves the training summary to a JSON file
"""
summary_dict = self.get_summary()
# write it to json
with open(json_fname, 'w') as fp:
json.dump(summary_dict, fp)
logger.info("Wrote summary statistics: %s to %s" % (str(summary_dict), json_fname))
[docs] def save_detailed_stats_to_disk(self, fname: str) -> None:
"""
Saves all batch statistics for every epoch as a CSV file
:param fname: filename to save the detailed information to
:return: None
"""
keys = ['epoch_number', 'train_acc', 'train_loss', 'combined_val_acc', 'combined_val_loss',
'clean_val_acc', 'clean_val_loss', 'triggered_val_acc', 'triggered_val_loss']
with open(fname, 'w') as output_file:
# write header first
dict_writer = csv.DictWriter(output_file, keys)
dict_writer.writeheader()
for ii, e in enumerate(self.stats_per_epoch_list):
# TODO: we ignore batch_statistics for now, we may want to add this in in the future
epoch_training_stats = e.get_epoch_training_stats()
epoch_val_stats = e.get_epoch_validation_stats()
combined_val_acc = None
combined_val_loss = None
clean_val_acc = None
clean_val_loss = None
triggered_val_acc = None
triggered_val_loss = None
if epoch_val_stats is not None:
combined_val_acc = epoch_val_stats.get_val_acc()
combined_val_loss = epoch_val_stats.get_val_loss()
clean_val_acc = epoch_val_stats.get_val_clean_acc()
clean_val_loss = epoch_val_stats.get_val_clean_loss()
triggered_val_acc = epoch_val_stats.get_val_triggered_acc()
triggered_val_loss = epoch_val_stats.get_val_triggered_loss()
dict_writer.writerow(dict(epoch_number=e.get_epoch_num(),
train_acc=epoch_training_stats.get_train_acc(),
train_loss=epoch_training_stats.get_train_loss(),
combined_val_acc=combined_val_acc,
combined_val_loss=combined_val_loss,
clean_val_acc=clean_val_acc,
clean_val_loss=clean_val_loss,
triggered_val_acc=triggered_val_acc,
triggered_val_loss=triggered_val_loss))
logger.info("Wrote detailed statistics to %s" % (fname,))