Divide to get ω₁
val a4=a3.withColumn("w1", col("sum(pxy)")/col("sum(pxx)") )
Result:
a4.orderBy("shop","product").show()
+----------+-------+-------------------+------------------+--------------------+
| shop|product| sum(pxy)| sum(pxx)| w1|
+----------+-------+-------------------+------------------+--------------------+
| megamart| bread| 1025.6277999999998|142.72660000000002| 7.185961131281762|
| megamart| cheese| -83.6016|124.39869999999998| -0.6720456081936549|
| megamart| milk| 4.699600000000002| 100.0994|0.046949332363630567|
| megamart| nuts| 466.9998999999999| 143.0| 3.265733566433566|
| megamart| razors| -69.4553|120.90820000000001| -0.5744465635912204|
| megamart| soap| -11.0| 110.0| -0.1|
|superstore| bread| 596.8173|136.18159999999997| 4.382510559429469|
|superstore| cheese|0.39890000000000114| 128.8991|0.003094668620649...|
|superstore| milk|-2.5002999999999993| 104.0996|-0.02401834397058...|
|superstore| nuts| -2096.0004| 143.0| -14.657345454545453|
|superstore| razors|-285.00030000000004| 143.0| -1.9930090909090912|
|superstore| soap| -11.0003| 110.0|-0.10000272727272727|
+----------+-------+-------------------+------------------+--------------------+
Calculate ω₀
Glue column ω₁ onto aggregation dataframe a2 (which has the averages), so that we can compute ω₀ from the formula:
val a5= a2.join( a4, Seq("shop","product"), "leftouter").
select("shop","product","avg_x","avg_y","w1")
Result:
a5.orderBy("shop","product").show()
+----------+-------+------+---------+--------------------+
| shop|product| avg_x| avg_y| w1|
+----------+-------+------+---------+--------------------+
| megamart| bread|6.4545| 461.2727| 7.185961131281762|
| megamart| cheese| 6.4| 56.4| -0.6720456081936549|
| megamart| milk| 6.7| 27.9|0.046949332363630567|
| megamart| nuts| 6.5|1341.6666| 3.265733566433566|
| megamart| razors| 6.909| 519.5454| -0.5744465635912204|
| megamart| soap| 7.0| 8.0| -0.1|
|superstore| bread|6.2727| 398.7272| 4.382510559429469|
|superstore| cheese| 6.9| 56.4|0.003094668620649...|
|superstore| milk| 7.3| 33.5|-0.02401834397058...|
|superstore| nuts| 6.5|1279.1666| -14.657345454545453|
|superstore| razors| 6.5| 352.3333| -1.9930090909090912|
|superstore| soap| 6.0| 7.2|-0.10000272727272727|
+----------+-------+------+---------+--------------------+
Now calculate w0:
val a6=a5.withColumn("w0", col("avg_y") - col("w1") * col("avg_x") )
Result:
a6.orderBy("shop","product").show()
+----------+-------+------+---------+--------------------+------------------+
| shop|product| avg_x| avg_y| w1| w0|
+----------+-------+------+---------+--------------------+------------------+
| megamart| bread|6.4545| 461.2727| 7.185967535361064| 414.890872543012|
| megamart| cheese| 6.4| 56.4| -0.6720257234726688| 60.70096463022508|
| megamart| milk| 6.7| 27.9| 0.04695304695304696|27.585414585414583|
| megamart| nuts| 6.5|1341.6667| 3.265738461538462| 1320.4394|
| megamart| razors|6.9091| 519.5455| -0.5744354248425476| 523.5143317937795|
| megamart| soap| 7.0| 8.0| -0.1| 8.7|
|superstore| bread|6.2727| 398.7273| 4.382494922243877| 371.2372241012408|
|superstore| cheese| 6.9| 56.4|0.003103180760279...| 56.37858805275407|
|superstore| milk| 7.3| 33.5|-0.02401536983669545| 33.67531219980788|
|superstore| nuts| 6.5|1279.1667| -14.657339160839161|1374.4394045454546|
|superstore| razors| 6.5| 352.3333| -1.9930055944055944|365.28783636363636|
|superstore| soap| 6.0| 7.2| -0.1| 7.800000000000001|
+----------+-------+------+---------+--------------------+------------------+
|