Skip to content

SparkSQL代码实操

1. 数据准备

我们这次Spark-sql操作中所有的数据均来自Hive,首先在Hive中创建表,,并导入数据。
一共有3张表: 1张用户行为表,1张城市表,1张产品表

sql
CREATE TABLE `user_visit_action`(
 `date` string,
 `user_id` bigint,
 `session_id` string,
 `page_id` bigint,
 `action_time` string,
 `search_keyword` string,
 `click_category_id` bigint,
 `click_product_id` bigint,
 `order_category_ids` string,
 `order_product_ids` string,
 `pay_category_ids` string,
 `pay_product_ids` string,
 `city_id` bigint)
row format delimited fields terminated by '\t';
load data local inpath '/opt/module/hive-4.0.1/examples/user_visit_action.txt' into table 
user_visit_action;
CREATE TABLE `product_info`(
 `product_id` bigint,
 `product_name` string,
 `extend_info` string)
row format delimited fields terminated by '\t';
load data local inpath '/opt/module/hive-4.0.1/examples/product_info.txt' into table product_info;
CREATE TABLE `city_info`(
 `city_id` bigint,
 `city_name` string,
 `area` string)
row format delimited fields terminated by '\t';
load data local inpath '/opt/module/hive-4.0.1/examples/city_info.txt' into table city_info;

2. 需求简介

计算各个区域前三大热门商品(从点击量的维度来看),并备注上每个商品在主要城市中的分布比例,超过两个城市用其他显示。 Alt text

3. 功能实现

scala
package com.rocket.spark.sql

import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{DataFrame, Encoder, Encoders, SparkSession, functions}

import scala.collection.mutable
import scala.collection.mutable.ListBuffer

object SparkSQLHive {

    def main(args: Array[String]): Unit = {
        System.setProperty("HADOOP_USER_NAME", "root")
        //创建上下文环境配置对象
        val conf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQLHive")
        // 创建数据库默认是在本地仓库, 需要指向HDFS
        conf.set("spark.sql.warehouse.dir", "hdfs://hadoop102:8020/warehouse")
        // 添加Hive功能特性支持
        val session = SparkSession.builder().enableHiveSupport().config(conf).getOrCreate()


        // 连接三张表的数据,获取完整的数据(只有点击)

        session.sql(
            """
              | select
              | ci.area, pi.product_name, city_name
              | from user_visit_action uva
              | join product_info pi on uva.click_product_id=pi.product_id
              | join city_info ci on ci.city_id = uva.city_id
              | where click_product_id > -1
              |""".stripMargin).createOrReplaceTempView("t1")

        // 根据区域,商品进行数据聚合
        session.udf.register("cityRemark", functions.udaf(new CityRemarkUDAF()))
        // 将数据根据地区,商品名称分组
        session.sql(
            """
              |select
              | count(*) clickCnt, area, product_name,
              | cityRemark(city_name) as city_remark
              | from t1 group by product_name, area
              |""".stripMargin).createOrReplaceTempView("t2")

        // 统计商品点击次数总和
        session.sql(
            """
              | select *,
              |  rank() over (partition by product_name, area order by clickCnt desc) rk
              | from t2
              | where rk<4 order by rk
              |""".stripMargin).createOrReplaceTempView("t3")
        // 取 Top3, 显示字符串不限制长度show(false)
        session.sql(
            """
              |select *
              | from t3 where rk<4
              |""".stripMargin).show(false)

        session.close()
    }

    case class Buffer( var total : Long, var cityMap: mutable.Map[String, Long])
    class CityRemarkUDAF extends Aggregator[String, Buffer, String]{
        // 缓冲区初始化
        override def zero: Buffer = {
            Buffer(0, mutable.Map[String, Long]())
        }

        // 更新缓冲区数据
        override def reduce(buff: Buffer, city: String): Buffer = {
            buff.total += 1
            val newCount = buff.cityMap.getOrElse(city, 0L) + 1
            buff.cityMap.update(city, newCount)
            buff
        }

        // 合并缓冲区数据
        override def merge(buff1: Buffer, buff2: Buffer): Buffer = {
            buff1.total += buff2.total

            val map1 = buff1.cityMap
            val map2 = buff2.cityMap

            // 两个Map的合并操作
            buff1.cityMap = map1.foldLeft(map2) {
                case ( map, (city, cnt) ) => {
                    val newCount = map.getOrElse(city, 0L) + cnt
                    map.update(city, newCount)
                    map
                }
            }
            buff1.cityMap = map1
            buff1
        }
        // 将统计的结果生成字符串信息
        override def finish(buff: Buffer): String = {
            val remarkList = ListBuffer[String]()

            val totalcnt = buff.total
            val cityMap = buff.cityMap

            // 降序排列
            val cityCntList = cityMap.toList.sortWith(
                (left, right) => {
                    left._2 > right._2
                }
            ).take(2)

            val hasMore = cityMap.size > 2
            var rsum = 0L
            cityCntList.foreach{
                case ( city, cnt ) => {
                    val r = cnt * 100 / totalcnt
                    remarkList.append(s"${city} ${r}%")
                    rsum += r
                }
            }
            if ( hasMore ) {
                remarkList.append(s"其他 ${100 - rsum}%")
            }

            remarkList.mkString(", ")
        }
        override def bufferEncoder: Encoder[Buffer] = Encoders.product

        override def outputEncoder: Encoder[String] = Encoders.STRING
    }
}