Back

Explore Courses Blog Tutorials Interview Questions
0 votes
2 views
in Big Data Hadoop & Spark by (11.4k points)

I have a simple dataframe like this:

rdd = sc.parallelize(
    [
        (0, "A", 223,"201603", "PORT"),
        (0, "A", 22,"201602", "PORT"),
        (0, "A", 422,"201601", "DOCK"),
        (1,"B", 3213,"201602", "DOCK"),
        (1,"B", 3213,"201601", "PORT"),
        (2,"C", 2321,"201601", "DOCK")
    ]
)
df_data = sqlContext.createDataFrame(rdd, ["id","type", "cost", "date", "ship"])

df_data.show()
 +---+----+----+------+----+
| id|type|cost|  date|ship|
+---+----+----+------+----+
|  0|   A| 223|201603|PORT|
|  0|   A|  22|201602|PORT|
|  0|   A| 422|201601|DOCK|
|  1|   B|3213|201602|DOCK|
|  1|   B|3213|201601|PORT|
|  2|   C|2321|201601|DOCK|

+---+----+----+------+----+

and I need to pivot it by date:

df_data.groupby(df_data.id, df_data.type).pivot("date").avg("cost").show()

+---+----+------+------+------+
| id|type|201601|201602|201603|
+---+----+------+------+------+
|  2|   C|2321.0|  null|  null|
|  0|   A| 422.0|  22.0| 223.0|
|  1|   B|3213.0|3213.0|  null|
+---+----+------+------+------+


Everything works as expected. But now I need to pivot it and get a non-numeric column:

df_data.groupby(df_data.id, df_data.type).pivot("date").avg("ship").show()


and of course I would get an exception:

AnalysisException: u'"ship" is not a numeric column. Aggregation function can only be applied on a numeric column.;'


I would like to generate something on the line of

+---+----+------+------+------+
| id|type|201601|201602|201603|
+---+----+------+------+------+
|  2|   C|DOCK  |  null|  null|
|  0|   A| DOCK |  PORT| DOCK|
|  1|   B|DOCK  |PORT  |  null|
+---+----+------+------+------+

1 Answer

0 votes
by (32.3k points)

Assuming that (id |type | date) combinations are unique and your only goal is pivoting and not aggregation, in such case you may use first function (or any other function not restricted to numeric values):

from pyspark.sql.functions import first

(df_data

    .groupby(df_data.id, df_data.type)

    .pivot("date")

    .agg(first("ship"))

    .show())

## +---+----+------+------+------+

## | id|type|201601|201602|201603|

## +---+----+------+------+------+

## |  2| C|  DOCK| null|  null|

## |  0| A|  DOCK| PORT|  PORT|

## |  1| B|  PORT| DOCK|  null|

## +---+----+------+------+------+

And if these assumptions are not correct you'll have to pre-aggregate your data. For example for the most common ship value:

from pyspark.sql.functions import max, struct

(df_data

    .groupby("id", "type", "date", "ship")

    .count()

    .groupby("id", "type")

    .pivot("date")

    .agg(max(struct("count", "ship")))

    .show())

## +---+----+--------+--------+--------+

## | id|type|  201601| 201602|  201603|

## +---+----+--------+--------+--------+

## |  2| C|[1,DOCK]|    null| null|

## |  0| A|[1,DOCK]|[1,PORT]|[1,PORT]|

## |  1| B|[1,PORT]|[1,DOCK]|    null|

## +---+----+--------+--------+--------+

Browse Categories

...