protocol_v2.go

package nsqd

import (
    "bytes"
    "encoding/binary"
    "encoding/json"
    "errors"
    "fmt"
    "io"
    "math"
    "math/rand"
    "net"
    "sync/atomic"
    "time"
    "unsafe"

    "github.com/nsqio/nsq/internal/protocol"
    "github.com/nsqio/nsq/internal/version"
)

const maxTimeout = time.Hour

const (
    frameTypeResponse int32 = 0
    frameTypeError    int32 = 1
    frameTypeMessage  int32 = 2
)

var separatorBytes = []byte(" ")
var heartbeatBytes = []byte("_heartbeat_")
var okBytes = []byte("OK")

type protocolV2 struct {
    ctx *context
}

func (p *protocolV2) IOLoop(conn net.Conn) error {
    var err error
    var line []byte
    var zeroTime time.Time

    clientID := atomic.AddInt64(&p.ctx.nsqd.clientIDSequence, 1)
    client := newClientV2(clientID, conn, p.ctx)

    // synchronize the startup of messagePump in order
    // to guarantee that it gets a chance to initialize
    // goroutine local state derived from client attributes
    // and avoid a potential race with IDENTIFY (where a client
    // could have changed or disabled said attributes)
    messagePumpStartedChan := make(chan bool)
    go p.messagePump(client, messagePumpStartedChan)
    <-messagePumpStartedChan

    for {
        if client.HeartbeatInterval > 0 {
            client.SetReadDeadline(time.Now().Add(client.HeartbeatInterval * 2))
        } else {
            client.SetReadDeadline(zeroTime)
        }

        // ReadSlice does not allocate new space for the data each request
        // ie. the returned slice is only valid until the next call to it
        line, err = client.Reader.ReadSlice('\n')
        if err != nil {
            if err == io.EOF {
                err = nil
            } else {
                err = fmt.Errorf("failed to read command - %s", err)
            }
            break
        }

        // trim the '\n'
        line = line[:len(line)-1]
        // optionally trim the '\r'
        if len(line) > 0 && line[len(line)-1] == '\r' {
            line = line[:len(line)-1]
        }
        params := bytes.Split(line, separatorBytes)

        if p.ctx.nsqd.getOpts().Verbose {
            p.ctx.nsqd.logf("PROTOCOL(V2): [%s] %s", client, params)
        }

        var response []byte
        response, err = p.Exec(client, params)
        if err != nil {
            ctx := ""
            if parentErr := err.(protocol.ChildErr).Parent(); parentErr != nil {
                ctx = " - " + parentErr.Error()
            }
            p.ctx.nsqd.logf("ERROR: [%s] - %s%s", client, err, ctx)

            sendErr := p.Send(client, frameTypeError, []byte(err.Error()))
            if sendErr != nil {
                p.ctx.nsqd.logf("ERROR: [%s] - %s%s", client, sendErr, ctx)
                break
            }

            // errors of type FatalClientErr should forceably close the connection
            if _, ok := err.(*protocol.FatalClientErr); ok {
                break
            }
            continue
        }

        if response != nil {
            err = p.Send(client, frameTypeResponse, response)
            if err != nil {
                err = fmt.Errorf("failed to send response - %s", err)
                break
            }
        }
    }

    p.ctx.nsqd.logf("PROTOCOL(V2): [%s] exiting ioloop", client)
    conn.Close()
    close(client.ExitChan)
    if client.Channel != nil {
        client.Channel.RemoveClient(client.ID)
    }

    return err
}

func (p *protocolV2) SendMessage(client *clientV2, msg *Message, buf *bytes.Buffer) error {
    if p.ctx.nsqd.getOpts().Verbose {
        p.ctx.nsqd.logf("PROTOCOL(V2): writing msg(%s) to client(%s) - %s",
            msg.ID, client, msg.Body)
    }

    buf.Reset()
    _, err := msg.WriteTo(buf)
    if err != nil {
        return err
    }

    err = p.Send(client, frameTypeMessage, buf.Bytes())
    if err != nil {
        return err
    }

    return nil
}

func (p *protocolV2) Send(client *clientV2, frameType int32, data []byte) error {
    client.writeLock.Lock()

    var zeroTime time.Time
    if client.HeartbeatInterval > 0 {
        client.SetWriteDeadline(time.Now().Add(client.HeartbeatInterval))
    } else {
        client.SetWriteDeadline(zeroTime)
    }

    _, err := protocol.SendFramedResponse(client.Writer, frameType, data)
    if err != nil {
        client.writeLock.Unlock()
        return err
    }

    if frameType != frameTypeMessage {
        err = client.Flush()
    }

    client.writeLock.Unlock()

    return err
}

func (p *protocolV2) Exec(client *clientV2, params [][]byte) ([]byte, error) {
    if bytes.Equal(params[0], []byte("IDENTIFY")) {
        return p.IDENTIFY(client, params)
    }
    err := enforceTLSPolicy(client, p, params[0])
    if err != nil {
        return nil, err
    }
    switch {
    case bytes.Equal(params[0], []byte("FIN")):
        return p.FIN(client, params)
    case bytes.Equal(params[0], []byte("RDY")):
        return p.RDY(client, params)
    case bytes.Equal(params[0], []byte("REQ")):
        return p.REQ(client, params)
    case bytes.Equal(params[0], []byte("PUB")):
        return p.PUB(client, params)
    case bytes.Equal(params[0], []byte("MPUB")):
        return p.MPUB(client, params)
    case bytes.Equal(params[0], []byte("DPUB")):
        return p.DPUB(client, params)
    case bytes.Equal(params[0], []byte("NOP")):
        return p.NOP(client, params)
    case bytes.Equal(params[0], []byte("TOUCH")):
        return p.TOUCH(client, params)
    case bytes.Equal(params[0], []byte("SUB")):
        return p.SUB(client, params)
    case bytes.Equal(params[0], []byte("CLS")):
        return p.CLS(client, params)
    case bytes.Equal(params[0], []byte("AUTH")):
        return p.AUTH(client, params)
    }
    return nil, protocol.NewFatalClientErr(nil, "E_INVALID", fmt.Sprintf("invalid command %s", params[0]))
}

func (p *protocolV2) messagePump(client *clientV2, startedChan chan bool) {
    var err error
    var buf bytes.Buffer
    var memoryMsgChan chan *Message
    var backendMsgChan chan []byte
    var subChannel *Channel
    // NOTE: `flusherChan` is used to bound message latency for
    // the pathological case of a channel on a low volume topic
    // with >1 clients having >1 RDY counts
    var flusherChan <-chan time.Time
    var sampleRate int32

    subEventChan := client.SubEventChan
    identifyEventChan := client.IdentifyEventChan
    outputBufferTicker := time.NewTicker(client.OutputBufferTimeout)
    heartbeatTicker := time.NewTicker(client.HeartbeatInterval)
    heartbeatChan := heartbeatTicker.C
    msgTimeout := client.MsgTimeout

    // v2 opportunistically buffers data to clients to reduce write system calls
    // we force flush in two cases:
    //    1. when the client is not ready to receive messages
    //    2. we're buffered and the channel has nothing left to send us
    //       (ie. we would block in this loop anyway)
    //
    flushed := true

    // signal to the goroutine that started the messagePump
    // that we've started up
    close(startedChan)

    for {
        if subChannel == nil || !client.IsReadyForMessages() {
            // the client is not ready to receive messages...
            memoryMsgChan = nil
            backendMsgChan = nil
            flusherChan = nil
            // force flush
            client.writeLock.Lock()
            err = client.Flush()
            client.writeLock.Unlock()
            if err != nil {
                goto exit
            }
            flushed = true
        } else if flushed {
            // last iteration we flushed...
            // do not select on the flusher ticker channel
            memoryMsgChan = subChannel.memoryMsgChan
            backendMsgChan = subChannel.backend.ReadChan()
            flusherChan = nil
        } else {
            // we're buffered (if there isn't any more data we should flush)...
            // select on the flusher ticker channel, too
            memoryMsgChan = subChannel.memoryMsgChan
            backendMsgChan = subChannel.backend.ReadChan()
            flusherChan = outputBufferTicker.C
        }

        select {
        case <-flusherChan:
            // if this case wins, we're either starved
            // or we won the race between other channels...
            // in either case, force flush
            client.writeLock.Lock()
            err = client.Flush()
            client.writeLock.Unlock()
            if err != nil {
                goto exit
            }
            flushed = true
        case <-client.ReadyStateChan:
        case subChannel = <-subEventChan:
            // you can't SUB anymore
            subEventChan = nil
        case identifyData := <-identifyEventChan:
            // you can't IDENTIFY anymore
            identifyEventChan = nil

            outputBufferTicker.Stop()
            if identifyData.OutputBufferTimeout > 0 {
                outputBufferTicker = time.NewTicker(identifyData.OutputBufferTimeout)
            }

            heartbeatTicker.Stop()
            heartbeatChan = nil
            if identifyData.HeartbeatInterval > 0 {
                heartbeatTicker = time.NewTicker(identifyData.HeartbeatInterval)
                heartbeatChan = heartbeatTicker.C
            }

            if identifyData.SampleRate > 0 {
                sampleRate = identifyData.SampleRate
            }

            msgTimeout = identifyData.MsgTimeout
        case <-heartbeatChan:
            err = p.Send(client, frameTypeResponse, heartbeatBytes)
            if err != nil {
                goto exit
            }
        case b := <-backendMsgChan:
            if sampleRate > 0 && rand.Int31n(100) > sampleRate {
                continue
            }

            msg, err := decodeMessage(b)
            if err != nil {
                p.ctx.nsqd.logf("ERROR: failed to decode message - %s", err)
                continue
            }
            msg.Attempts++

            subChannel.StartInFlightTimeout(msg, client.ID, msgTimeout)
            client.SendingMessage()
            err = p.SendMessage(client, msg, &buf)
            if err != nil {
                goto exit
            }
            flushed = false
        case msg := <-memoryMsgChan:
            if sampleRate > 0 && rand.Int31n(100) > sampleRate {
                continue
            }
            msg.Attempts++

            subChannel.StartInFlightTimeout(msg, client.ID, msgTimeout)
            client.SendingMessage()
            err = p.SendMessage(client, msg, &buf)
            if err != nil {
                goto exit
            }
            flushed = false
        case <-client.ExitChan:
            goto exit
        }
    }

exit:
    p.ctx.nsqd.logf("PROTOCOL(V2): [%s] exiting messagePump", client)
    heartbeatTicker.Stop()
    outputBufferTicker.Stop()
    if err != nil {
        p.ctx.nsqd.logf("PROTOCOL(V2): [%s] messagePump error - %s", client, err)
    }
}

func (p *protocolV2) IDENTIFY(client *clientV2, params [][]byte) ([]byte, error) {
    var err error

    if atomic.LoadInt32(&client.State) != stateInit {
        return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot IDENTIFY in current state")
    }

    bodyLen, err := readLen(client.Reader, client.lenSlice)
    if err != nil {
        return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "IDENTIFY failed to read body size")
    }

    if int64(bodyLen) > p.ctx.nsqd.getOpts().MaxBodySize {
        return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY",
            fmt.Sprintf("IDENTIFY body too big %d > %d", bodyLen, p.ctx.nsqd.getOpts().MaxBodySize))
    }

    if bodyLen <= 0 {
        return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY",
            fmt.Sprintf("IDENTIFY invalid body size %d", bodyLen))
    }

    body := make([]byte, bodyLen)
    _, err = io.ReadFull(client.Reader, body)
    if err != nil {
        return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "IDENTIFY failed to read body")
    }

    // body is a json structure with producer information
    var identifyData identifyDataV2
    err = json.Unmarshal(body, &identifyData)
    if err != nil {
        return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "IDENTIFY failed to decode JSON body")
    }

    if p.ctx.nsqd.getOpts().Verbose {
        p.ctx.nsqd.logf("PROTOCOL(V2): [%s] %+v", client, identifyData)
    }

    err = client.Identify(identifyData)
    if err != nil {
        return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "IDENTIFY "+err.Error())
    }

    // bail out early if we're not negotiating features
    if !identifyData.FeatureNegotiation {
        return okBytes, nil
    }

    tlsv1 := p.ctx.nsqd.tlsConfig != nil && identifyData.TLSv1
    deflate := p.ctx.nsqd.getOpts().DeflateEnabled && identifyData.Deflate
    deflateLevel := 0
    if deflate {
        if identifyData.DeflateLevel <= 0 {
            deflateLevel = 6
        }
        deflateLevel = int(math.Min(float64(deflateLevel), float64(p.ctx.nsqd.getOpts().MaxDeflateLevel)))
    }
    snappy := p.ctx.nsqd.getOpts().SnappyEnabled && identifyData.Snappy

    if deflate && snappy {
        return nil, protocol.NewFatalClientErr(nil, "E_IDENTIFY_FAILED", "cannot enable both deflate and snappy compression")
    }

    resp, err := json.Marshal(struct {
        MaxRdyCount         int64  `json:"max_rdy_count"`
        Version             string `json:"version"`
        MaxMsgTimeout       int64  `json:"max_msg_timeout"`
        MsgTimeout          int64  `json:"msg_timeout"`
        TLSv1               bool   `json:"tls_v1"`
        Deflate             bool   `json:"deflate"`
        DeflateLevel        int    `json:"deflate_level"`
        MaxDeflateLevel     int    `json:"max_deflate_level"`
        Snappy              bool   `json:"snappy"`
        SampleRate          int32  `json:"sample_rate"`
        AuthRequired        bool   `json:"auth_required"`
        OutputBufferSize    int    `json:"output_buffer_size"`
        OutputBufferTimeout int64  `json:"output_buffer_timeout"`
    }{
        MaxRdyCount:         p.ctx.nsqd.getOpts().MaxRdyCount,
        Version:             version.Binary,
        MaxMsgTimeout:       int64(p.ctx.nsqd.getOpts().MaxMsgTimeout / time.Millisecond),
        MsgTimeout:          int64(client.MsgTimeout / time.Millisecond),
        TLSv1:               tlsv1,
        Deflate:             deflate,
        DeflateLevel:        deflateLevel,
        MaxDeflateLevel:     p.ctx.nsqd.getOpts().MaxDeflateLevel,
        Snappy:              snappy,
        SampleRate:          client.SampleRate,
        AuthRequired:        p.ctx.nsqd.IsAuthEnabled(),
        OutputBufferSize:    client.OutputBufferSize,
        OutputBufferTimeout: int64(client.OutputBufferTimeout / time.Millisecond),
    })
    if err != nil {
        return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
    }

    err = p.Send(client, frameTypeResponse, resp)
    if err != nil {
        return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
    }

    if tlsv1 {
        p.ctx.nsqd.logf("PROTOCOL(V2): [%s] upgrading connection to TLS", client)
        err = client.UpgradeTLS()
        if err != nil {
            return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
        }

        err = p.Send(client, frameTypeResponse, okBytes)
        if err != nil {
            return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
        }
    }

    if snappy {
        p.ctx.nsqd.logf("PROTOCOL(V2): [%s] upgrading connection to snappy", client)
        err = client.UpgradeSnappy()
        if err != nil {
            return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
        }

        err = p.Send(client, frameTypeResponse, okBytes)
        if err != nil {
            return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
        }
    }

    if deflate {
        p.ctx.nsqd.logf("PROTOCOL(V2): [%s] upgrading connection to deflate", client)
        err = client.UpgradeDeflate(deflateLevel)
        if err != nil {
            return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
        }

        err = p.Send(client, frameTypeResponse, okBytes)
        if err != nil {
            return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
        }
    }

    return nil, nil
}

func (p *protocolV2) AUTH(client *clientV2, params [][]byte) ([]byte, error) {
    if atomic.LoadInt32(&client.State) != stateInit {
        return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot AUTH in current state")
    }

    if len(params) != 1 {
        return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "AUTH invalid number of parameters")
    }

    bodyLen, err := readLen(client.Reader, client.lenSlice)
    if err != nil {
        return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "AUTH failed to read body size")
    }

    if int64(bodyLen) > p.ctx.nsqd.getOpts().MaxBodySize {
        return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY",
            fmt.Sprintf("AUTH body too big %d > %d", bodyLen, p.ctx.nsqd.getOpts().MaxBodySize))
    }

    if bodyLen <= 0 {
        return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY",
            fmt.Sprintf("AUTH invalid body size %d", bodyLen))
    }

    body := make([]byte, bodyLen)
    _, err = io.ReadFull(client.Reader, body)
    if err != nil {
        return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "AUTH failed to read body")
    }

    if client.HasAuthorizations() {
        return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "AUTH Already set")
    }

    if !client.ctx.nsqd.IsAuthEnabled() {
        return nil, protocol.NewFatalClientErr(err, "E_AUTH_DISABLED", "AUTH Disabled")
    }

    if err := client.Auth(string(body)); err != nil {
        // we don't want to leak errors contacting the auth server to untrusted clients
        p.ctx.nsqd.logf("PROTOCOL(V2): [%s] Auth Failed %s", client, err)
        return nil, protocol.NewFatalClientErr(err, "E_AUTH_FAILED", "AUTH failed")
    }

    if !client.HasAuthorizations() {
        return nil, protocol.NewFatalClientErr(nil, "E_UNAUTHORIZED", "AUTH No authorizations found")
    }

    resp, err := json.Marshal(struct {
        Identity        string `json:"identity"`
        IdentityURL     string `json:"identity_url"`
        PermissionCount int    `json:"permission_count"`
    }{
        Identity:        client.AuthState.Identity,
        IdentityURL:     client.AuthState.IdentityURL,
        PermissionCount: len(client.AuthState.Authorizations),
    })
    if err != nil {
        return nil, protocol.NewFatalClientErr(err, "E_AUTH_ERROR", "AUTH error "+err.Error())
    }

    err = p.Send(client, frameTypeResponse, resp)
    if err != nil {
        return nil, protocol.NewFatalClientErr(err, "E_AUTH_ERROR", "AUTH error "+err.Error())
    }

    return nil, nil

}

func (p *protocolV2) CheckAuth(client *clientV2, cmd, topicName, channelName string) error {
    // if auth is enabled, the client must have authorized already
    // compare topic/channel against cached authorization data (refetching if expired)
    if client.ctx.nsqd.IsAuthEnabled() {
        if !client.HasAuthorizations() {
            return protocol.NewFatalClientErr(nil, "E_AUTH_FIRST",
                fmt.Sprintf("AUTH required before %s", cmd))
        }
        ok, err := client.IsAuthorized(topicName, channelName)
        if err != nil {
            // we don't want to leak errors contacting the auth server to untrusted clients
            p.ctx.nsqd.logf("PROTOCOL(V2): [%s] Auth Failed %s", client, err)
            return protocol.NewFatalClientErr(nil, "E_AUTH_FAILED", "AUTH failed")
        }
        if !ok {
            return protocol.NewFatalClientErr(nil, "E_UNAUTHORIZED",
                fmt.Sprintf("AUTH failed for %s on %q %q", cmd, topicName, channelName))
        }
    }
    return nil
}

func (p *protocolV2) SUB(client *clientV2, params [][]byte) ([]byte, error) {
    if atomic.LoadInt32(&client.State) != stateInit {
        return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot SUB in current state")
    }

    if client.HeartbeatInterval <= 0 {
        return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot SUB with heartbeats disabled")
    }

    if len(params) < 3 {
        return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "SUB insufficient number of parameters")
    }

    topicName := string(params[1])
    if !protocol.IsValidTopicName(topicName) {
        return nil, protocol.NewFatalClientErr(nil, "E_BAD_TOPIC",
            fmt.Sprintf("SUB topic name %q is not valid", topicName))
    }

    channelName := string(params[2])
    if !protocol.IsValidChannelName(channelName) {
        return nil, protocol.NewFatalClientErr(nil, "E_BAD_CHANNEL",
            fmt.Sprintf("SUB channel name %q is not valid", channelName))
    }

    if err := p.CheckAuth(client, "SUB", topicName, channelName); err != nil {
        return nil, err
    }

    topic := p.ctx.nsqd.GetTopic(topicName)
    channel := topic.GetChannel(channelName)
    channel.AddClient(client.ID, client)

    atomic.StoreInt32(&client.State, stateSubscribed)
    client.Channel = channel
    // update message pump
    client.SubEventChan <- channel

    return okBytes, nil
}

func (p *protocolV2) RDY(client *clientV2, params [][]byte) ([]byte, error) {
    state := atomic.LoadInt32(&client.State)

    if state == stateClosing {
        // just ignore ready changes on a closing channel
        p.ctx.nsqd.logf(
            "PROTOCOL(V2): [%s] ignoring RDY after CLS in state ClientStateV2Closing",
            client)
        return nil, nil
    }

    if state != stateSubscribed {
        return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot RDY in current state")
    }

    count := int64(1)
    if len(params) > 1 {
        b10, err := protocol.ByteToBase10(params[1])
        if err != nil {
            return nil, protocol.NewFatalClientErr(err, "E_INVALID",
                fmt.Sprintf("RDY could not parse count %s", params[1]))
        }
        count = int64(b10)
    }

    if count < 0 || count > p.ctx.nsqd.getOpts().MaxRdyCount {
        // this needs to be a fatal error otherwise clients would have
        // inconsistent state
        return nil, protocol.NewFatalClientErr(nil, "E_INVALID",
            fmt.Sprintf("RDY count %d out of range 0-%d", count, p.ctx.nsqd.getOpts().MaxRdyCount))
    }

    client.SetReadyCount(count)

    return nil, nil
}

func (p *protocolV2) FIN(client *clientV2, params [][]byte) ([]byte, error) {
    state := atomic.LoadInt32(&client.State)
    if state != stateSubscribed && state != stateClosing {
        return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot FIN in current state")
    }

    if len(params) < 2 {
        return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "FIN insufficient number of params")
    }

    id, err := getMessageID(params[1])
    if err != nil {
        return nil, protocol.NewFatalClientErr(nil, "E_INVALID", err.Error())
    }

    err = client.Channel.FinishMessage(client.ID, *id)
    if err != nil {
        return nil, protocol.NewClientErr(err, "E_FIN_FAILED",
            fmt.Sprintf("FIN %s failed %s", *id, err.Error()))
    }

    client.FinishedMessage()

    return nil, nil
}

func (p *protocolV2) REQ(client *clientV2, params [][]byte) ([]byte, error) {
    state := atomic.LoadInt32(&client.State)
    if state != stateSubscribed && state != stateClosing {
        return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot REQ in current state")
    }

    if len(params) < 3 {
        return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "REQ insufficient number of params")
    }

    id, err := getMessageID(params[1])
    if err != nil {
        return nil, protocol.NewFatalClientErr(nil, "E_INVALID", err.Error())
    }

    timeoutMs, err := protocol.ByteToBase10(params[2])
    if err != nil {
        return nil, protocol.NewFatalClientErr(err, "E_INVALID",
            fmt.Sprintf("REQ could not parse timeout %s", params[2]))
    }
    timeoutDuration := time.Duration(timeoutMs) * time.Millisecond

    if timeoutDuration < 0 || timeoutDuration > p.ctx.nsqd.getOpts().MaxReqTimeout {
        return nil, protocol.NewFatalClientErr(nil, "E_INVALID",
            fmt.Sprintf("REQ timeout %d out of range 0-%d", timeoutDuration, p.ctx.nsqd.getOpts().MaxReqTimeout))
    }

    err = client.Channel.RequeueMessage(client.ID, *id, timeoutDuration)
    if err != nil {
        return nil, protocol.NewClientErr(err, "E_REQ_FAILED",
            fmt.Sprintf("REQ %s failed %s", *id, err.Error()))
    }

    client.RequeuedMessage()

    return nil, nil
}

func (p *protocolV2) CLS(client *clientV2, params [][]byte) ([]byte, error) {
    if atomic.LoadInt32(&client.State) != stateSubscribed {
        return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot CLS in current state")
    }

    client.StartClose()

    return []byte("CLOSE_WAIT"), nil
}

func (p *protocolV2) NOP(client *clientV2, params [][]byte) ([]byte, error) {
    return nil, nil
}

func (p *protocolV2) PUB(client *clientV2, params [][]byte) ([]byte, error) {
    var err error

    if len(params) < 2 {
        return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "PUB insufficient number of parameters")
    }

    topicName := string(params[1])
    if !protocol.IsValidTopicName(topicName) {
        return nil, protocol.NewFatalClientErr(nil, "E_BAD_TOPIC",
            fmt.Sprintf("PUB topic name %q is not valid", topicName))
    }

    bodyLen, err := readLen(client.Reader, client.lenSlice)
    if err != nil {
        return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE", "PUB failed to read message body size")
    }

    if bodyLen <= 0 {
        return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE",
            fmt.Sprintf("PUB invalid message body size %d", bodyLen))
    }

    if int64(bodyLen) > p.ctx.nsqd.getOpts().MaxMsgSize {
        return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE",
            fmt.Sprintf("PUB message too big %d > %d", bodyLen, p.ctx.nsqd.getOpts().MaxMsgSize))
    }

    messageBody := make([]byte, bodyLen)
    _, err = io.ReadFull(client.Reader, messageBody)
    if err != nil {
        return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE", "PUB failed to read message body")
    }

    if err := p.CheckAuth(client, "PUB", topicName, ""); err != nil {
        return nil, err
    }

    topic := p.ctx.nsqd.GetTopic(topicName)
    msg := NewMessage(<-p.ctx.nsqd.idChan, messageBody)
    err = topic.PutMessage(msg)
    if err != nil {
        return nil, protocol.NewFatalClientErr(err, "E_PUB_FAILED", "PUB failed "+err.Error())
    }

    return okBytes, nil
}

func (p *protocolV2) MPUB(client *clientV2, params [][]byte) ([]byte, error) {
    var err error

    if len(params) < 2 {
        return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "MPUB insufficient number of parameters")
    }

    topicName := string(params[1])
    if !protocol.IsValidTopicName(topicName) {
        return nil, protocol.NewFatalClientErr(nil, "E_BAD_TOPIC",
            fmt.Sprintf("E_BAD_TOPIC MPUB topic name %q is not valid", topicName))
    }

    bodyLen, err := readLen(client.Reader, client.lenSlice)
    if err != nil {
        return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "MPUB failed to read body size")
    }

    if bodyLen <= 0 {
        return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY",
            fmt.Sprintf("MPUB invalid body size %d", bodyLen))
    }

    if int64(bodyLen) > p.ctx.nsqd.getOpts().MaxBodySize {
        return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY",
            fmt.Sprintf("MPUB body too big %d > %d", bodyLen, p.ctx.nsqd.getOpts().MaxBodySize))
    }

    messages, err := readMPUB(client.Reader, client.lenSlice, p.ctx.nsqd.idChan,
        p.ctx.nsqd.getOpts().MaxMsgSize)
    if err != nil {
        return nil, err
    }

    if err := p.CheckAuth(client, "MPUB", topicName, ""); err != nil {
        return nil, err
    }

    topic := p.ctx.nsqd.GetTopic(topicName)

    // if we've made it this far we've validated all the input,
    // the only possible error is that the topic is exiting during
    // this next call (and no messages will be queued in that case)
    err = topic.PutMessages(messages)
    if err != nil {
        return nil, protocol.NewFatalClientErr(err, "E_MPUB_FAILED", "MPUB failed "+err.Error())
    }

    return okBytes, nil
}

func (p *protocolV2) DPUB(client *clientV2, params [][]byte) ([]byte, error) {
    var err error

    if len(params) < 3 {
        return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "DPUB insufficient number of parameters")
    }

    topicName := string(params[1])
    if !protocol.IsValidTopicName(topicName) {
        return nil, protocol.NewFatalClientErr(nil, "E_BAD_TOPIC",
            fmt.Sprintf("DPUB topic name %q is not valid", topicName))
    }

    timeoutMs, err := protocol.ByteToBase10(params[2])
    if err != nil {
        return nil, protocol.NewFatalClientErr(err, "E_INVALID",
            fmt.Sprintf("DPUB could not parse timeout %s", params[2]))
    }
    timeoutDuration := time.Duration(timeoutMs) * time.Millisecond

    if timeoutDuration < 0 || timeoutDuration > p.ctx.nsqd.getOpts().MaxReqTimeout {
        return nil, protocol.NewFatalClientErr(nil, "E_INVALID",
            fmt.Sprintf("DPUB timeout %d out of range 0-%d",
                timeoutMs, p.ctx.nsqd.getOpts().MaxReqTimeout/time.Millisecond))
    }

    bodyLen, err := readLen(client.Reader, client.lenSlice)
    if err != nil {
        return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE", "DPUB failed to read message body size")
    }

    if bodyLen <= 0 {
        return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE",
            fmt.Sprintf("DPUB invalid message body size %d", bodyLen))
    }

    if int64(bodyLen) > p.ctx.nsqd.getOpts().MaxMsgSize {
        return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE",
            fmt.Sprintf("DPUB message too big %d > %d", bodyLen, p.ctx.nsqd.getOpts().MaxMsgSize))
    }

    messageBody := make([]byte, bodyLen)
    _, err = io.ReadFull(client.Reader, messageBody)
    if err != nil {
        return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE", "DPUB failed to read message body")
    }

    if err := p.CheckAuth(client, "DPUB", topicName, ""); err != nil {
        return nil, err
    }

    topic := p.ctx.nsqd.GetTopic(topicName)
    msg := NewMessage(<-p.ctx.nsqd.idChan, messageBody)
    msg.deferred = timeoutDuration
    err = topic.PutMessage(msg)
    if err != nil {
        return nil, protocol.NewFatalClientErr(err, "E_DPUB_FAILED", "DPUB failed "+err.Error())
    }

    return okBytes, nil
}

func (p *protocolV2) TOUCH(client *clientV2, params [][]byte) ([]byte, error) {
    state := atomic.LoadInt32(&client.State)
    if state != stateSubscribed && state != stateClosing {
        return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot TOUCH in current state")
    }

    if len(params) < 2 {
        return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "TOUCH insufficient number of params")
    }

    id, err := getMessageID(params[1])
    if err != nil {
        return nil, protocol.NewFatalClientErr(nil, "E_INVALID", err.Error())
    }

    client.writeLock.RLock()
    msgTimeout := client.MsgTimeout
    client.writeLock.RUnlock()
    err = client.Channel.TouchMessage(client.ID, *id, msgTimeout)
    if err != nil {
        return nil, protocol.NewClientErr(err, "E_TOUCH_FAILED",
            fmt.Sprintf("TOUCH %s failed %s", *id, err.Error()))
    }

    return nil, nil
}

func readMPUB(r io.Reader, tmp []byte, idChan chan MessageID, maxMessageSize int64) ([]*Message, error) {
    numMessages, err := readLen(r, tmp)
    if err != nil {
        return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "MPUB failed to read message count")
    }

    if numMessages <= 0 {
        return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY",
            fmt.Sprintf("MPUB invalid message count %d", numMessages))
    }

    messages := make([]*Message, 0, numMessages)
    for i := int32(0); i < numMessages; i++ {
        messageSize, err := readLen(r, tmp)
        if err != nil {
            return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE",
                fmt.Sprintf("MPUB failed to read message(%d) body size", i))
        }

        if messageSize <= 0 {
            return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE",
                fmt.Sprintf("MPUB invalid message(%d) body size %d", i, messageSize))
        }

        if int64(messageSize) > maxMessageSize {
            return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE",
                fmt.Sprintf("MPUB message too big %d > %d", messageSize, maxMessageSize))
        }

        msgBody := make([]byte, messageSize)
        _, err = io.ReadFull(r, msgBody)
        if err != nil {
            return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE", "MPUB failed to read message body")
        }

        messages = append(messages, NewMessage(<-idChan, msgBody))
    }

    return messages, nil
}

// validate and cast the bytes on the wire to a message ID
func getMessageID(p []byte) (*MessageID, error) {
    if len(p) != MsgIDLength {
        return nil, errors.New("Invalid Message ID")
    }
    return (*MessageID)(unsafe.Pointer(&p[0])), nil
}

func readLen(r io.Reader, tmp []byte) (int32, error) {
    _, err := io.ReadFull(r, tmp)
    if err != nil {
        return 0, err
    }
    return int32(binary.BigEndian.Uint32(tmp)), nil
}

func enforceTLSPolicy(client *clientV2, p *protocolV2, command []byte) error {
    if p.ctx.nsqd.getOpts().TLSRequired != TLSNotRequired && atomic.LoadInt32(&client.TLS) != 1 {
        return protocol.NewFatalClientErr(nil, "E_INVALID",
            fmt.Sprintf("cannot %s in current state (TLS required)", command))
    }
    return nil
}

上一篇:【原创】NIO框架入门(三):iOS与MINA2、Netty4的跨平台UDP双向通信实战


下一篇:(转)Maven学习-处理资源文件