encode / starlette

The little ASGI framework that shines. 🌟
https://www.starlette.io/
BSD 3-Clause "New" or "Revised" License
10.05k stars 899 forks source link

Broadcast interface #133

Closed tomchristie closed 2 years ago

tomchristie commented 5 years ago

WebSockets and SSE aren’t really any use without a broadcast interface. (Eg redis pub/sub or postgres listen/notify)

Look to channels’ group add/discard/send here.

tomchristie commented 5 years ago

We’ll want a general interface, which alternative backend implementions can then be plugged into.

rcox771 commented 5 years ago

I tried something hacky to get something going, but I'm a bit out of my league here. This works for client@localhost (sending/receiving all), and client@anotherhost (sending, but not receiving)

import uvicorn
from starlette.applications import Starlette
from starlette.endpoints import WebSocketEndpoint, HTTPEndpoint
from starlette.responses import HTMLResponse, JSONResponse
from starlette.middleware.cors import CORSMiddleware
from collections import defaultdict
from starlette.websockets import WebSocketState
app = Starlette()

app.add_middleware(
    CORSMiddleware, allow_origins=["*"], allow_headers=["*"], allow_methods=["*"]
)

html = """
<!DOCTYPE html>
<html>
    <head>
        <title>Chat</title>
    </head>
    <body>
        <h1>WebSocket Chat</h1>
        <form action="" onsubmit="sendMessage(event)">
            <input type="text" id="messageText" autocomplete="off"/>
            <button>Send</button>
        </form>
        <ul id='messages'>
        </ul>
        <script>
            var ws = new WebSocket("ws://"+location.host+"/ws");
            ws.onmessage = function(event) {
                var messages = document.getElementById('messages')
                var message = document.createElement('li')
                var content = document.createTextNode(event.data)
                message.appendChild(content)
                messages.appendChild(message)
            };
            function sendMessage(event) {
                var input = document.getElementById("messageText")
                ws.send(input.value)
                input.value = ''
                event.preventDefault()
            }
        </script>
    </body>
</html>
"""

@app.route("/")
class Homepage(HTTPEndpoint):
    async def get(self, request):
        return HTMLResponse(html)

@app.websocket_route("/ws")
class Broadcast(WebSocketEndpoint):

    encoding = "text"
    sessions = {}

    def update_sess_data(self, ws, data):
        sess_key = ws.headers.get('sec-websocket-key', 'last')
        self.sessions[sess_key] = ws
        self._reap_expired_sessions()

    async def broadcast_message(self, msg):
        for k in self.sessions:
            ws = self.sessions[k]
            await ws.send_text(f"message text was: {msg}")

    def _reap_expired_sessions(self):
        expired = []
        for k in self.sessions:
            sess = self.sessions[k]
            if sess.client_state != WebSocketState.CONNECTED:
                expired.append(k)
                print('removing expired session:', k)
        self.sessions = {k: self.sessions[k] for k in self.sessions if k not in expired}

    async def on_receive(self, ws, data):
        self.update_sess_data(ws, data)
        await self.broadcast_message(data)

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)
bergran commented 5 years ago

Hello @rcox771, as chat i prefer to have sessions as global app var that will be appended on on_connect and on_disconnect pop it from dictionary and notice all websockets connected on sessions. But yes, a channel layer it would be nice here to use push notificacions or similar i think. Im not an websocket expert.

import re

import jinja2
import uvicorn
from starlette.applications import Starlette
from starlette.endpoints import WebSocketEndpoint, HTTPEndpoint
from starlette.responses import HTMLResponse, Response
from starlette.websockets import WebSocket

def setup_jinja2(template_dir):
    @jinja2.contextfunction
    def url_for(context, name, **path_params):
        request = context['request']
        return request.url_for(name, **path_params)

    loader = jinja2.FileSystemLoader(template_dir)
    env = jinja2.Environment(loader=loader, autoescape=True)
    env.globals['url_for'] = url_for
    return env

env = setup_jinja2('templates')

class StarletteCustom(Starlette):
    sessions = {}

app = StarletteCustom(template_directory='templates')

def get_hostname(scope):
    ':'.join([str(param) for param in scope.get('server', ['localhost', '8000'])])

@app.route("/", name='index')
class Homepage(HTTPEndpoint):
    async def get(self, request):
        template = self.scope.get('app').get_template('index.html')
        content = template.render()
        return HTMLResponse(content)

@app.route("/chat/", name='chat')
class Homepage(HTTPEndpoint):
    async def get(self, request):
        user = request.query_params.get('user')
        if not user:
            return Response(status_code=404)
        template = env.get_template('chat.html')
        host = request.url_for('ws')
        content = template.render(url='{}?username={}'.format(host, user), user=user)
        return HTMLResponse(content)

@app.websocket_route('/ws', name='ws')
class Broadcast(WebSocketEndpoint):

    encoding = "text"
    session_name = ''

    def get_params(self, websocket: WebSocket) -> dict:
        params_raw = websocket.get('query_string', b'').decode('utf-8')
        return {param.split('=')[0]: param.split('=')[1] for param in params_raw.split('&')}

    async def on_connect(self, websocket: WebSocket):
        app = self.scope.get('app', None)
        self.channel_name = self.get_params(websocket).get('username', 'default_name')
        self.sessions = app.sessions
        await websocket.accept()
        await self.broadcast_message('User {} is connected'.format(self.channel_name))
        self.sessions[self.channel_name] = websocket

    async def on_disconnect(self, websocket: WebSocket, close_code: int):
        self.sessions.pop(self.channel_name, None)
        await self.broadcast_message('User {} is disconnected'.format(self.channel_name))

    async def broadcast_message(self, msg):
        for k in self.sessions:
            ws = self.sessions[k]
            await ws.send_text(f"message text was: {msg}")

    async def on_receive(self, ws, data):
        await self.broadcast_message(data)

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)

I thought about doing some examples on my github to help the community, also i have little time to do it.

taoufik07 commented 5 years ago

@bergran checkout nejma, a "channels" like package

3lpsy commented 5 years ago

How would you publish messages outside the context of the web socket? For example, if a user updated their profile, and you wanted to broadcast that new information to all connected channels, how would you go about that? Is it required to use a separate process? Would you have to continually poll some service (like redis)?

DrPyser commented 5 years ago

@3lpsy There's no single way to do it. If all your connections are handled by a single process, you can simply maintain a list of active connections and iterate over those to broadcast messages. Otherwise, I guess you could use unix sockets, or a full-blown message queue system like redis(which doesn't require polling, see https://aioredis.readthedocs.io/en/v1.2.0/mpsc.html).

tomchristie commented 5 years ago

@3lpsy You need a broadcast service of one kind or aother... Probably one of:

3lpsy commented 5 years ago

Thanks @tomchristie and @DrPyser for the advice.

For anyone else attempting to do this, I created a working implementation using aioredis and FastAPI(an extension of Starlette). I have serious concerns over thread safety and other things I'm doing wrong but it works at least. Below is some sample code but you can see the entire project here: https://github.com/3lpsy/bountydns/

A few code snippets:

# redis helper methods
async def make_redis():
    return await aioredis.create_redis(make_broadcast_url())
async def make_subscriber(name):
    subscriber = await make_redis()
    res = await subscriber.subscribe(f"channel:{name}")
    channel = res[0]
    return subscriber, channel
# two web socket route functions (these are attached to FastAPI instance in another file)

# a public socket for all users
async def broadcast_index(websocket: WebSocket):
    await websocket.accept()
    while True:
        data = await websocket.receive_json()
        await websocket.send_json({"message": "greetings"})
    await websocket.close()

# an authenticated socket for authed users
async def broadcast_authed_index(websocket: WebSocket):
    await websocket.accept()
    params = parse_qs(urlparse(str(websocket.url)).query)
    token = verify_jwt_token(params["ws_access_token"][0])
    if not token_has_required_scopes(token, []):  # TODO: check scopes later
        raise HTTPException(403, detail="Forbidden")
    user_repo = UserRepo(session())
    user = current_user(token, user_repo)
    subscriber, channel = await make_subscriber("auth")
    while await channel.wait_message():
        msg = await channel.get(encoding="utf-8")
        data = json.loads(msg)
        await websocket.send_json(data)
    await subscriber.unsubscribe("channel:auth")
    await websocket.close()

With this setup, I can push messages to all authenticated (or unauthenticated users) in route functions like this:

@router.get("/util/ws-test", name="util.ws-test")
async def ws_test():
    publisher = await aioredis.create_redis(make_broadcast_url())
    res = await publisher.publish_json(
        "channel:auth", {"type": "MESSAGE", "name": "TESTING_WS", "payload": ""}
    )
    return {"status": "success"}

I also wanted to listen for sqlalchemy's "after_insert" event. However, I had to attach the async method call to the event loop. It may be incorrect but it works as of now:

ORM_EVENTS = ["after_insert"]
# sqlalchemy wants a sync function (i think)
def make_event(func):
    def _event(*args, **kwargs):
        loop = asyncio.get_event_loop()
        result = asyncio.ensure_future(func(*args, **kwargs), loop=loop)
        return result
    return _event

def db_register_model_events(models):
    for m in models:
        for event_name in ORM_EVENTS:
            event_cb = "on_" + event_name
            if hasattr(m, event_cb):
                listen(m, event_name, make_event(getattr(m, event_cb)))

And here is an example of the event handler in the model class:

@staticmethod
async def on_after_insert(mapper, connection, target):
    print("on_after_insert", mapper, connection, target)
    publisher = await make_redis()
    res = await publisher.publish_json(
        "channel:auth",
        {"type": "MESSAGE", "name": "DNS_REQUEST_CREATED", "payload": ""},
    )

Hope this helps anyone looking to do something similar. On the front end, I create two websockets and then emit them using a Vuejs bus. You can check out the project if you're interested. And if you have any feedback on things I'm doing wrong, please let me know.

devxpy commented 4 years ago

If I've got this right, all the implementations here use python loops over all connected users. Is there a possibility that we can get Cython or C loops inbuilt into the library to do this without the python loop overhead, please?

tomchristie commented 2 years ago

The broadcaster project is a working approach to this. (Although it's stretched a bit thin on maintenance ATM) https://github.com/encode/broadcaster