# Training on imbalanced data using TensorFlow

1 view

The Situation:

I am wondering how to use TensorFlow optimally when my training data is imbalanced in label distribution between 2 labels. For instance, suppose the MNIST tutorial is simplified to only distinguish between 1's and 0's, where all images available to us are either 1's or 0's. This is straightforward to train using the provided TensorFlow tutorials when we have roughly 50% of each type of image to train and test on. But what about the case where 90% of the images available in our data are 0's and only 10% are 1's? I observe that in this case, TensorFlow routinely predicts my entire test set to be 0's, achieving an accuracy of a meaningless 90%.

One strategy I have used to some success is to pick random batches for training that do have an even distribution of 0's and 1's. This approach ensures that I can still use all of my training data and produced decent results, with less than 90% accuracy, but a much more useful classifier. Since accuracy is somewhat useless to me in this case, my metric of choice is typically area under the ROC curve (AUROC), and this produces a result respectably higher than .50.

Questions:

(1) Is the strategy I have described an accepted or optimal way of training on imbalanced data, or is there one that might work better?

(2) Since the accuracy metric is not as useful in the case of imbalanced data, is there another metric that can be maximized by altering the cost function? I can certainly calculate AUROC post-training, but can I train in such a way as to maximize AUROC?

(3) Is there some other alteration I can make to my cost function to improve my results for imbalanced data? Currently, I am using a default suggestion given in TensorFlow tutorials:

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y))

I have heard this may be possible by up-weighting the cost of miscategorizing the smaller label class, but I am unsure of how to do this.

by (33.2k points)

A classification data set with skewed class proportions is called imbalanced. You can use the following methods to deal with imbalanced datasets.

• Upweighting positive samples: This method is used to increase the loss of misclassified positive samples when datasets have much fewer positive samples. This gives rewards to the ML algorithm to learn parameters that are better for positive samples. For binary classification, there is a simple API in tensorflow that achieves this.

https://www.tensorflow.org/api_docs/python/tf/nn/weighted_cross_entropy_with_logits

• Batch Sampling: This method is used for sampling the dataset so that each batch of training data will have an even distribution positive samples to negative samples. This can be done using the rejections sampling API provided by tensorflow.

https://www.tensorflow.org/api_docs/python/tf/contrib/training/rejection_sample