cjuexuan / mynote

236 stars 34 forks source link

spark 实现mysql行级别控制 #41

Open cjuexuan opened 7 years ago

cjuexuan commented 7 years ago

从csdn迁移过来的

spark 的save mode

spark 的saveMode在org.apache.spark.sql.SaveMode下,是一个枚举类,支持

  1. Append(在mysql中为append)
  2. Overwrite(在mysql中为先删除表,再整体将新的df存进去)
  3. ErrorIfExists(存在表则报错)
  4. Ignore(存在表则不执行任何动作的退出)

而实际业务开发中,我们可能更希望一些行级别的动作而非这种表级别的动作

新的mysqlSaveMode

总结业务开发过程中常见的需求,设计出以下枚举类:

package org.apache.spark.sql.ximautil

package org.apache.spark.sql.ximautil

/**
  * @author todd.chen at 8/26/16 9:52 PM.
  *         email : todd.chen@ximalaya.com
  */
object JdbcSaveMode extends Enumeration {
  type SaveMode = Value
  val IgnoreTable, Append, Overwrite, Update, ErrorIfExists, IgnoreRecord = Value
}
  1. IgnoreTable 类似原来的Ignore,表存在则不执行动作
  2. Append 类似原来的Append
  3. Overwrite 类似原来的Overwrite
  4. Update 则通过ON DUPLICATE KEY UPDATE保证
  5. ErrorIfExists 则类似原来的ErrorIfExists
  6. IgnoreRecord 则通过INSERT IGNORE INTO保证

对应的执行SQL语句应该是

  /**
    * Returns a PreparedStatement that inserts a row into table via conn.
    */
  def insertStatement(conn: Connection, table: String, rddSchema: StructType, dialect: JdbcDialect, saveMode: SaveMode)
  : PreparedStatement = {
    val columnNames = rddSchema.fields.map(x => dialect.quoteIdentifier(x.name))
    val columns = columnNames.mkString(",")
    val placeholders = rddSchema.fields.map(_ => "?").mkString(",")

    val sql = saveMode match {
      case Update ⇒
        val duplicateSetting = columnNames.map(name ⇒ s"$name=?").mkString(",")
        s"INSERT INTO $table ($columns) VALUES ($placeholders) ON DUPLICATE KEY UPDATE $duplicateSetting"
      case Append | Overwrite ⇒
        s"INSERT INTO $table ($columns) VALUES ($placeholders)"
      case IgnoreRecord ⇒
        s"INSERT IGNORE INTO $table ($columns) VALUES ($placeholders)"
      case _ ⇒ throw new IllegalArgumentException(s"$saveMode is illegal")
    }
    conn.prepareStatement(sql)
  }

JDBCUtil 类的解读和满足需求下的重写

2.0之前的org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils其实是有问题的,对于每一行row的set都进行了比较类型,时间复杂度非常高,2.0之后重写出了一个setter逻辑,形成了一个prepareStatment的模板,这样瞬间将原来的比较类型进行了指数级优化,核心代码:

  // A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for
  // `PreparedStatement`. The last argument `Int` means the index for the value to be set
  // in the SQL statement and also used for the value in `Row`.
  private type JDBCValueSetter = (PreparedStatement, Row, Int) => Unit

  private def makeSetter(
      conn: Connection,
      dialect: JdbcDialect,
      dataType: DataType): JDBCValueSetter = dataType match {
    case IntegerType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setInt(pos + 1, row.getInt(pos))

    case LongType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setLong(pos + 1, row.getLong(pos))

    case DoubleType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setDouble(pos + 1, row.getDouble(pos))

    case FloatType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setFloat(pos + 1, row.getFloat(pos))

    case ShortType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setInt(pos + 1, row.getShort(pos))

    case ByteType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setInt(pos + 1, row.getByte(pos))

    case BooleanType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setBoolean(pos + 1, row.getBoolean(pos))

    case StringType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setString(pos + 1, row.getString(pos))

    case BinaryType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos))

    case TimestampType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos))

    case DateType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos))

    case t: DecimalType =>
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        stmt.setBigDecimal(pos + 1, row.getDecimal(pos))

    case ArrayType(et, _) =>
      // remove type length parameters from end of type name
      val typeName = getJdbcType(et, dialect).databaseTypeDefinition
        .toLowerCase.split("\\(")(0)
      (stmt: PreparedStatement, row: Row, pos: Int) =>
        val array = conn.createArrayOf(
          typeName,
          row.getSeq[AnyRef](pos).toArray)
        stmt.setArray(pos + 1, array)

    case _ =>
      (_: PreparedStatement, _: Row, pos: Int) =>
        throw new IllegalArgumentException(
          s"Can't translate non-null value for field $pos")
  }

这个虽然已经解决了大多数问题,但如果使用DUPLICATE还是有问题的:

  1. 非DUPLICATE的sql : insert into table_name (name,age,id) values (?,?,?)
  2. DUPLICATE的sql : insert into table_name (name,age,id) values (?,?,?) on duplicate key update name =? ,age=?,id=?

所以在prepareStatment中的占位符应该是row的两倍,而且应该是类似这样的一个逻辑:

row[1,2,3]
setter(0,1) //index of setter,index of row
setter(1,2)
setter(2,3)
setter(3,1)
setter(4,2)
setter(5,3)

我们能发现当超过setter.length 的一半时,此时的row的index应该是setterIndex - (setterIndex/2) + 1

所以新的一个实现是这样的:

// A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for
  // `PreparedStatement`.  argument `Int` means the index for the value to be set
  // in the SQL statement and also used for the value in `Row`.
  // offset using in duplicateSetting
  private type JDBCValueSetter = (PreparedStatement, Row, Int, Int) ⇒ Unit

  private def makeSetter(
                          conn: Connection,
                          dialect: JdbcDialect,
                          dataType: DataType): JDBCValueSetter = dataType match {
    case IntegerType ⇒
      (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
        stmt.setInt(pos + 1, row.getInt(pos - offset))

    case LongType ⇒
      (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
        stmt.setLong(pos + 1, row.getLong(pos - offset))

    case DoubleType ⇒
      (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
        stmt.setDouble(pos + 1, row.getDouble(pos - offset))

    case FloatType ⇒
      (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
        stmt.setFloat(pos + 1, row.getFloat(pos - offset))

    case ShortType ⇒
      (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
        stmt.setInt(pos + 1, row.getShort(pos - offset))

    case ByteType ⇒
      (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
        stmt.setInt(pos + 1, row.getByte(pos - offset))

    case BooleanType ⇒
      (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
        stmt.setBoolean(pos + 1, row.getBoolean(pos - offset))

    case StringType ⇒
      (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
        stmt.setString(pos + 1, row.getString(pos - offset))

    case BinaryType ⇒
      (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
        stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos - offset))

    case TimestampType ⇒
      (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
        stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos - offset))

    case DateType ⇒
      (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
        stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos - offset))

    case t: DecimalType ⇒
      (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
        stmt.setBigDecimal(pos + 1, row.getDecimal(pos - offset))

    case ArrayType(et, _) ⇒
      // remove type length parameters from end of type name
      val typeName = getJdbcType(et, dialect).databaseTypeDefinition
        .toLowerCase.split("\\(")(0)
      (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒
        val array = conn.createArrayOf(
          typeName,
          row.getSeq[AnyRef](pos - offset).toArray)
        stmt.setArray(pos + 1, array)

    case _ ⇒
      (_: PreparedStatement, _: Row, pos: Int, offset: Int) ⇒
        throw new IllegalArgumentException(
          s"Can't translate non-null value for field $pos")
  }

 private def getSetter(fields: Array[StructField], connection: Connection, dialect: JdbcDialect, isUpdateMode: Boolean): Array[JDBCValueSetter] = {
    val setter = fields.map(_.dataType).map(makeSetter(connection, dialect, _))
    if (isUpdateMode) {
      Array.fill(2)(setter).flatten
    } else {
      setter
    }
  }

在使用过程中的改变主要是:

源码:

  def savePartition(
      getConnection: () => Connection,
      table: String,
      iterator: Iterator[Row],
      rddSchema: StructType,
      nullTypes: Array[Int],
      batchSize: Int,
      dialect: JdbcDialect,
      isolationLevel: Int): Iterator[Byte] = {
    require(batchSize >= 1,
      s"Invalid value `${batchSize.toString}` for parameter " +
      s"`${JdbcUtils.JDBC_BATCH_INSERT_SIZE}`. The minimum value is 1.")

    val conn = getConnection()
    var committed = false

    var finalIsolationLevel = Connection.TRANSACTION_NONE
    if (isolationLevel != Connection.TRANSACTION_NONE) {
      try {
        val metadata = conn.getMetaData
        if (metadata.supportsTransactions()) {
          // Update to at least use the default isolation, if any transaction level
          // has been chosen and transactions are supported
          val defaultIsolation = metadata.getDefaultTransactionIsolation
          finalIsolationLevel = defaultIsolation
          if (metadata.supportsTransactionIsolationLevel(isolationLevel))  {
            // Finally update to actually requested level if possible
            finalIsolationLevel = isolationLevel
          } else {
            logWarning(s"Requested isolation level $isolationLevel is not supported; " +
                s"falling back to default isolation level $defaultIsolation")
          }
        } else {
          logWarning(s"Requested isolation level $isolationLevel, but transactions are unsupported")
        }
      } catch {
        case NonFatal(e) => logWarning("Exception while detecting transaction support", e)
      }
    }
    val supportsTransactions = finalIsolationLevel != Connection.TRANSACTION_NONE

    try {
      if (supportsTransactions) {
        conn.setAutoCommit(false) // Everything in the same db transaction.
        conn.setTransactionIsolation(finalIsolationLevel)
      }
      val stmt = insertStatement(conn, table, rddSchema, dialect)
      val setters: Array[JDBCValueSetter] = rddSchema.fields.map(_.dataType)
        .map(makeSetter(conn, dialect, _)).toArray

      try {
        var rowCount = 0
        while (iterator.hasNext) {
          val row = iterator.next()
          val numFields = rddSchema.fields.length
          var i = 0
          while (i < numFields) {
            if (row.isNullAt(i)) {
              stmt.setNull(i + 1, nullTypes(i))
            } else {
              setters(i).apply(stmt, row, i)
            }
            i = i + 1
          }
          stmt.addBatch()
          rowCount += 1
          if (rowCount % batchSize == 0) {
            stmt.executeBatch()
            rowCount = 0
          }
        }
        if (rowCount > 0) {
          stmt.executeBatch()
        }
      } finally {
        stmt.close()
      }
      if (supportsTransactions) {
        conn.commit()
      }
      committed = true
    } catch {
      case e: SQLException =>
        val cause = e.getNextException
        if (e.getCause != cause) {
          if (e.getCause == null) {
            e.initCause(cause)
          } else {
            e.addSuppressed(cause)
          }
        }
        throw e
    } finally {
      if (!committed) {
        // The stage must fail.  We got here through an exception path, so
        // let the exception through unless rollback() or close() want to
        // tell the user about another problem.
        if (supportsTransactions) {
          conn.rollback()
        }
        conn.close()
      } else {
        // The stage must succeed.  We cannot propagate any exception close() might throw.
        try {
          conn.close()
        } catch {
          case e: Exception => logWarning("Transaction succeeded, but closing failed", e)
        }
      }
    }
    Array[Byte]().iterator
  }

改动点:

def savePartition(
                     getConnection: () => Connection,
                     table: String,
                     iterator: Iterator[Row],
                     rddSchema: StructType,
                     nullTypes: Array[Int],
                     batchSize: Int,
                     dialect: JdbcDialect,
                     isolationLevel: Int,
                     saveMode: SaveMode) = {
    require(batchSize >= 1,
      s"Invalid value `${batchSize.toString}` for parameter " +
        s"`$JDBC_BATCH_INSERT_SIZE`. The minimum value is 1.")
    val isUpdateMode = saveMode == Update //check is UpdateMode
    val conn = getConnection()
    var committed = false
    val length = rddSchema.fields.length
    val numFields = if (isUpdateMode) length * 2 else length // real num Field length 
      val stmt = insertStatement(conn, table, rddSchema, dialect, saveMode)
      val setters: Array[JDBCValueSetter] = getSetter(rddSchema.fields, conn, dialect, isUpdateMode) //call method getSetter
        var rowCount = 0
        while (iterator.hasNext) {
          val row = iterator.next()
          var i = 0
          val midField = numFields / 2
          while (i < numFields) {
            //if duplicate ,'?' size = 2 * row.field.length
            if (isUpdateMode) {
              i < midField match { // check midField > i ,if midFiled >i ,rowIndex is setterIndex - (setterIndex/2) + 1
                case true ⇒
                  if (row.isNullAt(i)) {
                    stmt.setNull(i + 1, nullTypes(i))
                  } else {
                    setters(i).apply(stmt, row, i, 0)
                  }
                case false ⇒
                  if (row.isNullAt(i - midField)) {
                    stmt.setNull(i + 1, nullTypes(i - midField))
                  } else {
                    setters(i).apply(stmt, row, i, midField)
                  }
              }
            } else {
              if (row.isNullAt(i)) {
                stmt.setNull(i + 1, nullTypes(i))
              } else {
                setters(i).apply(stmt, row, i, 0)
              }
            }
            i = i + 1
          }

封装的bean对象:

case class JdbcSaveExplain(
                            url: String,
                            tableName: String,
                            saveMode: SaveMode,
                            jdbcParam: Properties
                          )

封装的DataFrameWriter对象


package com.ximalaya.spark.xql.exec.jdbc

import java.util.Properties

import com.ximalaya.spark.common.log.CommonLoggerTrait

import language._
import com.ximalaya.spark.xql.interpreter.jdbc.JdbcSaveExplain
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
import org.apache.spark.sql.ximautil.JdbcSaveMode.SaveMode
import org.apache.spark.sql.ximautil.JdbcSaveMode._
import org.apache.spark.sql.ximautil.XQLJdbcUtil

/**
  * @author todd.chen at 8/26/16 11:33 PM.
  *         email : todd.chen@ximalaya.com
  */
class JdbcDataFrameWriter(dataFrame: DataFrame) extends Serializable with CommonLoggerTrait {
  def writeJdbc(jdbcSaveExplain: JdbcSaveExplain) = {
    this.jdbcSaveExplain = jdbcSaveExplain
    this
  }

  def save(): Unit = {
    assert(jdbcSaveExplain != null)
    val saveMode = jdbcSaveExplain.saveMode
    val url = jdbcSaveExplain.url
    val table = jdbcSaveExplain.tableName
    val props = jdbcSaveExplain.jdbcParam
    if (checkTable(url, table, props, saveMode))
      XQLJdbcUtil.saveTable(dataFrame, url, table, props, saveMode)
  }

  private def checkTable(url: String, table: String, connectionProperties: Properties, saveMode: SaveMode): Boolean = {
    val props = new Properties()
    extraOptions.foreach { case (key, value) =>
      props.put(key, value)
    }
    // connectionProperties should override settings in extraOptions
    props.putAll(connectionProperties)
    val conn = JdbcUtils.createConnectionFactory(url, props)()

    try {
      var tableExists = JdbcUtils.tableExists(conn, url, table)
      //table ignore ,exit
      if (saveMode == IgnoreTable && tableExists) {
        logger.info(" table {} exists ,mode is ignoreTable,save nothing to it", table)
        return false
      }
      //error if table exists
      if (saveMode == ErrorIfExists && tableExists) {
        sys.error(s"Table $table already exists.")
      }
      //overwrite table ,delete table
      if (saveMode == Overwrite && tableExists) {
        JdbcUtils.dropTable(conn, table)
        tableExists = false
      }
      // Create the table if the table didn't exist.
      if (!tableExists) {
        checkField(dataFrame)
        val schema = JdbcUtils.schemaString(dataFrame, url)
        val sql = s"CREATE TABLE $table (id int not null primary key auto_increment , $schema)"
        conn.prepareStatement(sql).executeUpdate()
      }
      true
    } finally {
      conn.close()
    }

  }

  //because table in mysql need id  as primary key auto increment,illegal if dataFrame contains id  field
  private def checkField(dataFrame: DataFrame): Unit = {
    if (dataFrame.schema.exists(_.name == "id")) {
      throw new IllegalArgumentException("dataFrame exists id columns,but id is primary key auto increment in mysql ")
    }
  }

  private var jdbcSaveExplain: JdbcSaveExplain = _
  private val extraOptions = new scala.collection.mutable.HashMap[String, String]

}

object JdbcDataFrameWriter {
  implicit def dataFrame2JdbcWriter(dataFrame: DataFrame): JdbcDataFrameWriter = JdbcDataFrameWriter(dataFrame)

  def apply(dataFrame: DataFrame): JdbcDataFrameWriter = new JdbcDataFrameWriter(dataFrame)
}

测试用例:

 implicit def map2Prop(map: Map[String, String]): Properties = map.foldLeft(new Properties) {
    case (prop, kv) ⇒ prop.put(kv._1, kv._2); prop
  }

    val sparkContext = new SparkContext(sparkConf)
    val sqlContext = new SQLContext(sparkContext)
    //    val hiveContext = new HiveContext(sparkContext)
    //    import hiveContext.implicits._
    import sqlContext.implicits._
    val df = sparkContext.parallelize(Seq(
      (1, 1, "2", "ctccct", "286"),
      (2, 2, "2", "ccc", "11"),
      (4, 10, "2", "ccct", "12")
    )).toDF("id", "iid", "uid", "name", "age")
    val jdbcSaveExplain = JdbcSaveExplain(
      "test",
      "jdbc:mysql://localhost:3306/test",
      "mytest",
      JdbcSaveMode.Update,
      Map("user" → "user", "password" → "password")
    )
    import JdbcDataFrameWriter.dataFrame2JdbcWriter
    df.writeJdbc(jdbcSaveExplain).save()

mygithub

teeyog commented 6 years ago

请问当超过半数时:

 case false ⇒
                  if (row.isNullAt(i - midField)) {
                    stmt.setNull(i + 1, nullTypes(i - midField))
                  } else {
                    setters(i).apply(stmt, row, i, midField)
                  }

是否应该是

 case false ⇒
                  if (row.isNullAt(i - midField)) {
                    stmt.setNull(i + 1, nullTypes(i - midField))
                  } else {
                    setters(i - midField).apply(stmt, row, i, midField)
                  }
cjuexuan commented 6 years ago

@SOBIGUFO 不是的,你可以看下numFields在isUpdateMode下已经*2了,这个代码可以本地设置个断点看下

举个例子 INSERT INTO table (a, b) VALUES (?, ?) ON DUPLICATE KEY UPDATE a=?,b=?;

//(1,2)

首先如果isUpdate -> 这里numFields = length * 2 = 4

需要填充的4个值分别是1,2,1,2 midField = 2

所以isUpdate的时候,midField = 2 offset在到length之前都是0,所以相当于直接走row.get(pos) 在走完一遍循环以后 触发到 i > midField i - midField相当于再走一次循环,相当于第一次的row.get(pos)

teeyog commented 6 years ago

@cjuexuan 我明白你上面说的,但是setters的大小就是rddSchema.fields.length,而setters(i)中的i在超过半数的时候不减去midField的话是会越界的。

cjuexuan commented 6 years ago

@SOBIGUFO getSetter 中Array.fill(2)了,setter在遇到updateMode的时候已经扩展了一倍了,你越界的代码看下

teeyog commented 6 years ago

@cjuexuan 明白了,刚才没有注意到getSetter中已经Array.fill(2),谢谢!

cjuexuan commented 6 years ago

@SOBIGUFO :)