golang-jwt / jwt

Go implementation of JSON Web Tokens (JWT).
https://golang-jwt.github.io/jwt/
MIT License
6.98k stars 335 forks source link

I've mad a small library to help with JWT #355

Closed ivanjaros closed 11 months ago

ivanjaros commented 11 months ago

This library is lacking in "batteries included" department. A lot of manual work has to be done to make it work. So I made a small library to help with this. Consider implementing whatever parts you think are suitable.

package jwtool

// contains custom imports:
// random is just random string generator
// protoj is just json encoder

type Validatable interface {
    Validate() error
}

const TokenIdLength = 15

// returns a random TokenIdLength characters long string.
// the string is not crypto-random and it consists of a-z, A-Z and 0-9 charset without any special characters.
func MakeTokenId() string {
    return random.StandardString(TokenIdLength)
}

// supports secret as ecdsa.PrivateKey, string or []byte
func Sign(secret any, claims jwt.Claims) (string, error) {
    var signing jwt.SigningMethod

    switch secret.(type) {
    case ecdsa.PrivateKey:
        switch secret.(ecdsa.PrivateKey).Curve.Params().BitSize {
        case 256:
            signing = jwt.SigningMethodES256
        case 384:
            signing = jwt.SigningMethodES384
        case 521:
            signing = jwt.SigningMethodES512
        default:
            return "", errors.New("unknown elliptic curve")
        }
        key := secret.(ecdsa.PrivateKey)
        return jwt.NewWithClaims(signing, claims).SignedString(&key)
    case string:
        signing = jwt.SigningMethodHS384
        key := []byte(secret.(string))
        return jwt.NewWithClaims(signing, claims).SignedString(key)
    case []byte:
        signing = jwt.SigningMethodHS384
        key := secret.([]byte)
        return jwt.NewWithClaims(signing, claims).SignedString(key)
    default:
        return "", errors.New("unknown secret type")
    }
}

// supports secret as ecdsa.PublicKey, string or []byte
func Parse(secret any, str string, claims jwt.Claims) (*jwt.Token, error) {
    var method string
    var tok *jwt.Token
    var err error

    switch secret.(type) {
    case ecdsa.PublicKey:
        switch secret.(ecdsa.PublicKey).Curve.Params().BitSize {
        case 256:
            method = jwt.SigningMethodES256.Alg()
        case 384:
            method = jwt.SigningMethodES384.Alg()
        case 521:
            method = jwt.SigningMethodES512.Alg()
        default:
            return nil, errors.New("unknown elliptic curve")
        }
        key := secret.(ecdsa.PublicKey)
        tok, err = jwt.ParseWithClaims(str, claims, func(token *jwt.Token) (any, error) { return &key, nil }, jwt.WithValidMethods([]string{method}))
    case string:
        method = jwt.SigningMethodHS384.Alg()
        key := []byte(secret.(string))
        tok, err = jwt.ParseWithClaims(str, claims, func(token *jwt.Token) (any, error) { return key, nil }, jwt.WithValidMethods([]string{method}))
    case []byte:
        method = jwt.SigningMethodHS384.Alg()
        tok, err = jwt.ParseWithClaims(str, claims, func(token *jwt.Token) (any, error) { return secret, nil }, jwt.WithValidMethods([]string{method}))
    default:
        return nil, errors.New("unknown secret type")
    }

    if err != nil {
        return nil, err
    }

    if tok.Valid == false {
        return nil, errors.New("invalid token")
    }

    return tok, nil
}

func MakeClaims(options ...MakeOption) jwt.RegisteredClaims {
    now := time.Now()
    var c jwt.RegisteredClaims
    defaults := []MakeOption{
        WithExpiration(now.Add(time.Minute * 15)),
        WithId(MakeTokenId()),
        WithIssuedAt(now),
        WithNotBefore(now.Add(time.Minute * -1)), // compensate possible differences between client and server
    }
    for k := range defaults {
        defaults[k](&c)
    }
    for k := range options {
        options[k](&c)
    }
    return c
}

type MakeOption func(*jwt.RegisteredClaims)

func WithAudience(aud string) MakeOption {
    return func(c *jwt.RegisteredClaims) {
        c.Audience = append(c.Audience, aud)
    }
}

func WithExpiration(exp time.Time) MakeOption {
    return func(c *jwt.RegisteredClaims) {
        c.ExpiresAt = jwt.NewNumericDate(exp)
    }
}

func WithLifespan(l time.Duration) MakeOption {
    return WithExpiration(time.Now().Add(l))
}

func WithId(id string) MakeOption {
    return func(c *jwt.RegisteredClaims) {
        c.ID = id
    }
}

func WithIssuedAt(ia time.Time) MakeOption {
    return func(c *jwt.RegisteredClaims) {
        c.IssuedAt = jwt.NewNumericDate(ia)
    }
}

func WithIssuer(iss string) MakeOption {
    return func(c *jwt.RegisteredClaims) {
        c.Issuer = iss
    }
}

func WithNotBefore(nbf time.Time) MakeOption {
    return func(c *jwt.RegisteredClaims) {
        c.NotBefore = jwt.NewNumericDate(nbf)
    }
}

func WithSubject(sub string) MakeOption {
    return func(c *jwt.RegisteredClaims) {
        c.Subject = sub
    }
}

type UniversalClaims struct {
    jwt.RegisteredClaims
    Payload json.RawMessage `json:"pld"`
}

func NewToken(payload any, secret any, options ...MakeOption) (string, error) {
    typed, ok := payload.(Validatable)
    if ok {
        if err := typed.Validate(); err != nil {
            return "", err
        }
    }
    data, err := protoj.MarshalDenseData(payload)
    if err != nil {
        return "", err
    }
    c := UniversalClaims{RegisteredClaims: MakeClaims(options...), Payload: data}
    return Sign(secret, c)
}

func ParseToken(secret any, str string, payload any, validators ...Validator) error {
    var c UniversalClaims
    if _, err := Parse(secret, str, &c); err != nil {
        return err
    }
    for k := range validators {
        if err := validators[k](c.RegisteredClaims); err != nil {
            return err
        }
    }
    if payload != nil {
        if err := protoj.UnmarshalData(c.Payload, &payload); err != nil {
            return err
        }
        typed, ok := payload.(Validatable)
        if ok {
            return typed.Validate()
        }
    }
    return nil
}

type Validator func(jwt.RegisteredClaims) error

func VerifySubject(sub string) Validator {
    return func(c jwt.RegisteredClaims) error {
        if c.Subject != sub {
            return errors.New("expected subject '"+sub"', got '"+c.Subject+"'")
        }
        return nil
    }
}

func VerifyIssuer(iss ...string) Validator {
    return func(c jwt.RegisteredClaims) error {
        for k := range iss {
            if c.VerifyIssuer(iss[k], false) {
                return nil
            }
        }
        return errors.New("no issuer match")
    }
}

func VerifyAudience(aud ...string) Validator {
    return func(c jwt.RegisteredClaims) error {
        for k := range aud {
            if c.VerifyAudience(aud[k], false) {
                return nil
            }
        }
        return errors.New("no audience match")
    }
}

Usage example:

func NewFooToken(secret ecdsa.PrivateKey, issuer string, p FooPayload) (string, error) {
    if issuer != "Foo" {
        return "", errors.New("invalid issuer")
    }

    return jwtool.NewToken(p, secret,
        jwtool.WithIssuer(issuer),
        jwtool.WithSubject("foo_bar"),
        jwtool.WithAudience("baz"),
        jwtool.WithLifespan(time.Minute*5), // instead of default 15 min
    )
}

func ParseFooToken(secret ecdsa.PublicKey, str string) (FooPayload, error) {
    var p FooPayload
    err := jwtool.ParseToken(secret, str, &p, jwtool.VerifyIssuer("Foo"), jwtool.VerifySubject("foo_bar"), jwtool.VerifyAudience("baz"))
    return p, err
}