0 votes
1 view
in Big Data Hadoop & Spark by (11.5k points)

I've got a dataframe like this and I want to duplicate the row n times if the column n is bigger than one:

A   B   n 
1   2   1 
2   9   1 
3   8   2   
4   1   1   
5   3   3 


And transform like this:

A   B   n 
1   2   1 
2   9   1 
3   8   2
3   8   2      
4   1   1   
5   3   3
5   3   3
5   3   3 


I think I should use explode, but I don't understand how it works...

1 Answer

0 votes
by (31.4k points)

You shall you explode function here.

explode(): This function takes in an array (or a map) as an input and outputs the elements of the array (map) as separate rows. In a basic language it creates a new row for each element present in the selected map column or the array.


 

In order to exploit this function you can use a udf to create a list of size n for each row. Then explode the resulting array.

from pyspark.sql.functions import udf, explode

from pyspark.sql.types import ArrayType, IntegerType

df = spark.createDataFrame([(1,2,1), (2,9,1), (3,8,2), (4,1,1), (5,3,3)] ,["A", "B", "n"]) 

# use udf function to transform the n value to n times

n_to_array = udf(lambda n : [n] * n, ArrayType(IntegerType()))

df2 = df.withColumn('n', n_to_array(df.n))

# now use explode  

df2.withColumn('n', explode(df2.n)).show()

+---+---+---+ 

| A | B | n | 

+---+---+---+ 

|  1| 2|  1| 

|  2| 9|  1| 

|  3| 8|  2| 

|  3| 8|  2| 

|  4| 1|  1| 

|  5| 3|  3| 

|  5| 3|  3| 

|  5| 3|  3| 

+---+---+---+

Related questions

Welcome to Intellipaat Community. Get your technical queries answered by top developers !


Categories

...