xtaci / smux

A Stream Multiplexing Library for golang with least memory usage(TDMA)
MIT License
1.27k stars 189 forks source link

修正一处细节, #68

Open xdown-dev opened 4 years ago

xdown-dev commented 4 years ago

session.go文件, 搜索 s.notifyWriteError(err)

然后下面添加一行
s.Close()

完整文件:

package smux

import (
    "container/heap"
    "encoding/binary"
    "errors"
    "io"
    "net"
    "sync"
    "sync/atomic"
    "time"
)

const (
    defaultAcceptBacklog = 1024
)

var (
    ErrInvalidProtocol = errors.New("invalid protocol")
    ErrConsumed        = errors.New("peer consumed more than sent")
    ErrGoAway          = errors.New("stream id overflows, should start a new connection")
    ErrTimeout         = errors.New("timeout")
    ErrWouldBlock      = errors.New("operation would block on IO")
)

type writeRequest struct {
    prio   uint64
    frame  Frame
    result chan writeResult
}

type writeResult struct {
    n   int
    err error
}

type buffersWriter interface {
    WriteBuffers(v [][]byte) (n int, err error)
}

// Session defines a multiplexed connection for streams
type Session struct {
    conn io.ReadWriteCloser

    config           *Config
    nextStreamID     uint32 // next stream identifier
    nextStreamIDLock sync.Mutex

    bucket       int32         // token bucket
    bucketNotify chan struct{} // used for waiting for tokens

    streams    map[uint32]*Stream // all streams in this session
    streamLock sync.Mutex         // locks streams

    die     chan struct{} // flag session has died
    dieOnce sync.Once

    // socket error handling
    socketReadError      atomic.Value
    socketWriteError     atomic.Value
    chSocketReadError    chan struct{}
    chSocketWriteError   chan struct{}
    socketReadErrorOnce  sync.Once
    socketWriteErrorOnce sync.Once

    // smux protocol errors
    protoError     atomic.Value
    chProtoError   chan struct{}
    protoErrorOnce sync.Once

    chAccepts chan *Stream

    dataReady int32 // flag data has arrived

    goAway int32 // flag id exhausted

    deadline atomic.Value

    shaper chan writeRequest // a shaper for writing
    writes chan writeRequest
}

func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
    s := new(Session)
    s.die = make(chan struct{})
    s.conn = conn
    s.config = config
    s.streams = make(map[uint32]*Stream)
    s.chAccepts = make(chan *Stream, defaultAcceptBacklog)
    s.bucket = int32(config.MaxReceiveBuffer)
    s.bucketNotify = make(chan struct{}, 1)
    s.shaper = make(chan writeRequest)
    s.writes = make(chan writeRequest)
    s.chSocketReadError = make(chan struct{})
    s.chSocketWriteError = make(chan struct{})
    s.chProtoError = make(chan struct{})

    if client {
        s.nextStreamID = 1
    } else {
        s.nextStreamID = 0
    }

    go s.shaperLoop()
    go s.recvLoop()
    go s.sendLoop()
    go s.keepalive()
    return s
}

// OpenStream is used to create a new stream
func (s *Session) OpenStream() (*Stream, error) {
    if s.IsClosed() {
        return nil, io.ErrClosedPipe
    }

    // generate stream id
    s.nextStreamIDLock.Lock()
    if s.goAway > 0 {
        s.nextStreamIDLock.Unlock()
        return nil, ErrGoAway
    }

    s.nextStreamID += 2
    sid := s.nextStreamID
    if sid == sid%2 { // stream-id overflows
        s.goAway = 1
        s.nextStreamIDLock.Unlock()
        return nil, ErrGoAway
    }
    s.nextStreamIDLock.Unlock()

    stream := newStream(sid, s.config.MaxFrameSize, s)

    if _, err := s.writeFrame(newFrame(byte(s.config.Version), cmdSYN, sid)); err != nil {
        return nil, err
    }

    s.streamLock.Lock()
    defer s.streamLock.Unlock()
    select {
    case <-s.chSocketReadError:
        return nil, s.socketReadError.Load().(error)
    case <-s.chSocketWriteError:
        return nil, s.socketWriteError.Load().(error)
    case <-s.die:
        return nil, io.ErrClosedPipe
    default:
        s.streams[sid] = stream
        return stream, nil
    }
}

// Open returns a generic ReadWriteCloser
func (s *Session) Open() (io.ReadWriteCloser, error) {
    return s.OpenStream()
}

// AcceptStream is used to block until the next available stream
// is ready to be accepted.
func (s *Session) AcceptStream() (*Stream, error) {
    var deadline <-chan time.Time
    if d, ok := s.deadline.Load().(time.Time); ok && !d.IsZero() {
        timer := time.NewTimer(time.Until(d))
        defer timer.Stop()
        deadline = timer.C
    }

    select {
    case stream := <-s.chAccepts:
        return stream, nil
    case <-deadline:
        return nil, ErrTimeout
    case <-s.chSocketReadError:
        return nil, s.socketReadError.Load().(error)
    case <-s.chProtoError:
        return nil, s.protoError.Load().(error)
    case <-s.die:
        return nil, io.ErrClosedPipe
    }
}

// Accept Returns a generic ReadWriteCloser instead of smux.Stream
func (s *Session) Accept() (io.ReadWriteCloser, error) {
    return s.AcceptStream()
}

// Close is used to close the session and all streams.
func (s *Session) Close() error {
    var once bool
    s.dieOnce.Do(func() {
        close(s.die)
        once = true
    })

    if once {
        s.streamLock.Lock()
        for k := range s.streams {
            s.streams[k].sessionClose()
        }
        s.streamLock.Unlock()
        return s.conn.Close()
    } else {
        return io.ErrClosedPipe
    }
}

// notifyBucket notifies recvLoop that bucket is available
func (s *Session) notifyBucket() {
    select {
    case s.bucketNotify <- struct{}{}:
    default:
    }
}

func (s *Session) notifyReadError(err error) {
    s.socketReadErrorOnce.Do(func() {
        s.socketReadError.Store(err)
        close(s.chSocketReadError)
    })
}

func (s *Session) notifyWriteError(err error) {
    s.socketWriteErrorOnce.Do(func() {
        s.socketWriteError.Store(err)
        close(s.chSocketWriteError)
    })
}

func (s *Session) notifyProtoError(err error) {
    s.protoErrorOnce.Do(func() {
        s.protoError.Store(err)
        close(s.chProtoError)
    })
}

// IsClosed does a safe check to see if we have shutdown
func (s *Session) IsClosed() bool {
    select {
    case <-s.die:
        return true
    default:
        return false
    }
}

// NumStreams returns the number of currently open streams
func (s *Session) NumStreams() int {
    if s.IsClosed() {
        return 0
    }
    s.streamLock.Lock()
    defer s.streamLock.Unlock()
    return len(s.streams)
}

// SetDeadline sets a deadline used by Accept* calls.
// A zero time value disables the deadline.
func (s *Session) SetDeadline(t time.Time) error {
    s.deadline.Store(t)
    return nil
}

// LocalAddr satisfies net.Conn interface
func (s *Session) LocalAddr() net.Addr {
    if ts, ok := s.conn.(interface {
        LocalAddr() net.Addr
    }); ok {
        return ts.LocalAddr()
    }
    return nil
}

// RemoteAddr satisfies net.Conn interface
func (s *Session) RemoteAddr() net.Addr {
    if ts, ok := s.conn.(interface {
        RemoteAddr() net.Addr
    }); ok {
        return ts.RemoteAddr()
    }
    return nil
}

// notify the session that a stream has closed
func (s *Session) streamClosed(sid uint32) {
    s.streamLock.Lock()
    if n := s.streams[sid].recycleTokens(); n > 0 { // return remaining tokens to the bucket
        if atomic.AddInt32(&s.bucket, int32(n)) > 0 {
            s.notifyBucket()
        }
    }
    delete(s.streams, sid)
    s.streamLock.Unlock()
}

// returnTokens is called by stream to return token after read
func (s *Session) returnTokens(n int) {
    if atomic.AddInt32(&s.bucket, int32(n)) > 0 {
        s.notifyBucket()
    }
}

// recvLoop keeps on reading from underlying connection if tokens are available
func (s *Session) recvLoop() {
    var hdr rawHeader
    var updHdr updHeader

    for {
        for atomic.LoadInt32(&s.bucket) <= 0 && !s.IsClosed() {
            select {
            case <-s.bucketNotify:
            case <-s.die:
                return
            }
        }

        // read header first
        if _, err := io.ReadFull(s.conn, hdr[:]); err == nil {
            atomic.StoreInt32(&s.dataReady, 1)
            if hdr.Version() != byte(s.config.Version) {
                s.notifyProtoError(ErrInvalidProtocol)
                return
            }
            sid := hdr.StreamID()
            switch hdr.Cmd() {
            case cmdNOP:
            case cmdSYN:
                s.streamLock.Lock()
                if _, ok := s.streams[sid]; !ok {
                    stream := newStream(sid, s.config.MaxFrameSize, s)
                    s.streams[sid] = stream
                    select {
                    case s.chAccepts <- stream:
                    case <-s.die:
                    }
                }
                s.streamLock.Unlock()
            case cmdFIN:
                s.streamLock.Lock()
                if stream, ok := s.streams[sid]; ok {
                    stream.fin()
                    stream.notifyReadEvent()
                }
                s.streamLock.Unlock()
            case cmdPSH:
                if hdr.Length() > 0 {
                    newbuf := defaultAllocator.Get(int(hdr.Length()))
                    if written, err := io.ReadFull(s.conn, newbuf); err == nil {
                        s.streamLock.Lock()
                        if stream, ok := s.streams[sid]; ok {
                            stream.pushBytes(newbuf)
                            atomic.AddInt32(&s.bucket, -int32(written))
                            stream.notifyReadEvent()
                        }
                        s.streamLock.Unlock()
                    } else {
                        s.notifyReadError(err)
                        return
                    }
                }
            case cmdUPD:
                if _, err := io.ReadFull(s.conn, updHdr[:]); err == nil {
                    s.streamLock.Lock()
                    if stream, ok := s.streams[sid]; ok {
                        stream.update(updHdr.Consumed(), updHdr.Window())
                    }
                    s.streamLock.Unlock()
                } else {
                    s.notifyReadError(err)
                    return
                }
            default:
                s.notifyProtoError(ErrInvalidProtocol)
                return
            }
        } else {
            s.notifyReadError(err)
            return
        }
    }
}

func (s *Session) keepalive() {
    tickerPing := time.NewTicker(s.config.KeepAliveInterval)
    tickerTimeout := time.NewTicker(s.config.KeepAliveTimeout)
    defer tickerPing.Stop()
    defer tickerTimeout.Stop()
    for {
        select {
        case <-tickerPing.C:
            s.writeFrameInternal(newFrame(byte(s.config.Version), cmdNOP, 0), tickerPing.C, 0)
            s.notifyBucket() // force a signal to the recvLoop
        case <-tickerTimeout.C:
            if !atomic.CompareAndSwapInt32(&s.dataReady, 1, 0) {
                // recvLoop may block while bucket is 0, in this case,
                // session should not be closed.
                if atomic.LoadInt32(&s.bucket) > 0 {
                    s.Close()
                    return
                }
            }
        case <-s.die:
            return
        }
    }
}

// shaper shapes the sending sequence among streams
func (s *Session) shaperLoop() {
    var reqs shaperHeap
    var next writeRequest
    var chWrite chan writeRequest

    for {
        if len(reqs) > 0 {
            chWrite = s.writes
            next = heap.Pop(&reqs).(writeRequest)
        } else {
            chWrite = nil
        }

        select {
        case <-s.die:
            return
        case r := <-s.shaper:
            if chWrite != nil { // next is valid, reshape
                heap.Push(&reqs, next)
            }
            heap.Push(&reqs, r)
        case chWrite <- next:
        }
    }
}

func (s *Session) sendLoop() {
    var buf []byte
    var n int
    var err error
    var vec [][]byte // vector for writeBuffers

    bw, ok := s.conn.(buffersWriter)
    if ok {
        buf = make([]byte, headerSize)
        vec = make([][]byte, 2)
    } else {
        buf = make([]byte, (1<<16)+headerSize)
    }

    for {
        select {
        case <-s.die:
            return
        case request := <-s.writes:
            buf[0] = request.frame.ver
            buf[1] = request.frame.cmd
            binary.LittleEndian.PutUint16(buf[2:], uint16(len(request.frame.data)))
            binary.LittleEndian.PutUint32(buf[4:], request.frame.sid)

            if len(vec) > 0 {
                vec[0] = buf[:headerSize]
                vec[1] = request.frame.data
                n, err = bw.WriteBuffers(vec)
            } else {
                copy(buf[headerSize:], request.frame.data)
                n, err = s.conn.Write(buf[:headerSize+len(request.frame.data)])
            }

            n -= headerSize
            if n < 0 {
                n = 0
            }

            result := writeResult{
                n:   n,
                err: err,
            }

            request.result <- result
            close(request.result)

            // store conn error
            if err != nil {
                s.notifyWriteError(err)
                s.Close()
                return
            }
        }
    }
}

// writeFrame writes the frame to the underlying connection
// and returns the number of bytes written if successful
func (s *Session) writeFrame(f Frame) (n int, err error) {
    return s.writeFrameInternal(f, nil, 0)
}

// internal writeFrame version to support deadline used in keepalive
func (s *Session) writeFrameInternal(f Frame, deadline <-chan time.Time, prio uint64) (int, error) {
    req := writeRequest{
        prio:   prio,
        frame:  f,
        result: make(chan writeResult, 1),
    }
    select {
    case s.shaper <- req:
    case <-s.die:
        return 0, io.ErrClosedPipe
    case <-s.chSocketWriteError:
        return 0, s.socketWriteError.Load().(error)
    case <-deadline:
        return 0, ErrTimeout
    }

    select {
    case result := <-req.result:
        return result.n, result.err
    case <-s.die:
        return 0, io.ErrClosedPipe
    case <-s.chSocketWriteError:
        return 0, s.socketWriteError.Load().(error)
    case <-deadline:
        return 0, ErrTimeout
    }
}
xdown-dev commented 4 years ago

photo_2020-02-22_14-25-01

mux的链接session会话,在socket写数据的时候, pipe 破裂了, session只是存储了错误, 没有关闭这个session,

f4nff commented 2 years ago

这个bug这么久都没人同步一下,

Coffee-kitten commented 1 year ago

过了这么久还没改吗

xtaci commented 1 year ago

这个不需要修改啊,下次任何的write操作,就会返回错误,这也是标准socket行为