Back

Explore Courses Blog Tutorials Interview Questions
+1 vote
2 views
in AI and Deep Learning by (3.5k points)
edited by

What is the role of Flatten in Keras. I am executing the code below and it's a two layered network. The shape of it's 2-Dimensional data is (4,3) and the output is of 1-Dimensional data of shape (2,5):

model = Sequential()

model.add(Dense(16, input_shape=(4, 3)))

model.add(Activation('relu'))

model.add(Flatten())

model.add(Dense(4))

model.compile(loss='mean_squared_error', optimizer='SGD')

c = np.array([[[2, 3], [4, 5], [6, 7]]])

u = model.predict(c)

print u.shape

with flatten function it prints that u has shape(2,5) but when I remove it the shape of y changes to (2,4,5).

As I know, model.add(Dense(16, input_shape=(3, 2))) this function here is used to create a hidden fully-connected layer of 16 nodes. And every node is connected with every 4x3 input elements. Hence all the 16 nodes of first laser are flat. This means the output we should get from first layer should be (2,17) which is used by second layer as input and it gives data of shape (2,5) as output. 

So here my question is why am I further flatting it, when the first layer's output is already flat? 

1 Answer

0 votes
by (10.9k points)
edited by

Flatten is used to reshape the tensor to such a shape which is equal to the number of elements present in the tensor. 

Ex-

Suppose you have the output of a layer of shape(20,4,5,3) , flatten will unstack all the tensor values into a 1-D tensor with shape(20*4*5*3)

As per your code,

Dense(16, input_shape=(3,2))

The above statement actually means that it will result in a dense network having 2 inputs and 16 outputs which will be applied independently for the 3steps. If D(x) transforms a 3D vector to a 16D layer you will get the output of a sequence of vectors [D(x[0,:], D(x[1,:],..., D(x[4,:]] having shape (4, 16) but for that you will have to first flatten the input to a 15D vector and then apply:

model = Sequential()

model.add(Flatten(input_shape=(4,3 )))

model.add(Dense(16))

model.add(Activation('relu'))

model.add(Dense(4))

model.compile(loss='mean_squared_error', optimizer='SGD')

Browse Categories

...