# Unbalanced data and weighted cross entropy

1 view

I'm trying to train a network with unbalanced data. I have A (198 samples), B (436 samples), C (710 samples), D (272 samples) and I have read about the "weighted_cross_entropy_with_logits" but all the examples I found are for binary classification so I'm not very confident in how to set those weights.

Total samples: 1616

A_weight: 198/1616 = 0.12?

The idea behind, if I understood, is to penalize the errors of the majority class and value more positively the hits in the minority one, right?

My piece of code:

weights = tf.constant([0.12, 0.26, 0.43, 0.17])

cost = tf.reduce_mean(tf.nn.weighted_cross_entropy_with_logits(logits=pred, targets=y, pos_weight=weights))

I have read this one and other examples with binary classification but still not very clear.

by (33.2k points)

You should understand that weighted_cross_entropy_with_logits is the weighted variant of sigmoid_cross_entropy_with_logits. Sigmoid cross entropy is commonly used for binary classification. It can compute multiple labels, but sigmoid cross entropy gives a binary output on each of them

For example, a face recognition net having input label like "Does the subject wear glasses?", "Is the subject female?" and many others.

In the binary classification problems, each output node gives a binary decision. The weighting needs to happen within the computation of the loss. So weighted_cross_entropy_with_logits computes by weighting one term of the cross-entropy over the other.

In mutually exclusive multilabel classification, we use softmax_cross_entropy_with_logits, which does compute differently: each output channel depends on the score of a class candidate. The decision comes by comparing the respective outputs of each channel.

If the weights come in before the final decision, then a simple process of modifying the scores before comparing them starts simply by multiplication with weights. For example, a ternary classification task,

#class weights

class_weights = tf.constant([[1.0, 2.0, 3.0]])

# deduce weights for batch samples based on their true label

weights = tf.reduce_sum(class_weights * onehot_labels, axis=1)

# compute your (unweighted) softmax cross entropy loss

unweighted_losses = tf.nn.softmax_cross_entropy_with_logits(onehot_labels, logits)

# apply the weights, relying on broadcasting of the multiplication

weighted_losses = unweighted_losses * weights

# reduce the result to get your final loss

loss = tf.reduce_mean(weighted_losses)

You can also use tf.losses.softmax_cross_entropy to handle the last three steps of the above code.

You need to remove data imbalance, the class weights are inversely proportional to their frequency in your train data. We should be normalizing them so that they sum up to one or to the number of classes also makes sense.

We had penalized the loss based on the true label of the samples. We can also penalize the loss based on the estimated labels by simply adding:

weights = class_weights

You may want weights that depend on the output. In other words, for each pair of labels X and Y, you could choose how to penalize choosing label X when the true label is Y.