aehrc / VariantSpark

machine learning for genomic variants
http://bioinformatics.csiro.au/variantspark
Other
141 stars 45 forks source link

Batch size #97

Closed ArashBayatDev closed 9 months ago

ArashBayatDev commented 5 years ago

When using VariantSpark Interface for Hail, a large batch size could lead to a crash in the process. For example for the following setup a batch size of 250 result in failure (tested several times) while the batch size of 50 works well.

setup: Hail note book running on AWS EMR with r4.xlarge master node. "spark.driver.memory":"24G"

dataset: 0.5 SNPs and ~5k samples

the Failure error: Py4JJavaErrorTraceback (most recent call last)

in () ----> 1 Jiali('BD') in Jiali(prefix) 40 vds = vds.annotate_samples_table(sets, root='sa.set') 41 ---> 42 (sa_kt,va_kt) = LR_VS(vds) 43 vds = vds.annotate_samples_table(sa_kt, root='sa.set_all') 44 vds = vds.annotate_variants_table(va_kt, root='va.set_all') in LR_VS(vds) 9 print("MAF >>>", xvds.count()) 10 ---> 11 via = xvds.importance_analysis("sa.label", n_trees = 1000, mtry_fraction = 0.1, oob = False, seed = 13L, batch_size = 250) 12 vs_kt = via.important_variants(1000000).order_by(desc('importance')).indexed('rank').rename(['v','vs_is','vs_rank']); 13 /home/hadoop/miniconda2/envs/jupyter/lib/python2.7/site-packages/typedecorator/__init__.pyc in wrapper(*args, **kwargs) 396 "doesn't match signature %s" % (k, repr(v), 397 _constraint_to_string(types[k]))) --> 398 return fn(*args, **kwargs) 399 400 wrapper.__name__ = fn.__name__ /home/hadoop/miniconda2/envs/jupyter/lib/python2.7/site-packages/varspark/hail/extend.pyc in importance_analysis(self, y_expr, n_trees, mtry_fraction, oob, seed, batch_size) 54 self._vshf_cache.importanceAnalysis(y_expr, n_trees, joption(mtry_fraction), 55 oob, joption(long(seed) if seed is not None else None), ---> 56 batch_size)) 57 58 @params(self=object, operation_name=str) /usr/lib/spark/python/lib/py4j-0.10.4-src.zip/py4j/java_gateway.py in __call__(self, *args) 1131 answer = self.gateway_client.send_command(command) 1132 return_value = get_return_value( -> 1133 answer, self.gateway_client, self.target_id, self.name) 1134 1135 for temp_arg in temp_args: /usr/lib/spark/python/lib/pyspark.zip/pyspark/sql/utils.py in deco(*a, **kw) 61 def deco(*a, **kw): 62 try: ---> 63 return f(*a, **kw) 64 except py4j.protocol.Py4JJavaError as e: 65 s = e.java_exception.toString() /usr/lib/spark/python/lib/py4j-0.10.4-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name) 317 raise Py4JJavaError( 318 "An error occurred while calling {0}{1}{2}.\n". --> 319 format(target_id, ".", name), value) 320 else: 321 raise Py4JError( Py4JJavaError: An error occurred while calling o388.importanceAnalysis. : org.apache.spark.SparkException: Exception thrown in awaitResult: at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:205) at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75) at org.apache.spark.storage.BlockManagerMaster.removeBroadcast(BlockManagerMaster.scala:161) at org.apache.spark.broadcast.TorrentBroadcast$.unpersist(TorrentBroadcast.scala:306) at org.apache.spark.broadcast.TorrentBroadcast.doDestroy(TorrentBroadcast.scala:197) at org.apache.spark.broadcast.Broadcast.destroy(Broadcast.scala:111) at org.apache.spark.broadcast.Broadcast.destroy(Broadcast.scala:98) at au.csiro.pbdava.ssparkle.spark.SparkUtils$.withBroadcast(SparkUtils.scala:16) at au.csiro.variantspark.algo.DecisionTree$$anonfun$splitSubsets$3.apply(DecisionTree.scala:323) at au.csiro.variantspark.algo.DecisionTree$$anonfun$splitSubsets$3.apply(DecisionTree.scala:322) at au.csiro.pbdava.ssparkle.common.utils.Prof$class.profIt(Prof.scala:19) at au.csiro.variantspark.algo.DecisionTree$.profIt(DecisionTree.scala:311) at au.csiro.variantspark.algo.DecisionTree$.splitSubsets(DecisionTree.scala:322) at au.csiro.variantspark.algo.DecisionTree$$anonfun$findBestSplitsAndSubsets$2$$anonfun$apply$19.apply(DecisionTree.scala:644) at au.csiro.variantspark.algo.DecisionTree$$anonfun$findBestSplitsAndSubsets$2$$anonfun$apply$19.apply(DecisionTree.scala:642) at au.csiro.pbdava.ssparkle.spark.SparkUtils$.withBroadcast(SparkUtils.scala:14) at au.csiro.variantspark.algo.DecisionTree$$anonfun$findBestSplitsAndSubsets$2.apply(DecisionTree.scala:642) at au.csiro.variantspark.algo.DecisionTree$$anonfun$findBestSplitsAndSubsets$2.apply(DecisionTree.scala:640) at au.csiro.pbdava.ssparkle.common.utils.Prof$class.profIt(Prof.scala:19) at au.csiro.variantspark.algo.DecisionTree.profIt(DecisionTree.scala:546) at au.csiro.variantspark.algo.DecisionTree.findBestSplitsAndSubsets(DecisionTree.scala:640) at au.csiro.variantspark.algo.DecisionTree.buildSplit(DecisionTree.scala:611) at au.csiro.variantspark.algo.DecisionTree.buildSplit(DecisionTree.scala:617) at au.csiro.variantspark.algo.DecisionTree.buildSplit(DecisionTree.scala:617) at au.csiro.variantspark.algo.DecisionTree.buildSplit(DecisionTree.scala:617) at au.csiro.variantspark.algo.DecisionTree.buildSplit(DecisionTree.scala:617) at au.csiro.variantspark.algo.DecisionTree.buildSplit(DecisionTree.scala:617) at au.csiro.variantspark.algo.DecisionTree.buildSplit(DecisionTree.scala:617) at au.csiro.variantspark.algo.DecisionTree.buildSplit(DecisionTree.scala:617) at au.csiro.variantspark.algo.DecisionTree.buildSplit(DecisionTree.scala:617) at au.csiro.variantspark.algo.DecisionTree.buildSplit(DecisionTree.scala:617) at au.csiro.variantspark.algo.DecisionTree.buildSplit(DecisionTree.scala:617) at au.csiro.variantspark.algo.DecisionTree.buildSplit(DecisionTree.scala:617) at au.csiro.variantspark.algo.DecisionTree.buildSplit(DecisionTree.scala:617) at au.csiro.variantspark.algo.DecisionTree.buildSplit(DecisionTree.scala:617) at au.csiro.variantspark.algo.DecisionTree.buildSplit(DecisionTree.scala:617) at au.csiro.variantspark.algo.DecisionTree.buildSplit(DecisionTree.scala:617) at au.csiro.variantspark.algo.DecisionTree.buildSplit(DecisionTree.scala:617) at au.csiro.variantspark.algo.DecisionTree.buildSplit(DecisionTree.scala:617) at au.csiro.variantspark.algo.DecisionTree.buildSplit(DecisionTree.scala:617) at au.csiro.variantspark.algo.DecisionTree.buildSplit(DecisionTree.scala:617) at au.csiro.variantspark.algo.DecisionTree.buildSplit(DecisionTree.scala:617) at au.csiro.variantspark.algo.DecisionTree$$anonfun$14.apply(DecisionTree.scala:588) at au.csiro.variantspark.algo.DecisionTree$$anonfun$14.apply(DecisionTree.scala:587) at au.csiro.pbdava.ssparkle.spark.SparkUtils$.withBroadcast(SparkUtils.scala:14) at au.csiro.variantspark.algo.DecisionTree.batchTrain(DecisionTree.scala:587) at au.csiro.variantspark.algo.RandomForest$$anon$1.batchTrain(RandomForest.scala:177) at au.csiro.variantspark.algo.RandomForest$$anonfun$2$$anonfun$apply$4.apply(RandomForest.scala:235) at au.csiro.variantspark.algo.RandomForest$$anonfun$2$$anonfun$apply$4.apply(RandomForest.scala:232) at au.csiro.pbdava.ssparkle.common.utils.Timed$.time(Timed.scala:30) at au.csiro.variantspark.algo.RandomForest$$anonfun$2.apply(RandomForest.scala:232) at au.csiro.variantspark.algo.RandomForest$$anonfun$2.apply(RandomForest.scala:231) at scala.collection.Iterator$$anon$12.nextCur(Iterator.scala:434) at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440) at scala.collection.Iterator$class.foreach(Iterator.scala:893) at scala.collection.AbstractIterator.foreach(Iterator.scala:1336) at scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:59) at scala.collection.mutable.ListBuffer.$plus$plus$eq(ListBuffer.scala:183) at scala.collection.mutable.ListBuffer.$plus$plus$eq(ListBuffer.scala:45) at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:310) at scala.collection.AbstractIterator.to(Iterator.scala:1336) at scala.collection.TraversableOnce$class.toList(TraversableOnce.scala:294) at scala.collection.AbstractIterator.toList(Iterator.scala:1336) at au.csiro.variantspark.algo.RandomForest.batchTrain(RandomForest.scala:255) at au.csiro.variantspark.api.ImportanceAnalysis.rfModel$lzycompute(ImportanceAnalysis.scala:44) at au.csiro.variantspark.api.ImportanceAnalysis.rfModel(ImportanceAnalysis.scala:39) at au.csiro.variantspark.api.ImportanceAnalysis.(ImportanceAnalysis.scala:47) at au.csiro.variantspark.api.ImportanceAnalysis$.fromParams(ImportanceAnalysis.scala:105) at au.csiro.variantspark.hail.methods.RfImportanceAnalysis$.apply(RfImportanceAnalysis.scala:56) at au.csiro.variantspark.hail.VSHailFunctions$.importanceAnalysis$extension(VSHailFunctions.scala:21) at au.csiro.variantspark.hail.VSHailFunctions.importanceAnalysis(VSHailFunctions.scala:19) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244) at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357) at py4j.Gateway.invoke(Gateway.java:280) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.GatewayConnection.run(GatewayConnection.java:214) at java.lang.Thread.run(Thread.java:748) Caused by: java.io.IOException: Connection reset by peer at sun.nio.ch.FileDispatcherImpl.read0(Native Method) at sun.nio.ch.SocketDispatcher.read(SocketDispatcher.java:39) at sun.nio.ch.IOUtil.readIntoNativeBuffer(IOUtil.java:223) at sun.nio.ch.IOUtil.read(IOUtil.java:192) at sun.nio.ch.SocketChannelImpl.read(SocketChannelImpl.java:380) at io.netty.buffer.PooledUnsafeDirectByteBuf.setBytes(PooledUnsafeDirectByteBuf.java:221) at io.netty.buffer.AbstractByteBuf.writeBytes(AbstractByteBuf.java:899) at io.netty.channel.socket.nio.NioSocketChannel.doReadBytes(NioSocketChannel.java:275) at io.netty.channel.nio.AbstractNioByteChannel$NioByteUnsafe.read(AbstractNioByteChannel.java:119) at io.netty.channel.nio.NioEventLoop.processSelectedKey(NioEventLoop.java:643) at io.netty.channel.nio.NioEventLoop.processSelectedKeysOptimized(NioEventLoop.java:566) at io.netty.channel.nio.NioEventLoop.processSelectedKeys(NioEventLoop.java:480) at io.netty.channel.nio.NioEventLoop.run(NioEventLoop.java:442) at io.netty.util.concurrent.SingleThreadEventExecutor$2.run(SingleThreadEventExecutor.java:131) at io.netty.util.concurrent.DefaultThreadFactory$DefaultRunnableDecorator.run(DefaultThreadFactory.java:144) ... 1 more
rocreguant commented 9 months ago

Need to work with smaller batch sizes