使用 spark 时有需要将一个窄表转换为宽表(根据某列或某几列的内容生成新的列)。
例如有这样一张表 T,其中有三列 A, B, C,假设其内容如下所示
A | B | C |
---|---|---|
a1 | k2 | c12 |
a2 | k2 | c21 |
a2 | k1 | c22 |
a2 | k3 | c23 |
需要将之生成为
A | k1 | k2 | k3 |
---|---|---|---|
a1 | c11 | c12 | null |
a2 | c22 | c12 | c23 |
这种操作在 spark 中术语叫作 pivot(旋转)。在 spark 中需要进行 df.groupBy(...).pivot(...).agg(...)
生成新的 DataFrame
pivot
方法在 org.apache.spark.sql.RelationalGroupedDataset
中定义,返回类型仍是 GroupedData
类型。看得出 pivot
本质上与 group
相似,都是在进行聚合。
如上面的例子在 spark 中可以写作:
df.groupBy($"a").pivot("b").agg(expr("first(c)"))
此外,需要注意到,因为 pivot 是对指定列的内容进行聚合,实际上操作分为两步
相当于以下 spark 方法
df.groupBy("a").agg(
expr("first(case when b = 'k1' then c end, true)").as("k1"),
expr("first(case when b = 'k2' then c end, true)").as("k2"),
expr("first(case when b = 'k3' then c end, true)").as("k3")
).show
其中 first 第二个参数表示是否忽略 null 值
为了提高处理性能,可以在 pivot
加入第二个参数表示可选取值,即
def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset
其作用有二:
另外,pivot
只支持对一列进行旋转。如果要对多列(实际上要对多列内容进行合并,生成新的列名)进行操作,则可以先用 concat
方法生成一个新的列用于 pivot