dask / dask-kubernetes

Native Kubernetes integration for Dask
BSD 3-Clause "New" or "Revised" License
311 stars 149 forks source link

Port forwards can timeout or close and not get recreated #663

Closed jacobtomlinson closed 1 year ago

jacobtomlinson commented 1 year ago

When accessing a Dask cluster from outside of Kubernetes port forwards for the comm and dashboard will be set up automatically in a subprocess.


However the kubectl port-forward process can close for a number of reasons, including timeouts or cluster autoscaling operations disrupting the stream.

Currently we do not track this process so if it exits prematurely the port forward will be lost.

To resolve this we could start our own thread/task and run the kubectl port-forward command in a loop, rerunning it if it exits, until it is cancelled.

jacobtomlinson commented 1 year ago

I started digging down the road of removing the use of kubectl altogether and opening the port forward in Python, however this was a lot more challenging than I expected. It would be the ideal road to go down so I'll make some notes here, but the quick fix for this PR would be to just handle recreating the kubectl pot-forward process.

Opening a port forward via the Kubernetes API requires the following steps:

Currently we use a mixture of kubernetes_asyncio and pykube-ng (via our dask_kubernetes.aiopykube wrapper). Ideally we want to move in the direction of being 100% pykube.

It seems that today pykube-ng doesn't have any support for port forwarding. It is also heavily tied to requests, especially in terms of auth via a custom HTTPAdapter, which makes opening the websocket challenging as the Python websockets library does not support using requests.adapters.HTTPAdapter objects. Requests also doesn't appear to have any websocket support. Plus if we are going to go downt he road of implementing this ourselves we would want to use aiohttp which also doesn't support HTTPAdapter objects so we would need to reimplement the pykube auth.

I also explored kubernetes_asyncio which does have a connect_get_namespaced_pod_portforward method in the coreV1 API but that errors with a Bad Request. My guess is that it is making a regular HTTP request but is expecting to be upgraded to a websocket stream.

ApiException: (400)
Reason: Bad Request
HTTP response headers: <CIMultiDictProxy('Audit-Id': 'e75cfa62-fe41-41d6-be65-7eba0c071fab', 'Cache-Control': 'no-cache, private', 'Content-Type': 'application/json', 'Date': 'Wed, 01 Mar 2023 11:45:45 GMT', 'Content-Length': '139')>
HTTP response body: {"kind":"Status","apiVersion":"v1","metadata":{},"status":"Failure","message":"Upgrade request required","reason":"BadRequest","code":400}

It also seems to have a kubernetes_asyncio.stream.WsApiClient object for connecting websocket streams but I can't figure out how to use them together. There seems to be no documentation on doing this.

I did find this example from the sync kubernetes library but kubernetes_asyncio.stream doesn't seem to have a portforward module so the example doesn't translate over very well.

jacobtomlinson commented 1 year ago

I've gotten a bit further with the following code.

async with kubernetes_asyncio.stream.WsApiClient() as ws_api_client:
    api = kubernetes_asyncio.client.CoreV1Api(ws_api_client)
    port_forward = await api.connect_get_namespaced_pod_portforward(
        pod_name, namespace, ports=remote_port, _preload_content=False

The port_forward I get back is an aiohttp.client_ws.ClientWebSocketResponse and it looks like I can send and receive binary data with it. I've tried to hook it directly up to a local TCP socket.

    async def sync_sockets(ws, reader, writer):
        async def tcp_to_ws():
            while True:
                data = await reader.read()
                if not data:
                    raise ValueError()
                await ws.send(data)

        async def ws_to_tcp():
            while True:
                message = await ws.receive()
                if message.type == aiohttp.WSMsgType.CLOSED:
                    raise ValueError()
                await writer.drain()

            tasks = [
            await asyncio.gather(*tasks)
        except Exception:
            for task in tasks:

    server = await asyncio.start_server(
        lambda r, w: sync_sockets(port_forward, r, w), port=local_port
    async with server:
        await server.serve_forever()

But if I connect a Dask Client to that socket I get some weird exceptions.

In [22]: from dask.distributed import Client
    ...: client = Client("localhost:54619")
MemoryError                               Traceback (most recent call last)
~/Projects/dask/distributed/distributed/comm/core.py in connect(addr, timeout, deserialize, handshake_overrides, **connection_args)
    327         # write, handshake = await asyncio.gather(comm.write(local_info), comm.read())
--> 328         handshake = await asyncio.wait_for(comm.read(), time_left())
    329         await asyncio.wait_for(comm.write(local_info), time_left())

~/miniconda3/envs/dask/lib/python3.8/asyncio/tasks.py in wait_for(fut, timeout, loop)
    482         if fut.done():
--> 483             return fut.result()
    484         else:

~/Projects/dask/distributed/distributed/comm/tcp.py in read(self, deserializers)
--> 228             frames = host_array(frames_nbytes)
    229             for i, j in sliding_window(

~/Projects/dask/distributed/distributed/comm/utils.py in numpy_host_array(n)
     31     def numpy_host_array(n: int) -> memoryview:
---> 32         return numpy.empty((n,), dtype="u1").data

MemoryError: Unable to allocate 3.81 EiB for an array with shape (4395550971915293184,) and data type uint8

The above exception was the direct cause of the following exception:

OSError                                   Traceback (most recent call last)
<ipython-input-21-5f6f42235d27> in <module>
----> 1 client = Client("localhost:54619")

~/Projects/dask/distributed/distributed/client.py in __init__(self, address, loop, timeout, set_as_default, scheduler_file, security, asynchronous, name, heartbeat_interval, serializers, deserializers, extensions, direct_to_workers, connection_limit, **kwargs)
    986         self.preloads = preloading.process_preloads(self, preload, preload_argv)
--> 988         self.start(timeout=timeout)
    989         Client._instances.add(self)

~/Projects/dask/distributed/distributed/client.py in start(self, **kwargs)
   1183             self._started = asyncio.ensure_future(self._start(**kwargs))
   1184         else:
-> 1185             sync(self.loop, self._start, **kwargs)
   1187     def __await__(self):

~/Projects/dask/distributed/distributed/utils.py in sync(loop, func, callback_timeout, *args, **kwargs)
    403     if error:
    404         typ, exc, tb = error
--> 405         raise exc.with_traceback(tb)
    406     else:
    407         return result

~/Projects/dask/distributed/distributed/utils.py in f()
    376                 future = asyncio.wait_for(future, callback_timeout)
    377             future = asyncio.ensure_future(future)
--> 378             result = yield future
    379         except Exception:
    380             error = sys.exc_info()

~/.local/lib/python3.8/site-packages/tornado/gen.py in run(self)
    761                     try:
--> 762                         value = future.result()
    763                     except Exception:
    764                         exc_info = sys.exc_info()

~/Projects/dask/distributed/distributed/client.py in _start(self, timeout, **kwargs)
   1264         try:
-> 1265             await self._ensure_connected(timeout=timeout)
   1266         except (OSError, ImportError):
   1267             await self._close()

~/Projects/dask/distributed/distributed/client.py in _ensure_connected(self, timeout)
   1327         try:
-> 1328             comm = await connect(
   1329                 self.scheduler.address, timeout=timeout, **self.connection_args
   1330             )

~/Projects/dask/distributed/distributed/comm/core.py in connect(addr, timeout, deserialize, handshake_overrides, **connection_args)
    331         with suppress(Exception):
    332             await comm.close()
--> 333         raise OSError(
    334             f"Timed out during handshake while connecting to {addr} after {timeout} s"
    335         ) from exc

OSError: Timed out during handshake while connecting to tcp://localhost:54619 after 30 s

However the scheduler logs in Kubernetes do show that a connection was established and abandoned so the websocket stream is clearly being forwarded to the TCP port on the Pod.

2023-03-01 14:04:16,251 - distributed.comm.tcp - INFO - Connection from tcp:// closed before handshake completed

I'm guessing I'm just hooking the two sockets up incorrectly but I'm struggling to find the way forward.

jacobtomlinson commented 1 year ago

I got further again and have the websocket open and communication flowing. The last thing I'm stuck on is that the sockets seem to close prematurely. I've opened #666 with everything I have so far because I don't think I'll be able to continue that PR for a while.

In the meantime this issue could still be closed with a PR that starts a background loop that starts the port forward again if it exits.

jacobtomlinson commented 1 year ago

Closed by #809.

Port forwards are now handled by kr8s which will reopen a connection if it is closed.