cjuexuan / mynote

237 stars 34 forks source link

hbase-rdd的封装 #15

Open cjuexuan opened 8 years ago

cjuexuan commented 8 years ago

项目背景

spark sql读hbase据说官网现在在写,但还没稳定,所以我基于hbase-rdd这个项目进行了一个封装,其中会区分是否为2进制,如果是就在配置文件中指定为#b,如long#b,还有用了个公司封装的Byte转其他类型,这个如果别人用需要自己实现一套方案,如果我们完成这一步,将会得到一个DataFrame,后面就可以registerTmpTable,正常使用了使用hiveContext,是因为有一定的orc文件,我这套方案是兼容hbase和hfile的,比如:

val conf = new SparkConf
implicit val sc = new SparkContext(conf)
implicit val hiveContext = new HiveContext(sc)
HbaseMappingUtil.getHbaseDataFrame(tableName,startRow,stopRow).registerTempTable(tableName)
hiveContext.sql("select * from tableName limit 1").show()

配置文件

配置文件:

hbase {
  mapping {
    table {
      usertriat {
        name = "hb_user_trait_7days"
        columnfamily = "stat"
        columns = ["p_du", "p_counts", "p_period_dist"]
        nullable = [true,false,true]
      }

      toddtest {
        name = "todd_test"
        columnfamily = "cf1"
        columns = ["name", "age"]
        schemas = ["String", "int"]
        nullable = [true, true]
      }

      user {
        name = "hb_user"
        columnfamily = "user"
        columns = ["modifiedTime", "nickname", "isThirdparty"]
        schemas = ["long#b", "string", "boolean"]
        nullable = [true, true, true]
      }

    }
  }
}

就是需要配置一些比如columnfamily,column,是否为空,一定要配,相当于自定格式的一个配置

核心代码

核心代码:

import scala.language._
import unicredit.spark.hbase._
import net.ceedubs.ficus.Ficus._
import org.apache.spark.sql.types._
import org.apache.spark.SparkContext
import com.typesafe.config.ConfigFactory
import org.apache.hadoop.hbase.client.Scan
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.{DataFrame, Row}
import com.ximalaya.tran.{Bytes, PrimitiveByteTrans, Tran}
import java.lang.{Boolean ⇒ JBoolean, Double ⇒ JDouble, Float ⇒ JFloat, Long ⇒ JLong}

/**
  * Created by todd.chen on 16/3/28.
  * email : todd.chen@ximalaya.com
  */
object HbaseMappingUtil {

  lazy val config = ConfigFactory.load()

  def getHbaseDataFrame(table: String)(implicit @transient hiveContext: HiveContext,
                                       @transient sc: SparkContext): DataFrame = {
    getHbaseDataFrame(table, None, None)
  }

  def getHbaseDataFrame(table: String, startRow: Option[String], endRow: Option[String])
                       (implicit @transient hiveContext: HiveContext,
                        @transient sc: SparkContext): DataFrame = {
    lazy val hbasePrefix = s"hbase.mapping.table.$table"
    implicit val hbaseConfig = HBaseConfig()
    implicit def string2Integer(str: String): Integer = new Integer(str)
    val tableName = config.as[String](s"$hbasePrefix.name")
    val columnFamily = config.as[String](s"$hbasePrefix.columnfamily")
    val _columns = config.as[Set[String]](s"$hbasePrefix.columns")
    val _names = _columns.toSeq
    val _schemas = config.as[Seq[String]](s"$hbasePrefix.schemas")
    val _nullAbles = config.as[Seq[Boolean]](s"$hbasePrefix.nullable")
    implicit val columnsZipSchema: Map[String, Tran[_ <: AnyRef, Array[Byte]]] = schemaUtil(table)
    val columns = Map(columnFamily → _columns)
    val rddSchema = StructType(Seq(StructField("id", StringType, false)) ++ createSchema(_names, _schemas, _nullAbles))
    val scan = if (startRow.isDefined && endRow.isDefined) Some(getScan(startRow.get, endRow.get)) else None
    def row2Row(row: (String, Map[String, Map[String, Array[Byte]]])) = {
      val cf = row._2(columnFamily)
      val values = Seq(row._1) ++ _names.map(name ⇒ {
        val bytesArray = cf.getOrElse(name, null)
        arrayByte2Object(bytesArray, name)
      })
      Row(values: _*)
    }
    val rowRdd = if (scan.isDefined) {
      sc.hbase[Array[Byte]](tableName, columns, scan.get).map(row2Row
      )
    } else {
      sc.hbase[Array[Byte]](tableName, columns).map(row2Row)
    }
    hiveContext.createDataFrame(rowRdd, rddSchema)
  }

  private def createSchema(names: Seq[String], schemas: Seq[String], nullAbles: Seq[Boolean]): Seq[StructField] = {
    (names, schemas, nullAbles).zipped.map {
      case (name, schema, isnull) ⇒ (name, schema, isnull)
    }.map(string2StructField)
  }

  private def string2StructField(nameAndStyle: (String, String, Boolean)): StructField = {
    val (name, schema, nullAble) = nameAndStyle
    schema.toLowerCase match {
      case "string" ⇒ StructField(name, StringType, nullAble)
      case "double" ⇒ StructField(name, DoubleType, nullAble)
      case "int" | "int#b" ⇒ StructField(name, IntegerType, nullAble)
      case "long" | "long#b" ⇒ StructField(name, LongType, nullAble)
      case "boolean" ⇒ StructField(name, BooleanType, nullAble)
      case "float" ⇒ StructField(name, FloatType, nullAble)
      case "timestamp" ⇒ StructField(name, TimestampType, nullAble)
      case "date" ⇒ StructField(name, DateType, nullAble)
    }
  }

  private def arrayByte2Object(arrayBytes: Array[Byte], column: String)
                              (implicit columnsZipTran: Map[String, Tran[_ <: AnyRef, Array[Byte]]]) = {
    val tran = columnsZipTran.get(column).get
    tran.from(arrayBytes)
  }

  private def schemaUtil(tableName: String) = {
    lazy val hbasePrefix = s"hbase.mapping.table.$tableName"
    val _columns = config.as[Seq[String]](s"$hbasePrefix.columns")
    val _schemas = config.as[Seq[String]](s"$hbasePrefix.schemas")
    column2Tran(_columns.zip(_schemas))
  }

  private def column2Tran(columnZipSchema: Seq[(String, String)]) = {
    var columnZipTran = Map.empty[String, Tran[_ <: AnyRef, Array[Byte]]]
    columnZipSchema.foreach { cs ⇒
      val (column, schema) = cs
      columnZipTran += column → schema2Tran(schema)
    }
    columnZipTran
  }

  private def schema2Tran(schema: String): Tran[_ <: AnyRef, Array[Byte]] = {
    schema.toLowerCase match {
      case "string" ⇒ PrimitiveByteTrans.getTran(classOf[String])
      case "boolean" ⇒ PrimitiveByteTrans.getTran(classOf[JBoolean])
      case "double" ⇒ PrimitiveByteTrans.getTran(classOf[JDouble])
      case "float" ⇒ PrimitiveByteTrans.getTran(classOf[JFloat])
      case "long" ⇒ new Tran[JLong, Array[Byte]] {
        override def from(to: Array[Byte]): JLong = {
          val num = Bytes.toString(to)
          if (num == null) null else new JLong(num)
        }

        override def to(from: JLong): Array[Byte] = Bytes.toBytes(from.toString)
      }
      case "long#b" ⇒ PrimitiveByteTrans.getTran(classOf[JLong])
      case "int" ⇒ new Tran[Integer, Array[Byte]] {
        override def from(to: Array[Byte]): Integer = {
          val num = Bytes.toString(to)
          if (num == null) null else new Integer(num)
        }

        override def to(from: Integer): Array[Byte] = Bytes.toBytes(from.toString)
      }
      case "int#b" ⇒ PrimitiveByteTrans.getTran(classOf[java.lang.Integer])
    }
  }

  private def getScan(startRow: String, endRow: String): Scan = {
    val scan = new Scan()
    scan.setStartRow(Bytes.toBytes(startRow))
    scan.setStopRow(Bytes.toBytes(endRow))
    scan
  }
}

mygithub