From 95c785b0bc2fec76c2901bbded1e685d5a6804fb Mon Sep 17 00:00:00 2001 From: Stefan Benz <46600784+stebenz@users.noreply.github.com> Date: Wed, 11 Dec 2024 11:18:03 +0100 Subject: [PATCH] fix: corrected inflate and deflate logic for xml, and corresponding tests with some refactoring for SAML sessions (#92) * refactor: remove separate functions for error reasons * fix: corrected inflate and deflate logic, and corresponding tests * fix: corrected inflate and deflate logic, and corresponding tests * fix: corrected inflate and deflate logic, and corresponding tests * fix: switch over status codes to errors * fix: review changes * fix: review changes --- pkg/provider/attribute_query.go | 8 +- pkg/provider/identityprovider.go | 8 +- pkg/provider/login.go | 59 ++++++++------ pkg/provider/login_test.go | 33 ++++---- pkg/provider/logout.go | 12 +-- pkg/provider/logout_response.go | 39 ++-------- pkg/provider/post.go | 14 ++-- pkg/provider/provider.go | 30 ++++++- pkg/provider/redirect.go | 36 ++++----- pkg/provider/redirect_test.go | 2 +- pkg/provider/response.go | 101 +++++++----------------- pkg/provider/response_test.go | 4 +- pkg/provider/sso.go | 41 +++++----- pkg/provider/sso_test.go | 130 ++++++++----------------------- pkg/provider/xml/xml.go | 128 +++++++++++------------------- 15 files changed, 255 insertions(+), 390 deletions(-) diff --git a/pkg/provider/attribute_query.go b/pkg/provider/attribute_query.go index 00ef547..a17f842 100644 --- a/pkg/provider/attribute_query.go +++ b/pkg/provider/attribute_query.go @@ -128,7 +128,7 @@ func (p *IdentityProvider) attributeQueryHandleFunc(w http.ResponseWriter, r *ht queriedAttrs = append(queriedAttrs, queriedAttr) } } - response = makeAttributeQueryResponse(attrQuery.Id, p.GetEntityID(r.Context()), sp.GetEntityID(), attrs, queriedAttrs, p.timeFormat) + response = makeAttributeQueryResponse(attrQuery.Id, p.GetEntityID(r.Context()), sp.GetEntityID(), attrs, queriedAttrs, p.TimeFormat, p.Expiration) return nil }, func() { @@ -139,7 +139,11 @@ func (p *IdentityProvider) attributeQueryHandleFunc(w http.ResponseWriter, r *ht // create enveloped signature checkerInstance.WithLogicStep( func() error { - return createPostSignature(r.Context(), response, p) + cert, key, err := getResponseCert(r.Context(), p.storage) + if err != nil { + return err + } + return createPostSignature(response, key, cert, p.conf.SignatureAlgorithm) }, func() { http.Error(w, fmt.Errorf("failed to sign response: %w", err).Error(), http.StatusInternalServerError) diff --git a/pkg/provider/identityprovider.go b/pkg/provider/identityprovider.go index 8d87057..002fe60 100644 --- a/pkg/provider/identityprovider.go +++ b/pkg/provider/identityprovider.go @@ -71,7 +71,8 @@ type IdentityProvider struct { metadataEndpoint *Endpoint endpoints *Endpoints - timeFormat string + TimeFormat string + Expiration time.Duration } type Endpoints struct { @@ -90,7 +91,8 @@ func NewIdentityProvider(metadata Endpoint, conf *IdentityProviderConfig, storag postTemplate: conf.PostTemplate, logoutTemplate: conf.LogoutTemplate, endpoints: endpointConfigToEndpoints(conf.Endpoints), - timeFormat: DefaultTimeFormat, + TimeFormat: DefaultTimeFormat, + Expiration: DefaultExpiration, } if conf.PostTemplate == nil { @@ -160,7 +162,7 @@ func (p *IdentityProvider) GetMetadata(ctx context.Context) (*md.IDPSSODescripto return nil, nil, err } - metadata, aaMetadata := p.conf.getMetadata(ctx, p.GetEntityID(ctx), cert, p.timeFormat) + metadata, aaMetadata := p.conf.getMetadata(ctx, p.GetEntityID(ctx), cert, p.TimeFormat) return metadata, aaMetadata, nil } diff --git a/pkg/provider/login.go b/pkg/provider/login.go index 4b83a8a..435f951 100644 --- a/pkg/provider/login.go +++ b/pkg/provider/login.go @@ -1,10 +1,14 @@ package provider import ( + "context" "fmt" "net/http" "github.com/zitadel/logging" + + "github.com/zitadel/saml/pkg/provider/models" + "github.com/zitadel/saml/pkg/provider/xml/samlp" ) func (p *IdentityProvider) callbackHandleFunc(w http.ResponseWriter, r *http.Request) { @@ -16,7 +20,6 @@ func (p *IdentityProvider) callbackHandleFunc(w http.ResponseWriter, r *http.Req Issuer: p.GetEntityID(r.Context()), } - ctx := r.Context() if err := r.ParseForm(); err != nil { logging.Error(err) http.Error(w, fmt.Errorf("failed to parse form: %w", err).Error(), http.StatusInternalServerError) @@ -34,7 +37,7 @@ func (p *IdentityProvider) callbackHandleFunc(w http.ResponseWriter, r *http.Req authRequest, err := p.storage.AuthRequestByID(r.Context(), requestID) if err != nil { logging.Error(err) - response.sendBackResponse(r, w, response.makeDeniedResponse(fmt.Errorf("failed to get request: %w", err).Error(), p.timeFormat)) + response.sendBackResponse(r, w, p.errorResponse(response, StatusCodeRequestDenied, fmt.Errorf("failed to get request: %w", err).Error())) return } response.RequestID = authRequest.GetAuthRequestID() @@ -42,44 +45,50 @@ func (p *IdentityProvider) callbackHandleFunc(w http.ResponseWriter, r *http.Req response.ProtocolBinding = authRequest.GetBindingType() response.AcsUrl = authRequest.GetAccessConsumerServiceURL() - if !authRequest.Done() { + entityID, err := p.storage.GetEntityIDByAppID(r.Context(), authRequest.GetApplicationID()) + if err != nil { logging.Error(err) http.Error(w, fmt.Errorf("failed to get entityID: %w", err).Error(), http.StatusInternalServerError) return } + response.Audience = entityID - entityID, err := p.storage.GetEntityIDByAppID(r.Context(), authRequest.GetApplicationID()) + samlResponse, err := p.loginResponse(r.Context(), authRequest, response) if err != nil { - logging.Error(err) - http.Error(w, fmt.Errorf("failed to get entityID: %w", err).Error(), http.StatusInternalServerError) + response.sendBackResponse(r, w, response.makeFailedResponse(err.Error(), "failed to create response", p.TimeFormat)) return } - response.Audience = entityID + + response.sendBackResponse(r, w, samlResponse) + return +} + +func (p *IdentityProvider) loginResponse(ctx context.Context, authRequest models.AuthRequestInt, response *Response) (*samlp.ResponseType, error) { + if !authRequest.Done() { + logging.Error(StatusCodeAuthNFailed) + return nil, fmt.Errorf(StatusCodeAuthNFailed) + } attrs := &Attributes{} if err := p.storage.SetUserinfoWithUserID(ctx, authRequest.GetApplicationID(), attrs, authRequest.GetUserID(), []int{}); err != nil { logging.Error(err) - http.Error(w, fmt.Errorf("failed to get userinfo: %w", err).Error(), http.StatusInternalServerError) - return + return nil, fmt.Errorf(StatusCodeInvalidAttrNameOrValue) } - samlResponse := response.makeSuccessfulResponse(attrs, p.timeFormat) + cert, key, err := getResponseCert(ctx, p.storage) + if err != nil { + logging.Error(err) + return nil, fmt.Errorf(StatusCodeInvalidAttrNameOrValue) + } - switch response.ProtocolBinding { - case PostBinding: - if err := createPostSignature(r.Context(), samlResponse, p); err != nil { - logging.Error(err) - response.sendBackResponse(r, w, response.makeResponderFailResponse(fmt.Errorf("failed to sign response: %w", err).Error(), p.timeFormat)) - return - } - case RedirectBinding: - if err := createRedirectSignature(r.Context(), samlResponse, p, response); err != nil { - logging.Error(err) - response.sendBackResponse(r, w, response.makeResponderFailResponse(fmt.Errorf("failed to sign response: %w", err).Error(), p.timeFormat)) - return - } + samlResponse := response.makeSuccessfulResponse(attrs, p.TimeFormat, p.Expiration) + if err := createSignature(response, samlResponse, key, cert, p.conf.SignatureAlgorithm); err != nil { + logging.Error(err) + return nil, fmt.Errorf(StatusCodeResponder) } + return samlResponse, nil +} - response.sendBackResponse(r, w, samlResponse) - return +func (p *IdentityProvider) errorResponse(response *Response, reason string, description string) *samlp.ResponseType { + return response.makeFailedResponse(reason, description, p.TimeFormat) } diff --git a/pkg/provider/login_test.go b/pkg/provider/login_test.go index 1bf9e81..c188ec4 100644 --- a/pkg/provider/login_test.go +++ b/pkg/provider/login_test.go @@ -1,9 +1,9 @@ package provider import ( - "io/ioutil" "net/http" "net/http/httptest" + "net/url" "testing" "github.com/golang/mock/gomock" @@ -23,9 +23,11 @@ func TestSSO_loginHandleFunc(t *testing.T) { Done bool } type res struct { - code int - err bool - state string + code int + err bool + state string + inflate bool + b64 bool } type sp struct { appID string @@ -235,7 +237,7 @@ func TestSSO_loginHandleFunc(t *testing.T) { ID: "test", AuthRequestID: "test", Binding: RedirectBinding, - AcsURL: "url", + AcsURL: "https://sp.example.com", RelayState: "relaystate", UserID: "userid", Done: false, @@ -247,9 +249,11 @@ func TestSSO_loginHandleFunc(t *testing.T) { }, }, res{ - code: 500, - state: "", - err: false, + code: 302, + state: StatusCodeAuthNFailed, + err: false, + inflate: true, + b64: true, }}, } @@ -297,14 +301,15 @@ func TestSSO_loginHandleFunc(t *testing.T) { defer func() { _ = res.Body.Close() }() - response, err := ioutil.ReadAll(res.Body) - if res.StatusCode != tt.res.code { - t.Errorf("ssoHandleFunc() code got = %v, want %v", res.StatusCode, tt.res) - return - } + // currently only checked for redirect binding if tt.res.state != "" { - if err := parseForState(string(response), tt.res.state); err != nil { + responseURL, err := url.Parse(res.Header.Get("Location")) + if err != nil { + t.Errorf("error while parsing url") + } + + if err := parseForState(tt.res.inflate, tt.res.b64, responseURL.Query().Get("SAMLResponse"), tt.res.state); err != nil { t.Errorf("ssoHandleFunc() response state not: %v", tt.res.state) return } diff --git a/pkg/provider/logout.go b/pkg/provider/logout.go index 03eaef9..5c8614b 100644 --- a/pkg/provider/logout.go +++ b/pkg/provider/logout.go @@ -44,7 +44,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque return nil }, func() { - response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to parse form: %w", err).Error(), p.timeFormat)) + response.sendBackLogoutResponse(w, response.makeFailedLogoutResponse(StatusCodeRequestDenied, fmt.Errorf("failed to parse form: %w", err).Error(), p.TimeFormat)) }, ) @@ -60,7 +60,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque return nil }, func() { - response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to decode request: %w", err).Error(), p.timeFormat)) + response.sendBackLogoutResponse(w, response.makeFailedLogoutResponse(StatusCodeRequestDenied, fmt.Errorf("failed to decode request: %w", err).Error(), p.TimeFormat)) }, ) @@ -69,10 +69,10 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque checkIfRequestTimeIsStillValid( func() string { return logoutRequest.IssueInstant }, func() string { return logoutRequest.NotOnOrAfter }, - p.timeFormat, + p.TimeFormat, ), func() { - response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to validate request: %w", err).Error(), p.timeFormat)) + response.sendBackLogoutResponse(w, response.makeFailedLogoutResponse(StatusCodeRequestDenied, fmt.Errorf("failed to validate request: %w", err).Error(), p.TimeFormat)) }, ) @@ -83,7 +83,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque return err }, func() { - response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to find registered serviceprovider: %w", err).Error(), p.timeFormat)) + response.sendBackLogoutResponse(w, response.makeFailedLogoutResponse(StatusCodeRequestDenied, fmt.Errorf("failed to find registered serviceprovider: %w", err).Error(), p.TimeFormat)) }, ) @@ -106,7 +106,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque response.sendBackLogoutResponse( w, - response.makeSuccessfulLogoutResponse(p.timeFormat), + response.makeSuccessfulLogoutResponse(p.TimeFormat), ) logging.Info(fmt.Sprintf("logout request for user %s", logoutRequest.NameID.Text)) } diff --git a/pkg/provider/logout_response.go b/pkg/provider/logout_response.go index 63a5765..407ccd8 100644 --- a/pkg/provider/logout_response.go +++ b/pkg/provider/logout_response.go @@ -55,32 +55,8 @@ func (r *LogoutResponse) sendBackLogoutResponse(w http.ResponseWriter, resp *sam } } -func (r *LogoutResponse) makeSuccessfulLogoutResponse(timeFormat string) *samlp.LogoutResponseType { - return makeLogoutResponse( - r.RequestID, - r.LogoutURL, - time.Now().UTC().Format(timeFormat), - StatusCodeSuccess, - "", - getIssuer(r.Issuer), - ) -} - -func (r *LogoutResponse) makeUnsupportedlLogoutResponse( - message string, - timeFormat string, -) *samlp.LogoutResponseType { - return makeLogoutResponse( - r.RequestID, - r.LogoutURL, - time.Now().UTC().Format(timeFormat), - StatusCodeRequestUnsupported, - message, - getIssuer(r.Issuer), - ) -} - -func (r *LogoutResponse) makePartialLogoutResponse( +func (r *LogoutResponse) makeFailedLogoutResponse( + reason string, message string, timeFormat string, ) *samlp.LogoutResponseType { @@ -88,22 +64,19 @@ func (r *LogoutResponse) makePartialLogoutResponse( r.RequestID, r.LogoutURL, time.Now().UTC().Format(timeFormat), - StatusCodePartialLogout, + reason, message, getIssuer(r.Issuer), ) } -func (r *LogoutResponse) makeDeniedLogoutResponse( - message string, - timeFormat string, -) *samlp.LogoutResponseType { +func (r *LogoutResponse) makeSuccessfulLogoutResponse(timeFormat string) *samlp.LogoutResponseType { return makeLogoutResponse( r.RequestID, r.LogoutURL, time.Now().UTC().Format(timeFormat), - StatusCodeRequestDenied, - message, + StatusCodeSuccess, + "", getIssuer(r.Issuer), ) } diff --git a/pkg/provider/post.go b/pkg/provider/post.go index 82ef22e..38a44dc 100644 --- a/pkg/provider/post.go +++ b/pkg/provider/post.go @@ -1,7 +1,7 @@ package provider import ( - "context" + "crypto/rsa" "encoding/base64" "reflect" @@ -63,16 +63,12 @@ func verifyPostSignature( } func createPostSignature( - ctx context.Context, samlResponse *samlp.ResponseType, - idp *IdentityProvider, + key *rsa.PrivateKey, + cert []byte, + signatureAlgorithm string, ) error { - cert, key, err := getResponseCert(ctx, idp.storage) - if err != nil { - return err - } - - signer, err := signature.GetSigner(cert, key, idp.conf.SignatureAlgorithm) + signer, err := signature.GetSigner(cert, key, signatureAlgorithm) if err != nil { return err } diff --git a/pkg/provider/provider.go b/pkg/provider/provider.go index 0577862..df0471d 100644 --- a/pkg/provider/provider.go +++ b/pkg/provider/provider.go @@ -5,20 +5,23 @@ import ( "crypto/rsa" "fmt" "net/http" + "time" "github.com/google/uuid" "github.com/gorilla/handlers" "github.com/gorilla/mux" + "github.com/zitadel/saml/pkg/provider/models" "github.com/zitadel/saml/pkg/provider/signature" "github.com/zitadel/saml/pkg/provider/xml/md" + "github.com/zitadel/saml/pkg/provider/xml/samlp" ) const ( DefaultTimeFormat = "2006-01-02T15:04:05.999999Z" + DefaultExpiration = 5 * time.Minute PostBinding = "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST" RedirectBinding = "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" - SOAPBinding = "urn:oasis:names:tc:SAML:2.0:bindings:SOAP" DefaultMetadataEndpoint = "/metadata" ) @@ -80,7 +83,6 @@ type Provider struct { conf *Config issuerFromRequest IssuerFromRequest identityProvider *IdentityProvider - timeFormat string } func NewProvider( @@ -217,12 +219,32 @@ var allowAllOrigins = func(_ string) bool { } // AuthCallbackURL builds the url for the redirect (with the requestID) after a successful login -func AuthCallbackURL(p *Provider) func(context.Context, string) string { +func (p *Provider) AuthCallbackURL() func(context.Context, string) string { return func(ctx context.Context, requestID string) string { return p.identityProvider.endpoints.callbackEndpoint.Absolute(IssuerFromContext(ctx)) + "?id=" + requestID } } +// AuthCallbackResponse returns the SAMLResponse from as successful SAMLRequest +func (p *Provider) AuthCallbackResponse(ctx context.Context, authRequest models.AuthRequestInt, response *Response) (*samlp.ResponseType, error) { + return p.identityProvider.loginResponse(ctx, authRequest, response) +} + +// AuthCallbackErrorResponse returns the SAMLResponse from as failed SAMLRequest +func (p *Provider) AuthCallbackErrorResponse(response *Response, reason string, description string) *samlp.ResponseType { + return p.identityProvider.errorResponse(response, reason, description) +} + +// Timeformat return the used timeformat in messages +func (p *Provider) Timeformat() string { + return p.identityProvider.TimeFormat +} + +// Expiration return the used expiration in messages +func (p *Provider) Expiration() time.Duration { + return p.identityProvider.Expiration +} + func intercept(i IssuerFromRequest, interceptors ...HttpInterceptor) func(handler http.Handler) http.Handler { cors := handlers.CORS( handlers.AllowCredentials(), @@ -250,7 +272,7 @@ func WithAllowInsecure() Option { // WithCustomTimeFormat allows the use of a custom timeformat instead of the default func WithCustomTimeFormat(timeFormat string) Option { return func(p *Provider) error { - p.identityProvider.timeFormat = timeFormat + p.identityProvider.TimeFormat = timeFormat return nil } } diff --git a/pkg/provider/redirect.go b/pkg/provider/redirect.go index 45f76e5..a31bac9 100644 --- a/pkg/provider/redirect.go +++ b/pkg/provider/redirect.go @@ -1,7 +1,7 @@ package provider import ( - "context" + "crypto/rsa" "encoding/base64" "fmt" "net/url" @@ -66,47 +66,41 @@ func verifyRedirectSignature( } func createRedirectSignature( - ctx context.Context, samlResponse *samlp.ResponseType, - idp *IdentityProvider, - response *Response, -) error { + key *rsa.PrivateKey, + cert []byte, + signatureAlgorithm string, + relayState string, +) (string, string, error) { resp, err := xml.Marshal(samlResponse) if err != nil { - return err + return "", "", err } respData, err := xml.DeflateAndBase64(resp) if err != nil { - return err - } - - cert, key, err := getResponseCert(ctx, idp.storage) - if err != nil { - return err + return "", "", err } tlsCert, err := signature.ParseTlsKeyPair(cert, key) if err != nil { - return err + return "", "", err } - signingContext, err := signature.GetSigningContext(tlsCert, idp.conf.SignatureAlgorithm) + signingContext, err := signature.GetSigningContext(tlsCert, signatureAlgorithm) if err != nil { - return err + return "", "", err } - sig, err := signature.CreateRedirect(signingContext, buildRedirectQuery(string(respData), response.RelayState, idp.conf.SignatureAlgorithm, "")) + sig, err := signature.CreateRedirect(signingContext, BuildRedirectQuery(string(respData), relayState, signatureAlgorithm, "")) if err != nil { - return err + return "", "", err } - response.Signature = url.QueryEscape(base64.StdEncoding.EncodeToString(sig)) - response.SigAlg = url.QueryEscape(base64.StdEncoding.EncodeToString([]byte(idp.conf.SignatureAlgorithm))) - return nil + return url.QueryEscape(base64.StdEncoding.EncodeToString(sig)), url.QueryEscape(base64.StdEncoding.EncodeToString([]byte(signatureAlgorithm))), nil } -func buildRedirectQuery( +func BuildRedirectQuery( response string, relayState string, sigAlg string, diff --git a/pkg/provider/redirect_test.go b/pkg/provider/redirect_test.go index 18adeb7..87c689c 100644 --- a/pkg/provider/redirect_test.go +++ b/pkg/provider/redirect_test.go @@ -385,7 +385,7 @@ func TestRedirect_buildRedirectQuery(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := buildRedirectQuery(tt.args.response, tt.args.relayState, tt.args.sigAlg, tt.args.sig) + got := BuildRedirectQuery(tt.args.response, tt.args.relayState, tt.args.sigAlg, tt.args.sig) if got != tt.res { t.Errorf("verifyRedirectSignature() got = %v, want %v", got, tt.res) return diff --git a/pkg/provider/response.go b/pkg/provider/response.go index ddd4346..f868602 100644 --- a/pkg/provider/response.go +++ b/pkg/provider/response.go @@ -1,6 +1,7 @@ package provider import ( + "crypto/rsa" "encoding/base64" "fmt" "html/template" @@ -12,7 +13,7 @@ import ( "github.com/zitadel/saml/pkg/provider/xml/samlp" ) -const ( +var ( StatusCodeSuccess = "urn:oasis:names:tc:SAML:2.0:status:Success" StatusCodeVersionMissmatch = "urn:oasis:names:tc:SAML:2.0:status:VersionMismatch" StatusCodeAuthNFailed = "urn:oasis:names:tc:SAML:2.0:status:AuthnFailed" @@ -40,11 +41,7 @@ type Response struct { SendIP string } -func (r *Response) doResponse(request *http.Request, w http.ResponseWriter, response string) { - -} - -type AuthResponseForm struct { +type authResponseForm struct { RelayState string SAMLResponse string AssertionConsumerServiceURL string @@ -73,7 +70,7 @@ func (r *Response) sendBackResponse( case PostBinding: respData := base64.StdEncoding.EncodeToString(respData) - data := AuthResponseForm{ + data := authResponseForm{ r.RelayState, respData, r.AcsUrl, @@ -90,76 +87,42 @@ func (r *Response) sendBackResponse( return } - http.Redirect(w, req, fmt.Sprintf("%s?%s", r.AcsUrl, buildRedirectQuery(string(respData), r.RelayState, r.SigAlg, r.Signature)), http.StatusFound) + http.Redirect(w, req, fmt.Sprintf("%s?%s", r.AcsUrl, BuildRedirectQuery(string(respData), r.RelayState, r.SigAlg, r.Signature)), http.StatusFound) return default: //TODO: no binding } } -func (r *Response) makeUnsupportedBindingResponse( - message string, - timeFormat string, -) *samlp.ResponseType { - now := time.Now().UTC() - nowStr := now.Format(timeFormat) - return makeResponse( - NewID(), - r.RequestID, - r.AcsUrl, - nowStr, - StatusCodeUnsupportedBinding, - message, - r.Issuer, - ) -} - -func (r *Response) makeResponderFailResponse( - message string, - timeFormat string, -) *samlp.ResponseType { - now := time.Now().UTC() - nowStr := now.Format(timeFormat) - return makeResponse( - NewID(), - r.RequestID, - r.AcsUrl, - nowStr, - StatusCodeResponder, - message, - r.Issuer, - ) -} - -func (r *Response) makeDeniedResponse( - message string, - timeFormat string, -) *samlp.ResponseType { - now := time.Now().UTC() - nowStr := now.Format(timeFormat) - return makeResponse( - NewID(), - r.RequestID, - r.AcsUrl, - nowStr, - StatusCodeRequestDenied, - message, - r.Issuer, - ) +func createSignature(response *Response, samlResponse *samlp.ResponseType, key *rsa.PrivateKey, cert []byte, signatureAlgorithm string) error { + switch response.ProtocolBinding { + case PostBinding: + if err := createPostSignature(samlResponse, key, cert, signatureAlgorithm); err != nil { + return fmt.Errorf("failed to sign response: %w", err) + } + case RedirectBinding: + sig, sigAlg, err := createRedirectSignature(samlResponse, key, cert, signatureAlgorithm, response.RelayState) + if err != nil { + return fmt.Errorf("failed to sign response: %w", err) + } + response.Signature = sig + response.SigAlg = sigAlg + } + return nil } func (r *Response) makeFailedResponse( + reason string, message string, timeFormat string, ) *samlp.ResponseType { now := time.Now().UTC() - nowStr := now.Format(timeFormat) return makeResponse( NewID(), r.RequestID, r.AcsUrl, - nowStr, - StatusCodeAuthNFailed, + now.Format(timeFormat), + reason, message, r.Issuer, ) @@ -168,14 +131,12 @@ func (r *Response) makeFailedResponse( func (r *Response) makeSuccessfulResponse( attributes *Attributes, timeFormat string, + expiration time.Duration, ) *samlp.ResponseType { now := time.Now().UTC() - nowStr := now.Format(timeFormat) - fiveFromNowStr := now.Add(5 * time.Minute).Format(timeFormat) - return r.makeAssertionResponse( - nowStr, - fiveFromNowStr, + now.Format(timeFormat), + now.Add(expiration).Format(timeFormat), attributes, ) } @@ -206,13 +167,9 @@ func makeAttributeQueryResponse( attributes *Attributes, queriedAttrs []saml.AttributeType, timeFormat string, + expiration time.Duration, ) *samlp.ResponseType { now := time.Now().UTC() - nowStr := now.Format(timeFormat) - fiveMinutes, _ := time.ParseDuration("5m") - fiveFromNow := now.Add(fiveMinutes) - fiveFromNowStr := fiveFromNow.Format(timeFormat) - providedAttrs := []*saml.AttributeType{} attrsSaml := attributes.GetSAML() if queriedAttrs == nil || len(queriedAttrs) == 0 { @@ -229,8 +186,8 @@ func makeAttributeQueryResponse( } } - response := makeResponse(NewID(), requestID, "", nowStr, StatusCodeSuccess, "", issuer) - assertion := makeAssertion(requestID, "", "", nowStr, fiveFromNowStr, issuer, attributes.GetNameID(), providedAttrs, entityID, false) + response := makeResponse(NewID(), requestID, "", now.Format(timeFormat), StatusCodeSuccess, "", issuer) + assertion := makeAssertion(requestID, "", "", now.Format(timeFormat), now.Add(expiration).Format(timeFormat), issuer, attributes.GetNameID(), providedAttrs, entityID, false) response.Assertion = *assertion return response } diff --git a/pkg/provider/response_test.go b/pkg/provider/response_test.go index 693f71c..358cf6a 100644 --- a/pkg/provider/response_test.go +++ b/pkg/provider/response_test.go @@ -74,7 +74,7 @@ func TestResponse_sendBackResponse(t *testing.T) { signature: "sig", }, res{ - body: []byte("Found.\n\n"), + body: []byte("Found.\n\n"), }, }, { @@ -94,7 +94,7 @@ func TestResponse_sendBackResponse(t *testing.T) { signature: "sig", }, res{ - body: []byte("Found.\n\n"), + body: []byte("Found.\n\n"), }, }, { diff --git a/pkg/provider/sso.go b/pkg/provider/sso.go index dafca50..451566b 100644 --- a/pkg/provider/sso.go +++ b/pkg/provider/sso.go @@ -61,7 +61,7 @@ func (p *IdentityProvider) ssoHandleFunc(w http.ResponseWriter, r *http.Request) return nil }, func() { - response.sendBackResponse(r, w, response.makeDeniedResponse(fmt.Errorf("failed to parse form").Error(), p.timeFormat)) + response.sendBackResponse(r, w, response.makeFailedResponse(StatusCodeRequestDenied, fmt.Errorf("failed to parse form").Error(), p.TimeFormat)) }, ) @@ -70,7 +70,7 @@ func (p *IdentityProvider) ssoHandleFunc(w http.ResponseWriter, r *http.Request) "SAMLRequest", func() string { return authRequestForm.AuthRequest }, func() { - response.sendBackResponse(r, w, response.makeDeniedResponse(fmt.Errorf("no auth request provided").Error(), p.timeFormat)) + response.sendBackResponse(r, w, response.makeFailedResponse(StatusCodeRequestDenied, fmt.Errorf("no auth request provided").Error(), p.TimeFormat)) }, ) @@ -80,7 +80,7 @@ func (p *IdentityProvider) ssoHandleFunc(w http.ResponseWriter, r *http.Request) "Signature", func() string { return authRequestForm.Sig }, func() { - response.sendBackResponse(r, w, response.makeDeniedResponse(fmt.Errorf("signature algorith provided but no signature").Error(), p.timeFormat)) + response.sendBackResponse(r, w, response.makeFailedResponse(StatusCodeRequestDenied, fmt.Errorf("signature algorith provided but no signature").Error(), p.TimeFormat)) }, ) @@ -95,7 +95,7 @@ func (p *IdentityProvider) ssoHandleFunc(w http.ResponseWriter, r *http.Request) return nil }, func() { - response.sendBackResponse(r, w, response.makeDeniedResponse(fmt.Errorf("failed to decode request").Error(), p.timeFormat)) + response.sendBackResponse(r, w, response.makeFailedResponse(StatusCodeRequestDenied, fmt.Errorf("failed to decode request").Error(), p.TimeFormat)) }, ) @@ -110,7 +110,7 @@ func (p *IdentityProvider) ssoHandleFunc(w http.ResponseWriter, r *http.Request) return nil }, func() { - response.sendBackResponse(r, w, response.makeDeniedResponse(fmt.Errorf("failed to find registered serviceprovider: %w", err).Error(), p.timeFormat)) + response.sendBackResponse(r, w, response.makeFailedResponse(StatusCodeRequestDenied, fmt.Errorf("failed to find registered serviceprovider: %w", err).Error(), p.TimeFormat)) }, ) @@ -125,7 +125,7 @@ func (p *IdentityProvider) ssoHandleFunc(w http.ResponseWriter, r *http.Request) func() *md.EntityDescriptorType { return sp.Metadata }, ), func() { - response.sendBackResponse(r, w, response.makeDeniedResponse(fmt.Errorf("failed to validate certificate from request: %w", err).Error(), p.timeFormat)) + response.sendBackResponse(r, w, response.makeFailedResponse(StatusCodeRequestDenied, fmt.Errorf("failed to validate certificate from request: %w", err).Error(), p.TimeFormat)) }, ) @@ -146,7 +146,7 @@ func (p *IdentityProvider) ssoHandleFunc(w http.ResponseWriter, r *http.Request) func(errF error) { err = errF }, ), func() { - response.sendBackResponse(r, w, response.makeDeniedResponse(fmt.Errorf("failed to verify signature: %w", err).Error(), p.timeFormat)) + response.sendBackResponse(r, w, response.makeFailedResponse(StatusCodeRequestDenied, fmt.Errorf("failed to verify signature: %w", err).Error(), p.TimeFormat)) }, ) @@ -164,14 +164,14 @@ func (p *IdentityProvider) ssoHandleFunc(w http.ResponseWriter, r *http.Request) func(errF error) { err = errF }, ), func() { - response.sendBackResponse(r, w, response.makeDeniedResponse(fmt.Errorf("failed to verify signature: %w", err).Error(), p.timeFormat)) + response.sendBackResponse(r, w, response.makeFailedResponse(StatusCodeRequestDenied, fmt.Errorf("failed to verify signature: %w", err).Error(), p.TimeFormat)) }, ) // work out used acs url and protocolbinding for response checkerInstance.WithValueStep( func() { - response.AcsUrl, response.ProtocolBinding = getAcsUrlAndBindingForResponse(sp, authNRequest.ProtocolBinding) + response.AcsUrl, response.ProtocolBinding = GetAcsUrlAndBindingForResponse(sp.Metadata.SPSSODescriptor.AssertionConsumerService, authNRequest.ProtocolBinding) }, ) @@ -180,7 +180,7 @@ func (p *IdentityProvider) ssoHandleFunc(w http.ResponseWriter, r *http.Request) "acsUrl", func() string { return response.AcsUrl }, func() { - response.sendBackResponse(r, w, response.makeUnsupportedBindingResponse(fmt.Errorf("missing usable assertion consumer url").Error(), p.timeFormat)) + response.sendBackResponse(r, w, response.makeFailedResponse(StatusCodeUnsupportedBinding, fmt.Errorf("missing usable assertion consumer url").Error(), p.TimeFormat)) }, ) @@ -189,7 +189,7 @@ func (p *IdentityProvider) ssoHandleFunc(w http.ResponseWriter, r *http.Request) "protocol binding", func() string { return response.ProtocolBinding }, func() { - response.sendBackResponse(r, w, response.makeUnsupportedBindingResponse(fmt.Errorf("missing usable protocol binding").Error(), p.timeFormat)) + response.sendBackResponse(r, w, response.makeFailedResponse(StatusCodeUnsupportedBinding, fmt.Errorf("missing usable protocol binding").Error(), p.TimeFormat)) }, ) @@ -200,7 +200,7 @@ func (p *IdentityProvider) ssoHandleFunc(w http.ResponseWriter, r *http.Request) func() *samlp.AuthnRequestType { return authNRequest }, ), func() { - response.sendBackResponse(r, w, response.makeDeniedResponse(fmt.Errorf("failed to validate request content: %w", err).Error(), p.timeFormat)) + response.sendBackResponse(r, w, response.makeFailedResponse(StatusCodeRequestDenied, fmt.Errorf("failed to validate request content: %w", err).Error(), p.TimeFormat)) }, ) @@ -218,7 +218,7 @@ func (p *IdentityProvider) ssoHandleFunc(w http.ResponseWriter, r *http.Request) return err }, func() { - response.sendBackResponse(r, w, response.makeResponderFailResponse(fmt.Errorf("failed to persist request: %w", err).Error(), p.timeFormat)) + response.sendBackResponse(r, w, response.makeFailedResponse(StatusCodeResponder, fmt.Errorf("failed to persist request: %w", err).Error(), p.TimeFormat)) }, ) @@ -232,7 +232,7 @@ func (p *IdentityProvider) ssoHandleFunc(w http.ResponseWriter, r *http.Request) http.Redirect(w, r, sp.LoginURL(authRequest.GetID()), http.StatusSeeOther) default: logging.Error(err) - response.sendBackResponse(r, w, response.makeUnsupportedBindingResponse(fmt.Errorf("unsupported binding: %s", response.ProtocolBinding).Error(), p.timeFormat)) + response.sendBackResponse(r, w, response.makeFailedResponse(StatusCodeUnsupportedBinding, fmt.Errorf("unsupported binding: %s", response.ProtocolBinding).Error(), p.TimeFormat)) } return } @@ -258,6 +258,9 @@ func getAuthRequestFromRequest(r *http.Request) (*AuthRequestForm, error) { Sig: r.FormValue("Signature"), Binding: binding, } + if request.Encoding == "" && binding == RedirectBinding { + request.Encoding = xml.EncodingDeflate + } return request, nil } @@ -348,14 +351,14 @@ func checkCertificate( } } -func getAcsUrlAndBindingForResponse( - sp *serviceprovider.ServiceProvider, +func GetAcsUrlAndBindingForResponse( + acs []md.IndexedEndpointType, requestProtocolBinding string, ) (string, string) { acsUrl := "" protocolBinding := "" - for _, acs := range sp.Metadata.SPSSODescriptor.AssertionConsumerService { + for _, acs := range acs { if acs.Binding == requestProtocolBinding { acsUrl = acs.Location protocolBinding = acs.Binding @@ -364,7 +367,7 @@ func getAcsUrlAndBindingForResponse( } if acsUrl == "" { isDefaultFound := false - for _, acs := range sp.Metadata.SPSSODescriptor.AssertionConsumerService { + for _, acs := range acs { if acs.IsDefault == "true" { isDefaultFound = true acsUrl = acs.Location @@ -374,7 +377,7 @@ func getAcsUrlAndBindingForResponse( } if !isDefaultFound { index := 0 - for _, acs := range sp.Metadata.SPSSODescriptor.AssertionConsumerService { + for _, acs := range acs { i, _ := strconv.Atoi(acs.Index) if index == 0 || i < index { acsUrl = acs.Location diff --git a/pkg/provider/sso_test.go b/pkg/provider/sso_test.go index 4b68d8c..2b3d599 100644 --- a/pkg/provider/sso_test.go +++ b/pkg/provider/sso_test.go @@ -25,7 +25,7 @@ func TestSSO_getAcsUrlAndBindingForResponse(t *testing.T) { binding string } type args struct { - sp *serviceprovider.ServiceProvider + acs []md.IndexedEndpointType requestBinding string } tests := []struct { @@ -35,15 +35,9 @@ func TestSSO_getAcsUrlAndBindingForResponse(t *testing.T) { }{{ "sp with post and redirect, default used", args{ - &serviceprovider.ServiceProvider{ - Metadata: &md.EntityDescriptorType{ - SPSSODescriptor: &md.SPSSODescriptorType{ - AssertionConsumerService: []md.IndexedEndpointType{ - {Index: "1", IsDefault: "true", Binding: RedirectBinding, Location: "redirect"}, - {Index: "2", Binding: PostBinding, Location: "post"}, - }, - }, - }, + []md.IndexedEndpointType{ + {Index: "1", IsDefault: "true", Binding: RedirectBinding, Location: "redirect"}, + {Index: "2", Binding: PostBinding, Location: "post"}, }, RedirectBinding, }, @@ -55,15 +49,9 @@ func TestSSO_getAcsUrlAndBindingForResponse(t *testing.T) { { "sp with post and redirect, first index used", args{ - &serviceprovider.ServiceProvider{ - Metadata: &md.EntityDescriptorType{ - SPSSODescriptor: &md.SPSSODescriptorType{ - AssertionConsumerService: []md.IndexedEndpointType{ - {Index: "1", Binding: RedirectBinding, Location: "redirect"}, - {Index: "2", Binding: PostBinding, Location: "post"}, - }, - }, - }, + []md.IndexedEndpointType{ + {Index: "1", Binding: RedirectBinding, Location: "redirect"}, + {Index: "2", Binding: PostBinding, Location: "post"}, }, RedirectBinding, }, @@ -75,15 +63,9 @@ func TestSSO_getAcsUrlAndBindingForResponse(t *testing.T) { { "sp with post and redirect, redirect used", args{ - &serviceprovider.ServiceProvider{ - Metadata: &md.EntityDescriptorType{ - SPSSODescriptor: &md.SPSSODescriptorType{ - AssertionConsumerService: []md.IndexedEndpointType{ - {Binding: RedirectBinding, Location: "redirect"}, - {Binding: PostBinding, Location: "post"}, - }, - }, - }, + []md.IndexedEndpointType{ + {Binding: RedirectBinding, Location: "redirect"}, + {Binding: PostBinding, Location: "post"}, }, RedirectBinding, }, @@ -95,15 +77,9 @@ func TestSSO_getAcsUrlAndBindingForResponse(t *testing.T) { { "sp with post and redirect, post used", args{ - &serviceprovider.ServiceProvider{ - Metadata: &md.EntityDescriptorType{ - SPSSODescriptor: &md.SPSSODescriptorType{ - AssertionConsumerService: []md.IndexedEndpointType{ - {Binding: RedirectBinding, Location: "redirect"}, - {Binding: PostBinding, Location: "post"}, - }, - }, - }, + []md.IndexedEndpointType{ + {Binding: RedirectBinding, Location: "redirect"}, + {Binding: PostBinding, Location: "post"}, }, PostBinding, }, @@ -114,16 +90,9 @@ func TestSSO_getAcsUrlAndBindingForResponse(t *testing.T) { }, { "sp with redirect, post used", - args{ - &serviceprovider.ServiceProvider{ - Metadata: &md.EntityDescriptorType{ - SPSSODescriptor: &md.SPSSODescriptorType{ - AssertionConsumerService: []md.IndexedEndpointType{ - {Binding: RedirectBinding, Location: "redirect"}, - }, - }, - }, - }, + args{[]md.IndexedEndpointType{ + {Binding: RedirectBinding, Location: "redirect"}, + }, PostBinding, }, res{ @@ -134,14 +103,8 @@ func TestSSO_getAcsUrlAndBindingForResponse(t *testing.T) { { "sp with post, redirect used", args{ - &serviceprovider.ServiceProvider{ - Metadata: &md.EntityDescriptorType{ - SPSSODescriptor: &md.SPSSODescriptorType{ - AssertionConsumerService: []md.IndexedEndpointType{ - {Binding: PostBinding, Location: "post"}, - }, - }, - }, + []md.IndexedEndpointType{ + {Binding: PostBinding, Location: "post"}, }, RedirectBinding, }, @@ -153,9 +116,9 @@ func TestSSO_getAcsUrlAndBindingForResponse(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - acs, binding := getAcsUrlAndBindingForResponse(tt.args.sp, tt.args.requestBinding) + acs, binding := GetAcsUrlAndBindingForResponse(tt.args.acs, tt.args.requestBinding) if acs != tt.res.acs && binding != tt.res.binding { - t.Errorf("getAcsUrlAndBindingForResponse() got = %v/%v, want %v/%v", acs, binding, tt.res.acs, tt.res.binding) + t.Errorf("GetAcsUrlAndBindingForResponse() got = %v/%v, want %v/%v", acs, binding, tt.res.acs, tt.res.binding) return } }) @@ -447,9 +410,11 @@ func TestSSO_ssoHandleFunc(t *testing.T) { Binding string } type res struct { - code int - err bool - state string + code int + err bool + state string + inflate bool + b64 bool } type sp struct { entityID string @@ -628,39 +593,6 @@ func TestSSO_ssoHandleFunc(t *testing.T) { state: StatusCodeRequestDenied, err: false, }}, - { - "redirect request unknown service provider", - args{ - issuer: "http://localhost:50002", - metadataEndpoint: "/saml/metadata", - config: &IdentityProviderConfig{ - SignatureAlgorithm: dsig.RSASHA256SignatureMethod, - MetadataIDPConfig: &MetadataIDPConfig{}, - Endpoints: &EndpointConfig{ - SingleSignOn: getEndpointPointer("/saml/SSO", "http://localhost:50002/saml/SSO"), - }, - }, - certificate: "-----BEGIN CERTIFICATE-----\nMIICvDCCAaQCCQD6E8ZGsQ2usjANBgkqhkiG9w0BAQsFADAgMR4wHAYDVQQDDBVt\neXNlcnZpY2UuZXhhbXBsZS5jb20wHhcNMjIwMjE3MTQwNjM5WhcNMjMwMjE3MTQw\nNjM5WjAgMR4wHAYDVQQDDBVteXNlcnZpY2UuZXhhbXBsZS5jb20wggEiMA0GCSqG\nSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC7XKdCRxUZXjdqVqwwwOJqc1Ch0nOSmk+U\nerkUqlviWHdeLR+FolHKjqLzCBloAz4xVc0DFfR76gWcWAHJloqZ7GBS7NpDhzV8\nG+cXQ+bTU0Lu2e73zCQb30XUdKhWiGfDKaU+1xg9CD/2gIfsYPs3TTq1sq7oCs5q\nLdUHaVL5kcRaHKdnTi7cs5i9xzs3TsUnXcrJPwydjp+aEkyRh07oMpXBEobGisfF\n2p1MA6pVW2gjmywf7D5iYEFELQhM7poqPN3/kfBvU1n7Lfgq7oxmv/8LFi4Zopr5\nnyqsz26XPtUy1WqTzgznAmP+nN0oBTERFVbXXdRa3k2v4cxTNPn/AgMBAAEwDQYJ\nKoZIhvcNAQELBQADggEBAJYxROWSOZbOzXzafdGjQKsMgN948G/hHwVuZneyAcVo\nLMFTs1Weya9Z+snMp1u0AdDGmQTS9zGnD7syDYGOmgigOLcMvLMoWf5tCQBbEukW\n8O7DPjRR0XypChGSsHsqLGO0B0HaTel0HdP9Si827OCkc9Q+WbsFG/8/4ToGWL+u\nla1WuLawozoj8umPi9D8iXCoW35y2STU+WFQG7W+Kfdu+2CYz/0tGdwVqNG4Wsfa\nwWchrS00vGFKjm/fJc876gAfxiMH1I9fZvYSAxAZ3sVI//Ml2sUdgf067ywQ75oa\nLSS2NImmz5aos3vuWmOXhILd7iTU+BD8Uv6vWbI7I1M=\n-----END CERTIFICATE-----\n", - key: "-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC7XKdCRxUZXjdq\nVqwwwOJqc1Ch0nOSmk+UerkUqlviWHdeLR+FolHKjqLzCBloAz4xVc0DFfR76gWc\nWAHJloqZ7GBS7NpDhzV8G+cXQ+bTU0Lu2e73zCQb30XUdKhWiGfDKaU+1xg9CD/2\ngIfsYPs3TTq1sq7oCs5qLdUHaVL5kcRaHKdnTi7cs5i9xzs3TsUnXcrJPwydjp+a\nEkyRh07oMpXBEobGisfF2p1MA6pVW2gjmywf7D5iYEFELQhM7poqPN3/kfBvU1n7\nLfgq7oxmv/8LFi4Zopr5nyqsz26XPtUy1WqTzgznAmP+nN0oBTERFVbXXdRa3k2v\n4cxTNPn/AgMBAAECggEAF+rV9yH30Ysza8GwrXCR9qDN1Dp3QmmsavnXkonEvPoq\nEr2T3o0//6mBp6CLDboMQGQBjblJwl+3Y6PgZolvHAMOsMdHfYNPEo7FSzUBzEw+\nqRrs5HkMyvoPgfV6X8F97W3tiD4Q/AmHkMILl+MxbnfPXM54gWqPuwIqxY1uaCk5\nREwyb7WBon3rd58ceOI1SLRjod6SbqWBMMSN3cJ+5VEPObFjw/RlhNQ5rBI8G5Kt\nso2zBU5C4BB2CvqlWy98WDKJkTvWHbiTjZCy8BQ+gQ6UJM2vaNELFOVpuMGQnMIi\noWiX10Jg2e1gP9j3TdrohlGF8M3+TXjSFKNmeX0DUQKBgQDx7UazUWS5RtkgnjH9\nw2xH2xkstJVD7nAS8VTxNwcrgjVXPvTJha9El904obUjyRX7ppb02tuH5ML/bZh6\n9lL4bP5+SHcJ10e4q8CK/KAGHD6BYAbaGXRq0CoSk5a3vv5XPdob4T5qKCIHFpnu\nMfbvdbEoameLOyRYOGu/yVZIiwKBgQDGQs7FRTisHV0xooiRmlvYF0dcd19qpLed\nqhgJNqBPOTEvvGvJNRoi39haEY3cuTqsxZ5FAlFlVFMUUozz+d0xBLLInoVY/Y4h\nhSdGmdw/A6oHodLqyEp3N5RZNdLlh8/nDS3xXzMotAl75bW5kc2ttcRhRdtyNJ9Z\nup0PgppO3QKBgEC45upAQz8iCiKkz+EA8C4FGqYQJcLHvmoC8GOcAioMqrKNoDVt\ns2cZbdChynEpcd0iQ058YrDnbZeiPWHgFnBp0Gf+gQI7+u8X2+oTDci0s7Au/YZJ\nuxB8YlUX8QF1clvqqzg8OVNzKy9UR5gm+9YyWVPjq5HfH6kOZx0nAxNjAoGAERt8\nqgsCC9/wxbKnpCC0oh3IG5N1WUdjTKh7sHfVN2DQ/LR+fHsniTDVg1gWbKBTDsty\nj7PWgC7ZiFxjKz45NtyX7LW4/efLFttdezsVhR500nnFMFseCdFy7Iu3afThHKfH\nehdj27RFSTqWBrAtFjsj+dzERcOCqIRwvwDe/cUCgYEA5+1mzVXDVjKsWylKJPk+\nZZA4LUfvmTj3VLNDZrlSAI/xEikCFio0QWEA2TQYTAwbXTrKwQSeHQRhv7OTc1h+\nMhpAgvs189ze5J4jiNmULEkkrO+Cxxnw8tyV+UFRZtzW9gUoVBwXiZ/Wbl9sfnlO\nwLJHc0j6OltPcPJmxHP8gQI=\n-----END PRIVATE KEY-----\n", - request: request{ - ID: "test", - Binding: RedirectBinding, - SAMLRequest: url.QueryEscape("nJJBj9MwEIX/ijX3NG6a7DbWJlLZClFpYatN4cBt6k6oJccungmw/x61XaQioRy42vP5ved5D4yDP5nVKMfwQt9HYlG/Bh/YnC8aGFMwEdmxCTgQG7GmW318MsVMG2SmJC4GuEFO08wpRYk2elCbdQPukFlNd/c9LQpczPve6r3taVHWdbWoal3bfr7c03JJc1BfKLGLoYFipkFtmEfaBBYM0kChiyLTZVbc7XRtyntTVrOyrr6CWhOLCygX8ihyMnnuo0V/jCym0loX+dl33nXPoFZ/Ij3GwONAqaP0w1n6/PL0D3qptb7CaBnU9i3bOxcOLnyb/oj9dYjNh91um22fux20l2WYS7Kk3sc0oEw/cj5xh6y/jBoK4uQV2gmfAwkeUPAhv5Fq30rwCQfarLfRO/v6H/KSMLCjIKBW3sefj4lQqAFJI0HeXiX/rlr7OwAA//8="), - RelayState: url.QueryEscape("K6LS7mdqUO4SGedbfa8nBIyX-7K8gGbrHMqIMwVn6zCKLLoADHjEHUAm"), - Signature: url.QueryEscape("PWZ6JPNpAGE7mYLKD3dCUG9AZcThrMRQGtvdv31ewx3hms5Oglc677iAUEcbIBrvKtMrCPVwXPNxT6wQ0rg4qIgyKgoyS53ZTaxaFHPrB7wkkzqtK7GvWgdEqceT8iooK5SCLHFMJ3m30LqEbX7zFw62yE34+e7ypfZSM5Lrf0QFwPzX+LNCuYA+Ob9D5SKc132tn21J2vBRmNJ1zCY0ksRzQfyfErjAzcGVx8qK9jpaeyvsVBZSkH/I6+1hb8lQWE48xala9NbqfbMATGBCQj1UvpVMMfp6PE7KPk5Y1YDeSqPeRIEKH+Gnip6Hve5Ji1aiRp5bytVf1VHwTHSq8w=="), - SigAlg: url.QueryEscape("http://www.w3.org/2000/09/xmldsig#rsa-sha1"), - }, - sp: sp{ - entityID: "http://localhost:8000/saml/metadata", - metadata: "\n \n \n \n \n MIICvDCCAaQCCQD6E8ZGsQ2usjANBgkqhkiG9w0BAQsFADAgMR4wHAYDVQQDDBVteXNlcnZpY2UuZXhhbXBsZS5jb20wHhcNMjIwMjE3MTQwNjM5WhcNMjMwMjE3MTQwNjM5WjAgMR4wHAYDVQQDDBVteXNlcnZpY2UuZXhhbXBsZS5jb20wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC7XKdCRxUZXjdqVqwwwOJqc1Ch0nOSmk+UerkUqlviWHdeLR+FolHKjqLzCBloAz4xVc0DFfR76gWcWAHJloqZ7GBS7NpDhzV8G+cXQ+bTU0Lu2e73zCQb30XUdKhWiGfDKaU+1xg9CD/2gIfsYPs3TTq1sq7oCs5qLdUHaVL5kcRaHKdnTi7cs5i9xzs3TsUnXcrJPwydjp+aEkyRh07oMpXBEobGisfF2p1MA6pVW2gjmywf7D5iYEFELQhM7poqPN3/kfBvU1n7Lfgq7oxmv/8LFi4Zopr5nyqsz26XPtUy1WqTzgznAmP+nN0oBTERFVbXXdRa3k2v4cxTNPn/AgMBAAEwDQYJKoZIhvcNAQELBQADggEBAJYxROWSOZbOzXzafdGjQKsMgN948G/hHwVuZneyAcVoLMFTs1Weya9Z+snMp1u0AdDGmQTS9zGnD7syDYGOmgigOLcMvLMoWf5tCQBbEukW8O7DPjRR0XypChGSsHsqLGO0B0HaTel0HdP9Si827OCkc9Q+WbsFG/8/4ToGWL+ula1WuLawozoj8umPi9D8iXCoW35y2STU+WFQG7W+Kfdu+2CYz/0tGdwVqNG4WsfawWchrS00vGFKjm/fJc876gAfxiMH1I9fZvYSAxAZ3sVI//Ml2sUdgf067ywQ75oaLSS2NImmz5aos3vuWmOXhILd7iTU+BD8Uv6vWbI7I1M=\n \n \n \n \n \n \n \n \n \n \n MIICvDCCAaQCCQD6E8ZGsQ2usjANBgkqhkiG9w0BAQsFADAgMR4wHAYDVQQDDBVteXNlcnZpY2UuZXhhbXBsZS5jb20wHhcNMjIwMjE3MTQwNjM5WhcNMjMwMjE3MTQwNjM5WjAgMR4wHAYDVQQDDBVteXNlcnZpY2UuZXhhbXBsZS5jb20wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC7XKdCRxUZXjdqVqwwwOJqc1Ch0nOSmk+UerkUqlviWHdeLR+FolHKjqLzCBloAz4xVc0DFfR76gWcWAHJloqZ7GBS7NpDhzV8G+cXQ+bTU0Lu2e73zCQb30XUdKhWiGfDKaU+1xg9CD/2gIfsYPs3TTq1sq7oCs5qLdUHaVL5kcRaHKdnTi7cs5i9xzs3TsUnXcrJPwydjp+aEkyRh07oMpXBEobGisfF2p1MA6pVW2gjmywf7D5iYEFELQhM7poqPN3/kfBvU1n7Lfgq7oxmv/8LFi4Zopr5nyqsz26XPtUy1WqTzgznAmP+nN0oBTERFVbXXdRa3k2v4cxTNPn/AgMBAAEwDQYJKoZIhvcNAQELBQADggEBAJYxROWSOZbOzXzafdGjQKsMgN948G/hHwVuZneyAcVoLMFTs1Weya9Z+snMp1u0AdDGmQTS9zGnD7syDYGOmgigOLcMvLMoWf5tCQBbEukW8O7DPjRR0XypChGSsHsqLGO0B0HaTel0HdP9Si827OCkc9Q+WbsFG/8/4ToGWL+ula1WuLawozoj8umPi9D8iXCoW35y2STU+WFQG7W+Kfdu+2CYz/0tGdwVqNG4WsfawWchrS00vGFKjm/fJc876gAfxiMH1I9fZvYSAxAZ3sVI//Ml2sUdgf067ywQ75oaLSS2NImmz5aos3vuWmOXhILd7iTU+BD8Uv6vWbI7I1M=\n \n \n \n \n \n \n \n", - err: fmt.Errorf("unknown"), - }, - }, - res{ - code: 200, - state: StatusCodeRequestDenied, - err: false, - }}, { "signed post request", args{ @@ -693,7 +625,6 @@ func TestSSO_ssoHandleFunc(t *testing.T) { err: false, }}, } - for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { endpoint := NewEndpoint(tt.args.metadataEndpoint) @@ -766,7 +697,7 @@ func TestSSO_ssoHandleFunc(t *testing.T) { } if tt.res.state != "" { - if err := parseForState(string(response), tt.res.state); err != nil { + if err := parseForState(tt.res.inflate, tt.res.b64, string(response), tt.res.state); err != nil { t.Errorf("ssoHandleFunc() response state not: %v", tt.res.state) return } @@ -775,8 +706,13 @@ func TestSSO_ssoHandleFunc(t *testing.T) { } } -func parseForState(responseXML string, state string) error { - response, err := xml.DecodeResponse("", responseXML) +func parseForState(inflate bool, b64 bool, responseXML string, state string) error { + encoding := "" + if inflate { + encoding = xml.EncodingDeflate + } + + response, err := xml.DecodeResponse(encoding, b64, responseXML) if err != nil { return err } diff --git a/pkg/provider/xml/xml.go b/pkg/provider/xml/xml.go index ad4e976..f717b5a 100644 --- a/pkg/provider/xml/xml.go +++ b/pkg/provider/xml/xml.go @@ -7,6 +7,7 @@ import ( "encoding/base64" "encoding/xml" "fmt" + "io" "net/http" "strings" @@ -43,25 +44,20 @@ func Marshal(data interface{}) ([]byte, error) { } func DeflateAndBase64(data []byte) ([]byte, error) { - b := &bytes.Buffer{} - w1 := base64.NewEncoder(base64.StdEncoding, b) - defer w1.Close() - - w2, _ := flate.NewWriter(w1, 1) - defer w2.Close() - - bw := bufio.NewWriter(w1) - if _, err := bw.Write(data); err != nil { + buff := &bytes.Buffer{} + b64Encoder := base64.NewEncoder(base64.StdEncoding, buff) + // compression level is set at 9 as BestCompression, also used by other SAML application like crewjam/saml + flateWriter, _ := flate.NewWriter(b64Encoder, 9) + if _, err := flateWriter.Write(data); err != nil { return nil, err } - if err := bw.Flush(); err != nil { + if err := flateWriter.Close(); err != nil { return nil, err } - - if err := w2.Flush(); err != nil { + if err := b64Encoder.Close(); err != nil { return nil, err } - return b.Bytes(), nil + return buff.Bytes(), nil } func WriteXMLMarshalled(w http.ResponseWriter, body interface{}) error { @@ -86,53 +82,26 @@ func Write(w http.ResponseWriter, body []byte) error { } func DecodeAuthNRequest(encoding string, message string) (*samlp.AuthnRequestType, error) { - reqBytes, err := base64.StdEncoding.DecodeString(message) + data, err := InflateAndDecode(encoding, true, message) if err != nil { - return nil, fmt.Errorf("failed to base64 decode: %w", err) + return nil, err } - req := &samlp.AuthnRequestType{} - switch encoding { - case EncodingDeflate: - reader := flate.NewReader(bytes.NewReader(reqBytes)) - decoder := xml.NewDecoder(reader) - if err = decoder.Decode(req); err != nil { - return nil, fmt.Errorf("failed to defalte decode: %w", err) - } - default: - reader := flate.NewReader(bytes.NewReader(reqBytes)) - decoder := xml.NewDecoder(reader) - if err = decoder.Decode(req); err != nil { - if err := xml.Unmarshal(reqBytes, req); err != nil { - return nil, fmt.Errorf("failed to unmarshal: %w", err) - } - } + if err := xml.Unmarshal(data, req); err != nil { + return nil, err } - return req, nil } -func DecodeSignature(encoding string, message string) (*xml_dsig.SignatureType, error) { - retBytes := []byte(message) - +func DecodeSignature(encoding string, b64 bool, message string) (*xml_dsig.SignatureType, error) { + data, err := InflateAndDecode(encoding, b64, message) + if err != nil { + return nil, err + } ret := &xml_dsig.SignatureType{} - switch encoding { - case EncodingDeflate: - reader := flate.NewReader(bytes.NewReader(retBytes)) - decoder := xml.NewDecoder(reader) - if err := decoder.Decode(ret); err != nil { - return nil, fmt.Errorf("failed to defalte decode: %w", err) - } - default: - reader := flate.NewReader(bytes.NewReader(retBytes)) - decoder := xml.NewDecoder(reader) - if err := decoder.Decode(ret); err != nil { - if err := xml.Unmarshal(retBytes, ret); err != nil { - return nil, fmt.Errorf("failed to unmarshal: %w", err) - } - } + if err := xml.Unmarshal(data, ret); err != nil { + return nil, err } - return ret, nil } @@ -148,50 +117,45 @@ func DecodeAttributeQuery(request string) (*samlp.AttributeQueryType, error) { } func DecodeLogoutRequest(encoding string, message string) (*samlp.LogoutRequestType, error) { - reqBytes, err := base64.StdEncoding.DecodeString(message) + data, err := InflateAndDecode(encoding, true, message) if err != nil { return nil, err } - req := &samlp.LogoutRequestType{} - switch encoding { - case "": - reader := flate.NewReader(bytes.NewReader(reqBytes)) - decoder := xml.NewDecoder(reader) - if err = decoder.Decode(req); err != nil { - return nil, err - } - case EncodingDeflate: - reader := flate.NewReader(bytes.NewReader(reqBytes)) - decoder := xml.NewDecoder(reader) - if err = decoder.Decode(req); err != nil { - return nil, err - } - default: - return nil, fmt.Errorf("unknown encoding") + if err := xml.Unmarshal(data, req); err != nil { + return nil, err } - return req, nil } -func DecodeResponse(encoding string, message string) (*samlp.ResponseType, error) { - +func DecodeResponse(encoding string, b64 bool, message string) (*samlp.ResponseType, error) { + data, err := InflateAndDecode(encoding, b64, message) + if err != nil { + return nil, err + } req := &samlp.ResponseType{} - switch encoding { - case "": - decoder := xml.NewDecoder(bytes.NewReader([]byte(message))) - if err := decoder.Decode(req); err != nil { + if err := xml.Unmarshal(data, req); err != nil { + return nil, err + } + return req, nil +} + +func InflateAndDecode(encoding string, b64 bool, message string) (_ []byte, err error) { + data := []byte(message) + if b64 { + data, err = base64.StdEncoding.DecodeString(message) + if err != nil { return nil, err } + } + switch encoding { + case "": + return data, nil case EncodingDeflate: - reader := flate.NewReader(bytes.NewReader([]byte(message))) - decoder := xml.NewDecoder(reader) - if err := decoder.Decode(req); err != nil { - return nil, err - } + r := flate.NewReader(bytes.NewBuffer(data)) + defer r.Close() + return io.ReadAll(r) default: return nil, fmt.Errorf("unknown encoding") } - - return req, nil }