jupyter-server / enterprise_gateway

A lightweight, multi-tenant, scalable and secure gateway that enables Jupyter Notebooks to share resources across distributed clusters such as Apache Spark, Kubernetes and others.
https://jupyter-enterprise-gateway.readthedocs.io/en/latest/
Other
620 stars 223 forks source link

Changes for handling interrupts for PySpark Kubernetes Kernel #1115

Closed SamRanjha closed 2 years ago

SamRanjha commented 2 years ago

This PR fixes the issue #1112. Removed the previous commit and PR(#1114) as we are using the new logic suggested by @kevin-bates.

Testing:

  1. Started a PySpark kernel with the changes.
  2. Started kernels and executed code followed by interrupting the same. Verified job gets killed in Spark Logs and also we can start executing other cells.
  3. Ran jobs in threads and verified all running jobs were killed.
kevin-bates commented 2 years ago

Hmm. I'm not seeing this work as expected when running within YARN - although it should. You can check this by doing the following...

  1. Build the enterprise-gateway-demo image. I would run make clean dist kernelspecs enterprise-gateway-demo
  2. Start the container. This can be done via make itest-yarn-prep. If the startup times out, the container is still starting and can be monitored using docker logs -f itest-yarn. The container is considered started once you see the EG startup banner: [I 2022-06-21 23:38:30.915 EnterpriseGatewayApp] Jupyter Enterprise Gateway 3.0.0.dev0 is available at http://0.0.0.0:8888
  3. Start the jupyter lab client: jupyter lab --debug --notebook-dir=~/notebooks --gateway-url=http://localhost:8888
  4. Open your notebook using the Spark Python YARN Cluster kernel
  5. It takes a little while to launch and sometimes times out depending on your host, but you should be able to get the kernel launched in cluster mode. (Note that this image is used for the various CI integration tests for the YARN kernelspecs.)
  6. Confirm you have a fully established context by executing sc and you should see version information as well as your application name that should match the kernel id.
  7. I believe you could use the code for your long and short-running scenarios.
  8. The user application logs can be found in the container (via docker exec) in /usr/hdp/current/hadoop/logs/userlogs.

I'm seeing similar behavior where the next spark (short-running) command completes a bit longer than the expected time of the long-running command (despite the interrupt occurring at the notebook interface level).

kevin-bates commented 2 years ago

@SamRanjha - are you planning on making the suggested updates, or should one of the maintainers go ahead and make the changes if you're unable? Just checking your availability - thanks.

SamRanjha commented 2 years ago

@SamRanjha - are you planning on making the suggested updates, or should one of the maintainers go ahead and make the changes if you're unable? Just checking your availability - thanks.

@kevin-bates yes I am working on the same, will raise the updated review soon

kevin-bates commented 2 years ago

Hi @SamRanjha. With the last set of changes, I no longer see the error about a missing cancelAllJobs emitted from the non-spark kernel due to the change to protect using cluster_type.

However, for spark-based python kernels, I am now seeing the following message emitted...

Error occurred while calling handler: An error occurred while calling o33.sc

which is then followed by the interrupted output...

E 2022-06-29 18:53:39,179.179 root] Exception while sending command.
Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/socket.py", line 589, in readinto
    return self._sock.recv_into(b)
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/opt/spark/python/lib/py4j-0.10.9.3-src.zip/py4j/clientserver.py", line 475, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
RuntimeError: reentrant call inside <_io.BufferedReader name=57>

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/opt/spark/python/lib/py4j-0.10.9.3-src.zip/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
  File "/opt/spark/python/lib/py4j-0.10.9.3-src.zip/py4j/clientserver.py", line 504, in send_command
    "Error while sending or receiving", e, proto.ERROR_ON_RECEIVE)
py4j.protocol.Py4JNetworkError: Error while sending or receiving
[E 2022-06-29 18:53:39,180.180 root] KeyboardInterrupt while sending command.
Traceback (most recent call last):
  File "/opt/spark/python/lib/py4j-0.10.9.3-src.zip/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
  File "/opt/spark/python/lib/py4j-0.10.9.3-src.zip/py4j/clientserver.py", line 475, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
  File "/opt/conda/lib/python3.7/socket.py", line 589, in readinto
    return self._sock.recv_into(b)
KeyboardInterrupt
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
/usr/local/bin/kernel-launchers/python/scripts/launch_ipykernel.py in <module>
     13 
     14 start=time.time()
---> 15 count = sc.parallelize(range(1, n + 1), partitions).map(f).reduce(add)
     16 end=time.time()
     17 print(f"Pi is roughly {4.0 * count / n} in {end-start:.3f}s")

/opt/spark/python/lib/pyspark.zip/pyspark/rdd.py in reduce(self, f)
    997             yield reduce(f, iterator, initial)
    998 
--> 999         vals = self.mapPartitions(func).collect()
   1000         if vals:
   1001             return reduce(f, vals)

/opt/spark/python/lib/pyspark.zip/pyspark/rdd.py in collect(self)
    948         """
    949         with SCCallSiteSync(self.context) as css:
--> 950             sock_info = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
    951         return list(_load_from_socket(sock_info, self._jrdd_deserializer))
    952 

/opt/spark/python/lib/py4j-0.10.9.3-src.zip/py4j/java_gateway.py in __call__(self, *args)
   1318             proto.END_COMMAND_PART
   1319 
-> 1320         answer = self.gateway_client.send_command(command)
   1321         return_value = get_return_value(
   1322             answer, self.gateway_client, self.target_id, self.name)

/opt/spark/python/lib/py4j-0.10.9.3-src.zip/py4j/java_gateway.py in send_command(self, command, retry, binary)
   1036         connection = self._get_connection()
   1037         try:
-> 1038             response = connection.send_command(command)
   1039             if binary:
   1040                 return response, self._create_connection_guard(connection)

/opt/spark/python/lib/py4j-0.10.9.3-src.zip/py4j/clientserver.py in send_command(self, command)
    473         try:
    474             while True:
--> 475                 answer = smart_decode(self.stream.readline()[:-1])
    476                 logger.debug("Answer received: {0}".format(answer))
    477                 # Happens when a the other end is dead. There might be an empty

/opt/conda/lib/python3.7/socket.py in readinto(self, b)
    587         while True:
    588             try:
--> 589                 return self._sock.recv_into(b)
    590             except timeout:
    591                 self._timeout_occurred = True

KeyboardInterrupt: 

Then, the next cell, which should take less than a second...

import time

n = 1000000
start=time.time()
total=sc.parallelize(range(n)).sum()
end=time.time()
print(f"Sum of 0..{n-1} is {total} in {end-start:.3f}s")

produces this output...

Sum of 0..999999 is 499999500000 in 21.213s

which is about right if the Spark job was not cancelled (since the previous cell takes around 26 seconds to complete).

For completeness, here's the first cell's code:

import time
import sys
from random import random
from operator import add

partitions = 2000
n = 10000 * partitions

def f(_):
    x = random() * 2 - 1
    y = random() * 2 - 1
    return 1 if x ** 2 + y ** 2 <= 1 else 0

start=time.time()
count = sc.parallelize(range(1, n + 1), partitions).map(f).reduce(add)
end=time.time()
print(f"Pi is roughly {4.0 * count / n} in {end-start:.3f}s")

So the exception caught within the signal handler is An error occurred while calling o33.sc and implies that __spark_context isn't what is expected.

I'll try to look into this as well.

kevin-bates commented 2 years ago

I've spent the afternoon on this and seem to find that the second interrupt appears to behave as we'd expect the first interrupt to behave. By tailing the logs of the driver pod, you can clearly see the job's cancellation on the second interrupt. The stack trace always occurs on the first attempt, but the second attempt is clean.

I've confirmed the __spark_context global reflects a proper context. I've also tried switching the order of the signals such that the kernel listener first sends SIGUSR2, then sleeps for half a second, then sends SIGINT, but that did not change the results.

            if request.get("signum") is not None:
                signum = int(request.get("signum"))
                if signum == 2 and cluster_type == "spark":
                    os.kill(parent_pid, signal.SIGUSR2)
                    time.sleep(0.5)
                os.kill(parent_pid, signum)

Update (I had forgot to submit the previous comment, so extending now): FWIW, I just tried issuing the SIGUSR2 twice (but only found the handler called once) - so no effect. I also tried adding a call to __spark_context.cancelAllJobs() just after the global is set (with try/catch to ignore any exceptions) since its probably fine to call cancelAlllJobs when we know there are none. It too also had no effect. This implies that the exception produced on the first interrupt is indicating that the location of that driver thread is NOT able to fulfill the cancelAllJobs logic, but has then been "interrupted enough" such that the subsequent interrupt completes the job cancellation.

I'm hoping you can share what experiences you have when tailing the logs and whether or not your tests require two interrupts.

SamRanjha commented 2 years ago

Hi @kevin-bates, I am observing the same kind of error where the first interrupt is not working and causes exception while the second one works. I tried couple of things but none of them seemed to work. So i tried calling __spark_context.cancelAllJobs() again in the same interrupt which worked. I tried calling __spark_context.applicationId and __spark_context.uiWebUrl (to know whether it's only the first time we use spark context and that doesn't work) and all of them also seem to work so basically when we are calling the cancelAllJobs first we receive the error while if we call that again be it on second interrupt or again in the handler it succeeds.

These two solutions are currently working

    try:
        print("Sending command to cancel all jobs")
        __spark_context.cancelAllJobs()
    except Exception as e:
       if e.__class__.__name__ == 'Py4JError':
           try:
               print("Sending command to cancel all jobs")
               __spark_context.cancelAllJobs()
           except Exception as e:
               print(f"Error occurred while killing second time: {e}")
       else:
           print(f"Error occurred while calling handler: {e}")

that is calling the cancelAllJobs twice, I tried searching for the error but was not able to find out why the first cancel doesn't work.

The second solution is similar to what I proposed first that is using spark UI to kill jobs but since we now have spark context so we can rely on that to fetch UI address.

def __get_active_jobs_id():
    active_job_ids = []
    application_id = __spark_context.applicationId
    ui_address = __spark_context.uiWebUrl
    url = '{}/api/v1/applications/{}/jobs?status=running'.format(ui_address, application_id)
    print(url)
    response = __requests.get(url)
    if response.status_code == 200 and response.json():
        for job in response.json():
            jobid = job.get('jobId')
            active_job_ids.append(jobid)
    return active_job_ids

def __kill_job(jobid):
    ui_address = __spark_context.uiWebUrl
    url = '{}/jobs/job/kill/?id={}'.format(ui_address, jobid)
    print(f'Job Kill URL {url}')
    response = __requests.get(url)
    if response.status_code == 200:
        print(f"Killed job with id: {jobid}")
    else:
        logger.warning(f"Failed to kill job with id: {jobid}")

def __kill_active_jobs():
    print("Interrupt received killing active jobs..")
    try:
        active_job_ids = __get_active_jobs_id()
        if not active_job_ids:
            print("No active jobs available")
            return
        for active_job_id in active_job_ids:
            try:
                __kill_job(active_job_id)
            except Exception as ex:
                logger.warning(f"Error occurred while killing job {ex}")
    except Exception as e:
        print(f"Error occurred while interrupting jobs {e}")

The disadvantage of first approach is we still receive the exception.

kevin-bates commented 2 years ago

@SamRanjha - thank you so much for looking into this! It's reassuring that you're seeing the same thing. I think I'd still prefer using the context directly as it has far fewer points of possible issues - despite it has ONE :smile:.

Would you mind pushing a commit that performs the second attempt? I will then run with that a bit and we'll get this moving forward.

(I apologize for not getting back to you sooner today, but I had my "head down" on another item. Have a good weekend and perhaps we can tackle this early next week.)

SamRanjha commented 2 years ago

@kevin-bates have raised an updated review. Another thing to note here is the way I was testing interrupt worked. i tried creating multiple threads in order to submit multiple jobs and then only single interrupt was required. I used this code for the same :

import sys 
from random import random 
from operator import add 
from pyspark import SparkContext 
from threading import Thread
from time import sleep
partitions = 10000
n = 1000 * partitions 
def f(_): 
    x = random() * 2 - 1 
    y = random() * 2  - 1 
    return 1 if x ** 2 + y ** 2 <= 1 else 0 

def picompute(i):
    print('thread', i)
    count = sc.parallelize(range(1, n + 1), partitions).map(f).reduce(add) 
    print("Pi is roughly %f" % (4.0 * count / n)) 

i = 0
thread_list = []
while i < 10:
    i = i + 1
    thread = Thread(target=picompute, args=(i,))
    thread_list.append(thread)

for thread in thread_list:
    thread.start()

for thread in thread_list:
    thread.join()

Also I tried calling sc.cancelAllJobs() on notebook and that didn't work and the behavior was similar to what we observe when we call the same by Interrupting the kernel.

kevin-bates commented 2 years ago

@SamRanjha - This all sounds great - thank you! I'll try to take this for a spin today and reply back.

rahul26goyal commented 2 years ago

@SamRanjha : with the changes, the interrupt definitely works fine for the cell executions which has spark jobs but does it work with cell interrupts which has vanilla python code that may not trigger have a spark job?

kevin-bates commented 2 years ago

@SamRanjha and @rahul26goyal...

with the changes, the interrupt definitely works fine for the cell executions which has spark jobs but does it work with cell interrupts which has vanilla python code that may not trigger have a spark job?

I'm not seeing an issue in this case. The interrupt of the vanilla python code occurs as expected and without issues.

SamRanjha commented 2 years ago

@rahul26goyal @kevin-bates I have verified that the kernel interrupt works for vanilla python code as well and also have updated the error messages.

welcome[bot] commented 2 years ago

Congrats on your first merged pull request in this project! :tada: congrats Thank you for contributing, we are very proud of you! :heart: