+1 vote
1 view

edited

What does the following function do? Should I consider it as a lookup table like in skip-gram model?

tf.nn.embedding_lookup(params, ids, partition_strategy='mod', name=None)

by (10.9k points)
edited

It is similar to the tf.gather which returns elements of params as per the indexes specifies by ids.

Ex-

para = tf.constant([10,20,80,40])

id = tf.constant([1,2,3])

print tf.nn.embedding_lookup(param,id).eval()

This will give the output:

[20,80,40]

But,The main function of embedding_lookup is to retrieve rows of the params tensor.This params agrument mya have a list of tensors instead of a single tensor.

Ex-

param1 = tf.constant([10,2])

param2 = tf.constant([20,30])

ids = tf.constant([2,0,2,1,2,3])

result = tf.nn.embedding_lookup([param1, param2], ids)

In this cases, the indexes as specified in ids corresponds to the element of tensors as per the partition strategy. The partition_startegy controls the way how the ids will get distributed among the list, the default partition strategy is  ‘mod’.

The mod strategy:

Index 0 correspond to the first element of first tensor while index 1 corresponds to the first element of second tensor and so on.For index n,it cannot correspond to the n+1 tensor since the list params contain only n tensor so the nth index will correspond to the second element of first tensor. Similarly, the index n+1 corresponds to the second element of the second tensor and so on.

Now coming back to the code,

param1 = tf.constant([10,2])

param2 = tf.constant([20,30])

ids = tf.constant([2,0,2,1,2,3])

result = tf.nn.embedding_lookup([param1, param2], ids)

Result:

[2 10  2  20 2 30]

Hope this helps!

by (108k points)

The lookup function is used to perform parallel lookups on the list of tensors in params. Lookup ids is a list of embedding tensors. The embedding_lookup function retrieves rows of the params tensor.

Params are a single tensor representing the complete embedding tensor, or a list of P tensors all of the same shape except for the first dimension, representing all the fragmented embedding tensors.

Here is a piece of code for better understanding

For example (assuming you are inside tf.InteractiveSession())

params = tf.constant([10,20,30,40])

ids = tf.constant([0,1,2,3])

print tf.nn.embedding_lookup(params,ids).eval()

would return [10 20 30 40], because the first element (index 0) of params is 10, the second element of params (index 1) is 20, etc.

Similarly,

params = tf.constant([10,20,30,40])

ids = tf.constant([1,1,3])

print tf.nn.embedding_lookup(params,ids).eval()

would return [20 20 40].

But embedding_lookup is more than that. The params argument can be a list of tensors, rather than a single tensor.

params1 = tf.constant([1,2])

params2 = tf.constant([10,20])

ids = tf.constant([2,0,2,1,2,3])

result = tf.nn.embedding_lookup([params1, params2], ids)

In such a case, the indexes, specified in ids, correspond to elements of tensors according to a partition strategy, where the default partition strategy is 'mod'. In the 'mod' strategy, index 0 corresponds to the first element of the first tensor in the list. Index 1 corresponds to the first element of the second tensor. Index 2 corresponds to the first element of the third tensor, and so on. Simply index i corresponds to the first element of the (i+1)th tensor, for all the indexes 0..(n-1), assuming params is a list of n tensors.

Now, index n cannot correspond to tensor n+1, because the list params contain only n tensors. So index n corresponds to the second element of the first tensor. Similarly, index n+1 corresponds to the second element of the second tensor, etc. So, in the code

params1 = tf.constant([1,2])

params2 = tf.constant([10,20])

ids = tf.constant([2,0,2,1,2,3])

result = tf.nn.embedding_lookup([params1, params2], ids)

index 0 corresponds to the first element of the first tensor: 1

index 1 corresponds to the first element of the second tensor: 10

index 2 corresponds to the second element of the first tensor: 2

index 3 corresponds to the second element of the second tensor: 20. Thus, the result would be:

[ 2 1 2 10 2 20]

Answering to your question, yes it will find the correspond embedding as it shows you the list of the corresponding embedding tensors for the input.