go-chi / cors

CORS net/http middleware for Go
MIT License
332 stars 32 forks source link

Add custom handler option for failures #31

Open ablankz opened 6 months ago

ablankz commented 6 months ago

The following custom handlers can be added.

func customErrorHandler(w http.ResponseWriter, _ *http.Request, cors Cors, err error) bool {
    _, ok := err.(Error)
    if ok {
        cors.logf("%v", err)
        res := struct {
            Message string `json:"message"`
        }{
            Message: "CORS error: " + err.Error(),
        }
        switch {
        case errors.Is(err, &PreflightNotOptionMethodError{}):
            fallthrough
        case errors.Is(err, &PreflightNotAllowedMethodError{}):
            fallthrough
        case errors.Is(err, &ActualMethodNotAllowedError{}):
            w.WriteHeader(http.StatusMethodNotAllowed)
        default:
            w.WriteHeader(http.StatusForbidden)
        }
        if err := json.NewEncoder(w).Encode(res); err != nil {
            cors.logf("CORS error encoding failed: %v", err)
        }
        return false
    }
    res := struct {
        Message string `json:"message"`
    }{
        Message: "CORS error: An unexpected error has occurred",
    }
    if err := json.NewEncoder(w).Encode(res); err != nil {
        cors.logf("CORS error encoding failed: %v", err)
    }
    return false
}

Optionally specify this handler.

r.Use(cors.Handler(cors.Options{
  AllowOriginFunc:  AllowOriginFunc,
  AllowedMethods:   []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
  AllowedHeaders:   []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
  ExposedHeaders:   []string{"Link"},
  AllowCredentials: true,
  MaxAge:           300, // Maximum value not ignored by any of major browsers
  ErrorHandler:   customErrorHandler,
}))

While maintaining compatibility with the original code, we also added new test code and verified test code passing.

ablankz commented 6 months ago

The code for the customization presented, but from an external package, would look like this.

    ErrorHandler: func(w http.ResponseWriter, _ *http.Request, c cors.Cors, err error) bool {
        _, ok := err.(cors.Error)
        if ok {
            c.Log.Printf("CORS error: %v", err)
            res := struct {
                Message string `json:"message"`
            }{
                Message: "CORS error: " + err.Error(),
            }
            w.Header().Set("Content-Type", "application/json")
            noOrigin := false
            switch {
            case errors.Is(err, &cors.PreflightEmptyOriginError{}):
                fallthrough
            case errors.Is(err, &cors.ActualMissingOriginError{}):
                noOrigin = true
            case errors.Is(err, &cors.PreflightNotOptionMethodError{}):
                fallthrough
            case errors.Is(err, &cors.PreflightNotAllowedMethodError{}):
                fallthrough
            case errors.Is(err, &cors.ActualMethodNotAllowedError{}):
                w.WriteHeader(http.StatusMethodNotAllowed)
            default:
                w.WriteHeader(http.StatusForbidden)
            }
            // For requests that do not conform to the browser's same-origin policy (no Origin header,
            // such as postman, is given), pass through processing.
            if noOrigin {
                return true
            }
            if err := json.NewEncoder(w).Encode(res); err != nil {
                c.Log.Printf("CORS error encoding failed: %v", err)
            }
            return false
        }
        res := struct {
            Message string `json:"message"`
        }{
            Message: "CORS error: An unexpected error has occurred",
        }
        if err := json.NewEncoder(w).Encode(res); err != nil {
            c.Log.Printf("CORS error encoding failed: %v", err)
        }
        return false
    },