atdiar / xhttp

Set of libraries to facilitate the writing of webservers // light webframework
BSD 3-Clause "New" or "Revised" License
1 stars 0 forks source link

cors: needs to be reimplemented especially wrt new changes to ServeMux #1

Open atdiar opened 1 year ago

atdiar commented 1 year ago

https://go.dev/play/p/62hs8jlh0Un WIP

atdiar commented 1 year ago

package main

import (
    "play.ground/cors"
)

func main() {
    //foo.Bar()
}
-- go.mod --
module play.ground
-- cors/cors.go --
// Package cors implements the server-side logic that is used in response
// to determine whether a response is allowed to be returned by the server across request origins for a given API route
package cors

import (
    "net/http"
    "net/textproto"
    "strconv"
    "strings"
    "time"
)

// Access control reference: https://www.w3.org/TR/cors/

/*
Rationale
=========

A Cross Origin http request for a given resource is made when a user agent is
used to retrieve resources from a  given domain which themselves depend on
resources from another domain (such as an image stored on a foreign  CDN for
instance).

The current package can be used to specify the conditions under which we allow a
resource at a given endpoint to be accessed.

The default being `same-origin` policy (same domain, same protocol, same port,
same host), it can be relaxed by specifying the type of Cross Origin request the
server allows (by Origin, by Headers, Content-type, etc.)

Hence, the presence of these headers determines whether a resource is accessible.
*/

var (
    // SimpleRequestMethods is the set of methods for which CORS is allowed
    // without preflight.
    SimpleRequestMethods = newSet().Add("GET", "HEAD", "POST")

    // SimpleRequestHeaders is the set of headers for which CORS is allowed
    // without preflight.
    SimpleRequestHeaders = newSet().Add("Accept", "Accept-Language", "Content-Language", "Content-Type")

    // SimpleRequestContentTypes is the set of headers for which CORS is allowed
    // without preflight.
    SimpleRequestContentTypes = newSet().Add("application/x-www-form-urlencoded", "multipart/form-data", "text/plain")

    // SimpleResponseHeaders is the set of header field names for which CORS is
    // allows a response to a request without preflight.
    SimpleResponseHeaders = newSet().Add("Cache-Control", "Content-Language", "Content-Type", "Expires", "Last-Modified", "Pragma")
)

// PolicyStore holds the different CORS policies per API route.
// It applies to incoming http requests.
// CORS controls the access to resources available on the server by defining
// constraints (request origin, http methods allowed, headers allowed, etc.)
type PolicyStore struct {
    Policies map[string]Policy // each route can be given a specific policy which the preflight and the given route Handler can refer to
}

// New returns an object that holds the CORS policy for the different CORS enabled routes.

func New(s *http.ServeMux) *PolicyStore{
    var p PolicyStore
    p.Policies = make(map[string]Policy)
    return &p
}

func(p*PolicyStore) New(path string, rules Policy) 

// Policy is used to define a CORS
// response to a Cross-Origin request for a given resource.
// "*" is used to denote that anything is accepted (resp. Headers, Methods,
// Content-Types).
// The fields AllowedOrigins, AllowedHeaders, AllowedMethods, ExposeHeaders and
// AllowedContentTypes are sets of strings. A string may be inserted by using
// the `Add(str string, caseSensitive bool)` method.
// It is also possible to lookup for the existence of a string within a set
// thanks to the `Contains(str string, caseSensitive bool)` method.
type Policy struct {
    AllowedOrigins      set
    AllowedHeaders      set
    AllowedContentTypes set
    ExposeHeaders       set
    AllowedMethods      set
    AllowCredentials    bool
    MaxAge time.Duration // for preflight config
}

type preflightHandler struct {
    *PolicyStore
}

// MaxAge sets a limit to the validity of a preflight result in
// cache.
func (p *preflightHandler) MaxAge(t time.Duration) {
    // Implementation which should set the Access-Control-Max-Age header in sec.
    // (in the allowed headers)
    p.PolicyStore.AllowedHeaders.Add("Access-Control-Max-Age")
    p.PolicyStore.MaxAge = t

}

// NewHandler creates a new, CORS policy enforcing, request handler.
func NewHandler() Handler {
    h := Handler{}
    h.Policy = new(Policy )
    h.Policy .AllowedOrigins = newSet()
    h.Policy .AllowedHeaders = newSet().Add("Accept", "Accept-Language", "Content-Language", "Content-Type", "Origin")
    h.Policy .AllowedContentTypes = newSet().Add("application/x-www-form-urlencoded", "multipart/form-data", "text/plain")
    h.Policy .ExposeHeaders = newSet()
    h.Policy .AllowedMethods = newSet()
    return h
}

func (p *preflightHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {

    // Check Headers: Origin, Access-Control-Request-Method, Access-Control-Request-Headers
    if !originHeaderIsPresent(r) {
        http.Error(w, "origin header is absent", 403)
        return
    }

    // The preflight request is a preparation step that verifies that the request
    // observes the requirement from the server in terms of origin, method, headers

    // Checking origin
    w.Header().Add("Vary", "Origin")

    origin, ok := (textproto.MIMEHeader(r.Header))["Origin"]
    if !ok {
        http.Error(w, "origin header is absent or malformed", 403)
        return
    }
    originallowed := p.PolicyStore.AllowedOrigins.Contains(origin[0], true)
    if p.Parameters.AllowedOrigins.Contains("*", false) {
        originallowed = true
    }
    if !originallowed {
        http.Error(w, "origin not allowed", 403)
        return
    }

    // Checking method
    w.Header().Add("Vary", "Access-Control-Request-Method")

    method, ok := (textproto.MIMEHeader(r.Header))["Access-Control-Request-Method"]
    if !ok {
        http.Error(w, "method header absent", 403)
        return
    }
    methodallowed := p.Parameters.AllowedMethods.Contains(method[0], true)
    if p.Parameters.AllowedMethods.Contains("*", true) {
        methodallowed = true
    }
    if !methodallowed {
        http.Error(w, "method not allowed", 403)
        return
    }

    // Checking headers
    w.Header().Add("Vary", "Access-Control-Request-Headers")

    headers, ok := (textproto.MIMEHeader(r.Header))["Access-Control-Request-Headers"]
    if !ok {
        http.Error(w, "access control headers missing", 403)
        return
    }

    headersallowed := p.Parameters.AllowedHeaders.Contains(headers[0], false)
    for _, header := range headers {
        headersallowed = headersallowed && p.Parameters.AllowedHeaders.Contains(header, false)
    }
    if p.Parameters.AllowedHeaders.Contains("*", false) {
        headersallowed = true
    }
    if !headersallowed {
        http.Error(w, "unallowed headers present", 403)
        return
    }

    // Setting the appropriate Headers on the HTTP response
    setAllowCredentials(w, p.Parameters.AllowCredentials)

    if p.MaxAge != 0 {
        setMaxAge(w, int(p.MaxAge.Seconds()))
    }

    w.Header().Set("Access-Control-Allow-Methods", method[0])
    for _, header := range headers {
        w.Header().Add("Access-Control-Allow-Headers", header)
    }
}

// WithCredentials will allow the emmission of cookies, authorization headers,
// TLS client certificates with the http requests by the client.
func (h Handler) WithCredentials() Handler {
    h.Parameters.AllowCredentials = true
    return h
}

func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    w.Header().Add("Vary", "Origin")

    if !originIsPresent(r) {
        if h.next != nil {
            h.next.ServeHTTP(w, r)
        }
        return
    }

    // if the request is a simple one, we do not need to do much.
    if methodIsAllowed(r, SimpleRequestMethods) {
        if headersAreAllowed(r, SimpleRequestHeaders) {
            if contentTypeIsAllowed(r, SimpleRequestContentTypes) {
                if h.next != nil {
                    h.next.ServeHTTP(w, r)
                }
                return
            }
        }
    }
    setAllowOrigin(w, r, h.Parameters.AllowedOrigins)
    setAllowCredentials(w, h.Parameters.AllowCredentials)
    setExposeHeaders(w, h.Parameters.ExposeHeaders)

    if h.next != nil {
        h.next.ServeHTTP(w, r)
    }
}

// setAllowOrigin will write the Access-Control-Allow-Origin header assigning to
// it the correct value.
func setAllowOrigin(w http.ResponseWriter, r *http.Request, AllowedOrigins set) {
    header := textproto.MIMEHeader(r.Header)
    origin, ok := header["Origin"]
    if !ok {
        return
    }

    if len(origin) != 1 {
        return
    }

    ori := origin[0]

    if !AllowedOrigins.Contains(ori, true) {
        if AllowedOrigins.Contains("*", true) {
            w.Header().Set("Access-Control-Allow-Origin", ori)
            return
        }

        w.Header().Set("Access-Control-Allow-Origin", "null")
        return
    }

    w.Header().Set("Access-Control-Allow-Origin", ori)

}

// setAllowMethods will write the Access-Control-Allow-Methods header assigning to
// it the correct value. It is written in response to a preflight request to
// provide the user-agent with the list of methods that can be used in the actual
// request.
func setAllowMethods(w http.ResponseWriter, s set) {
    for method := range s {
        w.Header().Add("Access-Control-Allow-Methods", method)
    }
}

// setAllowHeaders will write the Access-Control-Allow-Headers header assigning to
// it the correct value. It is written in response to a preflight request to
// provide the user-agent with the list of headers that can be used in the actual
// request.
func setAllowHeaders(w http.ResponseWriter, s set) {
    for header := range s {
        w.Header().Add("Access-Control-Allow-Headers", header)
    }
}

// setExposeHeaders writes out the Access-Control-Expose-Headers header.
// This is merely a whitelist of headers that the user-agent can read from an
// http response to a CORS request.
func setExposeHeaders(w http.ResponseWriter, s set) {
    for header := range s {
        w.Header().Add("Access-Control-Expose-Headers", header)
    }
}

// setAllowCredentials writes out the Access-Control-Allow-Credentials header which
// indicates whether the actual request can include user credentials (in the
// case of a preflighted request).
// Otherwise (no preflight), it indicates whether the response can be exposed.
//
// NOTE: Note sure it will be that useful since the Basic Authenitcation scheme
// of the http protocol is not very practical.
func setAllowCredentials(w http.ResponseWriter, b bool) {
    if b {
        w.Header().Set("Access-Control-Allow-Credentials", "true")
        return
    }
    w.Header().Set("Access-Control-Allow-Credentials", "false")
}

// setMaxAge writes out the Access-Control-Max-Age header which indicates for
// how long the results of the preflight request can be cached by the user-agent
// (browser for instance)
func setMaxAge(w http.ResponseWriter, seconds int) {
    w.Header().Set("Access-Control-Max-Age", strconv.Itoa(seconds))
}

func headersAreAllowed(r *http.Request, s set) bool {
    for k := range r.Header {
        if !s.Contains(k, false) {
            return false
        }
    }
    return true
}

func methodIsAllowed(r *http.Request, s set) bool {
    return s.Contains(r.Method, true)
}

func contentTypeIsAllowed(r *http.Request, s set) bool {
    h := textproto.MIMEHeader(r.Header)
    ct := h["Content-Type"]
    var res bool
    for _, val := range ct {
        res = res && s.Contains(val, false)
    }
    return res
}

func originHeaderIsPresent(req *http.Request) bool {
    ori := textproto.MIMEHeader(req.Header).Get("Origin")
    if ori != "" {
        return true
    }
    return false
}

// set defines an unordered list of string elements.
// Two methods have been made available:
// - an insert method called `Add`
// - a delete method called `Remove`
// - a lookup method called `Contains`
type set map[string]struct{}

func newSet() set {
    s := make(map[string]struct{})
    return s
}

func (s set) Add(strls ...string) set {
    for _, str := range strls {
        s[str] = struct{}{}
    }
    return s
}

func (s set) Remove(str string, caseSensitive bool) {
    if !caseSensitive {
        str = strings.ToLower(str)
    }
    delete(s, str)
}

func (s set) Contains(str string, caseSensitive bool) bool {
    if !caseSensitive {
        str = strings.ToLower(str)
    }
    for k := range s {
        if k == str {
            return true
        }
    }
    return false
}