1 view

Classification problems, such as logistic regression or multinomial logistic regression, optimize a cross-entropy loss. Normally, the cross-entropy layer follows the softmax layer, which produces probability distribution.

In tensorflow, there are at least a dozen of different cross-entropy loss functions:

tf.losses.softmax_cross_entropy

tf.losses.sparse_softmax_cross_entropy

tf.losses.sigmoid_cross_entropy

tf.contrib.losses.softmax_cross_entropy

tf.contrib.losses.sigmoid_cross_entropy

tf.nn.softmax_cross_entropy_with_logits

tf.nn.sigmoid_cross_entropy_with_logits

...

Which works only for binary classification and which are suitable for multi-class problems? When should you use sigmoid instead of softmax? How are sparse functions different from others and why is it only softmax?

Related (more math-oriented) discussion: the cross-entropy jungle.

by (33.1k points)

The sigmoid function is related to the softmax function when the number of classes are equal. Both of them perform the same operation that is to transform the logits to probabilities.

In simple binary classification, there's no big difference between the both, but in the case of multinomial classification, sigmoid allows to deal with multi labels features, while softmax deals with exclusive classes.

Sigmoid functions family: These are the different classes related to a sigmoid function.

tf.nn.sigmoid_cross_entropy_with_logits

tf.nn.weighted_cross_entropy_with_logits

tf.losses.sigmoid_cross_entropy

tf.contrib.losses.sigmoid_cross_entropy

The sigmoid loss function is used for binary classification. But tensorflow functions are more extensive and allow to do multi-label classification when the classes are independent. The tensorflow function, tf.nn.sigmoid_cross_entropy_with_logits solves N binary classifications at once.

The labels in sigmoid must be one-hot encoded or can contain soft class probabilities.

tf.losses.sigmoid_cross_entropy allows to set the in-batch weights, i.e. make some examples more important than others. tf.nn.weighted_cross_entropy_with_logits allows to set class weights i.e. make positive errors larger than negative errors. This is useful when the training data is unbalanced.

Softmax functions family:

tf.nn.softmax_cross_entropy_with_logits (DEPRECATED IN 1.5)

tf.nn.softmax_cross_entropy_with_logits_v2

tf.losses.softmax_cross_entropy

tf.contrib.losses.softmax_cross_entropy

These loss functions should be used for multinomial mutually exclusive classification.

In softmax, labels must be one-hot encoded or can contain soft class probabilities: a particular example can belong to class A with 50% probability and class B with 50% probability.

In sigmoid family, tf.losses.softmax_cross_entropy allows to set the in-batch weights, i.e. make some examples more important than others. Tensorflow 1.3, there's no built-in way to set class weights.

Now, thesoftmax_cross_entropy_with_logits loss got deprecated. The only difference between them is that in a newer version, backpropagation happens into both logits and labels.

Sparse functions family

tf.nn.sparse_softmax_cross_entropy_with_logits

tf.losses.sparse_softmax_cross_entropy

tf.contrib.losses.sparse_softmax_cross_entropy

These loss functions should be used for multinomial mutually exclusive classification, i.e. pick one out of N classes. The main difference comes in labels encoding: the classes are specified as integers, not one-hot vectors. These functions don't allow soft classes, but it can save some memory when there are thousands or millions of classes. Logits argument must still contain logits per each class, it consumes at least [batch_size, classes] memory.

The tf.losses version has a weights argument which allows setting the in-batch weights.

Sampled softmax functions family

tf.nn.sampled_softmax_loss

tf.contrib.nn.rank_sampled_softmax_loss

tf.nn.nce_loss

These functions deal with the huge number of classes. They compute a loss estimate from a random sample.

The arguments in weights and biases used to separate the fully-connected layer that is used to compute the logits for a chosen sample.