|
Code reduction
Let's try and get rid of as many of intermediate dataframes as possible, ie. we chain as many as possible transformations together, and we skip the cosmetics (ie. rounding off).
Then the calculation of the omega's can be nicely reduced as follows ..
Starting with sale_narrow_df
sale_narrow_df.orderBy("shop","product","x").show()
+--------+-------+---+---+
| shop|product| x| y|
+--------+-------+---+---+
|megamart| bread| 1|371|
|megamart| bread| 2|432|
|megamart| bread| 3|425|
|megamart| bread| 4|524|
|megamart| bread| 5|468|
|megamart| bread| 6|414|
..
..
val avg_df = sale_narrow_df.
groupBy("shop","product").agg( avg("x"), avg("y") ).
select("shop","product","avg(x)","avg(y)")
val coeff_df = sale_narrow_df.
join( avg_df, Seq("shop","product"), "leftouter").
withColumn("xdif", col("x")-col("avg(x)")).
withColumn("ydif", col("y")-col("avg(y)")).
withColumn("pxy", col("xdif")*col("ydif") ).
withColumn("pxx", col("xdif")*col("xdif") ).
groupBy("shop","product").agg( sum("pxy"), sum("pxx") ).
withColumn("w1", col("sum(pxy)")/col("sum(pxx)") ).
join( avg_df, Seq("shop","product"), "leftouter").
withColumn("w0", col("avg(y)") - col("w1") * col("avg(x)") ).
select("shop","product","w0","w1")
Result
coeff_df.orderBy("shop","product").show()
+----------+-------+------------------+--------------------+
| shop|product| w0| w1|
+----------+-------+------------------+--------------------+
| megamart| bread|414.89044585987256| 7.185987261146498|
| megamart| cheese| 60.70096463022508| -0.6720257234726686|
| megamart| milk|27.585414585414583|0.046953046953046945|
| megamart| nuts| 1320.439393939394| 3.265734265734265|
| megamart| razors| 523.5142857142856| -0.5744360902255637|
| megamart| soap| 8.7| -0.1|
|superstore| bread| 371.2369826435247| 4.382510013351134|
|superstore| cheese| 56.37858805275407|0.003103180760279275|
|superstore| milk| 33.67531219980788|-0.02401536983669...|
|superstore| nuts| 1374.439393939394| -14.657342657342657|
|superstore| razors|365.28787878787875| -1.993006993006993|
|superstore| soap| 7.800000000000001| -0.1|
+----------+-------+------------------+--------------------+
Alternative
This alternative has one join less: the averages are carried over in the agg( .. ,max("avg(x)"), max("avg(y)")) :
val coeff_df = sale_narrow_df.join( avg_df, Seq("shop","product"), "leftouter").
withColumn("xdif", col("x")-col("avg(x)")).
withColumn("ydif", col("y")-col("avg(y)")).
withColumn("pxy", col("xdif")*col("ydif") ).
withColumn("pxx", col("xdif")*col("xdif") ).
groupBy("shop","product").agg( sum("pxy"), sum("pxx"), max("avg(x)"), max("avg(y)") ).
withColumn("w1", col("sum(pxy)")/col("sum(pxx)") ).
withColumn("w0", col("max(avg(y))") - col("w1") * col("max(avg(x))") ).
select("shop","product","w0","w1")
| |