snorkel-team / snorkel

A system for quickly generating training data with weak supervision
https://snorkel.org
Apache License 2.0
5.81k stars 857 forks source link

ERROR TaskSetManager: Total size of serialized results of X tasks is bigger than spark.driver.maxResultSize #1501

Closed rjurney closed 4 years ago

rjurney commented 5 years ago

Issue description

While this isn't the normal use pattern, I wanted to see if you want me to add a note to the documentation for SparkLFApplier to increase/set 0 the value of spark.driver.maxResultSize if you use a lot of LFs on a lot of data, otherwise you get this error:

ERROR TaskSetManager: Total size of serialized results of 8 tasks (1077.1 MB) is bigger than spark.driver.maxResultSize (1024.0 MB)

It indicates that the label numpy.array returned by SparkLFApplier when it calls collect() has caused spark.driver.maxResultSize of 1GB to be exceeded. This is because I have 1.7 million records and 1800 LFs. The final numpy array is 1.7M x 1800 and totals 24GB. This is a Spark configuration issue, but for multi-task problems it is one people may run into.

It dies on the second to last line below:

class SparkLFApplier(BaseLFApplier):
    def apply(self, data_points: RDD, fault_tolerant: bool = False) -> np.ndarray:
        f_caller = _FunctionCaller(fault_tolerant)

        def map_fn(args: Tuple[DataPoint, int]) -> RowData:
            return apply_lfs_to_data_point(*args, lfs=self._lfs, f_caller=f_caller)

        labels = data_points.zipWithIndex().map(map_fn).collect() # dies right here, I believe
        return self._numpy_from_row_data(labels)

I set it to 0 via spark-submit --conf spark.driver.maxResultSize=0 and it runs ok. The call to numpy.save() results in a 24GB file, so not everyone will have this happen but for multi-task problems it seems like it might be common enough.

Code example/repro steps

The script I ran is this: https://github.com/rjurney/weakly_supervised_learning_code/blob/0491c0a42ac7a2af9bf72287c5d7e9ec76ebff9c/ch05/label.spark.py

via spark-submit ch05/label.spark.py and after a while I get this:

[Stage 9:=============================>                            (7 + 7) / 14]19/10/25 12:54:22 ERROR TaskSetManager: Total size of serialized results of 8 tasks (1077.1 MB) is bigger than spark.driver.maxResultSize (1024.0 MB)
---------------------------------------------------------------------------
Py4JJavaError                             Traceback (most recent call last)
<ipython-input-8-8af585f28459> in <module>
      1 spark_applier = SparkLFApplier(list(keyword_lfs.values()))
----> 2 weak_labels = spark_applier.apply(label_encoded)
      3
      4 # Save the weak labels numpy array for use locallys
      5 np.save(

~/weakly_supervised_learning_code/snorkel/snorkel/labeling/apply/spark.py in apply(self, data_points, fault_tolerant)
     39             return apply_lfs_to_data_point(*args, lfs=self._lfs, f_caller=f_caller)
     40
---> 41         labels = data_points.zipWithIndex().map(map_fn).collect()
     42         return self._numpy_from_row_data(labels)

~/anaconda3/envs/weak/lib/python3.7/site-packages/pyspark/rdd.py in collect(self)
    814         """
    815         with SCCallSiteSync(self.context) as css:
--> 816             sock_info = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
    817         return list(_load_from_socket(sock_info, self._jrdd_deserializer))
    818

~/anaconda3/envs/weak/lib/python3.7/site-packages/py4j/java_gateway.py in __call__(self, *args)
   1255         answer = self.gateway_client.send_command(command)
   1256         return_value = get_return_value(
-> 1257             answer, self.gateway_client, self.target_id, self.name)
   1258
   1259         for temp_arg in temp_args:

~/anaconda3/envs/weak/lib/python3.7/site-packages/pyspark/sql/utils.py in deco(*a, **kw)
     61     def deco(*a, **kw):
     62         try:
---> 63             return f(*a, **kw)
     64         except py4j.protocol.Py4JJavaError as e:
     65             s = e.java_exception.toString()

~/anaconda3/envs/weak/lib/python3.7/site-packages/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
    326                 raise Py4JJavaError(
    327                     "An error occurred while calling {0}{1}{2}.\n".
--> 328                     format(target_id, ".", name), value)
    329             else:
    330                 raise Py4JError(

Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe.
: org.apache.spark.SparkException: Job aborted due to stage failure: Total size of serialized results of 8 tasks (1077.1 MB) is bigger than spark.driver.maxResultSize (1024.0 MB)
    at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1889)
    at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1877)
    at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1876)
    at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
    at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
    at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1876)
    at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
    at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
    at scala.Option.foreach(Option.scala:257)
    at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:926)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2110)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2059)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2048)
    at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
    at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:737)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:2082)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:2101)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:2126)
    at org.apache.spark.rdd.RDD$$anonfun$collect$1.apply(RDD.scala:945)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
    at org.apache.spark.rdd.RDD.withScope(RDD.scala:363)
    at org.apache.spark.rdd.RDD.collect(RDD.scala:944)
    at org.apache.spark.api.python.PythonRDD$.collectAndServe(PythonRDD.scala:166)
    at org.apache.spark.api.python.PythonRDD.collectAndServe(PythonRDD.scala)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
    at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
    at py4j.Gateway.invoke(Gateway.java:282)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
    at py4j.commands.CallCommand.execute(CallCommand.java:79)
    at py4j.GatewayConnection.run(GatewayConnection.java:238)
    at java.lang.Thread.run(Thread.java:748)

Expected behavior

I expect it to return a large numpy array :)

System info

Additional context

Add any other context about the problem here.

rjurney commented 5 years ago

Oh, one thing I thought about... rather than use -1 as obstain and integers as classes in a dense matrix, it would be incredibly more efficient to use a sparse matrix and impute abstain if possible. At least in my application.

henryre commented 5 years ago

Hi @rjurney, this is great, thanks for reporting!! Will mark as Q&A for other folks looking for answers. This is - I think - a somewhat unique case since you have thousands of LFs and are running single-node Spark (see follow-up on #1500), but we discussed the design decisions and future plans around sparse matrices a bit here: https://github.com/snorkel-team/snorkel/pull/1309#issuecomment-545203275

github-actions[bot] commented 4 years ago

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.