Closed tomchristie closed 2 years ago
We’ll want a general interface, which alternative backend implementions can then be plugged into.
I tried something hacky to get something going, but I'm a bit out of my league here. This works for client@localhost (sending/receiving all), and client@anotherhost (sending, but not receiving)
import uvicorn
from starlette.applications import Starlette
from starlette.endpoints import WebSocketEndpoint, HTTPEndpoint
from starlette.responses import HTMLResponse, JSONResponse
from starlette.middleware.cors import CORSMiddleware
from collections import defaultdict
from starlette.websockets import WebSocketState
app = Starlette()
app.add_middleware(
CORSMiddleware, allow_origins=["*"], allow_headers=["*"], allow_methods=["*"]
)
html = """
<!DOCTYPE html>
<html>
<head>
<title>Chat</title>
</head>
<body>
<h1>WebSocket Chat</h1>
<form action="" onsubmit="sendMessage(event)">
<input type="text" id="messageText" autocomplete="off"/>
<button>Send</button>
</form>
<ul id='messages'>
</ul>
<script>
var ws = new WebSocket("ws://"+location.host+"/ws");
ws.onmessage = function(event) {
var messages = document.getElementById('messages')
var message = document.createElement('li')
var content = document.createTextNode(event.data)
message.appendChild(content)
messages.appendChild(message)
};
function sendMessage(event) {
var input = document.getElementById("messageText")
ws.send(input.value)
input.value = ''
event.preventDefault()
}
</script>
</body>
</html>
"""
@app.route("/")
class Homepage(HTTPEndpoint):
async def get(self, request):
return HTMLResponse(html)
@app.websocket_route("/ws")
class Broadcast(WebSocketEndpoint):
encoding = "text"
sessions = {}
def update_sess_data(self, ws, data):
sess_key = ws.headers.get('sec-websocket-key', 'last')
self.sessions[sess_key] = ws
self._reap_expired_sessions()
async def broadcast_message(self, msg):
for k in self.sessions:
ws = self.sessions[k]
await ws.send_text(f"message text was: {msg}")
def _reap_expired_sessions(self):
expired = []
for k in self.sessions:
sess = self.sessions[k]
if sess.client_state != WebSocketState.CONNECTED:
expired.append(k)
print('removing expired session:', k)
self.sessions = {k: self.sessions[k] for k in self.sessions if k not in expired}
async def on_receive(self, ws, data):
self.update_sess_data(ws, data)
await self.broadcast_message(data)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
Hello @rcox771, as chat i prefer to have sessions as global app var that will be appended on on_connect
and on_disconnect
pop it from dictionary and notice all websockets connected on sessions. But yes, a channel layer it would be nice here to use push notificacions or similar i think. Im not an websocket expert.
import re
import jinja2
import uvicorn
from starlette.applications import Starlette
from starlette.endpoints import WebSocketEndpoint, HTTPEndpoint
from starlette.responses import HTMLResponse, Response
from starlette.websockets import WebSocket
def setup_jinja2(template_dir):
@jinja2.contextfunction
def url_for(context, name, **path_params):
request = context['request']
return request.url_for(name, **path_params)
loader = jinja2.FileSystemLoader(template_dir)
env = jinja2.Environment(loader=loader, autoescape=True)
env.globals['url_for'] = url_for
return env
env = setup_jinja2('templates')
class StarletteCustom(Starlette):
sessions = {}
app = StarletteCustom(template_directory='templates')
def get_hostname(scope):
':'.join([str(param) for param in scope.get('server', ['localhost', '8000'])])
@app.route("/", name='index')
class Homepage(HTTPEndpoint):
async def get(self, request):
template = self.scope.get('app').get_template('index.html')
content = template.render()
return HTMLResponse(content)
@app.route("/chat/", name='chat')
class Homepage(HTTPEndpoint):
async def get(self, request):
user = request.query_params.get('user')
if not user:
return Response(status_code=404)
template = env.get_template('chat.html')
host = request.url_for('ws')
content = template.render(url='{}?username={}'.format(host, user), user=user)
return HTMLResponse(content)
@app.websocket_route('/ws', name='ws')
class Broadcast(WebSocketEndpoint):
encoding = "text"
session_name = ''
def get_params(self, websocket: WebSocket) -> dict:
params_raw = websocket.get('query_string', b'').decode('utf-8')
return {param.split('=')[0]: param.split('=')[1] for param in params_raw.split('&')}
async def on_connect(self, websocket: WebSocket):
app = self.scope.get('app', None)
self.channel_name = self.get_params(websocket).get('username', 'default_name')
self.sessions = app.sessions
await websocket.accept()
await self.broadcast_message('User {} is connected'.format(self.channel_name))
self.sessions[self.channel_name] = websocket
async def on_disconnect(self, websocket: WebSocket, close_code: int):
self.sessions.pop(self.channel_name, None)
await self.broadcast_message('User {} is disconnected'.format(self.channel_name))
async def broadcast_message(self, msg):
for k in self.sessions:
ws = self.sessions[k]
await ws.send_text(f"message text was: {msg}")
async def on_receive(self, ws, data):
await self.broadcast_message(data)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
I thought about doing some examples on my github to help the community, also i have little time to do it.
How would you publish messages outside the context of the web socket? For example, if a user updated their profile, and you wanted to broadcast that new information to all connected channels, how would you go about that? Is it required to use a separate process? Would you have to continually poll some service (like redis)?
@3lpsy There's no single way to do it. If all your connections are handled by a single process, you can simply maintain a list of active connections and iterate over those to broadcast messages. Otherwise, I guess you could use unix sockets, or a full-blown message queue system like redis(which doesn't require polling, see https://aioredis.readthedocs.io/en/v1.2.0/mpsc.html).
@3lpsy You need a broadcast service of one kind or aother... Probably one of:
Thanks @tomchristie and @DrPyser for the advice.
For anyone else attempting to do this, I created a working implementation using aioredis and FastAPI(an extension of Starlette). I have serious concerns over thread safety and other things I'm doing wrong but it works at least. Below is some sample code but you can see the entire project here: https://github.com/3lpsy/bountydns/
A few code snippets:
# redis helper methods
async def make_redis():
return await aioredis.create_redis(make_broadcast_url())
async def make_subscriber(name):
subscriber = await make_redis()
res = await subscriber.subscribe(f"channel:{name}")
channel = res[0]
return subscriber, channel
# two web socket route functions (these are attached to FastAPI instance in another file)
# a public socket for all users
async def broadcast_index(websocket: WebSocket):
await websocket.accept()
while True:
data = await websocket.receive_json()
await websocket.send_json({"message": "greetings"})
await websocket.close()
# an authenticated socket for authed users
async def broadcast_authed_index(websocket: WebSocket):
await websocket.accept()
params = parse_qs(urlparse(str(websocket.url)).query)
token = verify_jwt_token(params["ws_access_token"][0])
if not token_has_required_scopes(token, []): # TODO: check scopes later
raise HTTPException(403, detail="Forbidden")
user_repo = UserRepo(session())
user = current_user(token, user_repo)
subscriber, channel = await make_subscriber("auth")
while await channel.wait_message():
msg = await channel.get(encoding="utf-8")
data = json.loads(msg)
await websocket.send_json(data)
await subscriber.unsubscribe("channel:auth")
await websocket.close()
With this setup, I can push messages to all authenticated (or unauthenticated users) in route functions like this:
@router.get("/util/ws-test", name="util.ws-test")
async def ws_test():
publisher = await aioredis.create_redis(make_broadcast_url())
res = await publisher.publish_json(
"channel:auth", {"type": "MESSAGE", "name": "TESTING_WS", "payload": ""}
)
return {"status": "success"}
I also wanted to listen for sqlalchemy's "after_insert" event. However, I had to attach the async method call to the event loop. It may be incorrect but it works as of now:
ORM_EVENTS = ["after_insert"]
# sqlalchemy wants a sync function (i think)
def make_event(func):
def _event(*args, **kwargs):
loop = asyncio.get_event_loop()
result = asyncio.ensure_future(func(*args, **kwargs), loop=loop)
return result
return _event
def db_register_model_events(models):
for m in models:
for event_name in ORM_EVENTS:
event_cb = "on_" + event_name
if hasattr(m, event_cb):
listen(m, event_name, make_event(getattr(m, event_cb)))
And here is an example of the event handler in the model class:
@staticmethod
async def on_after_insert(mapper, connection, target):
print("on_after_insert", mapper, connection, target)
publisher = await make_redis()
res = await publisher.publish_json(
"channel:auth",
{"type": "MESSAGE", "name": "DNS_REQUEST_CREATED", "payload": ""},
)
Hope this helps anyone looking to do something similar. On the front end, I create two websockets and then emit them using a Vuejs bus. You can check out the project if you're interested. And if you have any feedback on things I'm doing wrong, please let me know.
If I've got this right, all the implementations here use python loops over all connected users. Is there a possibility that we can get Cython or C loops inbuilt into the library to do this without the python loop overhead, please?
The broadcaster project is a working approach to this. (Although it's stretched a bit thin on maintenance ATM) https://github.com/encode/broadcaster
WebSockets and SSE aren’t really any use without a broadcast interface. (Eg redis pub/sub or postgres listen/notify)
Look to channels’ group add/discard/send here.