Back

Explore Courses Blog Tutorials Interview Questions
0 votes
2 views
in Python by (19.9k points)

I have a tensor a of type tf.int64. I want to filter out this tensor on the basis of a given python list.

For example -

l = [1,2,3]

a = tf.constant([1,2,3,4], dtype=tf.int64) 

Need a tensor with values 1,2,3 except 4. That is filtering out a on the basis of l. How can I do this in TensorFlow?

1 Answer

0 votes
by (25.1k points)

You can just use the reduce_any method from tensorflow for this along with a boolean mask.

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:

    l = tf.constant([1, 2, 3], dtype=tf.int64)

    a = tf.constant([1, 2, 3, 4], dtype=tf.int64)

    m = tf.reduce_any(tf.equal(tf.expand_dims(a, 1), l), axis=1)

    b = tf.boolean_mask(a, m)

    print(sess.run(b))

Browse Categories

...