Open atdiar opened 1 year ago
package main
import (
"play.ground/cors"
)
func main() {
//foo.Bar()
}
-- go.mod --
module play.ground
-- cors/cors.go --
// Package cors implements the server-side logic that is used in response
// to determine whether a response is allowed to be returned by the server across request origins for a given API route
package cors
import (
"net/http"
"net/textproto"
"strconv"
"strings"
"time"
)
// Access control reference: https://www.w3.org/TR/cors/
/*
Rationale
=========
A Cross Origin http request for a given resource is made when a user agent is
used to retrieve resources from a given domain which themselves depend on
resources from another domain (such as an image stored on a foreign CDN for
instance).
The current package can be used to specify the conditions under which we allow a
resource at a given endpoint to be accessed.
The default being `same-origin` policy (same domain, same protocol, same port,
same host), it can be relaxed by specifying the type of Cross Origin request the
server allows (by Origin, by Headers, Content-type, etc.)
Hence, the presence of these headers determines whether a resource is accessible.
*/
var (
// SimpleRequestMethods is the set of methods for which CORS is allowed
// without preflight.
SimpleRequestMethods = newSet().Add("GET", "HEAD", "POST")
// SimpleRequestHeaders is the set of headers for which CORS is allowed
// without preflight.
SimpleRequestHeaders = newSet().Add("Accept", "Accept-Language", "Content-Language", "Content-Type")
// SimpleRequestContentTypes is the set of headers for which CORS is allowed
// without preflight.
SimpleRequestContentTypes = newSet().Add("application/x-www-form-urlencoded", "multipart/form-data", "text/plain")
// SimpleResponseHeaders is the set of header field names for which CORS is
// allows a response to a request without preflight.
SimpleResponseHeaders = newSet().Add("Cache-Control", "Content-Language", "Content-Type", "Expires", "Last-Modified", "Pragma")
)
// PolicyStore holds the different CORS policies per API route.
// It applies to incoming http requests.
// CORS controls the access to resources available on the server by defining
// constraints (request origin, http methods allowed, headers allowed, etc.)
type PolicyStore struct {
Policies map[string]Policy // each route can be given a specific policy which the preflight and the given route Handler can refer to
}
// New returns an object that holds the CORS policy for the different CORS enabled routes.
func New(s *http.ServeMux) *PolicyStore{
var p PolicyStore
p.Policies = make(map[string]Policy)
return &p
}
func(p*PolicyStore) New(path string, rules Policy)
// Policy is used to define a CORS
// response to a Cross-Origin request for a given resource.
// "*" is used to denote that anything is accepted (resp. Headers, Methods,
// Content-Types).
// The fields AllowedOrigins, AllowedHeaders, AllowedMethods, ExposeHeaders and
// AllowedContentTypes are sets of strings. A string may be inserted by using
// the `Add(str string, caseSensitive bool)` method.
// It is also possible to lookup for the existence of a string within a set
// thanks to the `Contains(str string, caseSensitive bool)` method.
type Policy struct {
AllowedOrigins set
AllowedHeaders set
AllowedContentTypes set
ExposeHeaders set
AllowedMethods set
AllowCredentials bool
MaxAge time.Duration // for preflight config
}
type preflightHandler struct {
*PolicyStore
}
// MaxAge sets a limit to the validity of a preflight result in
// cache.
func (p *preflightHandler) MaxAge(t time.Duration) {
// Implementation which should set the Access-Control-Max-Age header in sec.
// (in the allowed headers)
p.PolicyStore.AllowedHeaders.Add("Access-Control-Max-Age")
p.PolicyStore.MaxAge = t
}
// NewHandler creates a new, CORS policy enforcing, request handler.
func NewHandler() Handler {
h := Handler{}
h.Policy = new(Policy )
h.Policy .AllowedOrigins = newSet()
h.Policy .AllowedHeaders = newSet().Add("Accept", "Accept-Language", "Content-Language", "Content-Type", "Origin")
h.Policy .AllowedContentTypes = newSet().Add("application/x-www-form-urlencoded", "multipart/form-data", "text/plain")
h.Policy .ExposeHeaders = newSet()
h.Policy .AllowedMethods = newSet()
return h
}
func (p *preflightHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Check Headers: Origin, Access-Control-Request-Method, Access-Control-Request-Headers
if !originHeaderIsPresent(r) {
http.Error(w, "origin header is absent", 403)
return
}
// The preflight request is a preparation step that verifies that the request
// observes the requirement from the server in terms of origin, method, headers
// Checking origin
w.Header().Add("Vary", "Origin")
origin, ok := (textproto.MIMEHeader(r.Header))["Origin"]
if !ok {
http.Error(w, "origin header is absent or malformed", 403)
return
}
originallowed := p.PolicyStore.AllowedOrigins.Contains(origin[0], true)
if p.Parameters.AllowedOrigins.Contains("*", false) {
originallowed = true
}
if !originallowed {
http.Error(w, "origin not allowed", 403)
return
}
// Checking method
w.Header().Add("Vary", "Access-Control-Request-Method")
method, ok := (textproto.MIMEHeader(r.Header))["Access-Control-Request-Method"]
if !ok {
http.Error(w, "method header absent", 403)
return
}
methodallowed := p.Parameters.AllowedMethods.Contains(method[0], true)
if p.Parameters.AllowedMethods.Contains("*", true) {
methodallowed = true
}
if !methodallowed {
http.Error(w, "method not allowed", 403)
return
}
// Checking headers
w.Header().Add("Vary", "Access-Control-Request-Headers")
headers, ok := (textproto.MIMEHeader(r.Header))["Access-Control-Request-Headers"]
if !ok {
http.Error(w, "access control headers missing", 403)
return
}
headersallowed := p.Parameters.AllowedHeaders.Contains(headers[0], false)
for _, header := range headers {
headersallowed = headersallowed && p.Parameters.AllowedHeaders.Contains(header, false)
}
if p.Parameters.AllowedHeaders.Contains("*", false) {
headersallowed = true
}
if !headersallowed {
http.Error(w, "unallowed headers present", 403)
return
}
// Setting the appropriate Headers on the HTTP response
setAllowCredentials(w, p.Parameters.AllowCredentials)
if p.MaxAge != 0 {
setMaxAge(w, int(p.MaxAge.Seconds()))
}
w.Header().Set("Access-Control-Allow-Methods", method[0])
for _, header := range headers {
w.Header().Add("Access-Control-Allow-Headers", header)
}
}
// WithCredentials will allow the emmission of cookies, authorization headers,
// TLS client certificates with the http requests by the client.
func (h Handler) WithCredentials() Handler {
h.Parameters.AllowCredentials = true
return h
}
func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Vary", "Origin")
if !originIsPresent(r) {
if h.next != nil {
h.next.ServeHTTP(w, r)
}
return
}
// if the request is a simple one, we do not need to do much.
if methodIsAllowed(r, SimpleRequestMethods) {
if headersAreAllowed(r, SimpleRequestHeaders) {
if contentTypeIsAllowed(r, SimpleRequestContentTypes) {
if h.next != nil {
h.next.ServeHTTP(w, r)
}
return
}
}
}
setAllowOrigin(w, r, h.Parameters.AllowedOrigins)
setAllowCredentials(w, h.Parameters.AllowCredentials)
setExposeHeaders(w, h.Parameters.ExposeHeaders)
if h.next != nil {
h.next.ServeHTTP(w, r)
}
}
// setAllowOrigin will write the Access-Control-Allow-Origin header assigning to
// it the correct value.
func setAllowOrigin(w http.ResponseWriter, r *http.Request, AllowedOrigins set) {
header := textproto.MIMEHeader(r.Header)
origin, ok := header["Origin"]
if !ok {
return
}
if len(origin) != 1 {
return
}
ori := origin[0]
if !AllowedOrigins.Contains(ori, true) {
if AllowedOrigins.Contains("*", true) {
w.Header().Set("Access-Control-Allow-Origin", ori)
return
}
w.Header().Set("Access-Control-Allow-Origin", "null")
return
}
w.Header().Set("Access-Control-Allow-Origin", ori)
}
// setAllowMethods will write the Access-Control-Allow-Methods header assigning to
// it the correct value. It is written in response to a preflight request to
// provide the user-agent with the list of methods that can be used in the actual
// request.
func setAllowMethods(w http.ResponseWriter, s set) {
for method := range s {
w.Header().Add("Access-Control-Allow-Methods", method)
}
}
// setAllowHeaders will write the Access-Control-Allow-Headers header assigning to
// it the correct value. It is written in response to a preflight request to
// provide the user-agent with the list of headers that can be used in the actual
// request.
func setAllowHeaders(w http.ResponseWriter, s set) {
for header := range s {
w.Header().Add("Access-Control-Allow-Headers", header)
}
}
// setExposeHeaders writes out the Access-Control-Expose-Headers header.
// This is merely a whitelist of headers that the user-agent can read from an
// http response to a CORS request.
func setExposeHeaders(w http.ResponseWriter, s set) {
for header := range s {
w.Header().Add("Access-Control-Expose-Headers", header)
}
}
// setAllowCredentials writes out the Access-Control-Allow-Credentials header which
// indicates whether the actual request can include user credentials (in the
// case of a preflighted request).
// Otherwise (no preflight), it indicates whether the response can be exposed.
//
// NOTE: Note sure it will be that useful since the Basic Authenitcation scheme
// of the http protocol is not very practical.
func setAllowCredentials(w http.ResponseWriter, b bool) {
if b {
w.Header().Set("Access-Control-Allow-Credentials", "true")
return
}
w.Header().Set("Access-Control-Allow-Credentials", "false")
}
// setMaxAge writes out the Access-Control-Max-Age header which indicates for
// how long the results of the preflight request can be cached by the user-agent
// (browser for instance)
func setMaxAge(w http.ResponseWriter, seconds int) {
w.Header().Set("Access-Control-Max-Age", strconv.Itoa(seconds))
}
func headersAreAllowed(r *http.Request, s set) bool {
for k := range r.Header {
if !s.Contains(k, false) {
return false
}
}
return true
}
func methodIsAllowed(r *http.Request, s set) bool {
return s.Contains(r.Method, true)
}
func contentTypeIsAllowed(r *http.Request, s set) bool {
h := textproto.MIMEHeader(r.Header)
ct := h["Content-Type"]
var res bool
for _, val := range ct {
res = res && s.Contains(val, false)
}
return res
}
func originHeaderIsPresent(req *http.Request) bool {
ori := textproto.MIMEHeader(req.Header).Get("Origin")
if ori != "" {
return true
}
return false
}
// set defines an unordered list of string elements.
// Two methods have been made available:
// - an insert method called `Add`
// - a delete method called `Remove`
// - a lookup method called `Contains`
type set map[string]struct{}
func newSet() set {
s := make(map[string]struct{})
return s
}
func (s set) Add(strls ...string) set {
for _, str := range strls {
s[str] = struct{}{}
}
return s
}
func (s set) Remove(str string, caseSensitive bool) {
if !caseSensitive {
str = strings.ToLower(str)
}
delete(s, str)
}
func (s set) Contains(str string, caseSensitive bool) bool {
if !caseSensitive {
str = strings.ToLower(str)
}
for k := range s {
if k == str {
return true
}
}
return false
}
https://go.dev/play/p/62hs8jlh0Un WIP