locationtech-labs / geopyspark

GeoTrellis for PySpark
Other
179 stars 59 forks source link

TiledRasterLayer count fails on combined multiband raster bands #677

Open darrenwiens opened 6 years ago

darrenwiens commented 6 years ago

The following fails using combined bands from a multiband geotiff, but succeeds with singleband jp2 files.

tif = 'PATH/TO/MULTIBAND_TIFF'
jpg1 = 'PATH/TO/SINGLEBAND_JP2_B01'
jpg2 = 'PATH/TO/SINGLEBAND_JP2_B02'
jpg3 = 'PATH/TO/SINGLEBAND_JP2_B03'

band0 = gps.RasterLayer.from_numpy_rdd(gps.LayerType.SPATIAL,  gps.rasterio.get(tif, bands=[1]))
band1 = gps.RasterLayer.from_numpy_rdd(gps.LayerType.SPATIAL,  gps.rasterio.get(tif, bands=[2]))
band2 = gps.RasterLayer.from_numpy_rdd(gps.LayerType.SPATIAL,  gps.rasterio.get(tif, bands=[3]))

# band0 = gps.RasterLayer.from_numpy_rdd(gps.LayerType.SPATIAL,  gps.rasterio.get(jpg1))
# band1 = gps.RasterLayer.from_numpy_rdd(gps.LayerType.SPATIAL,  gps.rasterio.get(jpg2))
# band2 = gps.RasterLayer.from_numpy_rdd(gps.LayerType.SPATIAL,  gps.rasterio.get(jpg3))

bands = gps.combine_bands([band0, band1, band2])
tiled_layer = bands.tile_to_layout(gps.GlobalLayout(zoom=14), target_crs=3857, partition_strategy=gps.SpatialPartitionStrategy(12))
pyramided_layer = tiled_layer.pyramid()
for tiled_layer in pyramided_layer.levels.values():
    print(tiled_layer.count())

Error (can provide more thorough traceback if useful):

Py4JJavaError: An error occurred while calling o6350.count.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 10 in stage 24946.0 failed 33 times, most recent failure: Lost task 10.32 in stage 24946.0 (TID 122161, ip-10-150-43-130.us-east-1.datasvc.internal, executor 234): geotrellis.raster.GeoAttrsError: invalid rows: 0

I suspect there is some difference in how the underlying RDDs are treated/disposed of when self.srdd.rdd().count() is called in the count method. The files appear to be tiled properly in both cases. For example, I get the correct tile count using:

def get_tile_count(tiled_raster_layer):
    bounds = tiled_raster_layer.layer_metadata.bounds
    cols = bounds.maxKey.col - bounds.minKey.col + 1
    rows = bounds.maxKey.row - bounds.minKey.row + 1
    return cols * rows
jbouffard commented 6 years ago

Thanks for bringing this to our attention, @phloem7! If you can, could you post the full stacktrace?

RDDs are lazy and will not be instantiated until they have to be. There are certain methods called, actions that will trigger their execution. The count method being one of them. So what's most likely happening is that something went wrong during the creation of tiled_layer, but it wasn't brought up until you did a count. What's interesting here is that since you're able to access the layer_metadata, we know that the error most likely occurred somewhere after tiling. What that error is, I'm not sure, but we will find out!

darrenwiens commented 6 years ago

Thanks for looking into this, @jbouffard. Here's the full stacktrace:

---------------------------------------------------------------------------
Py4JJavaError                             Traceback (most recent call last)
<ipython-input-330-8eb01bf8ca63> in <module>()
     23 pyramided_layer = tiled_layer.pyramid()
     24 for tiled_layer in pyramided_layer.levels.values():
---> 25     print(tiled_layer.count())
     26 #     print(get_tile_count(tiled_layer))

/usr/local/lib/python3.4/site-packages/geopyspark/geotrellis/layer.py in count(self)
    310         """
    311 
--> 312         return self.srdd.rdd().count()
    313 
    314     def isEmpty(self):

/usr/local/lib/python3.4/site-packages/py4j/java_gateway.py in __call__(self, *args)
   1255         answer = self.gateway_client.send_command(command)
   1256         return_value = get_return_value(
-> 1257             answer, self.gateway_client, self.target_id, self.name)
   1258 
   1259         for temp_arg in temp_args:

/usr/local/lib/python3.4/site-packages/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
    326                 raise Py4JJavaError(
    327                     "An error occurred while calling {0}{1}{2}.\n".
--> 328                     format(target_id, ".", name), value)
    329             else:
    330                 raise Py4JError(

Py4JJavaError: An error occurred while calling o6350.count.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 10 in stage 24946.0 failed 33 times, most recent failure: Lost task 10.32 in stage 24946.0 (TID 122161, ip-10-150-43-130.us-east-1.datasvc.internal, executor 234): geotrellis.raster.GeoAttrsError: invalid rows: 0
    at geotrellis.raster.RasterExtent.<init>(RasterExtent.scala:78)
    at geotrellis.raster.RasterExtent$.apply(RasterExtent.scala:279)
    at geotrellis.raster.RasterExtent$.apply(RasterExtent.scala:295)
    at geotrellis.raster.resample.Resample.<init>(Resample.scala:44)
    at geotrellis.raster.resample.NearestNeighborResample.<init>(NearestNeighborResample.scala:23)
    at geotrellis.raster.resample.Resample$.apply(Resample.scala:104)
    at geotrellis.raster.merge.SinglebandTileMergeMethods$class.merge(SinglebandTileMergeMethods.scala:104)
    at geotrellis.raster.package$withTileMethods.merge(package.scala:55)
    at geotrellis.raster.merge.MultibandTileMergeMethods$$anonfun$2.apply(MultibandTileMergeMethods.scala:61)
    at geotrellis.raster.merge.MultibandTileMergeMethods$$anonfun$2.apply(MultibandTileMergeMethods.scala:57)
    at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
    at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
    at scala.collection.immutable.Range.foreach(Range.scala:160)
    at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
    at scala.collection.AbstractTraversable.map(Traversable.scala:104)
    at geotrellis.raster.merge.MultibandTileMergeMethods$class.merge(MultibandTileMergeMethods.scala:57)
    at geotrellis.raster.package$withMultibandTileMethods.merge(package.scala:83)
    at geotrellis.raster.package$withMultibandTileMethods.merge(package.scala:83)
    at geotrellis.spark.tiling.CutTiles$$anonfun$apply$1$$anonfun$apply$2.apply(CutTiles.scala:58)
    at geotrellis.spark.tiling.CutTiles$$anonfun$apply$1$$anonfun$apply$2.apply(CutTiles.scala:54)
    at scala.collection.Iterator$$anon$11.next(Iterator.scala:409)
    at scala.collection.Iterator$$anon$12.next(Iterator.scala:444)
    at org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)
    at org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)
    at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)
    at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)
    at org.apache.spark.scheduler.Task.run(Task.scala:109)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    at java.lang.Thread.run(Thread.java:748)

Driver stacktrace:
    at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1753)
    at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1741)
    at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1740)
    at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
    at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
    at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1740)
    at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:871)
    at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:871)
    at scala.Option.foreach(Option.scala:257)
    at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:871)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:1974)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1923)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1912)
    at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:48)
    at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:682)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:2034)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:2055)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:2074)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:2099)
    at org.apache.spark.rdd.RDD.count(RDD.scala:1162)
    at sun.reflect.GeneratedMethodAccessor463.invoke(Unknown Source)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
    at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
    at py4j.Gateway.invoke(Gateway.java:282)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
    at py4j.commands.CallCommand.execute(CallCommand.java:79)
    at py4j.GatewayConnection.run(GatewayConnection.java:238)
    at java.lang.Thread.run(Thread.java:748)
Caused by: geotrellis.raster.GeoAttrsError: invalid rows: 0
    at geotrellis.raster.RasterExtent.<init>(RasterExtent.scala:78)
    at geotrellis.raster.RasterExtent$.apply(RasterExtent.scala:279)
    at geotrellis.raster.RasterExtent$.apply(RasterExtent.scala:295)
    at geotrellis.raster.resample.Resample.<init>(Resample.scala:44)
    at geotrellis.raster.resample.NearestNeighborResample.<init>(NearestNeighborResample.scala:23)
    at geotrellis.raster.resample.Resample$.apply(Resample.scala:104)
    at geotrellis.raster.merge.SinglebandTileMergeMethods$class.merge(SinglebandTileMergeMethods.scala:104)
    at geotrellis.raster.package$withTileMethods.merge(package.scala:55)
    at geotrellis.raster.merge.MultibandTileMergeMethods$$anonfun$2.apply(MultibandTileMergeMethods.scala:61)
    at geotrellis.raster.merge.MultibandTileMergeMethods$$anonfun$2.apply(MultibandTileMergeMethods.scala:57)
    at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
    at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
    at scala.collection.immutable.Range.foreach(Range.scala:160)
    at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
    at scala.collection.AbstractTraversable.map(Traversable.scala:104)
    at geotrellis.raster.merge.MultibandTileMergeMethods$class.merge(MultibandTileMergeMethods.scala:57)
    at geotrellis.raster.package$withMultibandTileMethods.merge(package.scala:83)
    at geotrellis.raster.package$withMultibandTileMethods.merge(package.scala:83)
    at geotrellis.spark.tiling.CutTiles$$anonfun$apply$1$$anonfun$apply$2.apply(CutTiles.scala:58)
    at geotrellis.spark.tiling.CutTiles$$anonfun$apply$1$$anonfun$apply$2.apply(CutTiles.scala:54)
    at scala.collection.Iterator$$anon$11.next(Iterator.scala:409)
    at scala.collection.Iterator$$anon$12.next(Iterator.scala:444)
    at org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:193)
    at org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)
    at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:96)
    at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)
    at org.apache.spark.scheduler.Task.run(Task.scala:109)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    ... 1 more
jbouffard commented 6 years ago

Hey, @phloem7! After doing some investigation, it looks like there's something wrong with the gps.rasterio module in terms of reading in GeoTiffs. I'm not entirely sure what the cause is right now, but I can offer a temporary workaround.

rdd = gps.geotiff.get(gps.LayerType.SPATIAL, tif)
bands = rdd.bands([1, 2, 3])

The above code should do what you're looking for. I'll keep looking into the gps.rasterio issue and will let you know what I find.