teeyog / blog

My Blog
76 stars 24 forks source link

Spark整合HBase(自定义HBase DataSource) #22

Open teeyog opened 6 years ago

teeyog commented 6 years ago

背景

Spark支持多种数据源,但是Spark对HBase 的读写都没有相对优雅的api,但spark和HBase整合的场景又比较多,故通过spark的DataSource API自己实现了一套比较方便操作HBase的API。

写 HBase

写HBase会根据Dataframe的schema写入对应数据类型的数据到Hbase,先上使用示例:

import spark.implicits._
import org.apache.hack.spark._
val df = spark.createDataset(Seq(("ufo",  "play"), ("yy",  ""))).toDF("name", "like")
// 方式一
val options = Map(
            "hbase.table.rowkey.field" -> "name",
            "hbase.table.numReg" -> "12",
            "hbase.table.rowkey.prefix" -> "00",
            "bulkload.enable" -> "false"
        )
df.saveToHbase("hbase_table", Some("XXX:2181"), options)
// 方式二
df1.write.format("org.apache.spark.sql.execution.datasources.hbase")
            .options(Map(
                "hbase.table.rowkey.field" -> "name",
                "hbase.table.name" -> "hbase_table",
                "hbase.zookeeper.quorum" -> "XXX:2181",
                "hbase.table.rowkey.prefix" -> "00",
                "hbase.table.numReg" -> "12",
                "bulkload.enable" -> "false"
            )).save()

上面两种方式实现的效果是一样的,下面解释一下每个参数的含义:

读 HBase

示例代码如下:

// 方式一
import org.apache.hack.spark._
 val options = Map(
    "spark.table.schema" -> "appid:String,appstoreid:int,firm:String",
    "hbase.table.schema" -> ":rowkey,info:appStoreId,info:firm"
)
spark.hbaseTableAsDataFrame("hbase_table", Some("XXX:2181")).show(false)
// 方式二
spark.read.format("org.apache.spark.sql.execution.datasources.hbase").
            options(Map(
            "spark.table.schema" -> "appid:String,appstoreid:int,firm:String",
            "hbase.table.schema" -> ":rowkey,info:appStoreId,info:firm",
            "hbase.zookeeper.quorum" -> "XXX:2181",
            "hbase.table.name" -> "hbase_table"
        )).load.show(false)  

spark和hbase表的schema映射关系指定不是必须的,默认会生成rowkey和content两个字段,content是由所有字段组成的json字符串,可通过field.type.fieldname对单个字段设置数据类型,默认都是StringType。这样映射出来还得通过spark程序转一下才是你想要的样子,而且所有字段都会去扫描,相对来说不是特别高效。

故我们可自定义schema映射来获取数据:

注意这两个schema是一一对应的,Hbase只会扫描hbase.table.schema对应的列。

核心代码

写 HBase

class DataFrameFunctions(data: DataFrame) extends Logging with Serializable {

    def saveToHbase(tableName: String, zkUrl: Option[String] = None,
                    options: Map[String, String] = new HashMap[String, String]): Unit = {

        val wrappedConf = {
            implicit val formats = DefaultFormats
            val hc = HBaseConfiguration.create()
            hc.set("hbase.zookeeper.quorum", zkUrl.getOrElse("127.0.0.1:2181"))
            new SerializableConfiguration(hc)
        }
        val hbaseConf = wrappedConf.value

        val rowkey = options.getOrElse("rowkey.field", data.schema.head.name)
        val family = options.getOrElse("family", "info")
        val numReg = options.getOrElse("numReg", -1).toString.toInt
        val startKey = options.getOrElse("startKey", null)
        val endKey = options.getOrElse("endKey", null)

        val rdd = data.rdd
        val f = family

        val tName = TableName.valueOf(tableName)
        val connection = ConnectionFactory.createConnection(hbaseConf)
        val admin = connection.getAdmin
        if (!admin.isTableAvailable(tName)) {
            HBaseUtils.createTable(connection, tName, family, startKey, endKey, numReg)
        }
        connection.close()
        if (hbaseConf.get("mapreduce.output.fileoutputformat.outputdir") == null) {
            hbaseConf.set("mapreduce.output.fileoutputformat.outputdir", "/tmp")
        }
        val jobConf = new JobConf(hbaseConf, this.getClass)
        jobConf.set(TableOutputFormat.OUTPUT_TABLE, tableName)

        val job = Job.getInstance(jobConf)
        job.setOutputKeyClass(classOf[ImmutableBytesWritable])
        job.setOutputValueClass(classOf[Result])
        job.setOutputFormatClass(classOf[TableOutputFormat[ImmutableBytesWritable]])

        val fields = data.schema.toArray
        val rowkeyIndex = fields.zipWithIndex.filter(f => f._1.name == rowkey).head._2
        val otherFields = fields.zipWithIndex.filter(f => f._1.name != rowkey)

        lazy val setters = otherFields.map(r => HBaseUtils.makeHbaseSetter(r))
        lazy val setters_bulkload = otherFields.map(r => HBaseUtils.makeHbaseSetter_bulkload(r))

        options.getOrElse("bulkload.enable", "true") match {

            case "true" =>
                val tmpPath = s"/tmp/bulkload/${tableName}" + System.currentTimeMillis()
                def convertToPut_bulkload(row: Row) = {
                    val rk = Bytes.toBytes(row.getString(rowkeyIndex))
                    setters_bulkload.map(_.apply(rk, row, f))
                }
                rdd.flatMap(convertToPut_bulkload)
                    .saveAsNewAPIHadoopFile(tmpPath, classOf[ImmutableBytesWritable], classOf[KeyValue], classOf[HFileOutputFormat2], job.getConfiguration)

                val bulkLoader: LoadIncrementalHFiles = new LoadIncrementalHFiles(hbaseConf)
                bulkLoader.doBulkLoad(new Path(tmpPath), new HTable(hbaseConf, tableName))

            case "false" =>
                def convertToPut(row: Row) = {
                    val put = new Put(Bytes.toBytes(row.getString(rowkeyIndex)))
                    setters.foreach(_.apply(put, row, f))
                    (new ImmutableBytesWritable, put)
                }
                rdd.map(convertToPut).saveAsNewAPIHadoopDataset(job.getConfiguration)
        }
    }
}

读Hbase

class SparkSqlContextFunctions(@transient val spark: SparkSession) extends Serializable {

    private val SPARK_TABLE_SCHEMA: String = "spark.table.schema"
    private val HBASE_TABLE_SCHEMA: String = "hbase.table.schema"

    def hbaseTableAsDataFrame(table: String, zkUrl: Option[String] = None,
                              options:Map[String, String] = new HashMap[String, String]
                             ): DataFrame = {

        val wrappedConf = {
            val hc = HBaseConfiguration.create()
            hc.set("hbase.zookeeper.quorum", zkUrl.getOrElse("127.0.0.1:2181"))
            hc.set(TableInputFormat.INPUT_TABLE, table)
            if (options.contains(HBASE_TABLE_SCHEMA)) {
                var str = ArrayBuffer[String]()
                options(HBASE_TABLE_SCHEMA)
                    .split(",", -1).map(field =>
                    if (!field.startsWith(":")) {
                        str += field
                    }
                )
                if (str.length > 1) hc.set(TableInputFormat.SCAN_COLUMNS, str.mkString(" "))
            }
            Array(SPARK_TABLE_SCHEMA,HBASE_TABLE_SCHEMA,TableInputFormat.SCAN_ROW_START,TableInputFormat.SCAN_ROW_STOP).foldLeft((hc,options)) {
                case ((_hc,_options),pram) => if(_options.contains(pram)) _hc.set(pram,_options(pram))
                    (_hc,_options)
            }
            new SerializableConfiguration(hc)
        }
        def hbaseConf = wrappedConf.value

        def schema: StructType = {
            import org.apache.spark.sql.types._
            Option(hbaseConf.get(SPARK_TABLE_SCHEMA)) match {
                case Some(schema) => HBaseUtils.registerSparkTableSchema(schema)
                case None =>
                    StructType(
                        Array(
                            StructField("rowkey", StringType, nullable = false),
                            StructField("content", StringType)
                        )
                    )
            }
        }

        Option(hbaseConf.get(SPARK_TABLE_SCHEMA)) match {
            case Some(s) =>
                require(hbaseConf.get(HBASE_TABLE_SCHEMA).nonEmpty, "Because the parameter spark.table.schema has been set, hbase.table.schema also needs to be set.")
                val sparkTableSchemas = schema.fields.map(f => SparkTableSchema(f.name, f.dataType))
                val hBaseTableSchemas = HBaseUtils.registerHbaseTableSchema(hbaseConf.get(HBASE_TABLE_SCHEMA))
                require(sparkTableSchemas.length == hBaseTableSchemas.length, "The length of the parameter spark.table.schema must be the same as the parameter hbase.table.schema.")
                val schemas = sparkTableSchemas.zip(hBaseTableSchemas)
                val setters = schemas.map(schema => HBaseUtils.makeHbaseGetter(schema))

                val hBaseRDD = spark.sparkContext.newAPIHadoopRDD(hbaseConf, classOf[TableInputFormat], classOf[ImmutableBytesWritable], classOf[Result])
                    .map { case (_, result) => Row.fromSeq(setters.map(r => r.apply(result)).toSeq) }
                spark.createDataFrame(hBaseRDD, schema)

            case None =>
                val hBaseRDD = spark.sparkContext.newAPIHadoopRDD(hbaseConf, classOf[TableInputFormat], classOf[ImmutableBytesWritable], classOf[Result])
                    .map { line =>
                        val rowKey = Bytes.toString(line._2.getRow)

                        implicit val formats = Serialization.formats(NoTypeHints)

                        val content = line._2.getMap.navigableKeySet().flatMap { f =>
                            line._2.getFamilyMap(f).map { c =>
                                val columnName = Bytes.toString(f) + ":" + Bytes.toString(c._1)
                                    options.get("field.type." + columnName) match {
                                    case Some(i) =>
                                        val value = i match {
                                            case "LongType" => Bytes.toLong(c._2)
                                            case "FloatType" => Bytes.toFloat(c._2)
                                            case "DoubleType" => Bytes.toDouble(c._2)
                                            case "IntegerType" => Bytes.toInt(c._2)
                                            case "BooleanType" => Bytes.toBoolean(c._2)
                                            case "BinaryType" => c._2
                                            case "TimestampType" => new Timestamp(Bytes.toLong(c._2))
                                            case "DateType" => new java.sql.Date(Bytes.toLong(c._2))
                                            case _ => Bytes.toString(c._2)
                                        }
                                        (columnName, value)
                                    case None => (columnName, Bytes.toString(c._2))
                                }
                            }
                        }.toMap
                        val contentStr = Serialization.write(content)
                        Row.fromSeq(Seq(rowKey,contentStr))
                    }
                spark.createDataFrame(hBaseRDD, schema)
        }
    }
}

扩展的DataSource都需要是名为DefaultSource 的类

class DefaultSource extends CreatableRelationProvider with RelationProvider with DataSourceRegister {

    override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation =
        HBaseRelation(parameters, None)(sqlContext)

    override def shortName(): String = "hbase"

    override def createRelation(sqlContext: SQLContext, mode: SaveMode, parameters: Map[String, String], data: DataFrame): BaseRelation = {
        val relation = InsertHBaseRelation(data, parameters)(sqlContext)
        relation.insert(data, false)
        relation
    }
}

private[sql] case class InsertHBaseRelation(
                                               dataFrame: DataFrame,
                                               parameters: Map[String, String]
                                           )(@transient val sqlContext: SQLContext)
    extends BaseRelation with InsertableRelation with Logging {

    override def insert(data: DataFrame, overwrite: Boolean): Unit = {

        def getZkURL: String = parameters.getOrElse("zk", parameters.getOrElse("hbase.zookeeper.quorum", sys.error("You must specify parameter zkurl...")))
        def getOutputTableName: String = parameters.getOrElse("outputTableName", sys.error("You must specify parameter outputTableName..."))

        import org.apache.hack.spark._
        data.saveToHbase(getOutputTableName, Some(getZkURL), parameters)
    }
    override def schema: StructType = dataFrame.schema
}

private[sql] case class HBaseRelation(
                                         parameters: Map[String, String],
                                         userSpecifiedschema: Option[StructType]
                                     )(@transient val sqlContext: SQLContext)
    extends BaseRelation with TableScan with Logging {

    def getZkURL: String = parameters.getOrElse("zk", parameters.getOrElse("hbase.zookeeper.quorum", sys.error("You must specify parameter zkurl...")))
    def getInputTableName: String = parameters.getOrElse("inputTableName", sys.error("You must specify parameter imputTableName..."))

    def buildScan(): RDD[Row] = {
        import org.apache.hack.spark._
        sqlContext.sparkSession.hbaseTableAsDataFrame(getInputTableName, Some(getZkURL), parameters).rdd
    }

    override def schema: StructType = {
        import org.apache.hack.spark._
        sqlContext.sparkSession.hbaseTableAsDataFrame(getInputTableName, Some(getZkURL), parameters).schema
    }
}

参考

FireSK7 commented 6 years ago

为什么df可以调用 那个方法

teeyog commented 6 years ago

@SteveYanzhi

import org.apache.hack.spark._

这里面有隐式转换

 implicit def toSparkSqlContextFunctions(spark: SparkSession): SparkSqlContextFunctions = {
        new SparkSqlContextFunctions(spark)
    }
    implicit def toDataFrameFunctions(data: DataFrame): DataFrameFunctions = {
        new DataFrameFunctions(data)
    }

完整的整合代码参考 https://github.com/teeyog/IQL/tree/master/iql-spark/src/main/scala/org/apache/spark/sql/execution/datasources/hbase

teeyog commented 6 years ago

@SteveYanzhi 没有,后面发的是最新的,功能更完善。

FireSK7 commented 6 years ago

代码真的写的很棒