From 95c785b0bc2fec76c2901bbded1e685d5a6804fb Mon Sep 17 00:00:00 2001
From: Stefan Benz <46600784+stebenz@users.noreply.github.com>
Date: Wed, 11 Dec 2024 11:18:03 +0100
Subject: [PATCH] fix: corrected inflate and deflate logic for xml, and
corresponding tests with some refactoring for SAML sessions (#92)
* refactor: remove separate functions for error reasons
* fix: corrected inflate and deflate logic, and corresponding tests
* fix: corrected inflate and deflate logic, and corresponding tests
* fix: corrected inflate and deflate logic, and corresponding tests
* fix: switch over status codes to errors
* fix: review changes
* fix: review changes
---
pkg/provider/attribute_query.go | 8 +-
pkg/provider/identityprovider.go | 8 +-
pkg/provider/login.go | 59 ++++++++------
pkg/provider/login_test.go | 33 ++++----
pkg/provider/logout.go | 12 +--
pkg/provider/logout_response.go | 39 ++--------
pkg/provider/post.go | 14 ++--
pkg/provider/provider.go | 30 ++++++-
pkg/provider/redirect.go | 36 ++++-----
pkg/provider/redirect_test.go | 2 +-
pkg/provider/response.go | 101 +++++++-----------------
pkg/provider/response_test.go | 4 +-
pkg/provider/sso.go | 41 +++++-----
pkg/provider/sso_test.go | 130 ++++++++-----------------------
pkg/provider/xml/xml.go | 128 +++++++++++-------------------
15 files changed, 255 insertions(+), 390 deletions(-)
diff --git a/pkg/provider/attribute_query.go b/pkg/provider/attribute_query.go
index 00ef547..a17f842 100644
--- a/pkg/provider/attribute_query.go
+++ b/pkg/provider/attribute_query.go
@@ -128,7 +128,7 @@ func (p *IdentityProvider) attributeQueryHandleFunc(w http.ResponseWriter, r *ht
queriedAttrs = append(queriedAttrs, queriedAttr)
}
}
- response = makeAttributeQueryResponse(attrQuery.Id, p.GetEntityID(r.Context()), sp.GetEntityID(), attrs, queriedAttrs, p.timeFormat)
+ response = makeAttributeQueryResponse(attrQuery.Id, p.GetEntityID(r.Context()), sp.GetEntityID(), attrs, queriedAttrs, p.TimeFormat, p.Expiration)
return nil
},
func() {
@@ -139,7 +139,11 @@ func (p *IdentityProvider) attributeQueryHandleFunc(w http.ResponseWriter, r *ht
// create enveloped signature
checkerInstance.WithLogicStep(
func() error {
- return createPostSignature(r.Context(), response, p)
+ cert, key, err := getResponseCert(r.Context(), p.storage)
+ if err != nil {
+ return err
+ }
+ return createPostSignature(response, key, cert, p.conf.SignatureAlgorithm)
},
func() {
http.Error(w, fmt.Errorf("failed to sign response: %w", err).Error(), http.StatusInternalServerError)
diff --git a/pkg/provider/identityprovider.go b/pkg/provider/identityprovider.go
index 8d87057..002fe60 100644
--- a/pkg/provider/identityprovider.go
+++ b/pkg/provider/identityprovider.go
@@ -71,7 +71,8 @@ type IdentityProvider struct {
metadataEndpoint *Endpoint
endpoints *Endpoints
- timeFormat string
+ TimeFormat string
+ Expiration time.Duration
}
type Endpoints struct {
@@ -90,7 +91,8 @@ func NewIdentityProvider(metadata Endpoint, conf *IdentityProviderConfig, storag
postTemplate: conf.PostTemplate,
logoutTemplate: conf.LogoutTemplate,
endpoints: endpointConfigToEndpoints(conf.Endpoints),
- timeFormat: DefaultTimeFormat,
+ TimeFormat: DefaultTimeFormat,
+ Expiration: DefaultExpiration,
}
if conf.PostTemplate == nil {
@@ -160,7 +162,7 @@ func (p *IdentityProvider) GetMetadata(ctx context.Context) (*md.IDPSSODescripto
return nil, nil, err
}
- metadata, aaMetadata := p.conf.getMetadata(ctx, p.GetEntityID(ctx), cert, p.timeFormat)
+ metadata, aaMetadata := p.conf.getMetadata(ctx, p.GetEntityID(ctx), cert, p.TimeFormat)
return metadata, aaMetadata, nil
}
diff --git a/pkg/provider/login.go b/pkg/provider/login.go
index 4b83a8a..435f951 100644
--- a/pkg/provider/login.go
+++ b/pkg/provider/login.go
@@ -1,10 +1,14 @@
package provider
import (
+ "context"
"fmt"
"net/http"
"github.com/zitadel/logging"
+
+ "github.com/zitadel/saml/pkg/provider/models"
+ "github.com/zitadel/saml/pkg/provider/xml/samlp"
)
func (p *IdentityProvider) callbackHandleFunc(w http.ResponseWriter, r *http.Request) {
@@ -16,7 +20,6 @@ func (p *IdentityProvider) callbackHandleFunc(w http.ResponseWriter, r *http.Req
Issuer: p.GetEntityID(r.Context()),
}
- ctx := r.Context()
if err := r.ParseForm(); err != nil {
logging.Error(err)
http.Error(w, fmt.Errorf("failed to parse form: %w", err).Error(), http.StatusInternalServerError)
@@ -34,7 +37,7 @@ func (p *IdentityProvider) callbackHandleFunc(w http.ResponseWriter, r *http.Req
authRequest, err := p.storage.AuthRequestByID(r.Context(), requestID)
if err != nil {
logging.Error(err)
- response.sendBackResponse(r, w, response.makeDeniedResponse(fmt.Errorf("failed to get request: %w", err).Error(), p.timeFormat))
+ response.sendBackResponse(r, w, p.errorResponse(response, StatusCodeRequestDenied, fmt.Errorf("failed to get request: %w", err).Error()))
return
}
response.RequestID = authRequest.GetAuthRequestID()
@@ -42,44 +45,50 @@ func (p *IdentityProvider) callbackHandleFunc(w http.ResponseWriter, r *http.Req
response.ProtocolBinding = authRequest.GetBindingType()
response.AcsUrl = authRequest.GetAccessConsumerServiceURL()
- if !authRequest.Done() {
+ entityID, err := p.storage.GetEntityIDByAppID(r.Context(), authRequest.GetApplicationID())
+ if err != nil {
logging.Error(err)
http.Error(w, fmt.Errorf("failed to get entityID: %w", err).Error(), http.StatusInternalServerError)
return
}
+ response.Audience = entityID
- entityID, err := p.storage.GetEntityIDByAppID(r.Context(), authRequest.GetApplicationID())
+ samlResponse, err := p.loginResponse(r.Context(), authRequest, response)
if err != nil {
- logging.Error(err)
- http.Error(w, fmt.Errorf("failed to get entityID: %w", err).Error(), http.StatusInternalServerError)
+ response.sendBackResponse(r, w, response.makeFailedResponse(err.Error(), "failed to create response", p.TimeFormat))
return
}
- response.Audience = entityID
+
+ response.sendBackResponse(r, w, samlResponse)
+ return
+}
+
+func (p *IdentityProvider) loginResponse(ctx context.Context, authRequest models.AuthRequestInt, response *Response) (*samlp.ResponseType, error) {
+ if !authRequest.Done() {
+ logging.Error(StatusCodeAuthNFailed)
+ return nil, fmt.Errorf(StatusCodeAuthNFailed)
+ }
attrs := &Attributes{}
if err := p.storage.SetUserinfoWithUserID(ctx, authRequest.GetApplicationID(), attrs, authRequest.GetUserID(), []int{}); err != nil {
logging.Error(err)
- http.Error(w, fmt.Errorf("failed to get userinfo: %w", err).Error(), http.StatusInternalServerError)
- return
+ return nil, fmt.Errorf(StatusCodeInvalidAttrNameOrValue)
}
- samlResponse := response.makeSuccessfulResponse(attrs, p.timeFormat)
+ cert, key, err := getResponseCert(ctx, p.storage)
+ if err != nil {
+ logging.Error(err)
+ return nil, fmt.Errorf(StatusCodeInvalidAttrNameOrValue)
+ }
- switch response.ProtocolBinding {
- case PostBinding:
- if err := createPostSignature(r.Context(), samlResponse, p); err != nil {
- logging.Error(err)
- response.sendBackResponse(r, w, response.makeResponderFailResponse(fmt.Errorf("failed to sign response: %w", err).Error(), p.timeFormat))
- return
- }
- case RedirectBinding:
- if err := createRedirectSignature(r.Context(), samlResponse, p, response); err != nil {
- logging.Error(err)
- response.sendBackResponse(r, w, response.makeResponderFailResponse(fmt.Errorf("failed to sign response: %w", err).Error(), p.timeFormat))
- return
- }
+ samlResponse := response.makeSuccessfulResponse(attrs, p.TimeFormat, p.Expiration)
+ if err := createSignature(response, samlResponse, key, cert, p.conf.SignatureAlgorithm); err != nil {
+ logging.Error(err)
+ return nil, fmt.Errorf(StatusCodeResponder)
}
+ return samlResponse, nil
+}
- response.sendBackResponse(r, w, samlResponse)
- return
+func (p *IdentityProvider) errorResponse(response *Response, reason string, description string) *samlp.ResponseType {
+ return response.makeFailedResponse(reason, description, p.TimeFormat)
}
diff --git a/pkg/provider/login_test.go b/pkg/provider/login_test.go
index 1bf9e81..c188ec4 100644
--- a/pkg/provider/login_test.go
+++ b/pkg/provider/login_test.go
@@ -1,9 +1,9 @@
package provider
import (
- "io/ioutil"
"net/http"
"net/http/httptest"
+ "net/url"
"testing"
"github.com/golang/mock/gomock"
@@ -23,9 +23,11 @@ func TestSSO_loginHandleFunc(t *testing.T) {
Done bool
}
type res struct {
- code int
- err bool
- state string
+ code int
+ err bool
+ state string
+ inflate bool
+ b64 bool
}
type sp struct {
appID string
@@ -235,7 +237,7 @@ func TestSSO_loginHandleFunc(t *testing.T) {
ID: "test",
AuthRequestID: "test",
Binding: RedirectBinding,
- AcsURL: "url",
+ AcsURL: "https://sp.example.com",
RelayState: "relaystate",
UserID: "userid",
Done: false,
@@ -247,9 +249,11 @@ func TestSSO_loginHandleFunc(t *testing.T) {
},
},
res{
- code: 500,
- state: "",
- err: false,
+ code: 302,
+ state: StatusCodeAuthNFailed,
+ err: false,
+ inflate: true,
+ b64: true,
}},
}
@@ -297,14 +301,15 @@ func TestSSO_loginHandleFunc(t *testing.T) {
defer func() {
_ = res.Body.Close()
}()
- response, err := ioutil.ReadAll(res.Body)
- if res.StatusCode != tt.res.code {
- t.Errorf("ssoHandleFunc() code got = %v, want %v", res.StatusCode, tt.res)
- return
- }
+ // currently only checked for redirect binding
if tt.res.state != "" {
- if err := parseForState(string(response), tt.res.state); err != nil {
+ responseURL, err := url.Parse(res.Header.Get("Location"))
+ if err != nil {
+ t.Errorf("error while parsing url")
+ }
+
+ if err := parseForState(tt.res.inflate, tt.res.b64, responseURL.Query().Get("SAMLResponse"), tt.res.state); err != nil {
t.Errorf("ssoHandleFunc() response state not: %v", tt.res.state)
return
}
diff --git a/pkg/provider/logout.go b/pkg/provider/logout.go
index 03eaef9..5c8614b 100644
--- a/pkg/provider/logout.go
+++ b/pkg/provider/logout.go
@@ -44,7 +44,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque
return nil
},
func() {
- response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to parse form: %w", err).Error(), p.timeFormat))
+ response.sendBackLogoutResponse(w, response.makeFailedLogoutResponse(StatusCodeRequestDenied, fmt.Errorf("failed to parse form: %w", err).Error(), p.TimeFormat))
},
)
@@ -60,7 +60,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque
return nil
},
func() {
- response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to decode request: %w", err).Error(), p.timeFormat))
+ response.sendBackLogoutResponse(w, response.makeFailedLogoutResponse(StatusCodeRequestDenied, fmt.Errorf("failed to decode request: %w", err).Error(), p.TimeFormat))
},
)
@@ -69,10 +69,10 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque
checkIfRequestTimeIsStillValid(
func() string { return logoutRequest.IssueInstant },
func() string { return logoutRequest.NotOnOrAfter },
- p.timeFormat,
+ p.TimeFormat,
),
func() {
- response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to validate request: %w", err).Error(), p.timeFormat))
+ response.sendBackLogoutResponse(w, response.makeFailedLogoutResponse(StatusCodeRequestDenied, fmt.Errorf("failed to validate request: %w", err).Error(), p.TimeFormat))
},
)
@@ -83,7 +83,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque
return err
},
func() {
- response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to find registered serviceprovider: %w", err).Error(), p.timeFormat))
+ response.sendBackLogoutResponse(w, response.makeFailedLogoutResponse(StatusCodeRequestDenied, fmt.Errorf("failed to find registered serviceprovider: %w", err).Error(), p.TimeFormat))
},
)
@@ -106,7 +106,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque
response.sendBackLogoutResponse(
w,
- response.makeSuccessfulLogoutResponse(p.timeFormat),
+ response.makeSuccessfulLogoutResponse(p.TimeFormat),
)
logging.Info(fmt.Sprintf("logout request for user %s", logoutRequest.NameID.Text))
}
diff --git a/pkg/provider/logout_response.go b/pkg/provider/logout_response.go
index 63a5765..407ccd8 100644
--- a/pkg/provider/logout_response.go
+++ b/pkg/provider/logout_response.go
@@ -55,32 +55,8 @@ func (r *LogoutResponse) sendBackLogoutResponse(w http.ResponseWriter, resp *sam
}
}
-func (r *LogoutResponse) makeSuccessfulLogoutResponse(timeFormat string) *samlp.LogoutResponseType {
- return makeLogoutResponse(
- r.RequestID,
- r.LogoutURL,
- time.Now().UTC().Format(timeFormat),
- StatusCodeSuccess,
- "",
- getIssuer(r.Issuer),
- )
-}
-
-func (r *LogoutResponse) makeUnsupportedlLogoutResponse(
- message string,
- timeFormat string,
-) *samlp.LogoutResponseType {
- return makeLogoutResponse(
- r.RequestID,
- r.LogoutURL,
- time.Now().UTC().Format(timeFormat),
- StatusCodeRequestUnsupported,
- message,
- getIssuer(r.Issuer),
- )
-}
-
-func (r *LogoutResponse) makePartialLogoutResponse(
+func (r *LogoutResponse) makeFailedLogoutResponse(
+ reason string,
message string,
timeFormat string,
) *samlp.LogoutResponseType {
@@ -88,22 +64,19 @@ func (r *LogoutResponse) makePartialLogoutResponse(
r.RequestID,
r.LogoutURL,
time.Now().UTC().Format(timeFormat),
- StatusCodePartialLogout,
+ reason,
message,
getIssuer(r.Issuer),
)
}
-func (r *LogoutResponse) makeDeniedLogoutResponse(
- message string,
- timeFormat string,
-) *samlp.LogoutResponseType {
+func (r *LogoutResponse) makeSuccessfulLogoutResponse(timeFormat string) *samlp.LogoutResponseType {
return makeLogoutResponse(
r.RequestID,
r.LogoutURL,
time.Now().UTC().Format(timeFormat),
- StatusCodeRequestDenied,
- message,
+ StatusCodeSuccess,
+ "",
getIssuer(r.Issuer),
)
}
diff --git a/pkg/provider/post.go b/pkg/provider/post.go
index 82ef22e..38a44dc 100644
--- a/pkg/provider/post.go
+++ b/pkg/provider/post.go
@@ -1,7 +1,7 @@
package provider
import (
- "context"
+ "crypto/rsa"
"encoding/base64"
"reflect"
@@ -63,16 +63,12 @@ func verifyPostSignature(
}
func createPostSignature(
- ctx context.Context,
samlResponse *samlp.ResponseType,
- idp *IdentityProvider,
+ key *rsa.PrivateKey,
+ cert []byte,
+ signatureAlgorithm string,
) error {
- cert, key, err := getResponseCert(ctx, idp.storage)
- if err != nil {
- return err
- }
-
- signer, err := signature.GetSigner(cert, key, idp.conf.SignatureAlgorithm)
+ signer, err := signature.GetSigner(cert, key, signatureAlgorithm)
if err != nil {
return err
}
diff --git a/pkg/provider/provider.go b/pkg/provider/provider.go
index 0577862..df0471d 100644
--- a/pkg/provider/provider.go
+++ b/pkg/provider/provider.go
@@ -5,20 +5,23 @@ import (
"crypto/rsa"
"fmt"
"net/http"
+ "time"
"github.com/google/uuid"
"github.com/gorilla/handlers"
"github.com/gorilla/mux"
+ "github.com/zitadel/saml/pkg/provider/models"
"github.com/zitadel/saml/pkg/provider/signature"
"github.com/zitadel/saml/pkg/provider/xml/md"
+ "github.com/zitadel/saml/pkg/provider/xml/samlp"
)
const (
DefaultTimeFormat = "2006-01-02T15:04:05.999999Z"
+ DefaultExpiration = 5 * time.Minute
PostBinding = "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST"
RedirectBinding = "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"
- SOAPBinding = "urn:oasis:names:tc:SAML:2.0:bindings:SOAP"
DefaultMetadataEndpoint = "/metadata"
)
@@ -80,7 +83,6 @@ type Provider struct {
conf *Config
issuerFromRequest IssuerFromRequest
identityProvider *IdentityProvider
- timeFormat string
}
func NewProvider(
@@ -217,12 +219,32 @@ var allowAllOrigins = func(_ string) bool {
}
// AuthCallbackURL builds the url for the redirect (with the requestID) after a successful login
-func AuthCallbackURL(p *Provider) func(context.Context, string) string {
+func (p *Provider) AuthCallbackURL() func(context.Context, string) string {
return func(ctx context.Context, requestID string) string {
return p.identityProvider.endpoints.callbackEndpoint.Absolute(IssuerFromContext(ctx)) + "?id=" + requestID
}
}
+// AuthCallbackResponse returns the SAMLResponse from as successful SAMLRequest
+func (p *Provider) AuthCallbackResponse(ctx context.Context, authRequest models.AuthRequestInt, response *Response) (*samlp.ResponseType, error) {
+ return p.identityProvider.loginResponse(ctx, authRequest, response)
+}
+
+// AuthCallbackErrorResponse returns the SAMLResponse from as failed SAMLRequest
+func (p *Provider) AuthCallbackErrorResponse(response *Response, reason string, description string) *samlp.ResponseType {
+ return p.identityProvider.errorResponse(response, reason, description)
+}
+
+// Timeformat return the used timeformat in messages
+func (p *Provider) Timeformat() string {
+ return p.identityProvider.TimeFormat
+}
+
+// Expiration return the used expiration in messages
+func (p *Provider) Expiration() time.Duration {
+ return p.identityProvider.Expiration
+}
+
func intercept(i IssuerFromRequest, interceptors ...HttpInterceptor) func(handler http.Handler) http.Handler {
cors := handlers.CORS(
handlers.AllowCredentials(),
@@ -250,7 +272,7 @@ func WithAllowInsecure() Option {
// WithCustomTimeFormat allows the use of a custom timeformat instead of the default
func WithCustomTimeFormat(timeFormat string) Option {
return func(p *Provider) error {
- p.identityProvider.timeFormat = timeFormat
+ p.identityProvider.TimeFormat = timeFormat
return nil
}
}
diff --git a/pkg/provider/redirect.go b/pkg/provider/redirect.go
index 45f76e5..a31bac9 100644
--- a/pkg/provider/redirect.go
+++ b/pkg/provider/redirect.go
@@ -1,7 +1,7 @@
package provider
import (
- "context"
+ "crypto/rsa"
"encoding/base64"
"fmt"
"net/url"
@@ -66,47 +66,41 @@ func verifyRedirectSignature(
}
func createRedirectSignature(
- ctx context.Context,
samlResponse *samlp.ResponseType,
- idp *IdentityProvider,
- response *Response,
-) error {
+ key *rsa.PrivateKey,
+ cert []byte,
+ signatureAlgorithm string,
+ relayState string,
+) (string, string, error) {
resp, err := xml.Marshal(samlResponse)
if err != nil {
- return err
+ return "", "", err
}
respData, err := xml.DeflateAndBase64(resp)
if err != nil {
- return err
- }
-
- cert, key, err := getResponseCert(ctx, idp.storage)
- if err != nil {
- return err
+ return "", "", err
}
tlsCert, err := signature.ParseTlsKeyPair(cert, key)
if err != nil {
- return err
+ return "", "", err
}
- signingContext, err := signature.GetSigningContext(tlsCert, idp.conf.SignatureAlgorithm)
+ signingContext, err := signature.GetSigningContext(tlsCert, signatureAlgorithm)
if err != nil {
- return err
+ return "", "", err
}
- sig, err := signature.CreateRedirect(signingContext, buildRedirectQuery(string(respData), response.RelayState, idp.conf.SignatureAlgorithm, ""))
+ sig, err := signature.CreateRedirect(signingContext, BuildRedirectQuery(string(respData), relayState, signatureAlgorithm, ""))
if err != nil {
- return err
+ return "", "", err
}
- response.Signature = url.QueryEscape(base64.StdEncoding.EncodeToString(sig))
- response.SigAlg = url.QueryEscape(base64.StdEncoding.EncodeToString([]byte(idp.conf.SignatureAlgorithm)))
- return nil
+ return url.QueryEscape(base64.StdEncoding.EncodeToString(sig)), url.QueryEscape(base64.StdEncoding.EncodeToString([]byte(signatureAlgorithm))), nil
}
-func buildRedirectQuery(
+func BuildRedirectQuery(
response string,
relayState string,
sigAlg string,
diff --git a/pkg/provider/redirect_test.go b/pkg/provider/redirect_test.go
index 18adeb7..87c689c 100644
--- a/pkg/provider/redirect_test.go
+++ b/pkg/provider/redirect_test.go
@@ -385,7 +385,7 @@ func TestRedirect_buildRedirectQuery(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- got := buildRedirectQuery(tt.args.response, tt.args.relayState, tt.args.sigAlg, tt.args.sig)
+ got := BuildRedirectQuery(tt.args.response, tt.args.relayState, tt.args.sigAlg, tt.args.sig)
if got != tt.res {
t.Errorf("verifyRedirectSignature() got = %v, want %v", got, tt.res)
return
diff --git a/pkg/provider/response.go b/pkg/provider/response.go
index ddd4346..f868602 100644
--- a/pkg/provider/response.go
+++ b/pkg/provider/response.go
@@ -1,6 +1,7 @@
package provider
import (
+ "crypto/rsa"
"encoding/base64"
"fmt"
"html/template"
@@ -12,7 +13,7 @@ import (
"github.com/zitadel/saml/pkg/provider/xml/samlp"
)
-const (
+var (
StatusCodeSuccess = "urn:oasis:names:tc:SAML:2.0:status:Success"
StatusCodeVersionMissmatch = "urn:oasis:names:tc:SAML:2.0:status:VersionMismatch"
StatusCodeAuthNFailed = "urn:oasis:names:tc:SAML:2.0:status:AuthnFailed"
@@ -40,11 +41,7 @@ type Response struct {
SendIP string
}
-func (r *Response) doResponse(request *http.Request, w http.ResponseWriter, response string) {
-
-}
-
-type AuthResponseForm struct {
+type authResponseForm struct {
RelayState string
SAMLResponse string
AssertionConsumerServiceURL string
@@ -73,7 +70,7 @@ func (r *Response) sendBackResponse(
case PostBinding:
respData := base64.StdEncoding.EncodeToString(respData)
- data := AuthResponseForm{
+ data := authResponseForm{
r.RelayState,
respData,
r.AcsUrl,
@@ -90,76 +87,42 @@ func (r *Response) sendBackResponse(
return
}
- http.Redirect(w, req, fmt.Sprintf("%s?%s", r.AcsUrl, buildRedirectQuery(string(respData), r.RelayState, r.SigAlg, r.Signature)), http.StatusFound)
+ http.Redirect(w, req, fmt.Sprintf("%s?%s", r.AcsUrl, BuildRedirectQuery(string(respData), r.RelayState, r.SigAlg, r.Signature)), http.StatusFound)
return
default:
//TODO: no binding
}
}
-func (r *Response) makeUnsupportedBindingResponse(
- message string,
- timeFormat string,
-) *samlp.ResponseType {
- now := time.Now().UTC()
- nowStr := now.Format(timeFormat)
- return makeResponse(
- NewID(),
- r.RequestID,
- r.AcsUrl,
- nowStr,
- StatusCodeUnsupportedBinding,
- message,
- r.Issuer,
- )
-}
-
-func (r *Response) makeResponderFailResponse(
- message string,
- timeFormat string,
-) *samlp.ResponseType {
- now := time.Now().UTC()
- nowStr := now.Format(timeFormat)
- return makeResponse(
- NewID(),
- r.RequestID,
- r.AcsUrl,
- nowStr,
- StatusCodeResponder,
- message,
- r.Issuer,
- )
-}
-
-func (r *Response) makeDeniedResponse(
- message string,
- timeFormat string,
-) *samlp.ResponseType {
- now := time.Now().UTC()
- nowStr := now.Format(timeFormat)
- return makeResponse(
- NewID(),
- r.RequestID,
- r.AcsUrl,
- nowStr,
- StatusCodeRequestDenied,
- message,
- r.Issuer,
- )
+func createSignature(response *Response, samlResponse *samlp.ResponseType, key *rsa.PrivateKey, cert []byte, signatureAlgorithm string) error {
+ switch response.ProtocolBinding {
+ case PostBinding:
+ if err := createPostSignature(samlResponse, key, cert, signatureAlgorithm); err != nil {
+ return fmt.Errorf("failed to sign response: %w", err)
+ }
+ case RedirectBinding:
+ sig, sigAlg, err := createRedirectSignature(samlResponse, key, cert, signatureAlgorithm, response.RelayState)
+ if err != nil {
+ return fmt.Errorf("failed to sign response: %w", err)
+ }
+ response.Signature = sig
+ response.SigAlg = sigAlg
+ }
+ return nil
}
func (r *Response) makeFailedResponse(
+ reason string,
message string,
timeFormat string,
) *samlp.ResponseType {
now := time.Now().UTC()
- nowStr := now.Format(timeFormat)
return makeResponse(
NewID(),
r.RequestID,
r.AcsUrl,
- nowStr,
- StatusCodeAuthNFailed,
+ now.Format(timeFormat),
+ reason,
message,
r.Issuer,
)
@@ -168,14 +131,12 @@ func (r *Response) makeFailedResponse(
func (r *Response) makeSuccessfulResponse(
attributes *Attributes,
timeFormat string,
+ expiration time.Duration,
) *samlp.ResponseType {
now := time.Now().UTC()
- nowStr := now.Format(timeFormat)
- fiveFromNowStr := now.Add(5 * time.Minute).Format(timeFormat)
-
return r.makeAssertionResponse(
- nowStr,
- fiveFromNowStr,
+ now.Format(timeFormat),
+ now.Add(expiration).Format(timeFormat),
attributes,
)
}
@@ -206,13 +167,9 @@ func makeAttributeQueryResponse(
attributes *Attributes,
queriedAttrs []saml.AttributeType,
timeFormat string,
+ expiration time.Duration,
) *samlp.ResponseType {
now := time.Now().UTC()
- nowStr := now.Format(timeFormat)
- fiveMinutes, _ := time.ParseDuration("5m")
- fiveFromNow := now.Add(fiveMinutes)
- fiveFromNowStr := fiveFromNow.Format(timeFormat)
-
providedAttrs := []*saml.AttributeType{}
attrsSaml := attributes.GetSAML()
if queriedAttrs == nil || len(queriedAttrs) == 0 {
@@ -229,8 +186,8 @@ func makeAttributeQueryResponse(
}
}
- response := makeResponse(NewID(), requestID, "", nowStr, StatusCodeSuccess, "", issuer)
- assertion := makeAssertion(requestID, "", "", nowStr, fiveFromNowStr, issuer, attributes.GetNameID(), providedAttrs, entityID, false)
+ response := makeResponse(NewID(), requestID, "", now.Format(timeFormat), StatusCodeSuccess, "", issuer)
+ assertion := makeAssertion(requestID, "", "", now.Format(timeFormat), now.Add(expiration).Format(timeFormat), issuer, attributes.GetNameID(), providedAttrs, entityID, false)
response.Assertion = *assertion
return response
}
diff --git a/pkg/provider/response_test.go b/pkg/provider/response_test.go
index 693f71c..358cf6a 100644
--- a/pkg/provider/response_test.go
+++ b/pkg/provider/response_test.go
@@ -74,7 +74,7 @@ func TestResponse_sendBackResponse(t *testing.T) {
signature: "sig",
},
res{
- body: []byte("Found.\n\n"),
+ body: []byte("Found.\n\n"),
},
},
{
@@ -94,7 +94,7 @@ func TestResponse_sendBackResponse(t *testing.T) {
signature: "sig",
},
res{
- body: []byte("Found.\n\n"),
+ body: []byte("Found.\n\n"),
},
},
{
diff --git a/pkg/provider/sso.go b/pkg/provider/sso.go
index dafca50..451566b 100644
--- a/pkg/provider/sso.go
+++ b/pkg/provider/sso.go
@@ -61,7 +61,7 @@ func (p *IdentityProvider) ssoHandleFunc(w http.ResponseWriter, r *http.Request)
return nil
},
func() {
- response.sendBackResponse(r, w, response.makeDeniedResponse(fmt.Errorf("failed to parse form").Error(), p.timeFormat))
+ response.sendBackResponse(r, w, response.makeFailedResponse(StatusCodeRequestDenied, fmt.Errorf("failed to parse form").Error(), p.TimeFormat))
},
)
@@ -70,7 +70,7 @@ func (p *IdentityProvider) ssoHandleFunc(w http.ResponseWriter, r *http.Request)
"SAMLRequest",
func() string { return authRequestForm.AuthRequest },
func() {
- response.sendBackResponse(r, w, response.makeDeniedResponse(fmt.Errorf("no auth request provided").Error(), p.timeFormat))
+ response.sendBackResponse(r, w, response.makeFailedResponse(StatusCodeRequestDenied, fmt.Errorf("no auth request provided").Error(), p.TimeFormat))
},
)
@@ -80,7 +80,7 @@ func (p *IdentityProvider) ssoHandleFunc(w http.ResponseWriter, r *http.Request)
"Signature",
func() string { return authRequestForm.Sig },
func() {
- response.sendBackResponse(r, w, response.makeDeniedResponse(fmt.Errorf("signature algorith provided but no signature").Error(), p.timeFormat))
+ response.sendBackResponse(r, w, response.makeFailedResponse(StatusCodeRequestDenied, fmt.Errorf("signature algorith provided but no signature").Error(), p.TimeFormat))
},
)
@@ -95,7 +95,7 @@ func (p *IdentityProvider) ssoHandleFunc(w http.ResponseWriter, r *http.Request)
return nil
},
func() {
- response.sendBackResponse(r, w, response.makeDeniedResponse(fmt.Errorf("failed to decode request").Error(), p.timeFormat))
+ response.sendBackResponse(r, w, response.makeFailedResponse(StatusCodeRequestDenied, fmt.Errorf("failed to decode request").Error(), p.TimeFormat))
},
)
@@ -110,7 +110,7 @@ func (p *IdentityProvider) ssoHandleFunc(w http.ResponseWriter, r *http.Request)
return nil
},
func() {
- response.sendBackResponse(r, w, response.makeDeniedResponse(fmt.Errorf("failed to find registered serviceprovider: %w", err).Error(), p.timeFormat))
+ response.sendBackResponse(r, w, response.makeFailedResponse(StatusCodeRequestDenied, fmt.Errorf("failed to find registered serviceprovider: %w", err).Error(), p.TimeFormat))
},
)
@@ -125,7 +125,7 @@ func (p *IdentityProvider) ssoHandleFunc(w http.ResponseWriter, r *http.Request)
func() *md.EntityDescriptorType { return sp.Metadata },
),
func() {
- response.sendBackResponse(r, w, response.makeDeniedResponse(fmt.Errorf("failed to validate certificate from request: %w", err).Error(), p.timeFormat))
+ response.sendBackResponse(r, w, response.makeFailedResponse(StatusCodeRequestDenied, fmt.Errorf("failed to validate certificate from request: %w", err).Error(), p.TimeFormat))
},
)
@@ -146,7 +146,7 @@ func (p *IdentityProvider) ssoHandleFunc(w http.ResponseWriter, r *http.Request)
func(errF error) { err = errF },
),
func() {
- response.sendBackResponse(r, w, response.makeDeniedResponse(fmt.Errorf("failed to verify signature: %w", err).Error(), p.timeFormat))
+ response.sendBackResponse(r, w, response.makeFailedResponse(StatusCodeRequestDenied, fmt.Errorf("failed to verify signature: %w", err).Error(), p.TimeFormat))
},
)
@@ -164,14 +164,14 @@ func (p *IdentityProvider) ssoHandleFunc(w http.ResponseWriter, r *http.Request)
func(errF error) { err = errF },
),
func() {
- response.sendBackResponse(r, w, response.makeDeniedResponse(fmt.Errorf("failed to verify signature: %w", err).Error(), p.timeFormat))
+ response.sendBackResponse(r, w, response.makeFailedResponse(StatusCodeRequestDenied, fmt.Errorf("failed to verify signature: %w", err).Error(), p.TimeFormat))
},
)
// work out used acs url and protocolbinding for response
checkerInstance.WithValueStep(
func() {
- response.AcsUrl, response.ProtocolBinding = getAcsUrlAndBindingForResponse(sp, authNRequest.ProtocolBinding)
+ response.AcsUrl, response.ProtocolBinding = GetAcsUrlAndBindingForResponse(sp.Metadata.SPSSODescriptor.AssertionConsumerService, authNRequest.ProtocolBinding)
},
)
@@ -180,7 +180,7 @@ func (p *IdentityProvider) ssoHandleFunc(w http.ResponseWriter, r *http.Request)
"acsUrl",
func() string { return response.AcsUrl },
func() {
- response.sendBackResponse(r, w, response.makeUnsupportedBindingResponse(fmt.Errorf("missing usable assertion consumer url").Error(), p.timeFormat))
+ response.sendBackResponse(r, w, response.makeFailedResponse(StatusCodeUnsupportedBinding, fmt.Errorf("missing usable assertion consumer url").Error(), p.TimeFormat))
},
)
@@ -189,7 +189,7 @@ func (p *IdentityProvider) ssoHandleFunc(w http.ResponseWriter, r *http.Request)
"protocol binding",
func() string { return response.ProtocolBinding },
func() {
- response.sendBackResponse(r, w, response.makeUnsupportedBindingResponse(fmt.Errorf("missing usable protocol binding").Error(), p.timeFormat))
+ response.sendBackResponse(r, w, response.makeFailedResponse(StatusCodeUnsupportedBinding, fmt.Errorf("missing usable protocol binding").Error(), p.TimeFormat))
},
)
@@ -200,7 +200,7 @@ func (p *IdentityProvider) ssoHandleFunc(w http.ResponseWriter, r *http.Request)
func() *samlp.AuthnRequestType { return authNRequest },
),
func() {
- response.sendBackResponse(r, w, response.makeDeniedResponse(fmt.Errorf("failed to validate request content: %w", err).Error(), p.timeFormat))
+ response.sendBackResponse(r, w, response.makeFailedResponse(StatusCodeRequestDenied, fmt.Errorf("failed to validate request content: %w", err).Error(), p.TimeFormat))
},
)
@@ -218,7 +218,7 @@ func (p *IdentityProvider) ssoHandleFunc(w http.ResponseWriter, r *http.Request)
return err
},
func() {
- response.sendBackResponse(r, w, response.makeResponderFailResponse(fmt.Errorf("failed to persist request: %w", err).Error(), p.timeFormat))
+ response.sendBackResponse(r, w, response.makeFailedResponse(StatusCodeResponder, fmt.Errorf("failed to persist request: %w", err).Error(), p.TimeFormat))
},
)
@@ -232,7 +232,7 @@ func (p *IdentityProvider) ssoHandleFunc(w http.ResponseWriter, r *http.Request)
http.Redirect(w, r, sp.LoginURL(authRequest.GetID()), http.StatusSeeOther)
default:
logging.Error(err)
- response.sendBackResponse(r, w, response.makeUnsupportedBindingResponse(fmt.Errorf("unsupported binding: %s", response.ProtocolBinding).Error(), p.timeFormat))
+ response.sendBackResponse(r, w, response.makeFailedResponse(StatusCodeUnsupportedBinding, fmt.Errorf("unsupported binding: %s", response.ProtocolBinding).Error(), p.TimeFormat))
}
return
}
@@ -258,6 +258,9 @@ func getAuthRequestFromRequest(r *http.Request) (*AuthRequestForm, error) {
Sig: r.FormValue("Signature"),
Binding: binding,
}
+ if request.Encoding == "" && binding == RedirectBinding {
+ request.Encoding = xml.EncodingDeflate
+ }
return request, nil
}
@@ -348,14 +351,14 @@ func checkCertificate(
}
}
-func getAcsUrlAndBindingForResponse(
- sp *serviceprovider.ServiceProvider,
+func GetAcsUrlAndBindingForResponse(
+ acs []md.IndexedEndpointType,
requestProtocolBinding string,
) (string, string) {
acsUrl := ""
protocolBinding := ""
- for _, acs := range sp.Metadata.SPSSODescriptor.AssertionConsumerService {
+ for _, acs := range acs {
if acs.Binding == requestProtocolBinding {
acsUrl = acs.Location
protocolBinding = acs.Binding
@@ -364,7 +367,7 @@ func getAcsUrlAndBindingForResponse(
}
if acsUrl == "" {
isDefaultFound := false
- for _, acs := range sp.Metadata.SPSSODescriptor.AssertionConsumerService {
+ for _, acs := range acs {
if acs.IsDefault == "true" {
isDefaultFound = true
acsUrl = acs.Location
@@ -374,7 +377,7 @@ func getAcsUrlAndBindingForResponse(
}
if !isDefaultFound {
index := 0
- for _, acs := range sp.Metadata.SPSSODescriptor.AssertionConsumerService {
+ for _, acs := range acs {
i, _ := strconv.Atoi(acs.Index)
if index == 0 || i < index {
acsUrl = acs.Location
diff --git a/pkg/provider/sso_test.go b/pkg/provider/sso_test.go
index 4b68d8c..2b3d599 100644
--- a/pkg/provider/sso_test.go
+++ b/pkg/provider/sso_test.go
@@ -25,7 +25,7 @@ func TestSSO_getAcsUrlAndBindingForResponse(t *testing.T) {
binding string
}
type args struct {
- sp *serviceprovider.ServiceProvider
+ acs []md.IndexedEndpointType
requestBinding string
}
tests := []struct {
@@ -35,15 +35,9 @@ func TestSSO_getAcsUrlAndBindingForResponse(t *testing.T) {
}{{
"sp with post and redirect, default used",
args{
- &serviceprovider.ServiceProvider{
- Metadata: &md.EntityDescriptorType{
- SPSSODescriptor: &md.SPSSODescriptorType{
- AssertionConsumerService: []md.IndexedEndpointType{
- {Index: "1", IsDefault: "true", Binding: RedirectBinding, Location: "redirect"},
- {Index: "2", Binding: PostBinding, Location: "post"},
- },
- },
- },
+ []md.IndexedEndpointType{
+ {Index: "1", IsDefault: "true", Binding: RedirectBinding, Location: "redirect"},
+ {Index: "2", Binding: PostBinding, Location: "post"},
},
RedirectBinding,
},
@@ -55,15 +49,9 @@ func TestSSO_getAcsUrlAndBindingForResponse(t *testing.T) {
{
"sp with post and redirect, first index used",
args{
- &serviceprovider.ServiceProvider{
- Metadata: &md.EntityDescriptorType{
- SPSSODescriptor: &md.SPSSODescriptorType{
- AssertionConsumerService: []md.IndexedEndpointType{
- {Index: "1", Binding: RedirectBinding, Location: "redirect"},
- {Index: "2", Binding: PostBinding, Location: "post"},
- },
- },
- },
+ []md.IndexedEndpointType{
+ {Index: "1", Binding: RedirectBinding, Location: "redirect"},
+ {Index: "2", Binding: PostBinding, Location: "post"},
},
RedirectBinding,
},
@@ -75,15 +63,9 @@ func TestSSO_getAcsUrlAndBindingForResponse(t *testing.T) {
{
"sp with post and redirect, redirect used",
args{
- &serviceprovider.ServiceProvider{
- Metadata: &md.EntityDescriptorType{
- SPSSODescriptor: &md.SPSSODescriptorType{
- AssertionConsumerService: []md.IndexedEndpointType{
- {Binding: RedirectBinding, Location: "redirect"},
- {Binding: PostBinding, Location: "post"},
- },
- },
- },
+ []md.IndexedEndpointType{
+ {Binding: RedirectBinding, Location: "redirect"},
+ {Binding: PostBinding, Location: "post"},
},
RedirectBinding,
},
@@ -95,15 +77,9 @@ func TestSSO_getAcsUrlAndBindingForResponse(t *testing.T) {
{
"sp with post and redirect, post used",
args{
- &serviceprovider.ServiceProvider{
- Metadata: &md.EntityDescriptorType{
- SPSSODescriptor: &md.SPSSODescriptorType{
- AssertionConsumerService: []md.IndexedEndpointType{
- {Binding: RedirectBinding, Location: "redirect"},
- {Binding: PostBinding, Location: "post"},
- },
- },
- },
+ []md.IndexedEndpointType{
+ {Binding: RedirectBinding, Location: "redirect"},
+ {Binding: PostBinding, Location: "post"},
},
PostBinding,
},
@@ -114,16 +90,9 @@ func TestSSO_getAcsUrlAndBindingForResponse(t *testing.T) {
},
{
"sp with redirect, post used",
- args{
- &serviceprovider.ServiceProvider{
- Metadata: &md.EntityDescriptorType{
- SPSSODescriptor: &md.SPSSODescriptorType{
- AssertionConsumerService: []md.IndexedEndpointType{
- {Binding: RedirectBinding, Location: "redirect"},
- },
- },
- },
- },
+ args{[]md.IndexedEndpointType{
+ {Binding: RedirectBinding, Location: "redirect"},
+ },
PostBinding,
},
res{
@@ -134,14 +103,8 @@ func TestSSO_getAcsUrlAndBindingForResponse(t *testing.T) {
{
"sp with post, redirect used",
args{
- &serviceprovider.ServiceProvider{
- Metadata: &md.EntityDescriptorType{
- SPSSODescriptor: &md.SPSSODescriptorType{
- AssertionConsumerService: []md.IndexedEndpointType{
- {Binding: PostBinding, Location: "post"},
- },
- },
- },
+ []md.IndexedEndpointType{
+ {Binding: PostBinding, Location: "post"},
},
RedirectBinding,
},
@@ -153,9 +116,9 @@ func TestSSO_getAcsUrlAndBindingForResponse(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- acs, binding := getAcsUrlAndBindingForResponse(tt.args.sp, tt.args.requestBinding)
+ acs, binding := GetAcsUrlAndBindingForResponse(tt.args.acs, tt.args.requestBinding)
if acs != tt.res.acs && binding != tt.res.binding {
- t.Errorf("getAcsUrlAndBindingForResponse() got = %v/%v, want %v/%v", acs, binding, tt.res.acs, tt.res.binding)
+ t.Errorf("GetAcsUrlAndBindingForResponse() got = %v/%v, want %v/%v", acs, binding, tt.res.acs, tt.res.binding)
return
}
})
@@ -447,9 +410,11 @@ func TestSSO_ssoHandleFunc(t *testing.T) {
Binding string
}
type res struct {
- code int
- err bool
- state string
+ code int
+ err bool
+ state string
+ inflate bool
+ b64 bool
}
type sp struct {
entityID string
@@ -628,39 +593,6 @@ func TestSSO_ssoHandleFunc(t *testing.T) {
state: StatusCodeRequestDenied,
err: false,
}},
- {
- "redirect request unknown service provider",
- args{
- issuer: "http://localhost:50002",
- metadataEndpoint: "/saml/metadata",
- config: &IdentityProviderConfig{
- SignatureAlgorithm: dsig.RSASHA256SignatureMethod,
- MetadataIDPConfig: &MetadataIDPConfig{},
- Endpoints: &EndpointConfig{
- SingleSignOn: getEndpointPointer("/saml/SSO", "http://localhost:50002/saml/SSO"),
- },
- },
- certificate: "-----BEGIN CERTIFICATE-----\nMIICvDCCAaQCCQD6E8ZGsQ2usjANBgkqhkiG9w0BAQsFADAgMR4wHAYDVQQDDBVt\neXNlcnZpY2UuZXhhbXBsZS5jb20wHhcNMjIwMjE3MTQwNjM5WhcNMjMwMjE3MTQw\nNjM5WjAgMR4wHAYDVQQDDBVteXNlcnZpY2UuZXhhbXBsZS5jb20wggEiMA0GCSqG\nSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC7XKdCRxUZXjdqVqwwwOJqc1Ch0nOSmk+U\nerkUqlviWHdeLR+FolHKjqLzCBloAz4xVc0DFfR76gWcWAHJloqZ7GBS7NpDhzV8\nG+cXQ+bTU0Lu2e73zCQb30XUdKhWiGfDKaU+1xg9CD/2gIfsYPs3TTq1sq7oCs5q\nLdUHaVL5kcRaHKdnTi7cs5i9xzs3TsUnXcrJPwydjp+aEkyRh07oMpXBEobGisfF\n2p1MA6pVW2gjmywf7D5iYEFELQhM7poqPN3/kfBvU1n7Lfgq7oxmv/8LFi4Zopr5\nnyqsz26XPtUy1WqTzgznAmP+nN0oBTERFVbXXdRa3k2v4cxTNPn/AgMBAAEwDQYJ\nKoZIhvcNAQELBQADggEBAJYxROWSOZbOzXzafdGjQKsMgN948G/hHwVuZneyAcVo\nLMFTs1Weya9Z+snMp1u0AdDGmQTS9zGnD7syDYGOmgigOLcMvLMoWf5tCQBbEukW\n8O7DPjRR0XypChGSsHsqLGO0B0HaTel0HdP9Si827OCkc9Q+WbsFG/8/4ToGWL+u\nla1WuLawozoj8umPi9D8iXCoW35y2STU+WFQG7W+Kfdu+2CYz/0tGdwVqNG4Wsfa\nwWchrS00vGFKjm/fJc876gAfxiMH1I9fZvYSAxAZ3sVI//Ml2sUdgf067ywQ75oa\nLSS2NImmz5aos3vuWmOXhILd7iTU+BD8Uv6vWbI7I1M=\n-----END CERTIFICATE-----\n",
- key: "-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC7XKdCRxUZXjdq\nVqwwwOJqc1Ch0nOSmk+UerkUqlviWHdeLR+FolHKjqLzCBloAz4xVc0DFfR76gWc\nWAHJloqZ7GBS7NpDhzV8G+cXQ+bTU0Lu2e73zCQb30XUdKhWiGfDKaU+1xg9CD/2\ngIfsYPs3TTq1sq7oCs5qLdUHaVL5kcRaHKdnTi7cs5i9xzs3TsUnXcrJPwydjp+a\nEkyRh07oMpXBEobGisfF2p1MA6pVW2gjmywf7D5iYEFELQhM7poqPN3/kfBvU1n7\nLfgq7oxmv/8LFi4Zopr5nyqsz26XPtUy1WqTzgznAmP+nN0oBTERFVbXXdRa3k2v\n4cxTNPn/AgMBAAECggEAF+rV9yH30Ysza8GwrXCR9qDN1Dp3QmmsavnXkonEvPoq\nEr2T3o0//6mBp6CLDboMQGQBjblJwl+3Y6PgZolvHAMOsMdHfYNPEo7FSzUBzEw+\nqRrs5HkMyvoPgfV6X8F97W3tiD4Q/AmHkMILl+MxbnfPXM54gWqPuwIqxY1uaCk5\nREwyb7WBon3rd58ceOI1SLRjod6SbqWBMMSN3cJ+5VEPObFjw/RlhNQ5rBI8G5Kt\nso2zBU5C4BB2CvqlWy98WDKJkTvWHbiTjZCy8BQ+gQ6UJM2vaNELFOVpuMGQnMIi\noWiX10Jg2e1gP9j3TdrohlGF8M3+TXjSFKNmeX0DUQKBgQDx7UazUWS5RtkgnjH9\nw2xH2xkstJVD7nAS8VTxNwcrgjVXPvTJha9El904obUjyRX7ppb02tuH5ML/bZh6\n9lL4bP5+SHcJ10e4q8CK/KAGHD6BYAbaGXRq0CoSk5a3vv5XPdob4T5qKCIHFpnu\nMfbvdbEoameLOyRYOGu/yVZIiwKBgQDGQs7FRTisHV0xooiRmlvYF0dcd19qpLed\nqhgJNqBPOTEvvGvJNRoi39haEY3cuTqsxZ5FAlFlVFMUUozz+d0xBLLInoVY/Y4h\nhSdGmdw/A6oHodLqyEp3N5RZNdLlh8/nDS3xXzMotAl75bW5kc2ttcRhRdtyNJ9Z\nup0PgppO3QKBgEC45upAQz8iCiKkz+EA8C4FGqYQJcLHvmoC8GOcAioMqrKNoDVt\ns2cZbdChynEpcd0iQ058YrDnbZeiPWHgFnBp0Gf+gQI7+u8X2+oTDci0s7Au/YZJ\nuxB8YlUX8QF1clvqqzg8OVNzKy9UR5gm+9YyWVPjq5HfH6kOZx0nAxNjAoGAERt8\nqgsCC9/wxbKnpCC0oh3IG5N1WUdjTKh7sHfVN2DQ/LR+fHsniTDVg1gWbKBTDsty\nj7PWgC7ZiFxjKz45NtyX7LW4/efLFttdezsVhR500nnFMFseCdFy7Iu3afThHKfH\nehdj27RFSTqWBrAtFjsj+dzERcOCqIRwvwDe/cUCgYEA5+1mzVXDVjKsWylKJPk+\nZZA4LUfvmTj3VLNDZrlSAI/xEikCFio0QWEA2TQYTAwbXTrKwQSeHQRhv7OTc1h+\nMhpAgvs189ze5J4jiNmULEkkrO+Cxxnw8tyV+UFRZtzW9gUoVBwXiZ/Wbl9sfnlO\nwLJHc0j6OltPcPJmxHP8gQI=\n-----END PRIVATE KEY-----\n",
- request: request{
- ID: "test",
- Binding: RedirectBinding,
- SAMLRequest: url.QueryEscape("nJJBj9MwEIX/ijX3NG6a7DbWJlLZClFpYatN4cBt6k6oJccungmw/x61XaQioRy42vP5ved5D4yDP5nVKMfwQt9HYlG/Bh/YnC8aGFMwEdmxCTgQG7GmW318MsVMG2SmJC4GuEFO08wpRYk2elCbdQPukFlNd/c9LQpczPve6r3taVHWdbWoal3bfr7c03JJc1BfKLGLoYFipkFtmEfaBBYM0kChiyLTZVbc7XRtyntTVrOyrr6CWhOLCygX8ihyMnnuo0V/jCym0loX+dl33nXPoFZ/Ij3GwONAqaP0w1n6/PL0D3qptb7CaBnU9i3bOxcOLnyb/oj9dYjNh91um22fux20l2WYS7Kk3sc0oEw/cj5xh6y/jBoK4uQV2gmfAwkeUPAhv5Fq30rwCQfarLfRO/v6H/KSMLCjIKBW3sefj4lQqAFJI0HeXiX/rlr7OwAA//8="),
- RelayState: url.QueryEscape("K6LS7mdqUO4SGedbfa8nBIyX-7K8gGbrHMqIMwVn6zCKLLoADHjEHUAm"),
- Signature: url.QueryEscape("PWZ6JPNpAGE7mYLKD3dCUG9AZcThrMRQGtvdv31ewx3hms5Oglc677iAUEcbIBrvKtMrCPVwXPNxT6wQ0rg4qIgyKgoyS53ZTaxaFHPrB7wkkzqtK7GvWgdEqceT8iooK5SCLHFMJ3m30LqEbX7zFw62yE34+e7ypfZSM5Lrf0QFwPzX+LNCuYA+Ob9D5SKc132tn21J2vBRmNJ1zCY0ksRzQfyfErjAzcGVx8qK9jpaeyvsVBZSkH/I6+1hb8lQWE48xala9NbqfbMATGBCQj1UvpVMMfp6PE7KPk5Y1YDeSqPeRIEKH+Gnip6Hve5Ji1aiRp5bytVf1VHwTHSq8w=="),
- SigAlg: url.QueryEscape("http://www.w3.org/2000/09/xmldsig#rsa-sha1"),
- },
- sp: sp{
- entityID: "http://localhost:8000/saml/metadata",
- metadata: "\n \n \n \n \n MIICvDCCAaQCCQD6E8ZGsQ2usjANBgkqhkiG9w0BAQsFADAgMR4wHAYDVQQDDBVteXNlcnZpY2UuZXhhbXBsZS5jb20wHhcNMjIwMjE3MTQwNjM5WhcNMjMwMjE3MTQwNjM5WjAgMR4wHAYDVQQDDBVteXNlcnZpY2UuZXhhbXBsZS5jb20wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC7XKdCRxUZXjdqVqwwwOJqc1Ch0nOSmk+UerkUqlviWHdeLR+FolHKjqLzCBloAz4xVc0DFfR76gWcWAHJloqZ7GBS7NpDhzV8G+cXQ+bTU0Lu2e73zCQb30XUdKhWiGfDKaU+1xg9CD/2gIfsYPs3TTq1sq7oCs5qLdUHaVL5kcRaHKdnTi7cs5i9xzs3TsUnXcrJPwydjp+aEkyRh07oMpXBEobGisfF2p1MA6pVW2gjmywf7D5iYEFELQhM7poqPN3/kfBvU1n7Lfgq7oxmv/8LFi4Zopr5nyqsz26XPtUy1WqTzgznAmP+nN0oBTERFVbXXdRa3k2v4cxTNPn/AgMBAAEwDQYJKoZIhvcNAQELBQADggEBAJYxROWSOZbOzXzafdGjQKsMgN948G/hHwVuZneyAcVoLMFTs1Weya9Z+snMp1u0AdDGmQTS9zGnD7syDYGOmgigOLcMvLMoWf5tCQBbEukW8O7DPjRR0XypChGSsHsqLGO0B0HaTel0HdP9Si827OCkc9Q+WbsFG/8/4ToGWL+ula1WuLawozoj8umPi9D8iXCoW35y2STU+WFQG7W+Kfdu+2CYz/0tGdwVqNG4WsfawWchrS00vGFKjm/fJc876gAfxiMH1I9fZvYSAxAZ3sVI//Ml2sUdgf067ywQ75oaLSS2NImmz5aos3vuWmOXhILd7iTU+BD8Uv6vWbI7I1M=\n \n \n \n \n \n \n \n \n \n \n MIICvDCCAaQCCQD6E8ZGsQ2usjANBgkqhkiG9w0BAQsFADAgMR4wHAYDVQQDDBVteXNlcnZpY2UuZXhhbXBsZS5jb20wHhcNMjIwMjE3MTQwNjM5WhcNMjMwMjE3MTQwNjM5WjAgMR4wHAYDVQQDDBVteXNlcnZpY2UuZXhhbXBsZS5jb20wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC7XKdCRxUZXjdqVqwwwOJqc1Ch0nOSmk+UerkUqlviWHdeLR+FolHKjqLzCBloAz4xVc0DFfR76gWcWAHJloqZ7GBS7NpDhzV8G+cXQ+bTU0Lu2e73zCQb30XUdKhWiGfDKaU+1xg9CD/2gIfsYPs3TTq1sq7oCs5qLdUHaVL5kcRaHKdnTi7cs5i9xzs3TsUnXcrJPwydjp+aEkyRh07oMpXBEobGisfF2p1MA6pVW2gjmywf7D5iYEFELQhM7poqPN3/kfBvU1n7Lfgq7oxmv/8LFi4Zopr5nyqsz26XPtUy1WqTzgznAmP+nN0oBTERFVbXXdRa3k2v4cxTNPn/AgMBAAEwDQYJKoZIhvcNAQELBQADggEBAJYxROWSOZbOzXzafdGjQKsMgN948G/hHwVuZneyAcVoLMFTs1Weya9Z+snMp1u0AdDGmQTS9zGnD7syDYGOmgigOLcMvLMoWf5tCQBbEukW8O7DPjRR0XypChGSsHsqLGO0B0HaTel0HdP9Si827OCkc9Q+WbsFG/8/4ToGWL+ula1WuLawozoj8umPi9D8iXCoW35y2STU+WFQG7W+Kfdu+2CYz/0tGdwVqNG4WsfawWchrS00vGFKjm/fJc876gAfxiMH1I9fZvYSAxAZ3sVI//Ml2sUdgf067ywQ75oaLSS2NImmz5aos3vuWmOXhILd7iTU+BD8Uv6vWbI7I1M=\n \n \n \n \n \n \n \n",
- err: fmt.Errorf("unknown"),
- },
- },
- res{
- code: 200,
- state: StatusCodeRequestDenied,
- err: false,
- }},
{
"signed post request",
args{
@@ -693,7 +625,6 @@ func TestSSO_ssoHandleFunc(t *testing.T) {
err: false,
}},
}
-
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
endpoint := NewEndpoint(tt.args.metadataEndpoint)
@@ -766,7 +697,7 @@ func TestSSO_ssoHandleFunc(t *testing.T) {
}
if tt.res.state != "" {
- if err := parseForState(string(response), tt.res.state); err != nil {
+ if err := parseForState(tt.res.inflate, tt.res.b64, string(response), tt.res.state); err != nil {
t.Errorf("ssoHandleFunc() response state not: %v", tt.res.state)
return
}
@@ -775,8 +706,13 @@ func TestSSO_ssoHandleFunc(t *testing.T) {
}
}
-func parseForState(responseXML string, state string) error {
- response, err := xml.DecodeResponse("", responseXML)
+func parseForState(inflate bool, b64 bool, responseXML string, state string) error {
+ encoding := ""
+ if inflate {
+ encoding = xml.EncodingDeflate
+ }
+
+ response, err := xml.DecodeResponse(encoding, b64, responseXML)
if err != nil {
return err
}
diff --git a/pkg/provider/xml/xml.go b/pkg/provider/xml/xml.go
index ad4e976..f717b5a 100644
--- a/pkg/provider/xml/xml.go
+++ b/pkg/provider/xml/xml.go
@@ -7,6 +7,7 @@ import (
"encoding/base64"
"encoding/xml"
"fmt"
+ "io"
"net/http"
"strings"
@@ -43,25 +44,20 @@ func Marshal(data interface{}) ([]byte, error) {
}
func DeflateAndBase64(data []byte) ([]byte, error) {
- b := &bytes.Buffer{}
- w1 := base64.NewEncoder(base64.StdEncoding, b)
- defer w1.Close()
-
- w2, _ := flate.NewWriter(w1, 1)
- defer w2.Close()
-
- bw := bufio.NewWriter(w1)
- if _, err := bw.Write(data); err != nil {
+ buff := &bytes.Buffer{}
+ b64Encoder := base64.NewEncoder(base64.StdEncoding, buff)
+ // compression level is set at 9 as BestCompression, also used by other SAML application like crewjam/saml
+ flateWriter, _ := flate.NewWriter(b64Encoder, 9)
+ if _, err := flateWriter.Write(data); err != nil {
return nil, err
}
- if err := bw.Flush(); err != nil {
+ if err := flateWriter.Close(); err != nil {
return nil, err
}
-
- if err := w2.Flush(); err != nil {
+ if err := b64Encoder.Close(); err != nil {
return nil, err
}
- return b.Bytes(), nil
+ return buff.Bytes(), nil
}
func WriteXMLMarshalled(w http.ResponseWriter, body interface{}) error {
@@ -86,53 +82,26 @@ func Write(w http.ResponseWriter, body []byte) error {
}
func DecodeAuthNRequest(encoding string, message string) (*samlp.AuthnRequestType, error) {
- reqBytes, err := base64.StdEncoding.DecodeString(message)
+ data, err := InflateAndDecode(encoding, true, message)
if err != nil {
- return nil, fmt.Errorf("failed to base64 decode: %w", err)
+ return nil, err
}
-
req := &samlp.AuthnRequestType{}
- switch encoding {
- case EncodingDeflate:
- reader := flate.NewReader(bytes.NewReader(reqBytes))
- decoder := xml.NewDecoder(reader)
- if err = decoder.Decode(req); err != nil {
- return nil, fmt.Errorf("failed to defalte decode: %w", err)
- }
- default:
- reader := flate.NewReader(bytes.NewReader(reqBytes))
- decoder := xml.NewDecoder(reader)
- if err = decoder.Decode(req); err != nil {
- if err := xml.Unmarshal(reqBytes, req); err != nil {
- return nil, fmt.Errorf("failed to unmarshal: %w", err)
- }
- }
+ if err := xml.Unmarshal(data, req); err != nil {
+ return nil, err
}
-
return req, nil
}
-func DecodeSignature(encoding string, message string) (*xml_dsig.SignatureType, error) {
- retBytes := []byte(message)
-
+func DecodeSignature(encoding string, b64 bool, message string) (*xml_dsig.SignatureType, error) {
+ data, err := InflateAndDecode(encoding, b64, message)
+ if err != nil {
+ return nil, err
+ }
ret := &xml_dsig.SignatureType{}
- switch encoding {
- case EncodingDeflate:
- reader := flate.NewReader(bytes.NewReader(retBytes))
- decoder := xml.NewDecoder(reader)
- if err := decoder.Decode(ret); err != nil {
- return nil, fmt.Errorf("failed to defalte decode: %w", err)
- }
- default:
- reader := flate.NewReader(bytes.NewReader(retBytes))
- decoder := xml.NewDecoder(reader)
- if err := decoder.Decode(ret); err != nil {
- if err := xml.Unmarshal(retBytes, ret); err != nil {
- return nil, fmt.Errorf("failed to unmarshal: %w", err)
- }
- }
+ if err := xml.Unmarshal(data, ret); err != nil {
+ return nil, err
}
-
return ret, nil
}
@@ -148,50 +117,45 @@ func DecodeAttributeQuery(request string) (*samlp.AttributeQueryType, error) {
}
func DecodeLogoutRequest(encoding string, message string) (*samlp.LogoutRequestType, error) {
- reqBytes, err := base64.StdEncoding.DecodeString(message)
+ data, err := InflateAndDecode(encoding, true, message)
if err != nil {
return nil, err
}
-
req := &samlp.LogoutRequestType{}
- switch encoding {
- case "":
- reader := flate.NewReader(bytes.NewReader(reqBytes))
- decoder := xml.NewDecoder(reader)
- if err = decoder.Decode(req); err != nil {
- return nil, err
- }
- case EncodingDeflate:
- reader := flate.NewReader(bytes.NewReader(reqBytes))
- decoder := xml.NewDecoder(reader)
- if err = decoder.Decode(req); err != nil {
- return nil, err
- }
- default:
- return nil, fmt.Errorf("unknown encoding")
+ if err := xml.Unmarshal(data, req); err != nil {
+ return nil, err
}
-
return req, nil
}
-func DecodeResponse(encoding string, message string) (*samlp.ResponseType, error) {
-
+func DecodeResponse(encoding string, b64 bool, message string) (*samlp.ResponseType, error) {
+ data, err := InflateAndDecode(encoding, b64, message)
+ if err != nil {
+ return nil, err
+ }
req := &samlp.ResponseType{}
- switch encoding {
- case "":
- decoder := xml.NewDecoder(bytes.NewReader([]byte(message)))
- if err := decoder.Decode(req); err != nil {
+ if err := xml.Unmarshal(data, req); err != nil {
+ return nil, err
+ }
+ return req, nil
+}
+
+func InflateAndDecode(encoding string, b64 bool, message string) (_ []byte, err error) {
+ data := []byte(message)
+ if b64 {
+ data, err = base64.StdEncoding.DecodeString(message)
+ if err != nil {
return nil, err
}
+ }
+ switch encoding {
+ case "":
+ return data, nil
case EncodingDeflate:
- reader := flate.NewReader(bytes.NewReader([]byte(message)))
- decoder := xml.NewDecoder(reader)
- if err := decoder.Decode(req); err != nil {
- return nil, err
- }
+ r := flate.NewReader(bytes.NewBuffer(data))
+ defer r.Close()
+ return io.ReadAll(r)
default:
return nil, fmt.Errorf("unknown encoding")
}
-
- return req, nil
}