danielgtaylor / huma

Huma REST/HTTP API Framework for Golang with OpenAPI 3.1
https://huma.rocks/
MIT License
1.88k stars 138 forks source link

Best practice for middleware testing #541

Open benkawecki opened 1 month ago

benkawecki commented 1 month ago

Description

When recently writing tests for a huma router-agnostic middleware I didn't feel that there was an obvious pattern. I wasn't sure if this was an issue on my end, so I'd like explore the best way to test huma router-agnostic middleware and either contribute a documentation change or an update to the humatest package to make it easier for people to develop Huma router-agnostic middleware.

Example

Lets say I have some authentication middleware that does the following:

File under test...

package server

import (
    "net/http"

    "github.com/danielgtaylor/huma/v2"
)

var (
    claimsContextKey string = "claims"
)

type Claims struct {
    Claims []string
}

func authMiddleware(api huma.API) func(huma.Context, func(huma.Context)) {
    return func(ctx huma.Context, next func(huma.Context)) {
        authHeader := ctx.Header("Authorization")
        if authHeader == "" {
            huma.WriteErr(api, ctx, http.StatusUnauthorized, "no authorization header present")
            return
        }
        c := Claims{
            Claims: []string{"foo", "bar"},
        }
        authorizedCtx := huma.WithValue(ctx, claimsContextKey, c)
        next(authorizedCtx)

    }
}

Test file...

package server

import (
    "context"
    "net/http"
    "testing"

    "github.com/danielgtaylor/huma/v2"
    "github.com/danielgtaylor/huma/v2/humatest"
    "github.com/stretchr/testify/assert"
)

func TestAuthMiddleware(t *testing.T) {
    tests := []struct {
        desc      string
        header    string
        expCode   int
        expClaims Claims
    }{
        {
            desc:      "no auth header",
            header:    "",
            expCode:   401,
            expClaims: Claims{},
        },
        {
            desc:      "correct header",
            header:    "Authorization: bearer foo",
            expCode:   204,
            expClaims: Claims{Claims: []string{"foo", "bar"}},
        },
    }

    for _, tt := range tests {
        t.Run(tt.desc, func(*testing.T) {

            _, api := humatest.New(t)
            api.UseMiddleware(authMiddleware(api))
            huma.Register(api, huma.Operation{
                Method:      http.MethodGet,
                Path:        "/test",
                Summary:     "test",
                Description: "test route",
            }, func(ctx context.Context, _ *struct{}) (*struct{}, error) {
                if tt.expCode == 204 {
                    claims, ok := ctx.Value(claimsContextKey).(Claims)
                    assert.True(t, ok)
                    assert.Equal(t, tt.expClaims, claims)
                }
                return nil, nil
            })

            resp := api.Get("/test", tt.header)
            if tt.expCode != 204 {
                assert.Equal(t, tt.expCode, resp.Code)
            }

        })
    }
}

Comments / Concerns

I found what felt off during testing was that I didn't have a way to work directly with the context but rather had to use the api.get pattern to invoke the middleware itself. In order to make any assertions about how that middleware changed the context I had to create a custom operation which itself contained the asserts. The other thing that felt off here was that depending on what happened, I needed to make my assertions in different places, in the case of a 4xx the next handler would never be called and instead I would have to check on the API itself.

Questions

  1. Is this the best approach to test middleware?
  2. If it is, is there some utility we can provide in the humatest package to streamline this process?
danielgtaylor commented 4 weeks ago

@benkawecki you can use humatest.NewContext(...) for this purpose. You can pass a dummy operation unless you specifically need some operation for the middleware to work. For example:

req, _ := http.NewRequest(http.MethodGet, "/demo", nil)
w := httptest.NewRecorder()
ctx := humatest.NewContext(&huma.Operation{}, req, w)

https://go.dev/play/p/jr3gEy2NejH

In general I'm a big fan of "if client does X, client will see Y" which is why you see a lot of end-to-end tests via the API, but I do understand there are times you may want to just test the functionality of a small component like an individual piece of middleware. The functionality is there so feel free to use whatever makes the most sense for your use case!

benkawecki commented 3 weeks ago

@danielgtaylor That makes sense for how to generate a test context, and I agree on testing from the client perspective most of the time.

I'm not sure if this approach reduces the testing complexity from the original example. I attempted to rewrite my original tests and came up with the following (note tests currently pass but I'm actually letting a bug through)

func TestAuthMiddleWare2(t *testing.T) {
    tests := []struct {
        desc      string
        header    string
        expCode   int
        expClaims Claims
    }{
        {
            desc:      "no auth header",
            header:    "",
            expCode:   401,
            expClaims: Claims{},
        },
        {
            desc:      "correct header",
            header:    "Authorization: bearer foo",
            expCode:   204,
            expClaims: Claims{Claims: []string{"foo", "bar"}},
        },
    }

    for _, tt := range tests {
        t.Run(tt.desc, func(*testing.T) {

            _, api := humatest.New(t)
            mw := authMiddleware(api)

            req, _ := http.NewRequest(http.MethodGet, "/demo", nil)
            w := httptest.NewRecorder()

            ctx := humatest.NewContext(&huma.Operation{}, req, w)
            next := func(ctx huma.Context) {
                if tt.expCode == 204 {
                    claims, ok := ctx.Context().Value(claimsContextKey).(Claims)
                    assert.True(t, ok)
                    assert.Equal(t, tt.expClaims, claims)
                } else {
                    assert.Nil(t, 1)
                }
            }

            mw(ctx, next)

            if tt.expCode != 204 {
                assert.Equal(t, tt.expCode, w.Code)
            }
        })
    }
}

Here the second test case shouldn't pass since I'm not passing the header from the test case. The test passes since my function returns before invoking next which means my assertions are never called. I could find away to get around this, but I think it demonstrates the complexity with this approach in general.

I think the core issue issue is that there are two key paths:

  1. Cases where the middleware ends the chain of execution.
  2. Cases where middleware updates the context.

Because we have to enclose the tests for the 2nd path in either an operation or in a function it makes the test cases verbose and introduces complexity where the user could make a mistake.

I think using a test utility makes this much nicer. See the example below.

type NextRecorder struct {
    Called  bool
    context huma.Context
}

func (nr *NextRecorder) Next() func(huma.Context) {
    return func(ctx huma.Context) {
        nr.Called = true
        nr.context = ctx
    }

}

func (nr *NextRecorder) Context() huma.Context {
    return nr.context
}

func TestAuthMiddleware3(t *testing.T) {
    tests := []struct {
        desc      string
        header    string
        expCode   int
        expClaims Claims
    }{
        {
            desc:      "no auth header",
            expCode:   401,
            expClaims: Claims{},
        },
        {
            desc:      "correct header",
            header:    "bearer foo",
            expCode:   204,
            expClaims: Claims{Claims: []string{"foo", "bar"}},
        },
    }

    for _, tt := range tests {
        t.Run(tt.desc, func(*testing.T) {

            _, api := humatest.New(t)
            mw := authMiddleware(api)

            req, _ := http.NewRequest(http.MethodGet, "/demo", nil)
            if tt.header != "" {
                req.Header.Set("Authorization", tt.header)
            }
            w := httptest.NewRecorder()

            ctx := humatest.NewContext(&huma.Operation{}, req, w)

            nr := &NextRecorder{}

            mw(ctx, nr.Next())

            if tt.expCode == 204 {
                require.True(t, nr.Called)
                nextCtx := nr.Context()
                claims, ok := nextCtx.Context().Value(claimsContextKey).(Claims)
                assert.True(t, ok)
                assert.Equal(t, tt.expClaims, claims)

            }
            if tt.expCode != 204 {
                assert.Equal(t, tt.expCode, w.Code)
                assert.False(t, nr.Called)
            }
        })
    }
}
benkawecki commented 2 weeks ago

Hey @danielgtaylor I wanted to check back in on this and get your opinion. If you think this method makes more sense for testing I'd like to open a PR and update documentation.

danielgtaylor commented 7 hours ago

@benkawecki sorry for the delay, I had a family emergency and had to fly to Germany last minute to help with some things so haven't been able to work on Huma in a few weeks.

I'd say go for it if you want to add a small utility to capture the context for testing & update the docs. That seems reasonable and useful to me, thanks!