cjuexuan / mynote

237 stars 34 forks source link

spark streaming mapWithState关键特性解读 #33

Open cjuexuan opened 7 years ago

cjuexuan commented 7 years ago

spark mapWithState

阅读这篇文章之前,假设读者已经对mapWithState的api比较熟悉,如果不熟悉,可以通过以下link进行学习

databricks demo spark example

spark1.6之后,状态管理多了一个新的选择,mapWithState,本文将从以下几点剖析下这个新的api

  1. 设计理念与设计思路
  2. 过期数据的处理
  3. 底层存储的高效体现

设计理念与设计思路

spark streaming与storm一个很大的不同是在storm里面计算是顶点,数据是边,即数据流向计算,而在spark streaming中数据是节点,计算是边(体现在compute方法上面),普遍做法是将计算流向数据,即Dstream是数据本身,这与RDD的immutable的思路是一致的,但是比较坑爹的是基于这种思路实现的老版状态apiupdateStateByKey在数据量很大的情况下,很容易oom,但新版的mapWithState的却以一种新的设计思路实现了这一点,关键代码如下

InternalMapWithStateDStream的compute方法:

  override def compute(validTime: Time): Option[RDD[MapWithStateRDDRecord[K, S, E]]] = {
    val prevStateRDD = getOrCompute(validTime - slideDuration) match {
      case Some(rdd) =>
        if (rdd.partitioner != Some(partitioner)) {
          MapWithStateRDD.createFromRDD[K, V, S, E](
            rdd.flatMap { _.stateMap.getAll() }, partitioner, validTime)
        } else {
          rdd
        }
      case None =>
        MapWithStateRDD.createFromPairRDD[K, V, S, E](
          spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
          partitioner,
          validTime
        )
    }

这里用到了一个partitioner,该partitioner可以通过StateSpec传入,如果没有指定,将使用HashPartitioner

  private val partitioner = spec.getPartitioner().getOrElse(
    new HashPartitioner(ssc.sc.defaultParallelism))

为什么这样做呢,其实就是类似数据流向计算的思路,如果当前的batch计算有结果,且RDD的partitioner和我们这个类维护的partitioner是一致的,则不需要进行重排,否则,我们将这次计算的数据按照这个类维护的partitioner进行repartition,得到新的RDD,这样就能保证相同的key的数据永远在一个分区,同时也保证了每个分区维护一个大Map这种思路的可行性

过期数据的处理

在我们初始化StateSpec的时候我们传入了一个func,这个方法签名如下

  mappingFunction: (KeyType, Option[ValueType], State[StateType]) => MappedType

这里就产生疑问了,什么时候会出现Value为None的情况,答案是过期的时候,具体方法在MapWithStateRDDRecord中的updateRecordWithData

    dataIterator.foreach { case (key, value) =>
      wrappedState.wrap(newStateMap.get(key))
      //用我们自己的方法去更新状态
      val returned = mappingFunction(batchTime, key, Some(value), wrappedState)
      if (wrappedState.isRemoved) {
        newStateMap.remove(key)
      } else if (wrappedState.isUpdated
          || (wrappedState.exists && timeoutThresholdTime.isDefined)) {
        newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
      }
      mappedData ++= returned
    }

    // Get the timed out state records, call the mapping function on each and collect the
    // data returned
    if (removeTimedoutData && timeoutThresholdTime.isDefined) {
      newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) =>
        wrappedState.wrapTimingOutState(state)
        val returned = mappingFunction(batchTime, key, None, wrappedState)//过期调用
        mappedData ++= returned
        //删除过期数据
        newStateMap.remove(key)
      }
    }

所以在我们的case match中处理过期数据就是判断value是不是None

底层存储的高效性

通过上一个方法,保证了每个分区一个大Map实现思路的可行性,也保证了相同的key一定会打到同一个分区,所以部分更新是可以实现的,现在就看一下真正存放这些状态的数据结果:StateMap, StateMap有两个子类,一个是EmptyStateMap,一个是基于OpenHashMap实现的OpenHashMapBasedStateMap,如果每次key的更新都创建和clone旧的状态,gc的压力会非常大,所以这里在每一层的OpenHashMapBasedStateMap维护了一个deltaMap,在创建新的stateMap的时候,如果不需要合并则传递引用

  //`copy`方法,在`MapWithStateRDDRecord`中被调用,其实就是`copyByReference`
  override def copy(): StateMap[K, S] = {
    new OpenHashMapBasedStateMap[K, S](this, deltaChainThreshold = deltaChainThreshold)
  }

下面看几个具体方法的实现,首先是remove方法

  //`remove` 如果key在当前层,则将state设为被删除状态,
  // 如果不在当前层,则把key加入当前层,且把state置为删除状态(写冗余),
  // 这也是为什么在getAll方法的时候需要去重,以及get方法要从当前层开始查找的原因
  override def remove(key: K): Unit = {
    val stateInfo = deltaMap(key) //判断在不在当前层
    if (stateInfo != null) {
      stateInfo.markDeleted() //当前层删除
    } else {
      val newInfo = new StateInfo[S](deleted = true)
      deltaMap.update(key, newInfo)//删除父类,
      // 在当前层冗余一个delete标志,这也是get为啥要从当前层先开始判断的一个原因,
      // 防止先从父类获取时能够找到,通过children的delete标志屏蔽
    }
  }

同时我们看一下对应的get方法

  //`get` 则先判断在不在当前层的`deltaMap`中,如果在判断有没有被删除掉,如果不在则递归调用parent的get方法
  override def get(key: K): Option[S] = {
    val stateInfo = deltaMap(key)//判断在不在这一层
    if (stateInfo != null) {
      if (!stateInfo.deleted) {//判断有没有被删掉
        Some(stateInfo.data)
      } else {
        None
      }
    } else {
      parentStateMap.get(key)//从parent中获取
    }
  }

最后看一下获取全部状态的方法

  //`getAll`方法,先获取parent的那些不在当前`deleteMap`的值 ,union上当前`deleteMap`中没被删除的
  override def getAll(): Iterator[(K, S, Long)] = {
    //DISTINCT
    val oldStates = parentStateMap.getAll().filter { case (key, _, _) =>
      !deltaMap.contains(key) //由于每一层都可能冗余写delete,所以要过滤
    }//递归获取父类的

    val updatedStates = deltaMap.iterator.filter { ! _._2.deleted }.map { case (key, stateInfo) =>
      (key, stateInfo.data, stateInfo.updateTime)
    }
    oldStates ++ updatedStates
  }

通过写的时候的冗余写和读的时候屏蔽过滤,实现了这种高效的数据结构