lxzan / gws

simple, fast, reliable websocket server & client, supports running over tcp/kcp/unix domain socket. keywords: ws, proxy, chat, go, golang...
https://pkg.go.dev/github.com/lxzan/gws
Apache License 2.0
1.4k stars 90 forks source link

大佬,请教一个问题 #108

Closed 3517283258 closed 2 months ago

3517283258 commented 2 months ago

大佬,请教一下,我参考chatroom修改了一下案列,想实现外部调用函数发送消息给前端客户端。

package websocket_api

import (
    "fmt"
    "github.com/gin-gonic/gin"
    "github.com/lxzan/gws"
    "net/http"
    "schisandra-cloud-album/global"
    "time"
)

const (
    PingInterval         = 5 * time.Second  // 客户端心跳间隔
    HeartbeatWaitTimeout = 10 * time.Second // 心跳等待超时时间
)

func (WebsocketAPI) NewGWSServer(c *gin.Context) {
    var handler = NewWebSocket()
    upgrader := gws.NewUpgrader(handler, &gws.ServerOption{
        HandshakeTimeout: 5 * time.Second, // 握手超时时间
        ReadBufferSize:   1024,            // 读缓冲区大小
        ParallelEnabled:  true,            // 开启并行消息处理
        Recovery:         gws.Recovery,    // 开启异常恢复
        CheckUtf8Enabled: true,            // 开启UTF8校验
        PermessageDeflate: gws.PermessageDeflate{
            Enabled: true, // 开启压缩
        },
        Authorize: func(r *http.Request, session gws.SessionStorage) bool {
            var clientId = r.URL.Query().Get("client_id")
            if clientId == "" {
                return false
            }
            session.Store("client_id", clientId)
            return true
        },
    })
    socket, err := upgrader.Upgrade(c.Writer, c.Request)
    if err != nil {
        return
    }
    go func() {
        socket.ReadLoop() // 此处阻塞会使请求上下文不能顺利被GC
    }()
}
func MustLoad[T any](session gws.SessionStorage, key string) (v T) {
    if value, exist := session.Load(key); exist {
        v = value.(T)
    }
    return
}

func NewWebSocket() *WebSocket {
    return &WebSocket{
        sessions: gws.NewConcurrentMap[string, *gws.Conn](64, 128),
    }
}

type WebSocket struct {
    gws.BuiltinEventHandler
    sessions *gws.ConcurrentMap[string, *gws.Conn] // 使用内置的ConcurrentMap存储连接, 可以减少锁冲突
}

func (c *WebSocket) OnOpen(socket *gws.Conn) {
    name := MustLoad[string](socket.Session(), "client_id")
    if conn, ok := c.sessions.Load(name); ok {
        conn.WriteClose(1000, []byte("connection is replaced"))
    }
    c.sessions.Store(name, socket)
    global.LOG.Printf("%s connected\n", name)
}

func (c *WebSocket) OnClose(socket *gws.Conn, err error) {
    name := MustLoad[string](socket.Session(), "client_id")
    sharding := c.sessions.GetSharding(name)
    c.sessions.Delete(name)
    sharding.Lock()
    defer sharding.Unlock()

    global.LOG.Printf("onerror, name=%s, msg=%s\n", name, err.Error())
}

func (c *WebSocket) OnPing(socket *gws.Conn, payload []byte) {
    _ = socket.SetDeadline(time.Now().Add(PingInterval + HeartbeatWaitTimeout))
    _ = socket.WritePong(payload)
}

func (c *WebSocket) OnPong(socket *gws.Conn, payload []byte) {}

func (c *WebSocket) OnMessage(socket *gws.Conn, message *gws.Message) {
    defer message.Close()
    name := MustLoad[string](socket.Session(), "client_id")
    if conn, ok := c.sessions.Load(name); ok {
        _ = conn.WriteMessage(gws.OpcodeText, message.Bytes())
    }
}

// SendMessageToClient 向指定客户端发送消息
func (c *WebSocket) SendMessageToClient(clientId string, message []byte) error {
    conn, ok := c.sessions.Load(clientId)
    if ok {
        return conn.WriteMessage(gws.OpcodeText, message)
    }
    return fmt.Errorf("client %s not found", clientId)
}

我写了一个函数供外部使用:

func (c *WebSocket) SendMessageToClient(clientId string, message []byte) error {
    conn, ok := c.sessions.Load(clientId)
    if ok {
        return conn.WriteMessage(gws.OpcodeText, message)
    }
    return fmt.Errorf("client %s not found", clientId)
}

有个问题就是,这里的c.sessions.Load(clientId) 获取客户端连接一直获是空的。 新手学习,不太懂。

lxzan commented 2 months ago

我看了没问题

package main

import (
    "fmt"
    "github.com/lxzan/gws"
    "log"
    "net/http"
    "time"
)

const (
    PingInterval         = 5 * time.Second  // 客户端心跳间隔
    HeartbeatWaitTimeout = 10 * time.Second // 心跳等待超时时间
)

func main() {
    var handler = NewWebSocket()
    upgrader := gws.NewUpgrader(handler, &gws.ServerOption{
        HandshakeTimeout: 5 * time.Second, // 握手超时时间
        ReadBufferSize:   1024,            // 读缓冲区大小
        ParallelEnabled:  true,            // 开启并行消息处理
        Recovery:         gws.Recovery,    // 开启异常恢复
        CheckUtf8Enabled: true,            // 开启UTF8校验
        PermessageDeflate: gws.PermessageDeflate{
            Enabled: true, // 开启压缩
        },
        Authorize: func(r *http.Request, session gws.SessionStorage) bool {
            var clientId = r.URL.Query().Get("client_id")
            if clientId == "" {
                return false
            }
            session.Store("client_id", clientId)
            return true
        },
    })

    http.HandleFunc("/connect", func(writer http.ResponseWriter, request *http.Request) {
        socket, err := upgrader.Upgrade(writer, request)
        if err != nil {
            return
        }
        go func() {
            socket.ReadLoop() // 此处阻塞会使请求上下文不能顺利被GC
        }()
    })

    http.HandleFunc("/send", func(writer http.ResponseWriter, request *http.Request) {
        params := request.URL.Query()
        handler.SendMessageToClient(params.Get("client_id"), []byte(params.Get("msg")))
    })

    http.ListenAndServe(":8000", nil)
}

func MustLoad[T any](session gws.SessionStorage, key string) (v T) {
    if value, exist := session.Load(key); exist {
        v = value.(T)
    }
    return
}

func NewWebSocket() *WebSocket {
    return &WebSocket{
        sessions: gws.NewConcurrentMap[string, *gws.Conn](64, 128),
    }
}

type WebSocket struct {
    gws.BuiltinEventHandler
    sessions *gws.ConcurrentMap[string, *gws.Conn] // 使用内置的ConcurrentMap存储连接, 可以减少锁冲突
}

func (c *WebSocket) OnOpen(socket *gws.Conn) {
    name := MustLoad[string](socket.Session(), "client_id")
    if conn, ok := c.sessions.Load(name); ok {
        conn.WriteClose(1000, []byte("connection is replaced"))
    }
    c.sessions.Store(name, socket)
    log.Printf("%s connected\n", name)
}

func (c *WebSocket) OnClose(socket *gws.Conn, err error) {
    name := MustLoad[string](socket.Session(), "client_id")
    sharding := c.sessions.GetSharding(name)
    c.sessions.Delete(name)
    sharding.Lock()
    defer sharding.Unlock()

    log.Printf("onerror, name=%s, msg=%s\n", name, err.Error())
}

func (c *WebSocket) OnPing(socket *gws.Conn, payload []byte) {
    _ = socket.SetDeadline(time.Now().Add(PingInterval + HeartbeatWaitTimeout))
    _ = socket.WritePong(payload)
}

func (c *WebSocket) OnPong(socket *gws.Conn, payload []byte) {}

func (c *WebSocket) OnMessage(socket *gws.Conn, message *gws.Message) {
    defer message.Close()
    name := MustLoad[string](socket.Session(), "client_id")
    if conn, ok := c.sessions.Load(name); ok {
        _ = conn.WriteMessage(gws.OpcodeText, message.Bytes())
    }
}

// SendMessageToClient 向指定客户端发送消息
func (c *WebSocket) SendMessageToClient(clientId string, message []byte) error {
    conn, ok := c.sessions.Load(clientId)
    if ok {
        return conn.WriteMessage(gws.OpcodeText, message)
    }
    return fmt.Errorf("client %s not found", clientId)
}
3517283258 commented 2 months ago

我看了没问题

package main

import (
  "fmt"
  "github.com/lxzan/gws"
  "log"
  "net/http"
  "time"
)

const (
  PingInterval         = 5 * time.Second  // 客户端心跳间隔
  HeartbeatWaitTimeout = 10 * time.Second // 心跳等待超时时间
)

func main() {
  var handler = NewWebSocket()
  upgrader := gws.NewUpgrader(handler, &gws.ServerOption{
      HandshakeTimeout: 5 * time.Second, // 握手超时时间
      ReadBufferSize:   1024,            // 读缓冲区大小
      ParallelEnabled:  true,            // 开启并行消息处理
      Recovery:         gws.Recovery,    // 开启异常恢复
      CheckUtf8Enabled: true,            // 开启UTF8校验
      PermessageDeflate: gws.PermessageDeflate{
          Enabled: true, // 开启压缩
      },
      Authorize: func(r *http.Request, session gws.SessionStorage) bool {
          var clientId = r.URL.Query().Get("client_id")
          if clientId == "" {
              return false
          }
          session.Store("client_id", clientId)
          return true
      },
  })

  http.HandleFunc("/connect", func(writer http.ResponseWriter, request *http.Request) {
      socket, err := upgrader.Upgrade(writer, request)
      if err != nil {
          return
      }
      go func() {
          socket.ReadLoop() // 此处阻塞会使请求上下文不能顺利被GC
      }()
  })

  http.HandleFunc("/send", func(writer http.ResponseWriter, request *http.Request) {
      params := request.URL.Query()
      handler.SendMessageToClient(params.Get("client_id"), []byte(params.Get("msg")))
  })

  http.ListenAndServe(":8000", nil)
}

func MustLoad[T any](session gws.SessionStorage, key string) (v T) {
  if value, exist := session.Load(key); exist {
      v = value.(T)
  }
  return
}

func NewWebSocket() *WebSocket {
  return &WebSocket{
      sessions: gws.NewConcurrentMap[string, *gws.Conn](64, 128),
  }
}

type WebSocket struct {
  gws.BuiltinEventHandler
  sessions *gws.ConcurrentMap[string, *gws.Conn] // 使用内置的ConcurrentMap存储连接, 可以减少锁冲突
}

func (c *WebSocket) OnOpen(socket *gws.Conn) {
  name := MustLoad[string](socket.Session(), "client_id")
  if conn, ok := c.sessions.Load(name); ok {
      conn.WriteClose(1000, []byte("connection is replaced"))
  }
  c.sessions.Store(name, socket)
  log.Printf("%s connected\n", name)
}

func (c *WebSocket) OnClose(socket *gws.Conn, err error) {
  name := MustLoad[string](socket.Session(), "client_id")
  sharding := c.sessions.GetSharding(name)
  c.sessions.Delete(name)
  sharding.Lock()
  defer sharding.Unlock()

  log.Printf("onerror, name=%s, msg=%s\n", name, err.Error())
}

func (c *WebSocket) OnPing(socket *gws.Conn, payload []byte) {
  _ = socket.SetDeadline(time.Now().Add(PingInterval + HeartbeatWaitTimeout))
  _ = socket.WritePong(payload)
}

func (c *WebSocket) OnPong(socket *gws.Conn, payload []byte) {}

func (c *WebSocket) OnMessage(socket *gws.Conn, message *gws.Message) {
  defer message.Close()
  name := MustLoad[string](socket.Session(), "client_id")
  if conn, ok := c.sessions.Load(name); ok {
      _ = conn.WriteMessage(gws.OpcodeText, message.Bytes())
  }
}

// SendMessageToClient 向指定客户端发送消息
func (c *WebSocket) SendMessageToClient(clientId string, message []byte) error {
  conn, ok := c.sessions.Load(clientId)
  if ok {
      return conn.WriteMessage(gws.OpcodeText, message)
  }
  return fmt.Errorf("client %s not found", clientId)
}

好的,感谢大佬,我知道我哪里错了,调用的时候写错了,已经弄好了,感谢