lesismal / nbio

Pure Go 1000k+ connections solution, support tls/http1.x/websocket and basically compatible with net/http, with high-performance and low memory cost, non-blocking, event-driven, easy-to-use.
MIT License
2.15k stars 151 forks source link

当tls.Config中ClientAuth为tls.RequireAndVerifyClientCert时, http.Request.TLS中没拿到数据 #223

Closed Kotodian closed 1 year ago

Kotodian commented 1 year ago

在gorilla包下我能看到http: TLS handshake error from 192.168.0.4:56081: remote error: tls: bad certificate这样的错误,但是在nbio下我看不到tls的错误

lesismal commented 1 year ago

先提供下你的demo我看下

Kotodian commented 1 year ago

先提供下你的demo我看下

这个先不管, 就是我使用wss时,我的tlsConfig.ClientAuth设置的是tls.RequireAndVerifyCert, 但是我的*http.Request拿不到tls.ConnectState

lesismal commented 1 year ago

我没有为仓库设置issue模板是因为给其他一些仓库提issue的时候,模板里的一些选项不是很通用,所以不想在别人给我的仓库提issue时带来这种困扰。

维护个仓库不容易,提issue请尽量提供完善一点的信息。

Kotodian commented 1 year ago
package main

import (
    "flag"
    "fmt"
    "log"
    "net/http"
    "os"
    "os/signal"
    "time"

    "github.com/lesismal/llib/std/crypto/tls"
    "github.com/lesismal/nbio/nbhttp"
    "github.com/lesismal/nbio/nbhttp/websocket"
)

var (
    svr   *nbhttp.Server
    print = flag.Bool("print", false, "stdout output of echoed data")
)

func newUpgrader() *websocket.Upgrader {
    u := websocket.NewUpgrader()
    u.OnMessage(func(c *websocket.Conn, messageType websocket.MessageType, data []byte) {
        // echo
        c.WriteMessage(messageType, data)
        if *print {
            fmt.Println("OnMessage:", messageType, string(data))
        }
        c.SetReadDeadline(time.Now().Add(nbhttp.DefaultKeepaliveTime))
    })
    u.OnClose(func(c *websocket.Conn, err error) {
        if *print {
            fmt.Println("OnClose:", c.RemoteAddr().String(), err)
        }
    })

    return u
}

func onWebsocket(w http.ResponseWriter, r *http.Request) {
    // time.Sleep(time.Second * 5)
    upgrader := newUpgrader()
    conn, err := upgrader.Upgrade(w, r, nil)
    if err != nil {
        panic(err)
    }
      // 没拿到
    if r.TLS == nil {
        panic("no tls connection state")
    }
    wsConn := conn.(*websocket.Conn)
    if *print {
        fmt.Println("OnOpen:", wsConn.RemoteAddr().String())
    }
}

func main() {
    flag.Parse()

    cert, err := tls.X509KeyPair(rsaCertPEM, rsaKeyPEM)
    if err != nil {
        log.Fatalf("tls.X509KeyPair failed: %v", err)
    }
    tlsConfig := &tls.Config{
        Certificates:       []tls.Certificate{cert},
        InsecureSkipVerify: true,
                // 我需要证书
        ClientAuth: tls.RequireAnyClientCert,
    }

    mux := &http.ServeMux{}
    mux.HandleFunc("/wss", onWebsocket)

    svr = nbhttp.NewServer(nbhttp.Config{
        Network:   "tcp",
        AddrsTLS:  []string{"localhost:8888"},
        TLSConfig: tlsConfig,
        Handler:   mux,
    })

    err = svr.Start()
    if err != nil {
        fmt.Printf("nbio.Start failed: %v\n", err)
        return
    }
    defer svr.Stop()

    interrupt := make(chan os.Signal, 1)
    signal.Notify(interrupt, os.Interrupt)
    <-interrupt
    log.Println("exit")
}

var rsaCertPEM = []byte(`-----BEGIN CERTIFICATE-----
MIIDazCCAlOgAwIBAgIUJeohtgk8nnt8ofratXJg7kUJsI4wDQYJKoZIhvcNAQEL
BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yMDEyMDcwODIyNThaFw0zMDEy
MDUwODIyNThaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEw
HwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3DQEB
AQUAA4IBDwAwggEKAoIBAQCy+ZrIvwwiZv4bPmvKx/637ltZLwfgh3ouiEaTchGu
IQltthkqINHxFBqqJg44TUGHWthlrq6moQuKnWNjIsEc6wSD1df43NWBLgdxbPP0
x4tAH9pIJU7TQqbznjDBhzRbUjVXBIcn7bNknY2+5t784pPF9H1v7h8GqTWpNH9l
cz/v+snoqm9HC+qlsFLa4A3X9l5v05F1uoBfUALlP6bWyjHAfctpiJkoB9Yw1TJa
gpq7E50kfttwfKNkkAZIbib10HugkMoQJAs2EsGkje98druIl8IXmuvBIF6nZHuM
lt3UIZjS9RwPPLXhRHt1P0mR7BoBcOjiHgtSEs7Wk+j7AgMBAAGjUzBRMB0GA1Ud
DgQWBBQdheJv73XSOhgMQtkwdYPnfO02+TAfBgNVHSMEGDAWgBQdheJv73XSOhgM
QtkwdYPnfO02+TAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQBf
SKVNMdmBpD9m53kCrguo9iKQqmhnI0WLkpdWszc/vBgtpOE5ENOfHGAufHZve871
2fzTXrgR0TF6UZWsQOqCm5Oh3URsCdXWewVMKgJ3DCii6QJ0MnhSFt6+xZE9C6Hi
WhcywgdR8t/JXKDam6miohW8Rum/IZo5HK9Jz/R9icKDGumcqoaPj/ONvY4EUwgB
irKKB7YgFogBmCtgi30beLVkXgk0GEcAf19lHHtX2Pv/lh3m34li1C9eBm1ca3kk
M2tcQtm1G89NROEjcG92cg+GX3GiWIjbI0jD1wnVy2LCOXMgOVbKfGfVKISFt0b1
DNn00G8C6ttLoGU2snyk
-----END CERTIFICATE-----
`)

var rsaKeyPEM = []byte(`-----BEGIN RSA PRIVATE KEY-----
MIIEogIBAAKCAQEAsvmayL8MImb+Gz5rysf+t+5bWS8H4Id6LohGk3IRriEJbbYZ
KiDR8RQaqiYOOE1Bh1rYZa6upqELip1jYyLBHOsEg9XX+NzVgS4HcWzz9MeLQB/a
SCVO00Km854wwYc0W1I1VwSHJ+2zZJ2Nvube/OKTxfR9b+4fBqk1qTR/ZXM/7/rJ
6KpvRwvqpbBS2uAN1/Zeb9ORdbqAX1AC5T+m1soxwH3LaYiZKAfWMNUyWoKauxOd
JH7bcHyjZJAGSG4m9dB7oJDKECQLNhLBpI3vfHa7iJfCF5rrwSBep2R7jJbd1CGY
0vUcDzy14UR7dT9JkewaAXDo4h4LUhLO1pPo+wIDAQABAoIBAF6yWwekrlL1k7Xu
jTI6J7hCUesaS1yt0iQUzuLtFBXCPS7jjuUPgIXCUWl9wUBhAC8SDjWe+6IGzAiH
xjKKDQuz/iuTVjbDAeTb6exF7b6yZieDswdBVjfJqHR2Wu3LEBTRpo9oQesKhkTS
aFF97rZ3XCD9f/FdWOU5Wr8wm8edFK0zGsZ2N6r57yf1N6ocKlGBLBZ0v1Sc5ShV
1PVAxeephQvwL5DrOgkArnuAzwRXwJQG78L0aldWY2q6xABQZQb5+ml7H/kyytef
i+uGo3jHKepVALHmdpCGr9Yv+yCElup+ekv6cPy8qcmMBqGMISL1i1FEONxLcKWp
GEJi6QECgYEA3ZPGMdUm3f2spdHn3C+/+xskQpz6efiPYpnqFys2TZD7j5OOnpcP
ftNokA5oEgETg9ExJQ8aOCykseDc/abHerYyGw6SQxmDbyBLmkZmp9O3iMv2N8Pb
Nrn9kQKSr6LXZ3gXzlrDvvRoYUlfWuLSxF4b4PYifkA5AfsdiKkj+5sCgYEAzseF
XDTRKHHJnzxZDDdHQcwA0G9agsNj64BGUEjsAGmDiDyqOZnIjDLRt0O2X3oiIE5S
TXySSEiIkxjfErVJMumLaIwqVvlS4pYKdQo1dkM7Jbt8wKRQdleRXOPPN7msoEUk
Ta9ZsftHVUknPqblz9Uthb5h+sRaxIaE1llqDiECgYATS4oHzuL6k9uT+Qpyzymt
qThoIJljQ7TgxjxvVhD9gjGV2CikQM2Vov1JBigj4Toc0XuxGXaUC7cv0kAMSpi2
Y+VLG+K6ux8J70sGHTlVRgeGfxRq2MBfLKUbGplBeDG/zeJs0tSW7VullSkblgL6
nKNa3LQ2QEt2k7KHswryHwKBgENDxk8bY1q7wTHKiNEffk+aFD25q4DUHMH0JWti
fVsY98+upFU+gG2S7oOmREJE0aser0lDl7Zp2fu34IEOdfRY4p+s0O0gB+Vrl5VB
L+j7r9bzaX6lNQN6MvA7ryHahZxRQaD/xLbQHgFRXbHUyvdTyo4yQ1821qwNclLk
HUrhAoGAUtjR3nPFR4TEHlpTSQQovS8QtGTnOi7s7EzzdPWmjHPATrdLhMA0ezPj
Mr+u5TRncZBIzAZtButlh1AHnpN/qO3P0c0Rbdep3XBc/82JWO8qdb5QvAkxga3X
BpA7MNLxiqss+rCbwf3NbWxEMiDQ2zRwVoafVFys7tjmv6t2Xck=
-----END RSA PRIVATE KEY-----
`)

这样子能理解吗

lesismal commented 1 year ago

client代码也提供下。是server还是client拿不到err信息,期望在哪里能拿到err?尽量清晰些

Kotodian commented 1 year ago

client代码也提供下。是server还是client拿不到err信息,期望在哪里能拿到err?尽量清晰些

error那个无所谓了,是我自己的问题,就是我想拿到Request的TLS那个字段

Kotodian commented 1 year ago

client代码也提供下。是server还是client拿不到err信息,期望在哪里能拿到err?尽量清晰些

看看我server的那个中文的注释

lesismal commented 1 year ago

http.Request目前没有设置这个TLS,可以下面这样拿,标准库的tls这个接口是返回的拷贝后的结构体而非指针、魔改也只是去掉了锁,所以肯定能拿到,但是魔改后的用在nbio里并发模型不同,目前没有经过大量测试:

func onWebsocket(w http.ResponseWriter, r *http.Request) {
    h, ok := w.(http.Hijacker)
    if !ok {
        panic("invalid Hijacker")
    }
    c, _, err := h.Hijack()
    if err != nil {
        panic(err)
    }

    tlsConn, ok := c.(*tls.Conn)
    if !ok {
        panic("invalid tls conn")
    }

    tlsState := tlsConn.ConnectionState()
    log.Printf("tlsState: %v", tlsState)

    // time.Sleep(time.Second * 5)
    upgrader := newUpgrader()
    conn, err := upgrader.Upgrade(w, r, nil)
    if err != nil {
        panic(err)
    }

    wsConn := conn.(*websocket.Conn)
    if *print {
        fmt.Println("OnOpen:", wsConn.RemoteAddr().String())
    }
}

也可以在Upgrade后通过wsConn.Conn的断言拿到tlsConn,但你们如果需要更多TLS相关的处理,应该是Upgrade前更合理些。

你可以先试下这样用满足需要不,我考虑下是否需要把Request.TLS设置上

通常TLS自己握手阶段就足够安全验证相关的了,你们拿这个是要做哪些进一步的验证吗?

Kotodian commented 1 year ago

感谢感谢,领导需要序列号和ou,所以才会去看,我先测试下,到时候给你反馈

lesismal commented 1 year ago

好的

感谢感谢,领导需要序列号和ou,所以才会去看,我先测试下,到时候给你反馈

Kotodian commented 1 year ago

我是这样想的,因为你的tls和普通的是能够放在一个server里的,我是想着区分tls和普通连接的,如果能在Request里加上tls,我能在接口里知道这是一个tls连接,这样感觉处理起来会更方便一点

lesismal commented 1 year ago

我想了下,不应该设置Request.TLS。原因:

  1. 对于标准库,是一个协程去循环读取,读到一个Request就处理一个,在这个连接的生命周期中,TLS的读写都是串行的,不存在并发的问题。
  2. 对于nbio,是poller协程在读,读到一个Request就交给另外的逻辑协程去处理,这期间存在并发的读写的race、TLS字段一致性问题,并且标准库Request.TLS上是指针,而不是通过接口获取,所以不好处理这个并发的问题。除非再单独实现个Request、不使用标准库的Request,但那就得不偿失了。
Kotodian commented 1 year ago

好吧 能看看这个问题吗,使用tls时遇到的

image
lesismal commented 1 year ago

求你不要在这种贴日志图了。 贴可以复现的完整的server和client的代码和这个日志。

Kotodian commented 1 year ago

求你不要在这种贴日志图了。 贴可以复现的完整的server和client的代码和这个日志。

client不是我写的 我也贴不了,因为我自己测试是可以的,所以能提供的信息很有限 就这个空指针就是certReq是空的感觉

lesismal commented 1 year ago

不是非要别人完整的client,能复现的代码就行。还有server是怎么使用TLS的

lesismal commented 1 year ago

调用栈都打印出来了,你也可以自己在源码对应的位置加日志或者调试去确认下哪个是nil

Kotodian commented 1 year ago

server代码

package ws

import (
    "context"
    "crypto/x509"
    "fmt"
    "net/http"
    "os"
    "strings"
    "time"

    "gitee.com/csms/jxeu-ocpp/api"
    esam "gitee.com/csms/jxeu-ocpp/api/esam/v1"
    services "gitee.com/csms/jxeu-ocpp/api/services/v1/equip"
    "gitee.com/csms/jxeu-ocpp/util/config"
    "gitee.com/csms/jxeu-ocpp/util/log"
    "github.com/Kotodian/gokit/sync/errgroup"
    "github.com/gorilla/mux"
    "github.com/lesismal/llib/std/crypto/tls"
    "github.com/lesismal/nbio/nbhttp"
    "github.com/lesismal/nbio/nbhttp/websocket"
    "go.uber.org/zap"
)

type NbWebsocket struct {
    id                string
    conn              *websocket.Conn
    baseURL           string
    keepaliveInterval int
    encryptionKey     string
    remoteAddress     string
    connTime          int64
    close             chan struct{}
}

func newNbWebsocket(id string, conn *websocket.Conn, baseURL string, interval int, encryptionKey, remoteAddress string) *NbWebsocket {
    return &NbWebsocket{
        id:                id,
        conn:              conn,
        baseURL:           baseURL,
        keepaliveInterval: interval,
        encryptionKey:     encryptionKey,
        remoteAddress:     remoteAddress,
        connTime:          time.Now().Unix(),
        close:             make(chan struct{}),
    }
}

func (nb *NbWebsocket) ID() string {
    return nb.id
}

func (nb *NbWebsocket) Interval() int {
    return nb.keepaliveInterval
}

func (nb *NbWebsocket) RemoteAddress() string {
    return nb.remoteAddress
}

func (nb *NbWebsocket) ConnTime() int64 {
    return nb.connTime
}

func (nb *NbWebsocket) BaseURL() string {
    return nb.baseURL
}

type NbServer struct {
    server      *nbhttp.Server
    pprofServer *http.Server
    // connections ConnectionPoolG[*NbWebsocket]
    connections ConnectionPool
    hostname    string

    clientHandler             func(Channel)
    disconnectedClientHandler func(Channel)
    pingHandler               func(Channel)
    subProtocols              []string
    handleFunc                http.HandlerFunc
    // 处理消息方法
    messageHandler func(Channel, []byte) error

    logger *log.Logger
    close  chan struct{}

    ctx    context.Context
    cancel context.CancelFunc
}

func NewNbServer(logger *log.Logger) *NbServer {
    hostname, _ := os.Hostname()
    nb := &NbServer{
        pprofServer: &http.Server{Addr: ":6060", Handler: nil},
        hostname:    hostname,
        close:       make(chan struct{}, 1),
        logger:      logger,
    }
    configs := nbhttp.Config{Addrs: []string{config.AppConfig.WSPort}}

    if config.AppConfig.WSTLSCertKeyPath != "" && config.AppConfig.WSTLSCertPath != "" {
        keyPath := config.AppConfig.WSTLSCertKeyPath
        certPath := config.AppConfig.WSTLSCertPath
        tlsConfig := &tls.Config{ClientAuth: tls.RequireAndVerifyClientCert}
        cert, err := tls.LoadX509KeyPair(certPath, keyPath)
        if err != nil {
            nb.logger.Panic("failed to read server cert or server key", zap.Error(err))
        }
        tlsConfig.Certificates = []tls.Certificate{cert}
        if config.AppConfig.WSTLSCACertPath != "" {
            caPath := config.AppConfig.WSTLSCACertPath
            if caCert, err := os.ReadFile(caPath); err != nil {
                nb.logger.Error("failed to read ca cert, InsecureSkipVerify will be true", zap.Error(err))
            } else {
                pool := x509.NewCertPool()
                pool.AppendCertsFromPEM(caCert)
                tlsConfig.ClientCAs = pool
            }
        } else {
            nb.logger.Warn("no ca cert")
        }

        configs.TLSConfig = tlsConfig
        configs.AddrsTLS = []string{config.AppConfig.WSTLSPort}
    }
    nb.server = nbhttp.NewServer(configs)
    nb.server.CheckUtf8 = func(data []byte) bool {
        return true
    }
    nb.ctx, nb.cancel = context.WithCancel(context.Background())
    return nb
}

// func (nb *NbServer) SetConnectionPool(pool ConnectionPoolG[*NbWebsocket]) {
//  nb.connections = pool
// }

func (nb *NbServer) SetConnectionPool(pool ConnectionPool) {
    nb.connections = pool
}

func (nb *NbServer) Start(routePath string) {
    router := mux.NewRouter()
    router.HandleFunc(routePath, func(w http.ResponseWriter, r *http.Request) {
        if nb.handleFunc != nil {
            nb.handleFunc(w, r)
        } else {
            nb.onWebsocket(w, r)
        }
    })

    nb.server.Handler = router

    // nb.upgrader.CheckOrigin = func(r *http.Request) bool {
    //  return true
    // }

    // nb.upgrader.KeepaliveTime = readWait

    // nb.upgrader.SetPingHandler(func(conn *websocket.Conn, s string) {
    //  _ = conn.SetReadDeadline(time.Now().Add(readWait))
    //  session := conn.Session().(*NbWebsocket)
    //  if nb.pingHandler != nil {
    //      nb.pingHandler(session)
    //  }
    //  _ = conn.SetWriteDeadline(time.Now().Add(writeWait))
    //  if err := conn.WriteMessage(websocket.PongMessage, []byte(s)); err != nil {
    //      nb.logger.Error("failed to send pong message", zap.String("id", session.ID()), zap.Error(err))
    //      _ = conn.Close()
    //  }
    // })

    // nb.upgrader.OnClose(func(conn *websocket.Conn, err error) {
    //  session := conn.Session().(*NbWebsocket)

    //  nb.connections.Delete(session.id)

    //  if nb.disconnectedClientHandler != nil {
    //      nb.disconnectedClientHandler(session)
    //  }
    //  nb.logger.Error("websocket is closed", zap.String("id", session.ID()), zap.Error(err))
    // })

    // nb.upgrader.OnMessage(func(conn *websocket.Conn, messageType websocket.MessageType, bytes []byte) {
    //  session := conn.Session().(*NbWebsocket)
    //  if messageType == websocket.TextMessage {
    //      if nb.messageHandler != nil {
    //          if err := nb.messageHandler(session, bytes); err != nil {
    //              nb.logger.Error("failed to handle message", zap.String("id", session.ID()), zap.Error(err))
    //          }
    //      }
    //  }
    // })

    group, _ := errgroup.WithContext(nb.ctx)
    group.Go(func() error {
        if err := nb.server.Start(); err != nil {
            nb.logger.Error("failed to start websocket server", zap.Error(err))
            return err
        }
        <-nb.ctx.Done()
        ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
        defer cancel()
        return nb.server.Shutdown(ctx)
    })

    group.Go(func() error {
        nb.logger.Debug("start pprof server", zap.String("port", ":6060"))
        return nb.pprofServer.ListenAndServe()
    })
    _ = group.Wait()
}

func (nb *NbServer) onWebsocket(w http.ResponseWriter, r *http.Request) {
    snPassword := mux.Vars(r)["client"]
    if snPassword == "" {
        w.WriteHeader(http.StatusBadRequest)
        _, _ = w.Write([]byte("client is empty"))
        return
    }

    var password string
    spilt := strings.Split(snPassword, ":")
    if len(spilt) > 1 {
        password = spilt[1]
    }

    sn := spilt[0]

    var certSerialNumber string

    port := "31887"
    if host := strings.Split(r.Host, ":"); len(host) == 2 {
        port = host[1]
    } else {
        nb.logger.Warn("get host failed", zap.String("id", sn), zap.String("host", r.Host))
    }

    if port == "31900" || port == "8846" {
        h, ok := w.(http.Hijacker)
        if !ok {
            nb.logger.Warn("cannot convert http.ResponseWriter to Hijacker", zap.String("sn", sn))
        } else {
            c, _, err := h.Hijack()
            if err != nil {
                nb.logger.Error("cannot hijack", zap.String("sn", sn), zap.Error(err))
            } else {
                tlsConn, ok := c.(*tls.Conn)
                if !ok {
                    nb.logger.Warn("invalid tls conn", zap.String("sn", sn))
                } else {
                    if len(tlsConn.ConnectionState().PeerCertificates) > 0 {
                        certSerialNumber = tlsConn.ConnectionState().PeerCertificates[0].SerialNumber.String()
                    } else {
                        nb.logger.Warn("no peer certificates", zap.String("sn", sn))
                    }
                }
            }
        }
    }

    var protocol api.OcppVersion
    if config.AppConfig.OcppVersion == "1.6" {
        protocol = api.Ocpp16
    } else if config.AppConfig.OcppVersion == "2.0.1" {
        protocol = api.Ocpp201
    }

    remoteAddress := readUserIP(r)

    verifyReq := esam.NewAccessVerifyRequest(sn, port, protocol)

    if password != "" {
        verifyReq.AccountPassword = &password
    }

    if certSerialNumber != "" {
        verifyReq.CertSerialNumber = &certSerialNumber
    }

    if remoteAddress != "" {
        verifyReq.RemoteAddress = &remoteAddress
    }

    resp, err := esam.AccessVerifyRequest(verifyReq)

    if err != nil {
        nb.logger.Error("failed to request to AccessVerify", zap.String("id", sn), zap.Error(err))
        w.WriteHeader(http.StatusBadRequest)
        _, _ = w.Write([]byte(err.Error()))
        return
    }

    oldWs, ok := nb.getConnection(sn)
    if ok {
        nb.logger.Warn("the duplicate connection has already existed", zap.String("id", sn))
        _ = oldWs.conn.WriteMessage(websocket.CloseMessage, []byte("the old sn hasn't been closed"))
        nb.cleanupConnection(oldWs)
        // we need to make sure the connection is closed
        t := time.NewTimer(3 * time.Second)
        select {
        case <-t.C:
        case <-oldWs.close:
            t.Stop()
        }
    }

    // if compression(r) {
    // nb.upgrader.EnableCompression(true)
    // }

    if resp.Data.Registered {
        onlineReq := services.NewEquipOnlineRequest(sn, protocol, nb.hostname)
        if remoteAddress != "" {
            onlineReq.Data.RemoteAddress = &remoteAddress
        }
        err = services.OnlineRequest(onlineReq)
    } else {
        registerReq := services.NewEquipRegisterRequest(sn, protocol, nb.hostname)
        if remoteAddress != "" {
            registerReq.Data.RemoteAddress = &remoteAddress
        }
        err = services.RegisterRequest(registerReq)
    }

    if err != nil {
        nb.logger.Error("failed to register or be online", zap.String("id", sn), zap.Error(err))
        w.WriteHeader(http.StatusBadRequest)
        _, _ = w.Write([]byte("failed to register or be online because of a bad request to services"))
        return
    }
    up := websocket.NewUpgrader()
    up.CheckOrigin = func(r *http.Request) bool {
        return true
    }
    up.KeepaliveTime = readWait
    up.Subprotocols = nb.subProtocols
    up.SetPingHandler(nb.wrapPingHandler())
    up.OnMessage(nb.wrapMessageHandler())
    up.OnClose(nb.wrapCloseHandler())

    if compression(r) {
        up.EnableCompression(true)
    }

    conn, err := up.Upgrade(w, r, nil)
    if err != nil {
        nb.logger.Error(err.Error(), zap.String("id", sn), zap.Error(err))
        w.WriteHeader(http.StatusBadRequest)
        _, _ = w.Write([]byte(fmt.Sprintf("upgrade failed because of %s", err.Error())))

        offlineReq := services.NewEquipOfflineRequest(sn, protocol, nb.hostname)
        if err = services.OfflineRequest(offlineReq); err != nil {
            nb.logger.Error("failed to send offline request to services", zap.String("id", sn), zap.Error(err))
        }
        return
    }

    nb.logger.Debug("new websocket connection", zap.String("id", sn), zap.String("port", port))

    nwsConn := conn.(*websocket.Conn)

    nws := newNbWebsocket(sn, nwsConn, resp.Data.BaseUrl, resp.Data.KeepaliveInterval, resp.Data.EncryptionKey, remoteAddress)

    nwsConn.SetSession(nws)

    _ = nb.connections.Set(nws.id, nws)

    if nb.clientHandler != nil {
        nb.clientHandler(nws)
    }
}

func (nb *NbServer) getConnection(id string) (*NbWebsocket, bool) {
    conn, err := nb.connections.Get(id)
    if err != nil || conn == nil {
        return nil, false
    }
    return conn.(*NbWebsocket), true
}

func (nb *NbServer) Write(id string, data []byte) error {
    conn, ok := nb.getConnection(id)
    if !ok {
        return fmt.Errorf("connection not found; %s", id)
    }
    _ = conn.conn.SetWriteDeadline(time.Now().Add(writeWait))
    if err := conn.conn.WriteMessage(websocket.TextMessage, data); err != nil {
        return err
    }
    return nil
}

func (nb *NbServer) Stop() {
    ctx := context.Background()
    _ = nb.pprofServer.Shutdown(ctx)
    nb.cancel()
}

func (nb *NbServer) SetHandleFunc(handleFunc http.HandlerFunc) {
    nb.handleFunc = handleFunc
}

func (nb *NbServer) SetDisconnectedClientHandler(handler func(Channel)) {
    nb.disconnectedClientHandler = handler
}

func (nb *NbServer) SetNewClientHandler(handler func(Channel)) {
    nb.clientHandler = handler
}

func (nb *NbServer) SetPingHandler(handler func(channel Channel)) {
    nb.pingHandler = handler
}

func (nb *NbServer) SetMessageHandler(handler func(Channel, []byte) error) {
    nb.messageHandler = handler
}

func (nb *NbServer) IsOnline(id string) (Channel, bool) {
    conn, ok := nb.getConnection(id)
    return conn, ok
}

func (nb *NbServer) AddSupportedSubprotocol(protocol string) {
    for _, sub := range nb.subProtocols {
        if sub == protocol {
            return
        }
    }
    nb.subProtocols = append(nb.subProtocols, protocol)
}

func (nb *NbServer) cleanupConnection(nws *NbWebsocket) {
    _ = nws.conn.Close()
}

func NewEchoNbServer(logger *log.Logger) *NbServer {
    server := NewNbServer(logger)
    server.SetHandleFunc(func(w http.ResponseWriter, r *http.Request) {
        snPassword := mux.Vars(r)["client"]
        if snPassword == "" {
            w.WriteHeader(http.StatusBadRequest)
            _, _ = w.Write([]byte("client is empty"))
            return
        }
        fmt.Println("this is started")
        sn := snPassword
        oldWs, ok := server.getConnection(snPassword)
        if ok {
            server.logger.Warn("the duplicate connection has already existed", zap.String("id", sn))
            _ = oldWs.conn.WriteMessage(websocket.CloseMessage, []byte("the old sn hasn't been closed"))
            server.cleanupConnection(oldWs)
            t := time.NewTimer(3 * time.Second)
            select {
            case <-t.C:
            case <-oldWs.close:
                t.Stop()
            }
        }

        up := websocket.NewUpgrader()
        up.CheckOrigin = func(r *http.Request) bool {
            return true
        }
        up.KeepaliveTime = readWait
        up.Subprotocols = server.subProtocols
        up.SetPingHandler(server.wrapPingHandler())
        up.OnMessage(server.wrapMessageHandler())
        up.OnClose(server.wrapCloseHandler())

        if compression(r) {
            up.EnableCompression(true)
        }

        conn, err := up.Upgrade(w, r, nil)
        if err != nil {
            server.logger.Error(err.Error())
            w.WriteHeader(http.StatusBadRequest)
            _, _ = w.Write([]byte("upgrade failed"))
            return
        }

        ws := conn.(*websocket.Conn)
        wsConn := newNbWebsocket(snPassword, ws, "https://www.baidu.com", 180, "", "")

        err = server.connections.Set(sn, wsConn)
        if err != nil {
            server.logger.Error("cannot set the connection into the connection pool", zap.String("id", sn))
            _ = ws.WriteMessage(websocket.CloseMessage, []byte("failed set"))
            _ = ws.Close()
            return
        }

        ws.SetSession(wsConn)
        server.logger.Debug("new websocket connection", zap.String("id", wsConn.id))
    })

    return server
}

func (nb *NbServer) wrapPingHandler() func(conn *websocket.Conn, s string) {
    return func(conn *websocket.Conn, s string) {
        _ = conn.SetReadDeadline(time.Now().Add(readWait))
        session := conn.Session().(*NbWebsocket)
        if nb.pingHandler != nil {
            nb.pingHandler(session)
        }
        _ = conn.SetWriteDeadline(time.Now().Add(writeWait))
        if err := conn.WriteMessage(websocket.PongMessage, []byte(s)); err != nil {
            nb.logger.Error("failed to send pong message", zap.String("id", session.ID()), zap.Error(err))
            _ = conn.Close()
        }
    }
}

func (nb *NbServer) wrapMessageHandler() func(conn *websocket.Conn, messageType websocket.MessageType, bytes []byte) {
    return func(conn *websocket.Conn, messageType websocket.MessageType, bytes []byte) {
        session := conn.Session().(*NbWebsocket)
        if messageType == websocket.TextMessage {
            if nb.messageHandler != nil {
                if err := nb.messageHandler(session, bytes); err != nil {
                    nb.logger.Error("failed to handle message", zap.String("id", session.ID()), zap.Error(err))
                }
            }
        }
    }
}

func (nb *NbServer) wrapCloseHandler() func(conn *websocket.Conn, err error) {
    return func(conn *websocket.Conn, err error) {
        session := conn.Session().(*NbWebsocket)

        _ = nb.connections.Delete(session.id)

        if nb.disconnectedClientHandler != nil {
            nb.disconnectedClientHandler(session)
        }
        close(session.close)
        nb.logger.Error("websocket is closed", zap.String("id", session.ID()), zap.Error(err))
    }
}

空指针panic在这里

func (hs *serverHandshakeState) doFullHandshake() error {
    c := hs.c

    var err error
    var msg interface{}
    var pub crypto.PublicKey // public key for client auth, if any
    var certReq *certificateRequestMsg
    if c.handshakeStatusAsync < stateServerHandshakeDoFullHandshake2 {
        c.handshakeStatusAsync = stateServerHandshakeDoFullHandshake2

        if hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 {
            hs.hello.ocspStapling = true
        }

        hs.hello.ticketSupported = hs.clientHello.ticketSupported && !c.config.SessionTicketsDisabled
        hs.hello.cipherSuite = hs.suite.id

        hs.finishedHash = newFinishedHash(hs.c.vers, hs.suite)
        if c.config.ClientAuth == NoClientCert {
            // No need to keep a full record of the handshake if client
            // certificates won't be used.
            hs.finishedHash.discardHandshakeBuffer()
        }

        c.buffering = true
        hs.finishedHash.Write(hs.clientHello.marshal())
        hs.finishedHash.Write(hs.hello.marshal())
        if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil {
            return err
        }

        certMsg := new(certificateMsg)
        certMsg.certificates = hs.cert.Certificate
        hs.finishedHash.Write(certMsg.marshal())
        if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
            return err
        }

        if hs.hello.ocspStapling {
            certStatus := new(certificateStatusMsg)
            certStatus.response = hs.cert.OCSPStaple
            hs.finishedHash.Write(certStatus.marshal())
            if _, err := c.writeRecord(recordTypeHandshake, certStatus.marshal()); err != nil {
                return err
            }
        }

        hs.ka = hs.suite.ka(c.vers)
        skx, err := hs.ka.generateServerKeyExchange(c.config, hs.cert, hs.clientHello, hs.hello)
        if err != nil {
            c.sendAlert(alertHandshakeFailure)
            return err
        }
        if skx != nil {
            hs.finishedHash.Write(skx.marshal())
            if _, err := c.writeRecord(recordTypeHandshake, skx.marshal()); err != nil {
                return err
            }
        }

        if c.config.ClientAuth >= RequestClientCert {
            fmt.Println("init certReq")
            // Request a client certificate
            certReq = new(certificateRequestMsg)
            certReq.certificateTypes = []byte{
                byte(certTypeRSASign),
                byte(certTypeECDSASign),
            }
            if c.vers >= VersionTLS12 {
                certReq.hasSignatureAlgorithm = true
                certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms
            }

            // An empty list of certificateAuthorities signals to
            // the client that it may send any certificate in response
            // to our request. When we know the CAs we trust, then
            // we can send them down, so that the client can choose
            // an appropriate certificate to give to us.
            if c.config.ClientCAs != nil {
                certReq.certificateAuthorities = c.config.ClientCAs.Subjects()
            }
            hs.finishedHash.Write(certReq.marshal())
            if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil {
                return err
            }
        }

        helloDone := new(serverHelloDoneMsg)
        hs.finishedHash.Write(helloDone.marshal())
        if _, err := c.writeRecord(recordTypeHandshake, helloDone.marshal()); err != nil {
            return err
        }

        if _, err := c.flush(); err != nil {
            return err
        }
    }

    if c.handshakeStatusAsync < stateServerHandshakeDoFullHandshake2ReadHandshake1 {
        msg, err = c.readHandshake()
        if err != nil {
            if err != errDataNotEnough {
                c.handshakeStatusAsync = stateServerHandshakeDoFullHandshake2ReadHandshake1
            }
            return err
        }
        c.handshakeStatusAsync = stateServerHandshakeDoFullHandshake2ReadHandshake1
    }

    // If we requested a client certificate, then the client must send a
    // certificate message, even if it's empty.

    if c.config.ClientAuth >= RequestClientCert {
        if c.handshakeStatusAsync < stateServerHandshakeDoFullHandshake2HandleCertificateMsg {
            c.handshakeStatusAsync = stateServerHandshakeDoFullHandshake2HandleCertificateMsg

            certMsg, ok := msg.(*certificateMsg)
            if !ok {
                c.sendAlert(alertUnexpectedMessage)
                return unexpectedMessageError(certMsg, msg)
            }
            hs.finishedHash.Write(certMsg.marshal())

            if err := c.processCertsFromClient(Certificate{
                Certificate: certMsg.certificates,
            }); err != nil {
                return err
            }
            if len(certMsg.certificates) != 0 {
                pub = c.peerCertificates[0].PublicKey
            }
        }
        if c.handshakeStatusAsync < stateServerHandshakeDoFullHandshake2ReadHandshake2 {
            msg, err = c.readHandshake()
            if err != nil {
                if err != errDataNotEnough {
                    c.handshakeStatusAsync = stateServerHandshakeDoFullHandshake2ReadHandshake2
                }
                return err
            }
            c.handshakeStatusAsync = stateServerHandshakeDoFullHandshake2ReadHandshake2

        }
    }

    if c.handshakeStatusAsync < stateServerHandshakeDoFullHandshake2HandleVerifyConnection {
        c.handshakeStatusAsync = stateServerHandshakeDoFullHandshake2HandleVerifyConnection
        if c.config.VerifyConnection != nil {
            if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
                c.sendAlert(alertBadCertificate)
                return err
            }
        }

        // Get client key exchange
        ckx, ok := msg.(*clientKeyExchangeMsg)
        if !ok {
            c.sendAlert(alertUnexpectedMessage)
            return unexpectedMessageError(ckx, msg)
        }
        hs.finishedHash.Write(ckx.marshal())

        preMasterSecret, err := hs.ka.processClientKeyExchange(c.config, hs.cert, ckx, c.vers)
        if err != nil {
            c.sendAlert(alertHandshakeFailure)
            return err
        }
        hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.clientHello.random, hs.hello.random)
        if err := c.config.writeKeyLog(keyLogLabelTLS12, hs.clientHello.random, hs.masterSecret); err != nil {
            c.sendAlert(alertInternalError)
            return err
        }

    }

    if c.handshakeStatusAsync >= stateServerHandshakeDoFullHandshake2ReadHandshake3 {
        return nil
    }
    // If we received a client cert in response to our certificate request message,
    // the client will send us a certificateVerifyMsg immediately after the
    // clientKeyExchangeMsg. This message is a digest of all preceding
    // handshake-layer messages that is signed using the private key corresponding
    // to the client's certificate. This allows us to verify that the client is in
    // possession of the private key of the certificate.
    if len(c.peerCertificates) > 0 {
        msg, err := c.readHandshake()
        if err != nil {
            if err != errDataNotEnough {
                c.handshakeStatusAsync = stateServerHandshakeDoFullHandshake2ReadHandshake3
            }
            return err
        }
        certVerify, ok := msg.(*certificateVerifyMsg)
        if !ok {
            c.sendAlert(alertUnexpectedMessage)
            c.handshakeStatusAsync = stateServerHandshakeDoFullHandshake2ReadHandshake3
            return unexpectedMessageError(certVerify, msg)
        }

        var sigType uint8
        var sigHash crypto.Hash
        if c.vers >= VersionTLS12 {
            fmt.Println(certReq)
                         // certReq是空指针,但是初始化跟客户端没什么关系,初始化是根据server的tls.Config去做的
            if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, certReq.supportedSignatureAlgorithms) {
                c.sendAlert(alertIllegalParameter)
                c.handshakeStatusAsync = stateServerHandshakeDoFullHandshake2ReadHandshake3
                return errors.New("tls: client certificate used with invalid signature algorithm")
            }
            sigType, sigHash, err = typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm)
            if err != nil {
                c.handshakeStatusAsync = stateServerHandshakeDoFullHandshake2ReadHandshake3
                return c.sendAlert(alertInternalError)
            }
        } else {
            sigType, sigHash, err = legacyTypeAndHashFromPublicKey(pub)
            if err != nil {
                c.sendAlert(alertIllegalParameter)
                c.handshakeStatusAsync = stateServerHandshakeDoFullHandshake2ReadHandshake3
                return err
            }
        }

        signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash, hs.masterSecret)
        if err := verifyHandshakeSignature(sigType, pub, sigHash, signed, certVerify.signature); err != nil {
            c.sendAlert(alertDecryptError)
            c.handshakeStatusAsync = stateServerHandshakeDoFullHandshake2ReadHandshake3
            return errors.New("tls: invalid signature by the client certificate: " + err.Error())
        }

        hs.finishedHash.Write(certVerify.marshal())
    }

    hs.finishedHash.discardHandshakeBuffer()

    c.handshakeStatusAsync = stateServerHandshakeDoFullHandshake2ReadHandshake3

    return nil
}
Kotodian commented 1 year ago

明明已经init过了但是到了下面变成了空指针 RequireAndVerifyClientCert init certReq nil 2022/09/30 11:44:30.025 [ERR] execute parser failed: runtime error: invalid memory address or nil pointer dereference goroutine 102 [running]: github.com/lesismal/nbio/nbhttp.(Engine).TLSDataHandler.func1() /go/pkg/mod/github.com/!kotodian/nbio@v0.0.9/nbhttp/engine.go:469 +0x72 panic({0xcee6e0, 0x151b3e0}) /usr/local/go/src/runtime/panic.go:884 +0x212 github.com/lesismal/llib/std/crypto/tls.(serverHandshakeState).doFullHandshake(0xc002712000) /go/pkg/mod/github.com/!kotodian/llib@v0.0.2/std/crypto/tls/handshake_server.go:773 +0x12c1 github.com/lesismal/llib/std/crypto/tls.(serverHandshakeState).handshake(0xc002712000) /go/pkg/mod/github.com/!kotodian/llib@v0.0.2/std/crypto/tls/handshake_server.go:170 +0x394 github.com/lesismal/llib/std/crypto/tls.(Conn).serverHandshake(0xc000388e00) /go/pkg/mod/github.com/!kotodian/llib@v0.0.2/std/crypto/tls/handshake_server.go:116 +0x172 github.com/lesismal/llib/std/crypto/tls.(Conn).Handshake(0xc000388e00) /go/pkg/mod/github.com/!kotodian/llib@v0.0.2/std/crypto/tls/conn.go:1706 +0x43 github.com/lesismal/llib/std/crypto/tls.(Conn).AppendAndRead(0xc000388e00, {0xc000318000, 0x10d, 0x8000}, {0xc0005e6000, 0x8000, 0x8000?}) /go/pkg/mod/github.com/!kotodian/llib@v0.0.2/std/crypto/tls/conn.go:1565 +0x28d github.com/lesismal/nbio/nbhttp.(Engine).TLSDataHandler(0xc0004540c0, 0xc0006c2340, {0xc000318000, 0x10d, 0x8000}) /go/pkg/mod/github.com/!kotodian/nbio@v0.0.9/nbhttp/engine.go:485 +0x1e7 github.com/lesismal/nbio/nbhttp.NewEngine.func6(0xc0006c2340?, {0xc000318000?, 0xf50020?, 0xf4a968?}) /go/pkg/mod/github.com/!kotodian/nbio@v0.0.9/nbhttp/engine.go:705 +0x2a github.com/lesismal/nbio.(poller).readWriteLoop(0xc0005a4380) /go/pkg/mod/github.com/!kotodian/nbio@v0.0.9/poller_epoll.go:188 +0x444 github.com/lesismal/nbio.(*poller).start(0xc0005a4380) /go/pkg/mod/github.com/!kotodian/nbio@v0.0.9/poller_epoll.go:101 +0x29c created by github.com/lesismal/nbio.noRacePollerRun /go/pkg/mod/github.com/!kotodian/nbio@v0.0.9/race_disable.go:8 +0x27

lesismal commented 1 year ago

刚提交了下llib:https://github.com/lesismal/llib/blob/master/std/crypto/tls/handshake_server.go llib更到最新版试试看行不

Kotodian commented 1 year ago

抱歉 可能得等到国庆后了,只是看了下改动,就是包装一个函数去初始化变量,我好奇为什么之前初始化好的这个局部变量会变成空指针

lesismal commented 1 year ago

是包装一个函数去初始化

可能单次收到的可能只有half packet的数据,所以第一次走了初始化那个 if 分支并且 state 更新了。下次再来数据的时候,state 已经不满足第一次的那个初始化的条件,所以没走刚开始那个初始化的 if 分支

Kotodian commented 1 year ago

是包装一个函数去初始化

可能单次收到的可能只有half packet的数据,所以第一次走了初始化那个 if 分支并且 state 更新了。下次再来数据的时候,state 已经不满足第一次的那个初始化的条件,所以没走刚开始那个初始化的 if 分支

好的 理解了

Kotodian commented 1 year ago

还是没成功,我再排查下,就是 https://github.com/Kotodian/nbio/blob/c2ce42e4556609c02b6e058daa8e5d446a642f1e/nbhttp/engine.go#L500 这里能不能加点日志打印

Kotodian commented 1 year ago

在llib里加了好几个fmt.Println都没打印出来,算了感觉这个tls太麻烦了 还使用gorilla吧 不知道什么原因