You should understand that weighted_cross_entropy_with_logits is the weighted variant of sigmoid_cross_entropy_with_logits. Sigmoid cross entropy is commonly used for binary classification. It can compute multiple labels, but sigmoid cross entropy gives a binary output on each of them
For example, a face recognition net having input label like "Does the subject wear glasses?", "Is the subject female?" and many others.
In the binary classification problems, each output node gives a binary decision. The weighting needs to happen within the computation of the loss. So weighted_cross_entropy_with_logits computes by weighting one term of the cross-entropy over the other.
In mutually exclusive multilabel classification, we use softmax_cross_entropy_with_logits, which does compute differently: each output channel depends on the score of a class candidate. The decision comes by comparing the respective outputs of each channel.
If the weights come in before the final decision, then a simple process of modifying the scores before comparing them starts simply by multiplication with weights. For example, a ternary classification task,
class_weights = tf.constant([[1.0, 2.0, 3.0]])
# deduce weights for batch samples based on their true label
weights = tf.reduce_sum(class_weights * onehot_labels, axis=1)
# compute your (unweighted) softmax cross entropy loss
unweighted_losses = tf.nn.softmax_cross_entropy_with_logits(onehot_labels, logits)
# apply the weights, relying on broadcasting of the multiplication
weighted_losses = unweighted_losses * weights
# reduce the result to get your final loss
loss = tf.reduce_mean(weighted_losses)
You can also use tf.losses.softmax_cross_entropy to handle the last three steps of the above code.
You need to remove data imbalance, the class weights are inversely proportional to their frequency in your train data. We should be normalizing them so that they sum up to one or to the number of classes also makes sense.
We had penalized the loss based on the true label of the samples. We can also penalize the loss based on the estimated labels by simply adding:
weights = class_weights
You may want weights that depend on the output. In other words, for each pair of labels X and Y, you could choose how to penalize choosing label X when the true label is Y.
Hope this answer helps.