wind-c / comqtt

A lightweight, high-performance go mqtt server(v3.0|v3.1.1|v5.0) supporting distributed cluster
MIT License
869 stars 50 forks source link

Authenticating against a MongoDB collection with Bcrypt-hashed passwords #83

Closed devilslane-com closed 4 months ago

devilslane-com commented 4 months ago

Firstly - thank you for this amazing library.

To save people the trouble if you're looking to use a DB other than Redis or PostgreSQL, here's a quick guide to doing authentication. In my case, i needed to access users in Mongo which had been stored by Laravel using Bcrypt. It wasn't amazingly clear when parsing the examples; particularly the use of packets.

This is obviously debug code (printing out user/password etc), and there's more involved. But it's a rough overview.

package main

import (
    "context"
    "crypto/tls"
    "crypto/x509"
    "flag"
    "fmt"
    "io/ioutil"
    "log"
    "os"
    "os/signal"
    "strconv"
    "syscall"
    "time"

    "github.com/wind-c/comqtt/v2/mqtt"
    "github.com/wind-c/comqtt/v2/mqtt/hooks/auth"
    "github.com/wind-c/comqtt/v2/mqtt/hooks/debug"
    "github.com/wind-c/comqtt/v2/mqtt/listeners"
    "github.com/wind-c/comqtt/v2/mqtt/packets"
    "golang.org/x/crypto/bcrypt"

    "go.mongodb.org/mongo-driver/bson"
    "go.mongodb.org/mongo-driver/mongo"
    "go.mongodb.org/mongo-driver/mongo/options"
)

type Auth struct {
    mqtt.HookBase 
    mongo_client   *mongo.Client
    database       string
    collection string
}

func auth_hook(mongo_client *mongo.Client) *Auth {
    return &Auth{
        mongo_client: mongo_client,
        database:   env("DB_DATABASE", "mycooldb"),
        collection: env("DB_AUTH_COLLECTION", "users"),
    }
}

func (hook *Auth) ID() string {
    return "mongodb-auth"
}

func (hook *Auth) Provides(b byte) bool {

    return b == mqtt.OnConnectAuthenticate
}

func (hook *Auth) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool {

    log.Printf("Client trying to connect with username: %s, password: %v", string(pk.Connect.Username), string(pk.Connect.Password));

    var User bson.M

    err := users.FindOne(ctx, bson.M{env("DB_AUTH_USERNAME_FIELD", "username"): string(pk.Connect.Username)}).Decode(&User)

    if err != nil {
        log.Printf("Authentication failed for user %s: %v", string(pk.Connect.Username), err)

        return false // Authentication failed
    }

    password, ok := User[env("DB_AUTH_PASSWORD_FIELD", "password")].(string)

    if !ok {

        log.Printf("Invalid password format for user %s", string(pk.Connect.Username))

        return false
    }

    // Compare the provided password with the stored hashed password
    err = bcrypt.CompareHashAndPassword([]byte(password), []byte(pk.Connect.Password))

    if err != nil {

        log.Printf("Bcrypt hash comparison failed. Invalid password for user %s", string(pk.Connect.Username))

        return false // Authentication failed
    }

    log.Printf("User %s authenticated successfully.", string(pk.Connect.Username))
    return true
}

func env(key, def string) string {
    value := os.Getenv(key)

    if len(value) == 0 {
        return def
    }

    return value
}

func main() {
    done := make(chan bool, 1)
    sigs := make(chan os.Signal, 1)

    // Register the signals to catch
    signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)

    go func() {
        sig := <-sigs
        log.Printf("Received signal: %s", sig)
        done <- true
    }()

       // Other stuff 

    mongo_uri_flag := flag.String("mongo_uri", env("MONGO_URI", "mongodb://user:pass@localhost:27017/?authSource=admin"), "MongoDB URI")

        mongo_client := mongo_client(*mongo_uri_flag)

        // vars declared elsewhere
        mqtt_server := mqtt.New(&mqtt.Options{
        Capabilities: &mqtt.Capabilities{
            MaximumSessionExpiryInterval: uint32(max_session_expiry_interval),
            Compatibilities: mqtt.Compatibilities{
                ObscureNotAuthorized: obscure_not_authorized,
            },
        },
        ClientNetWriteBufferSize: client_net_write_buffer_size,
        ClientNetReadBufferSize:  client_net_read_buffer_size,
        SysTopicResendInterval:   int64(sys_topic_resend_interval),
    })

    debug_hook := new(debug.Hook)

    err := mqtt_server.AddHook(debug_hook, &debug.Options{})

        if err != nil {
            log.Fatalf("Failed to add debug hook: %v", err)
        }

        err = mqtt_server.AddHook(auth_hook(mongo_client), nil)

        if err != nil {
            log.Fatalf("Failed to add MongoDB authentication hook: %v", err)
        }

    // Start the MQTT server
          go func() {
              if err := mqtt_server.Serve(); err != nil {
                  log.Fatalf("Buzz MQTT server failed to start: %v", err)
              }
          }()

         // loads of other stuff

        <-done
        log.Println("Shutting down the server...")

        // Clean up resources here if necessary
        if err := mqtt_server.Close(); err != nil {
            log.Printf("Error shutting down server: %v", err)
        } else {
            log.Println("Server shutdown successfully.")
        }
}