From 767a3b6d85a5679a775970db380ed78c245032c7 Mon Sep 17 00:00:00 2001 From: Hiram Silvey <9485196+HiramSilvey@users.noreply.github.com> Date: Thu, 14 Nov 2024 18:19:26 -0500 Subject: [PATCH 01/33] Rough draft of idea. TODO: Determine how to short-circuit Thrift requests. TODO: Refactor into a common shared function with func parameters. --- faults/common.go | 11 +++ faults/headers.go | 11 +++ httpbp/client_middlewares.go | 85 +++++++++++++++++++++ thriftbp/client_middlewares.go | 132 ++++++++++++++++++++++++++++++--- 4 files changed, 228 insertions(+), 11 deletions(-) create mode 100644 faults/common.go create mode 100644 faults/headers.go diff --git a/faults/common.go b/faults/common.go new file mode 100644 index 000000000..5d4906802 --- /dev/null +++ b/faults/common.go @@ -0,0 +1,11 @@ +package faults + +import "strings" + +func GetShortenedAddress(address string) string { + parts := strings.Split(address, ".") + if len(parts) < 2 { + return "" + } + return strings.Join(parts[:2], ".") +} diff --git a/faults/headers.go b/faults/headers.go new file mode 100644 index 000000000..c694bb208 --- /dev/null +++ b/faults/headers.go @@ -0,0 +1,11 @@ +package faults + +const ( + FaultServerAddressHeader = "X-Bp-Fault-Server-Address" + FaultDelayMsHeader = "X-Bp-Fault-Delay-Ms" + FaultAbortCodeHeader = "X-Bp-Fault-Abort-Code" + FaultServerMethodHeader = "X-Bp-Fault-Server-Method" + FaultAbortMessageHeader = "X-Bp-Fault-Abort-Message" + FaultDelayPercentageHeader = "X-Bp-Fault-Delay-Percentage" + FaultAbortPercentageHeader = "X-Bp-Fault-Abort-Percentage" +) diff --git a/httpbp/client_middlewares.go b/httpbp/client_middlewares.go index b02948e78..4a8d9271c 100644 --- a/httpbp/client_middlewares.go +++ b/httpbp/client_middlewares.go @@ -5,8 +5,11 @@ import ( "fmt" "io" "log/slog" + "math/rand" "net/http" + "net/url" "strconv" + "strings" "sync" "sync/atomic" "time" @@ -15,6 +18,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/reddit/baseplate.go/breakerbp" + "github.com/reddit/baseplate.go/faults" //lint:ignore SA1019 This library is internal only, not actually deprecated "github.com/reddit/baseplate.go/internalv2compat" "github.com/reddit/baseplate.go/retrybp" @@ -349,3 +353,84 @@ func PrometheusClientMetrics(serverSlug string) ClientMiddleware { }) } } + +type ServiceAddressParts struct { + Name string + Namespace string +} + +func GetShortenedAddress(url *url.URL) string { + hostParts := strings.Split(url.Host, ".") + if len(hostParts) < 2 { + return "" + } + return strings.Join(hostParts[:2], ".") +} + +func FaultInjection(serverSlug string) ClientMiddleware { + return func(next http.RoundTripper) http.RoundTripper { + return roundTripperFunc(func(req *http.Request) (resp *http.Response, err error) { + serverAddress := req.Header.Get(faults.FaultServerAddressHeader) + if serverAddress == "" || serverAddress != GetShortenedAddress(req.URL) { + return next.RoundTrip(req) + } + + serverMethod := req.Header.Get(faults.FaultServerMethodHeader) + if serverMethod != "" && serverMethod != req.URL.Path { + return next.RoundTrip(req) + } + + delayMs := req.Header.Get(faults.FaultDelayMsHeader) + if delayMs != "" { + delayPercentage := req.Header.Get(faults.FaultDelayPercentageHeader) + if delayPercentage != "" { + percentage, err := strconv.Atoi(delayPercentage) + if err != nil { + return nil, errors.New("provided delay percentage is not a valid integer") + } + if percentage < 0 || percentage > 100 { + return nil, errors.New("provided delay percentage is outside the valid range of [0-100]") + } + if percentage == 0 || (percentage != 100 && rand.Intn(100) >= percentage) { + return next.RoundTrip(req) + } + } + + delay, err := strconv.Atoi(delayMs) + if err != nil { + return nil, errors.New("provided delay is not a valid integer") + } + time.Sleep(time.Duration(delay) * time.Millisecond) + } + + abortCode := req.Header.Get(faults.FaultAbortCodeHeader) + if abortCode != "" { + abortPercentage := req.Header.Get(faults.FaultAbortPercentageHeader) + if abortPercentage != "" { + percentage, err := strconv.Atoi(abortPercentage) + if err != nil { + return nil, errors.New("provided abort percentage is not a valid integer") + } + if percentage < 0 || percentage > 100 { + return nil, errors.New("provided abort percentage is outside the valid range of [0-100]") + } + if percentage != 100 && rand.Intn(100) >= percentage { + return next.RoundTrip(req) + } + } + + code, err := strconv.Atoi(abortCode) + if err != nil { + return nil, errors.New("provided abort code is not a valid integer") + } + if code < 100 || code >= 600 { + return nil, errors.New("provided abort code is outside the valid HTTP status code range of [100-599]") + } + resp.StatusCode = code + return resp, nil + } + + return next.RoundTrip(req) + }) + } +} diff --git a/thriftbp/client_middlewares.go b/thriftbp/client_middlewares.go index 00c33016f..2b1cab3c7 100644 --- a/thriftbp/client_middlewares.go +++ b/thriftbp/client_middlewares.go @@ -11,12 +11,15 @@ import ( "github.com/apache/thrift/lib/go/thrift" "github.com/avast/retry-go" "github.com/prometheus/client_golang/prometheus" + "golang.org/x/exp/rand" "github.com/reddit/baseplate.go/breakerbp" "github.com/reddit/baseplate.go/ecinterface" "github.com/reddit/baseplate.go/errorsbp" + "github.com/reddit/baseplate.go/faults" "github.com/reddit/baseplate.go/internal/gen-go/reddit/baseplate" "github.com/reddit/baseplate.go/internal/thriftint" + //lint:ignore SA1019 This library is internal only, not actually deprecated "github.com/reddit/baseplate.go/internalv2compat" "github.com/reddit/baseplate.go/prometheusbp" @@ -66,6 +69,11 @@ type DefaultClientMiddlewareArgs struct { // ImageUploadService -> image-upload ServiceSlug string + // Address is the DNS address of the thrift service you are creating clients for. + // + // If not provided, the client will be unable to use the fault injection middleware. + Address string + // RetryOptions is the list of retry.Options to apply as the defaults for the // Retry middleware. // @@ -105,38 +113,42 @@ type DefaultClientMiddlewareArgs struct { // // Currently they are (in order): // -// 1. ForwardEdgeRequestContext. +// 1. FaultInjectionClientMiddleware - This injects faults at the client side if +// the request matches the provided configuration. // -// 2. SetClientName(clientName) +// 2. ForwardEdgeRequestContext // -// 3. MonitorClient with MonitorClientWrappedSlugSuffix - This creates the spans +// 3. SetClientName(clientName) +// +// 4. MonitorClient with MonitorClientWrappedSlugSuffix - This creates the spans // from the view of the client that group all retries into a single, // wrapped span. // -// 4. PrometheusClientMiddleware with MonitorClientWrappedSlugSuffix - This +// 5. PrometheusClientMiddleware with MonitorClientWrappedSlugSuffix - This // creates the prometheus client metrics from the view of the client that group // all retries into a single operation. // -// 5. Retry(retryOptions) - If retryOptions is empty/nil, default to only +// 6. Retry(retryOptions) - If retryOptions is empty/nil, default to only // retry.Attempts(1), this will not actually retry any calls but your client is // configured to set retry logic per-call using retrybp.WithOptions. // -// 6. FailureRatioBreaker - Only if BreakerConfig is non-nil. +// 7. FailureRatioBreaker - Only if BreakerConfig is non-nil. // -// 7. MonitorClient - This creates the spans of the raw client calls. +// 8. MonitorClient - This creates the spans of the raw client calls. // -// 8. PrometheusClientMiddleware +// 9. PrometheusClientMiddleware // -// 9. BaseplateErrorWrapper +// 10. BaseplateErrorWrapper // -// 10. thrift.ExtractIDLExceptionClientMiddleware +// 11. thrift.ExtractIDLExceptionClientMiddleware // -// 11. SetDeadlineBudget +// 12. SetDeadlineBudget func BaseplateDefaultClientMiddlewares(args DefaultClientMiddlewareArgs) []thrift.ClientMiddleware { if len(args.RetryOptions) == 0 { args.RetryOptions = []retry.Option{retry.Attempts(1)} } middlewares := []thrift.ClientMiddleware{ + FaultInjectionClientMiddleware(args.Address), ForwardEdgeRequestContext(args.EdgeContextImpl), SetClientName(args.ClientName), MonitorClient(MonitorClientArgs{ @@ -390,6 +402,104 @@ func PrometheusClientMiddleware(remoteServerSlug string) thrift.ClientMiddleware } } +// differences: +// -- resume function +// -- get header function +func FaultInjectionClientMiddleware(address string) thrift.ClientMiddleware { + return func(next thrift.TClient) thrift.TClient { + return thrift.WrappedTClient{ + Wrapped: func(ctx context.Context, method string, args, result thrift.TStruct) (thrift.ResponseMeta, error) { + serverAddress, ok := thrift.GetHeader(ctx, faults.FaultServerAddressHeader) + if !ok { + return next.Call(ctx, method, args, result) + } + if serverAddress == "" || serverAddress != faults.GetShortenedAddress(address) { + return next.Call(ctx, method, args, result) + } + + serverMethod, ok := thrift.GetHeader(ctx, faults.FaultServerMethodHeader) + if !ok { + return next.Call(ctx, method, args, result) + } + if serverMethod != "" && serverMethod != method { + return next.Call(ctx, method, args, result) + } + + delayMs, ok := thrift.GetHeader(ctx, faults.FaultDelayMsHeader) + if !ok { + return next.Call(ctx, method, args, result) + } + if delayMs != "" { + delayPercentage, ok := thrift.GetHeader(ctx, faults.FaultDelayPercentageHeader) + if !ok { + return next.Call(ctx, method, args, result) + } + if delayPercentage != "" { + percentage, err := strconv.Atoi(delayPercentage) + if err != nil { + // log "provided delay percentage is not a valid integer" + return next.Call(ctx, method, args, result) + } + if percentage < 0 || percentage > 100 { + // log "provided delay percentage is outside the valid range of [0-100]" + return next.Call(ctx, method, args, result) + } + if percentage == 0 || (percentage != 100 && rand.Intn(100) >= percentage) { + return next.Call(ctx, method, args, result) + } + } + + delay, err := strconv.Atoi(delayMs) + if err != nil { + // log "provided delay is not a valid integer" + return next.Call(ctx, method, args, result) + } + time.Sleep(time.Duration(delay) * time.Millisecond) + } + + abortCode, ok := thrift.GetHeader(ctx, faults.FaultAbortCodeHeader) + if !ok { + return next.Call(ctx, method, args, result) + } + if abortCode != "" { + abortPercentage, ok := thrift.GetHeader(ctx, faults.FaultAbortPercentageHeader) + if !ok { + return next.Call(ctx, method, args, result) + } + if abortPercentage != "" { + percentage, err := strconv.Atoi(abortPercentage) + if err != nil { + // log "provided abort percentage is not a valid integer" + return next.Call(ctx, method, args, result) + } + if percentage < 0 || percentage > 100 { + // log "provided abort percentage is outside the valid range of [0-100]" + return next.Call(ctx, method, args, result) + } + if percentage != 100 && rand.Intn(100) >= percentage { + return next.Call(ctx, method, args, result) + } + } + + code, err := strconv.Atoi(abortCode) + if err != nil { + // log "provided abort code is not a valid integer" + return next.Call(ctx, method, args, result) + } + if code < 100 || code >= 600 { + // log "provided abort code is outside the valid HTTP status code range of [100-599]" + return next.Call(ctx, method, args, result) + } + resp.StatusCode = code + return resp, nil + } + + return next.Call(ctx, method, args, result) + }, + } + } +} + func getClientError(result thrift.TStruct, err error) error { if err != nil { return err From 89dee2f63db59c67aca3b2828d9a3e1d2c8f4eb3 Mon Sep 17 00:00:00 2001 From: Hiram Silvey <9485196+HiramSilvey@users.noreply.github.com> Date: Fri, 15 Nov 2024 17:32:20 -0500 Subject: [PATCH 02/33] Refactored into common library. Need to add logging, tests, and potentially special Thrift error logic. --- faults/common.go | 82 ++++++++++++++++++++++++++++- httpbp/client_middlewares.go | 90 ++++++++------------------------ thriftbp/client_middlewares.go | 94 +++++----------------------------- 3 files changed, 113 insertions(+), 153 deletions(-) diff --git a/faults/common.go b/faults/common.go index 5d4906802..d35f4cc65 100644 --- a/faults/common.go +++ b/faults/common.go @@ -1,11 +1,89 @@ package faults -import "strings" +import ( + "strconv" + "strings" + "time" -func GetShortenedAddress(address string) string { + "github.com/reddit/baseplate.go/faults" + "golang.org/x/exp/rand" +) + +func getShortenedAddress(address string) string { parts := strings.Split(address, ".") if len(parts) < 2 { return "" } return strings.Join(parts[:2], ".") } + +func InjectFault(address, method string, getHeaderFn func(key string) string, resumeFn func() (interface{}, error), responseFn func(code int, message string) interface{}) (interface{}, error) { + serverAddress := getHeaderFn(faults.FaultServerAddressHeader) + if serverAddress == "" || serverAddress != getShortenedAddress(address) { + return resumeFn() + } + + serverMethod := getHeaderFn(faults.FaultServerMethodHeader) + if serverMethod != "" && serverMethod != method { + return resumeFn() + } + + delayMs := getHeaderFn(faults.FaultDelayMsHeader) + if delayMs != "" { + delayPercentage := getHeaderFn(faults.FaultDelayPercentageHeader) + if delayPercentage != "" { + percentage, err := strconv.Atoi(delayPercentage) + if err != nil { + // log "provided delay percentage is not a valid integer" + return resumeFn() + } + if percentage < 0 || percentage > 100 { + // log "provided delay percentage is outside the valid range of [0-100]" + return resumeFn() + } + if percentage == 0 || (percentage != 100 && rand.Intn(100) >= percentage) { + return resumeFn() + } + } + + delay, err := strconv.Atoi(delayMs) + if err != nil { + // log "provided delay is not a valid integer" + return resumeFn() + } + time.Sleep(time.Duration(delay) * time.Millisecond) + } + + abortCode := getHeaderFn(faults.FaultAbortCodeHeader) + if abortCode != "" { + abortPercentage := getHeaderFn(faults.FaultAbortPercentageHeader) + if abortPercentage != "" { + percentage, err := strconv.Atoi(abortPercentage) + if err != nil { + // log "provided abort percentage is not a valid integer" + return resumeFn() + } + if percentage < 0 || percentage > 100 { + // log "provided abort percentage is outside the valid range of [0-100]" + return resumeFn() + } + if percentage != 100 && rand.Intn(100) >= percentage { + return resumeFn() + } + } + + code, err := strconv.Atoi(abortCode) + if err != nil { + // log "provided abort code is not a valid integer" + return resumeFn() + } + if code < 100 || code >= 600 { + // log "provided abort code is outside the valid HTTP status code range of [100-599]" + return resumeFn() + } + abortMessage := getHeaderFn(faults.FaultAbortMessageHeader) + return responseFn(code, abortMessage), nil + } + + return resumeFn() +} diff --git a/httpbp/client_middlewares.go b/httpbp/client_middlewares.go index 4a8d9271c..cc3a49938 100644 --- a/httpbp/client_middlewares.go +++ b/httpbp/client_middlewares.go @@ -5,11 +5,8 @@ import ( "fmt" "io" "log/slog" - "math/rand" "net/http" - "net/url" "strconv" - "strings" "sync" "sync/atomic" "time" @@ -19,6 +16,7 @@ import ( "github.com/reddit/baseplate.go/breakerbp" "github.com/reddit/baseplate.go/faults" + //lint:ignore SA1019 This library is internal only, not actually deprecated "github.com/reddit/baseplate.go/internalv2compat" "github.com/reddit/baseplate.go/retrybp" @@ -359,78 +357,32 @@ type ServiceAddressParts struct { Namespace string } -func GetShortenedAddress(url *url.URL) string { - hostParts := strings.Split(url.Host, ".") - if len(hostParts) < 2 { - return "" - } - return strings.Join(hostParts[:2], ".") -} - func FaultInjection(serverSlug string) ClientMiddleware { return func(next http.RoundTripper) http.RoundTripper { - return roundTripperFunc(func(req *http.Request) (resp *http.Response, err error) { - serverAddress := req.Header.Get(faults.FaultServerAddressHeader) - if serverAddress == "" || serverAddress != GetShortenedAddress(req.URL) { - return next.RoundTrip(req) - } - - serverMethod := req.Header.Get(faults.FaultServerMethodHeader) - if serverMethod != "" && serverMethod != req.URL.Path { + return roundTripperFunc(func(req *http.Request) (*http.Response, error) { + resumeFn := func() (interface{}, error) { return next.RoundTrip(req) } - - delayMs := req.Header.Get(faults.FaultDelayMsHeader) - if delayMs != "" { - delayPercentage := req.Header.Get(faults.FaultDelayPercentageHeader) - if delayPercentage != "" { - percentage, err := strconv.Atoi(delayPercentage) - if err != nil { - return nil, errors.New("provided delay percentage is not a valid integer") - } - if percentage < 0 || percentage > 100 { - return nil, errors.New("provided delay percentage is outside the valid range of [0-100]") - } - if percentage == 0 || (percentage != 100 && rand.Intn(100) >= percentage) { - return next.RoundTrip(req) - } - } - - delay, err := strconv.Atoi(delayMs) - if err != nil { - return nil, errors.New("provided delay is not a valid integer") - } - time.Sleep(time.Duration(delay) * time.Millisecond) - } - - abortCode := req.Header.Get(faults.FaultAbortCodeHeader) - if abortCode != "" { - abortPercentage := req.Header.Get(faults.FaultAbortPercentageHeader) - if abortPercentage != "" { - percentage, err := strconv.Atoi(abortPercentage) - if err != nil { - return nil, errors.New("provided abort percentage is not a valid integer") - } - if percentage < 0 || percentage > 100 { - return nil, errors.New("provided abort percentage is outside the valid range of [0-100]") - } - if percentage != 100 && rand.Intn(100) >= percentage { - return next.RoundTrip(req) - } - } - - code, err := strconv.Atoi(abortCode) - if err != nil { - return nil, errors.New("provided abort code is not a valid integer") + responseFn := func(code int, message string) interface{} { + return &http.Response{ + Status: http.StatusText(code), + StatusCode: code, + Proto: req.Proto, + ProtoMajor: req.ProtoMajor, + ProtoMinor: req.ProtoMinor, + Header: map[string][]string{ + // Copied from the standard http.Error() function. + "Content-Type": {"text/plain; charset=utf-8"}, + "X-Content-Type-Options": {"nosniff"}, + }, + ContentLength: 0, + TransferEncoding: req.TransferEncoding, + Request: req, + TLS: req.TLS, } - if code < 100 || code >= 600 { - return nil, errors.New("provided abort code is outside the valid HTTP status code range of [100-599]") - } - resp.StatusCode = code - return resp, nil } - - return next.RoundTrip(req) + resp, err := faults.InjectFault(req.URL.Host, req.URL.Path, req.Header.Get, resumeFn, responseFn) + return resp.(*http.Response), err }) } } diff --git a/thriftbp/client_middlewares.go b/thriftbp/client_middlewares.go index 2b1cab3c7..0c77975fd 100644 --- a/thriftbp/client_middlewares.go +++ b/thriftbp/client_middlewares.go @@ -11,13 +11,13 @@ import ( "github.com/apache/thrift/lib/go/thrift" "github.com/avast/retry-go" "github.com/prometheus/client_golang/prometheus" - "golang.org/x/exp/rand" "github.com/reddit/baseplate.go/breakerbp" "github.com/reddit/baseplate.go/ecinterface" "github.com/reddit/baseplate.go/errorsbp" "github.com/reddit/baseplate.go/faults" "github.com/reddit/baseplate.go/internal/gen-go/reddit/baseplate" + baseplatethrift "github.com/reddit/baseplate.go/internal/gen-go/reddit/baseplate" "github.com/reddit/baseplate.go/internal/thriftint" //lint:ignore SA1019 This library is internal only, not actually deprecated @@ -402,99 +402,29 @@ func PrometheusClientMiddleware(remoteServerSlug string) thrift.ClientMiddleware } } -// differences: -// -- resume function -// -- get header function func FaultInjectionClientMiddleware(address string) thrift.ClientMiddleware { return func(next thrift.TClient) thrift.TClient { return thrift.WrappedTClient{ Wrapped: func(ctx context.Context, method string, args, result thrift.TStruct) (thrift.ResponseMeta, error) { - serverAddress, ok := thrift.GetHeader(ctx, faults.FaultServerAddressHeader) - if !ok { - return next.Call(ctx, method, args, result) - } - if serverAddress == "" || serverAddress != faults.GetShortenedAddress(address) { - return next.Call(ctx, method, args, result) - } - - serverMethod, ok := thrift.GetHeader(ctx, faults.FaultServerMethodHeader) - if !ok { - return next.Call(ctx, method, args, result) - } - if serverMethod != "" && serverMethod != method { - return next.Call(ctx, method, args, result) - } - - delayMs, ok := thrift.GetHeader(ctx, faults.FaultDelayMsHeader) - if !ok { - return next.Call(ctx, method, args, result) - } - if delayMs != "" { - delayPercentage, ok := thrift.GetHeader(ctx, faults.FaultDelayPercentageHeader) + getHeaderFn := func(key string) string { + header, ok := thrift.GetHeader(ctx, key) if !ok { - return next.Call(ctx, method, args, result) - } - if delayPercentage != "" { - percentage, err := strconv.Atoi(delayPercentage) - if err != nil { - // log "provided delay percentage is not a valid integer" - return next.Call(ctx, method, args, result) - } - if percentage < 0 || percentage > 100 { - // log "provided delay percentage is outside the valid range of [0-100]" - return next.Call(ctx, method, args, result) - } - if percentage == 0 || (percentage != 100 && rand.Intn(100) >= percentage) { - return next.Call(ctx, method, args, result) - } - } - - delay, err := strconv.Atoi(delayMs) - if err != nil { - // log "provided delay is not a valid integer" - return next.Call(ctx, method, args, result) + return "" } - time.Sleep(time.Duration(delay) * time.Millisecond) + return header } - - abortCode, ok := thrift.GetHeader(ctx, faults.FaultAbortCodeHeader) - if !ok { + resumeFn := func() (interface{}, error) { return next.Call(ctx, method, args, result) } - if abortCode != "" { - abortPercentage, ok := thrift.GetHeader(ctx, faults.FaultAbortPercentageHeader) - if !ok { - return next.Call(ctx, method, args, result) - } - if abortPercentage != "" { - percentage, err := strconv.Atoi(abortPercentage) - if err != nil { - // log "provided abort percentage is not a valid integer" - return next.Call(ctx, method, args, result) - } - if percentage < 0 || percentage > 100 { - // log "provided abort percentage is outside the valid range of [0-100]" - return next.Call(ctx, method, args, result) - } - if percentage != 100 && rand.Intn(100) >= percentage { - return next.Call(ctx, method, args, result) - } + responseFn := func(code int, message string) interface{} { + return &baseplatethrift.Error{ + Code: thrift.Int32Ptr(int32(code)), + Message: thrift.StringPtr(message), } - - code, err := strconv.Atoi(abortCode) - if err != nil { - // log "provided abort code is not a valid integer" - return next.Call(ctx, method, args, result) - } - if code < 100 || code >= 600 { - // log "provided abort code is outside the valid HTTP status code range of [100-599]" - return next.Call(ctx, method, args, result) - } - resp.StatusCode = code - return resp, nil } - return next.Call(ctx, method, args, result) + responseMeta, err := faults.InjectFault(address, method, getHeaderFn, resumeFn, responseFn) + return responseMeta.(thrift.ResponseMeta), err }, } } From 0aff005fa466632de67a15c500634a26b1325ab2 Mon Sep 17 00:00:00 2001 From: Hiram Silvey <9485196+HiramSilvey@users.noreply.github.com> Date: Mon, 18 Nov 2024 17:05:33 -0500 Subject: [PATCH 03/33] Return universal Thrift errors. --- faults/headers.go | 4 ++++ thriftbp/client_middlewares.go | 20 ++++++++++++++++---- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/faults/headers.go b/faults/headers.go index c694bb208..1d5739583 100644 --- a/faults/headers.go +++ b/faults/headers.go @@ -1,6 +1,7 @@ package faults const ( + // General FaultServerAddressHeader = "X-Bp-Fault-Server-Address" FaultDelayMsHeader = "X-Bp-Fault-Delay-Ms" FaultAbortCodeHeader = "X-Bp-Fault-Abort-Code" @@ -8,4 +9,7 @@ const ( FaultAbortMessageHeader = "X-Bp-Fault-Abort-Message" FaultDelayPercentageHeader = "X-Bp-Fault-Delay-Percentage" FaultAbortPercentageHeader = "X-Bp-Fault-Abort-Percentage" + + // Thrift-specific + FaultThriftErrorTypeHeader = "X-Bp-Fault-Thrift-Error-Type" ) diff --git a/thriftbp/client_middlewares.go b/thriftbp/client_middlewares.go index 0c77975fd..67f68a390 100644 --- a/thriftbp/client_middlewares.go +++ b/thriftbp/client_middlewares.go @@ -17,7 +17,6 @@ import ( "github.com/reddit/baseplate.go/errorsbp" "github.com/reddit/baseplate.go/faults" "github.com/reddit/baseplate.go/internal/gen-go/reddit/baseplate" - baseplatethrift "github.com/reddit/baseplate.go/internal/gen-go/reddit/baseplate" "github.com/reddit/baseplate.go/internal/thriftint" //lint:ignore SA1019 This library is internal only, not actually deprecated @@ -417,10 +416,23 @@ func FaultInjectionClientMiddleware(address string) thrift.ClientMiddleware { return next.Call(ctx, method, args, result) } responseFn := func(code int, message string) interface{} { - return &baseplatethrift.Error{ - Code: thrift.Int32Ptr(int32(code)), - Message: thrift.StringPtr(message), + switch errorType := getHeaderFn(faults.FaultThriftErrorTypeHeader); errorType { + case "transport": + if code <= thrift.UNKNOWN_TRANSPORT_EXCEPTION || code > thrift.END_OF_FILE { + return thrift.NewTTransportException(thrift.UNKNOWN_TRANSPORT_EXCEPTION, message) + } + case "protocol": + if code <= thrift.UNKNOWN_PROTOCOL_EXCEPTION || code > thrift.DEPTH_LIMIT { + return thrift.NewTProtocolExceptionWithType(thrift.UNKNOWN_PROTOCOL_EXCEPTION, errors.New(message)) + } + case "application": + if code <= thrift.UNKNOWN_APPLICATION_EXCEPTION || code > thrift.VALIDATION_FAILED { + return thrift.NewTApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, message) + } } + + // Log exception type doesn't match + return thrift.NewTTransportException(thrift.UNKNOWN_TRANSPORT_EXCEPTION, message) } responseMeta, err := faults.InjectFault(address, method, getHeaderFn, resumeFn, responseFn) From fb9035e721df35f6cfddac1bbd01def5cff504d2 Mon Sep 17 00:00:00 2001 From: Hiram Silvey <9485196+HiramSilvey@users.noreply.github.com> Date: Mon, 18 Nov 2024 17:11:09 -0500 Subject: [PATCH 04/33] Revert to only support Transport errors. --- faults/headers.go | 4 ---- thriftbp/client_middlewares.go | 19 +++---------------- 2 files changed, 3 insertions(+), 20 deletions(-) diff --git a/faults/headers.go b/faults/headers.go index 1d5739583..c694bb208 100644 --- a/faults/headers.go +++ b/faults/headers.go @@ -1,7 +1,6 @@ package faults const ( - // General FaultServerAddressHeader = "X-Bp-Fault-Server-Address" FaultDelayMsHeader = "X-Bp-Fault-Delay-Ms" FaultAbortCodeHeader = "X-Bp-Fault-Abort-Code" @@ -9,7 +8,4 @@ const ( FaultAbortMessageHeader = "X-Bp-Fault-Abort-Message" FaultDelayPercentageHeader = "X-Bp-Fault-Delay-Percentage" FaultAbortPercentageHeader = "X-Bp-Fault-Abort-Percentage" - - // Thrift-specific - FaultThriftErrorTypeHeader = "X-Bp-Fault-Thrift-Error-Type" ) diff --git a/thriftbp/client_middlewares.go b/thriftbp/client_middlewares.go index 67f68a390..9f241c384 100644 --- a/thriftbp/client_middlewares.go +++ b/thriftbp/client_middlewares.go @@ -416,23 +416,10 @@ func FaultInjectionClientMiddleware(address string) thrift.ClientMiddleware { return next.Call(ctx, method, args, result) } responseFn := func(code int, message string) interface{} { - switch errorType := getHeaderFn(faults.FaultThriftErrorTypeHeader); errorType { - case "transport": - if code <= thrift.UNKNOWN_TRANSPORT_EXCEPTION || code > thrift.END_OF_FILE { - return thrift.NewTTransportException(thrift.UNKNOWN_TRANSPORT_EXCEPTION, message) - } - case "protocol": - if code <= thrift.UNKNOWN_PROTOCOL_EXCEPTION || code > thrift.DEPTH_LIMIT { - return thrift.NewTProtocolExceptionWithType(thrift.UNKNOWN_PROTOCOL_EXCEPTION, errors.New(message)) - } - case "application": - if code <= thrift.UNKNOWN_APPLICATION_EXCEPTION || code > thrift.VALIDATION_FAILED { - return thrift.NewTApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, message) - } + if code <= thrift.UNKNOWN_TRANSPORT_EXCEPTION || code > thrift.END_OF_FILE { + return thrift.NewTTransportException(thrift.UNKNOWN_TRANSPORT_EXCEPTION, message) } - - // Log exception type doesn't match - return thrift.NewTTransportException(thrift.UNKNOWN_TRANSPORT_EXCEPTION, message) + return thrift.NewTTransportException(code, message) } responseMeta, err := faults.InjectFault(address, method, getHeaderFn, resumeFn, responseFn) From c4c7c1c257282f3be392587cd8dcb547fbf755dc Mon Sep 17 00:00:00 2001 From: Hiram Silvey <9485196+HiramSilvey@users.noreply.github.com> Date: Tue, 19 Nov 2024 12:29:17 -0500 Subject: [PATCH 05/33] Initial POC test. Needs refactoring still. --- faults/common.go | 23 ++++---- faults/headers.go | 4 +- go.mod | 1 + go.sum | 1 + httpbp/client_middlewares.go | 9 ++-- thriftbp/client_middlewares.go | 13 +++-- thriftbp/client_middlewares_test.go | 83 +++++++++++++++++++++++++++++ 7 files changed, 110 insertions(+), 24 deletions(-) diff --git a/faults/common.go b/faults/common.go index d35f4cc65..5d9d4a4fb 100644 --- a/faults/common.go +++ b/faults/common.go @@ -5,7 +5,6 @@ import ( "strings" "time" - "github.com/reddit/baseplate.go/faults" "golang.org/x/exp/rand" ) @@ -17,20 +16,20 @@ func getShortenedAddress(address string) string { return strings.Join(parts[:2], ".") } -func InjectFault(address, method string, getHeaderFn func(key string) string, resumeFn func() (interface{}, error), responseFn func(code int, message string) interface{}) (interface{}, error) { - serverAddress := getHeaderFn(faults.FaultServerAddressHeader) +func InjectFault(address, method string, abortCodeMin, abortCodeMax int, getHeaderFn func(key string) string, resumeFn func() (interface{}, error), responseFn func(code int, message string) (interface{}, error)) (interface{}, error) { + serverAddress := getHeaderFn(FaultServerAddressHeader) if serverAddress == "" || serverAddress != getShortenedAddress(address) { return resumeFn() } - serverMethod := getHeaderFn(faults.FaultServerMethodHeader) + serverMethod := getHeaderFn(FaultServerMethodHeader) if serverMethod != "" && serverMethod != method { return resumeFn() } - delayMs := getHeaderFn(faults.FaultDelayMsHeader) + delayMs := getHeaderFn(FaultDelayMsHeader) if delayMs != "" { - delayPercentage := getHeaderFn(faults.FaultDelayPercentageHeader) + delayPercentage := getHeaderFn(FaultDelayPercentageHeader) if delayPercentage != "" { percentage, err := strconv.Atoi(delayPercentage) if err != nil { @@ -54,9 +53,9 @@ func InjectFault(address, method string, getHeaderFn func(key string) string, re time.Sleep(time.Duration(delay) * time.Millisecond) } - abortCode := getHeaderFn(faults.FaultAbortCodeHeader) + abortCode := getHeaderFn(FaultAbortCodeHeader) if abortCode != "" { - abortPercentage := getHeaderFn(faults.FaultAbortPercentageHeader) + abortPercentage := getHeaderFn(FaultAbortPercentageHeader) if abortPercentage != "" { percentage, err := strconv.Atoi(abortPercentage) if err != nil { @@ -77,12 +76,12 @@ func InjectFault(address, method string, getHeaderFn func(key string) string, re // log "provided abort code is not a valid integer" return resumeFn() } - if code < 100 || code >= 600 { - // log "provided abort code is outside the valid HTTP status code range of [100-599]" + if code < abortCodeMin || code > abortCodeMax { + // log "provided abort code is outside of the valid range" return resumeFn() } - abortMessage := getHeaderFn(faults.FaultAbortMessageHeader) - return responseFn(code, abortMessage), nil + abortMessage := getHeaderFn(FaultAbortMessageHeader) + return responseFn(code, abortMessage) } return resumeFn() diff --git a/faults/headers.go b/faults/headers.go index c694bb208..8302e97e1 100644 --- a/faults/headers.go +++ b/faults/headers.go @@ -2,10 +2,10 @@ package faults const ( FaultServerAddressHeader = "X-Bp-Fault-Server-Address" + FaultServerMethodHeader = "X-Bp-Fault-Server-Method" FaultDelayMsHeader = "X-Bp-Fault-Delay-Ms" + FaultDelayPercentageHeader = "X-Bp-Fault-Delay-Percentage" FaultAbortCodeHeader = "X-Bp-Fault-Abort-Code" - FaultServerMethodHeader = "X-Bp-Fault-Server-Method" FaultAbortMessageHeader = "X-Bp-Fault-Abort-Message" - FaultDelayPercentageHeader = "X-Bp-Fault-Delay-Percentage" FaultAbortPercentageHeader = "X-Bp-Fault-Abort-Percentage" ) diff --git a/go.mod b/go.mod index 9eb78ce30..90c8c6738 100644 --- a/go.mod +++ b/go.mod @@ -22,6 +22,7 @@ require ( github.com/sony/gobreaker v0.4.1 go.uber.org/automaxprocs v1.5.1 go.uber.org/zap v1.24.0 + golang.org/x/exp v0.0.0-20190121172915-509febef88a4 golang.org/x/sys v0.18.0 golang.org/x/time v0.0.0-20220609170525-579cf78fd858 google.golang.org/grpc v1.56.3 diff --git a/go.sum b/go.sum index 2d48c7c67..e8c0106dc 100644 --- a/go.sum +++ b/go.sum @@ -337,6 +337,7 @@ golang.org/x/crypto v0.0.0-20201112155050-0c6587e931a9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4 h1:c2HOrn5iMezYjSlGPncknSEr/8x5LELb/ilJbXi9DEA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= diff --git a/httpbp/client_middlewares.go b/httpbp/client_middlewares.go index cc3a49938..c7403eb80 100644 --- a/httpbp/client_middlewares.go +++ b/httpbp/client_middlewares.go @@ -363,7 +363,7 @@ func FaultInjection(serverSlug string) ClientMiddleware { resumeFn := func() (interface{}, error) { return next.RoundTrip(req) } - responseFn := func(code int, message string) interface{} { + responseFn := func(code int, message string) (interface{}, error) { return &http.Response{ Status: http.StatusText(code), StatusCode: code, @@ -379,9 +379,12 @@ func FaultInjection(serverSlug string) ClientMiddleware { TransferEncoding: req.TransferEncoding, Request: req, TLS: req.TLS, - } + }, nil } - resp, err := faults.InjectFault(req.URL.Host, req.URL.Path, req.Header.Get, resumeFn, responseFn) + + abortCodeMin := 100 + abortCodeMax := 599 + resp, err := faults.InjectFault(req.URL.Host, req.URL.Path, abortCodeMin, abortCodeMax, req.Header.Get, resumeFn, responseFn) return resp.(*http.Response), err }) } diff --git a/thriftbp/client_middlewares.go b/thriftbp/client_middlewares.go index 9f241c384..a3e6ccf7f 100644 --- a/thriftbp/client_middlewares.go +++ b/thriftbp/client_middlewares.go @@ -415,15 +415,14 @@ func FaultInjectionClientMiddleware(address string) thrift.ClientMiddleware { resumeFn := func() (interface{}, error) { return next.Call(ctx, method, args, result) } - responseFn := func(code int, message string) interface{} { - if code <= thrift.UNKNOWN_TRANSPORT_EXCEPTION || code > thrift.END_OF_FILE { - return thrift.NewTTransportException(thrift.UNKNOWN_TRANSPORT_EXCEPTION, message) - } - return thrift.NewTTransportException(code, message) + responseFn := func(code int, message string) (interface{}, error) { + return thrift.ResponseMeta{}, thrift.NewTTransportException(code, message) } - responseMeta, err := faults.InjectFault(address, method, getHeaderFn, resumeFn, responseFn) - return responseMeta.(thrift.ResponseMeta), err + abortCodeMin := thrift.UNKNOWN_TRANSPORT_EXCEPTION + abortCodeMax := thrift.END_OF_FILE + resp, err := faults.InjectFault(address, method, abortCodeMin, abortCodeMax, getHeaderFn, resumeFn, responseFn) + return resp.(thrift.ResponseMeta), err }, } } diff --git a/thriftbp/client_middlewares_test.go b/thriftbp/client_middlewares_test.go index 7d3f24c0e..6f1a43c2b 100644 --- a/thriftbp/client_middlewares_test.go +++ b/thriftbp/client_middlewares_test.go @@ -12,6 +12,7 @@ import ( "github.com/reddit/baseplate.go" "github.com/reddit/baseplate.go/ecinterface" + "github.com/reddit/baseplate.go/faults" baseplatethrift "github.com/reddit/baseplate.go/internal/gen-go/reddit/baseplate" "github.com/reddit/baseplate.go/internal/prometheusbpint/spectest" "github.com/reddit/baseplate.go/prometheusbp" @@ -23,6 +24,7 @@ import ( const ( service = "testService" + address = "testService.testNamespace.svc.cluster.local" method = "testMethod" ) @@ -38,6 +40,7 @@ func initClients(ecImpl ecinterface.Interface) (*thrifttest.MockClient, *thriftt thriftbp.DefaultClientMiddlewareArgs{ EdgeContextImpl: ecImpl, ServiceSlug: service, + Address: address, }, )..., ) @@ -395,6 +398,86 @@ func TestPrometheusClientMiddleware(t *testing.T) { } } +func TestFaultInjectionClientMiddleware(t *testing.T) { + testCases := []struct { + name string + + faultServerAddrHeader string + faultServerMethodHeader string + faultDelayMsHeader string + faultDelayPercentageHeader string + faultAbortCodeHeader string + faultAbortMessageHeader string + faultAbortPercentageHeader string + + wantErr error + }{ + { + name: "no fault", + wantErr: nil, + }, + { + name: "abort", + + faultServerAddrHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultAbortCodeHeader: "1", // NOT_OPEN + faultAbortMessageHeader: "test fault", + + wantErr: thrift.NewTTransportException(1, "test fault"), + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + + impl := ecinterface.Mock() + ctx := context.Background() + + if tt.faultServerAddrHeader != "" { + ctx = thriftbp.AddClientHeader(ctx, faults.FaultServerAddressHeader, tt.faultServerAddrHeader) + } + if tt.faultServerMethodHeader != "" { + ctx = thriftbp.AddClientHeader(ctx, faults.FaultServerMethodHeader, tt.faultServerMethodHeader) + } + if tt.faultDelayMsHeader != "" { + ctx = thriftbp.AddClientHeader(ctx, faults.FaultDelayMsHeader, tt.faultDelayMsHeader) + } + if tt.faultDelayPercentageHeader != "" { + ctx = thriftbp.AddClientHeader(ctx, faults.FaultDelayPercentageHeader, tt.faultDelayPercentageHeader) + } + if tt.faultAbortCodeHeader != "" { + ctx = thriftbp.AddClientHeader(ctx, faults.FaultAbortCodeHeader, tt.faultAbortCodeHeader) + } + if tt.faultAbortMessageHeader != "" { + ctx = thriftbp.AddClientHeader(ctx, faults.FaultAbortMessageHeader, tt.faultAbortMessageHeader) + } + if tt.faultAbortPercentageHeader != "" { + ctx = thriftbp.AddClientHeader(ctx, faults.FaultAbortPercentageHeader, tt.faultAbortPercentageHeader) + } + + mock, _, client := initClients(impl) + mock.AddMockCall( + method, + func(ctx context.Context, args, result thrift.TStruct) (meta thrift.ResponseMeta, err error) { + return + }, + ) + + _, err := client.Call(ctx, method, nil, nil) + if tt.wantErr == nil && err != nil { + t.Fatalf("expected no error, got %v", err) + } + if tt.wantErr != nil && err == nil { + t.Fatal("expected an error, got nil") + } + if err != nil && err.Error() != tt.wantErr.Error() { + t.Fatalf("expected error %v, got %v", tt.wantErr, err) + } + }) + } +} + type mockBaseplateService struct { fail bool err error From ed571bf61c1c8f6358ff6aa7c1c0c4aff9d1f2d6 Mon Sep 17 00:00:00 2001 From: Hiram Silvey <9485196+HiramSilvey@users.noreply.github.com> Date: Fri, 22 Nov 2024 17:38:30 -0500 Subject: [PATCH 06/33] Refactored and moved heavy testing logic into common test. --- faults/common.go | 121 +++++++++++++++++----------- thriftbp/client_middlewares.go | 25 +++--- thriftbp/client_middlewares_test.go | 46 ++++++++++- 3 files changed, 136 insertions(+), 56 deletions(-) diff --git a/faults/common.go b/faults/common.go index 5d9d4a4fb..ba5fbbb90 100644 --- a/faults/common.go +++ b/faults/common.go @@ -8,6 +8,22 @@ import ( "golang.org/x/exp/rand" ) +type GetHeaderFn func(key string) string +type ResumeFn func() (interface{}, error) +type ResponseFn func(code int, message string) (interface{}, error) +type SleepFn func(d time.Duration) + +type randSingleton struct { + randInt *int +} + +func (r randSingleton) getRandInt() int { + if r.randInt == nil { + *r.randInt = rand.Intn(100) + } + return *r.randInt +} + func getShortenedAddress(address string) string { parts := strings.Split(address, ".") if len(parts) < 2 { @@ -16,73 +32,88 @@ func getShortenedAddress(address string) string { return strings.Join(parts[:2], ".") } -func InjectFault(address, method string, abortCodeMin, abortCodeMax int, getHeaderFn func(key string) string, resumeFn func() (interface{}, error), responseFn func(code int, message string) (interface{}, error)) (interface{}, error) { - serverAddress := getHeaderFn(FaultServerAddressHeader) - if serverAddress == "" || serverAddress != getShortenedAddress(address) { - return resumeFn() +func isSelected(percentageHeader string, GetHeaderFn func(key string) string, singleRand randSingleton) bool { + percentageStr := GetHeaderFn(percentageHeader) + if percentageStr == "" { + return true + } + percentage, err := strconv.Atoi(percentageStr) + if err != nil { + // log "provided delay percentage is not a valid integer" + return false + } + if percentage < 0 || percentage > 100 { + // log "provided delay percentage is outside the valid range of [0-100]" + return false } + return singleRand.getRandInt() < percentage +} + +type InjectFaultParams struct { + Address, Method string + AbortCodeMin, AbortCodeMax int - serverMethod := getHeaderFn(FaultServerMethodHeader) - if serverMethod != "" && serverMethod != method { - return resumeFn() + GetHeaderFn GetHeaderFn + ResumeFn ResumeFn + ResponseFn ResponseFn + + // Exposed for tests + RandInt *int + SleepFn *SleepFn +} + +func InjectFault(params InjectFaultParams) (interface{}, error) { + serverAddress := params.GetHeaderFn(FaultServerAddressHeader) + if serverAddress == "" || serverAddress != getShortenedAddress(params.Address) { + return params.ResumeFn() + } + + serverMethod := params.GetHeaderFn(FaultServerMethodHeader) + if serverMethod != "" && serverMethod != params.Method { + return params.ResumeFn() } - delayMs := getHeaderFn(FaultDelayMsHeader) + singleRand := randSingleton{ + randInt: params.RandInt, + } + + delayMs := params.GetHeaderFn(FaultDelayMsHeader) if delayMs != "" { - delayPercentage := getHeaderFn(FaultDelayPercentageHeader) - if delayPercentage != "" { - percentage, err := strconv.Atoi(delayPercentage) - if err != nil { - // log "provided delay percentage is not a valid integer" - return resumeFn() - } - if percentage < 0 || percentage > 100 { - // log "provided delay percentage is outside the valid range of [0-100]" - return resumeFn() - } - if percentage == 0 || (percentage != 100 && rand.Intn(100) >= percentage) { - return resumeFn() - } + if !isSelected(FaultDelayPercentageHeader, params.GetHeaderFn, singleRand) { + return params.ResumeFn() } delay, err := strconv.Atoi(delayMs) if err != nil { // log "provided delay is not a valid integer" - return resumeFn() + return params.ResumeFn() + } + + sleepFn := time.Sleep + if params.SleepFn != nil { + sleepFn = *params.SleepFn } - time.Sleep(time.Duration(delay) * time.Millisecond) + sleepFn(time.Duration(delay) * time.Millisecond) } - abortCode := getHeaderFn(FaultAbortCodeHeader) + abortCode := params.GetHeaderFn(FaultAbortCodeHeader) if abortCode != "" { - abortPercentage := getHeaderFn(FaultAbortPercentageHeader) - if abortPercentage != "" { - percentage, err := strconv.Atoi(abortPercentage) - if err != nil { - // log "provided abort percentage is not a valid integer" - return resumeFn() - } - if percentage < 0 || percentage > 100 { - // log "provided abort percentage is outside the valid range of [0-100]" - return resumeFn() - } - if percentage != 100 && rand.Intn(100) >= percentage { - return resumeFn() - } + if !isSelected(FaultAbortPercentageHeader, params.GetHeaderFn, singleRand) { + return params.ResumeFn() } code, err := strconv.Atoi(abortCode) if err != nil { // log "provided abort code is not a valid integer" - return resumeFn() + return params.ResumeFn() } - if code < abortCodeMin || code > abortCodeMax { + if code < params.AbortCodeMin || code > params.AbortCodeMax { // log "provided abort code is outside of the valid range" - return resumeFn() + return params.ResumeFn() } - abortMessage := getHeaderFn(FaultAbortMessageHeader) - return responseFn(code, abortMessage) + abortMessage := params.GetHeaderFn(FaultAbortMessageHeader) + return params.ResponseFn(code, abortMessage) } - return resumeFn() + return params.ResumeFn() } diff --git a/thriftbp/client_middlewares.go b/thriftbp/client_middlewares.go index a3e6ccf7f..02f82fc22 100644 --- a/thriftbp/client_middlewares.go +++ b/thriftbp/client_middlewares.go @@ -405,23 +405,28 @@ func FaultInjectionClientMiddleware(address string) thrift.ClientMiddleware { return func(next thrift.TClient) thrift.TClient { return thrift.WrappedTClient{ Wrapped: func(ctx context.Context, method string, args, result thrift.TStruct) (thrift.ResponseMeta, error) { - getHeaderFn := func(key string) string { + getHeaderFn := faults.GetHeaderFn(func(key string) string { header, ok := thrift.GetHeader(ctx, key) if !ok { return "" } return header - } - resumeFn := func() (interface{}, error) { + }) + resumeFn := faults.ResumeFn(func() (interface{}, error) { return next.Call(ctx, method, args, result) - } - responseFn := func(code int, message string) (interface{}, error) { + }) + responseFn := faults.ResponseFn(func(code int, message string) (interface{}, error) { return thrift.ResponseMeta{}, thrift.NewTTransportException(code, message) - } - - abortCodeMin := thrift.UNKNOWN_TRANSPORT_EXCEPTION - abortCodeMax := thrift.END_OF_FILE - resp, err := faults.InjectFault(address, method, abortCodeMin, abortCodeMax, getHeaderFn, resumeFn, responseFn) + }) + + resp, err := faults.InjectFault(faults.InjectFaultParams{ + Address: address, + Method: method, + AbortCodeMin: thrift.UNKNOWN_TRANSPORT_EXCEPTION, + AbortCodeMax: thrift.END_OF_FILE, + GetHeaderFn: getHeaderFn, + ResumeFn: resumeFn, + ResponseFn: responseFn}) return resp.(thrift.ResponseMeta), err }, } diff --git a/thriftbp/client_middlewares_test.go b/thriftbp/client_middlewares_test.go index 6f1a43c2b..256de4fe7 100644 --- a/thriftbp/client_middlewares_test.go +++ b/thriftbp/client_middlewares_test.go @@ -398,6 +398,10 @@ func TestPrometheusClientMiddleware(t *testing.T) { } } +func intPtr(i int) *int { + return &i +} + func TestFaultInjectionClientMiddleware(t *testing.T) { testCases := []struct { name string @@ -413,7 +417,7 @@ func TestFaultInjectionClientMiddleware(t *testing.T) { wantErr error }{ { - name: "no fault", + name: "no fault specified", wantErr: nil, }, { @@ -426,6 +430,46 @@ func TestFaultInjectionClientMiddleware(t *testing.T) { wantErr: thrift.NewTTransportException(1, "test fault"), }, + { + name: "service does not match", + + faultServerAddrHeader: "fooService.testNamespace", + faultServerMethodHeader: "testMethod", + faultAbortCodeHeader: "1", // NOT_OPEN + faultAbortMessageHeader: "test fault", + + wantErr: nil, + }, + { + name: "method does not match", + + faultServerAddrHeader: "testService.testNamespace", + faultServerMethodHeader: "fooMethod", + faultAbortCodeHeader: "1", // NOT_OPEN + faultAbortMessageHeader: "test fault", + + wantErr: nil, + }, + { + name: "less than min abort code", + + faultServerAddrHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultAbortCodeHeader: "-1", + faultAbortMessageHeader: "test fault", + + wantErr: nil, + }, + { + name: "greater than max abort code", + + faultServerAddrHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultAbortCodeHeader: "5", + faultAbortMessageHeader: "test fault", + + wantErr: nil, + }, } for _, tt := range testCases { From 73cee20388d853d1bd0b671741525a1d6451232a Mon Sep 17 00:00:00 2001 From: Hiram Silvey <9485196+HiramSilvey@users.noreply.github.com> Date: Mon, 25 Nov 2024 18:07:56 -0500 Subject: [PATCH 07/33] WIP HTTP tests. Need to adjust server name matching either for tests or permanently. The issue with using server prefix instead of the . paradigm is that care needs to be taken to ensure the prefix doesn't match too many destinations. Blast radius can explode in that model. Prometheus exports and logging are next after that. --- httpbp/client_middlewares.go | 12 ++- httpbp/client_middlewares_test.go | 127 ++++++++++++++++++++++++++++++ 2 files changed, 136 insertions(+), 3 deletions(-) diff --git a/httpbp/client_middlewares.go b/httpbp/client_middlewares.go index c7403eb80..7aec8f6d0 100644 --- a/httpbp/client_middlewares.go +++ b/httpbp/client_middlewares.go @@ -78,6 +78,7 @@ func NewClient(config ClientConfig, middleware ...ClientMiddleware) (*http.Clien } defaults := []ClientMiddleware{ + FaultInjection(config.Slug), MonitorClient(config.Slug + transport.WithRetrySlugSuffix), PrometheusClientMetrics(config.Slug + transport.WithRetrySlugSuffix), Retries(config.MaxErrorReadAhead, config.RetryOptions...), @@ -382,9 +383,14 @@ func FaultInjection(serverSlug string) ClientMiddleware { }, nil } - abortCodeMin := 100 - abortCodeMax := 599 - resp, err := faults.InjectFault(req.URL.Host, req.URL.Path, abortCodeMin, abortCodeMax, req.Header.Get, resumeFn, responseFn) + resp, err := faults.InjectFault(faults.InjectFaultParams{ + Address: req.URL.Host, + Method: req.URL.Path, + AbortCodeMin: 100, + AbortCodeMax: 599, + GetHeaderFn: req.Header.Get, + ResumeFn: resumeFn, + ResponseFn: responseFn}) return resp.(*http.Response), err }) } diff --git a/httpbp/client_middlewares_test.go b/httpbp/client_middlewares_test.go index 735d2fe9f..1020baaaf 100644 --- a/httpbp/client_middlewares_test.go +++ b/httpbp/client_middlewares_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "runtime/debug" "sync" "sync/atomic" "testing" @@ -16,6 +17,7 @@ import ( "github.com/sony/gobreaker" "github.com/reddit/baseplate.go/breakerbp" + "github.com/reddit/baseplate.go/faults" ) func TestNewClient(t *testing.T) { @@ -395,3 +397,128 @@ func TestCircuitBreaker(t *testing.T) { t.Errorf("Expected the third request to return %v, got %v", gobreaker.ErrOpenState, err) } } + +func TestFaultInjection(t *testing.T) { + testCases := []struct { + name string + faultServerAddrHeader string + faultServerMethodHeader string + faultDelayMsHeader string + faultDelayPercentageHeader string + faultAbortCodeHeader string + faultAbortMessageHeader string + faultAbortPercentageHeader string + + wantResp *http.Response + }{ + { + name: "no fault specified", + wantResp: &http.Response{ + StatusCode: http.StatusOK, + }, + }, + { + name: "abort", + + faultServerAddrHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultAbortCodeHeader: "500", + + wantResp: &http.Response{ + StatusCode: http.StatusInternalServerError, + }, + }, + { + name: "service does not match", + + faultServerAddrHeader: "fooService.testNamespace", + faultServerMethodHeader: "testMethod", + faultAbortCodeHeader: "500", + + wantResp: &http.Response{ + StatusCode: http.StatusOK, + }, + }, + { + name: "method does not match", + + faultServerAddrHeader: "testService.testNamespace", + faultServerMethodHeader: "fooMethod", + faultAbortCodeHeader: "500", + + wantResp: &http.Response{ + StatusCode: http.StatusOK, + }, + }, + { + name: "less than min abort code", + + faultServerAddrHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultAbortCodeHeader: "99", + + wantResp: &http.Response{ + StatusCode: http.StatusOK, + }, + }, + { + name: "greater than max abort code", + + faultServerAddrHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultAbortCodeHeader: "600", + + wantResp: &http.Response{ + StatusCode: http.StatusOK, + }, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "Success!") + })) + defer server.Close() + + client := server.Client() + req, err := http.NewRequest("GET", server.URL, nil) + if err != nil { + t.Fatalf("unexpected error when creating request: %v", err) + } + + if tt.faultServerAddrHeader != "" { + req.Header.Set(faults.FaultServerAddressHeader, tt.faultServerAddrHeader) + } + if tt.faultServerMethodHeader != "" { + req.Header.Set(faults.FaultServerMethodHeader, tt.faultServerMethodHeader) + } + if tt.faultDelayMsHeader != "" { + req.Header.Set(faults.FaultDelayMsHeader, tt.faultDelayMsHeader) + } + if tt.faultDelayPercentageHeader != "" { + req.Header.Set(faults.FaultDelayPercentageHeader, tt.faultDelayPercentageHeader) + } + if tt.faultAbortCodeHeader != "" { + req.Header.Set(faults.FaultAbortCodeHeader, tt.faultAbortCodeHeader) + } + if tt.faultAbortMessageHeader != "" { + req.Header.Set(faults.FaultAbortMessageHeader, tt.faultAbortMessageHeader) + } + if tt.faultAbortPercentageHeader != "" { + req.Header.Set(faults.FaultAbortPercentageHeader, tt.faultAbortPercentageHeader) + } + + resp, err := client.Do(req) + + if err != nil { + t.Log(string(debug.Stack())) + t.Fatalf("expected no error, got %v", err) + } + if tt.wantResp.StatusCode != resp.StatusCode { + t.Fatalf("expected response code %v, got %v", tt.wantResp.StatusCode, resp.StatusCode) + } + }) + } +} From f18a4c3e302438c10c0745084d95f5183424f62f Mon Sep 17 00:00:00 2001 From: Hiram Silvey <9485196+HiramSilvey@users.noreply.github.com> Date: Fri, 6 Dec 2024 12:17:27 -0500 Subject: [PATCH 08/33] Fixed HTTP tests. --- httpbp/client_middlewares.go | 3 ++- httpbp/client_middlewares_test.go | 32 +++++++++++++++++++++---------- 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/httpbp/client_middlewares.go b/httpbp/client_middlewares.go index 7aec8f6d0..d7a092c7b 100644 --- a/httpbp/client_middlewares.go +++ b/httpbp/client_middlewares.go @@ -7,6 +7,7 @@ import ( "log/slog" "net/http" "strconv" + "strings" "sync" "sync/atomic" "time" @@ -385,7 +386,7 @@ func FaultInjection(serverSlug string) ClientMiddleware { resp, err := faults.InjectFault(faults.InjectFaultParams{ Address: req.URL.Host, - Method: req.URL.Path, + Method: strings.TrimPrefix(req.URL.Path, "/"), AbortCodeMin: 100, AbortCodeMax: 599, GetHeaderFn: req.Header.Get, diff --git a/httpbp/client_middlewares_test.go b/httpbp/client_middlewares_test.go index 1020baaaf..45feabb7b 100644 --- a/httpbp/client_middlewares_test.go +++ b/httpbp/client_middlewares_test.go @@ -401,7 +401,7 @@ func TestCircuitBreaker(t *testing.T) { func TestFaultInjection(t *testing.T) { testCases := []struct { name string - faultServerAddrHeader string + faultServerAddrMatch bool faultServerMethodHeader string faultDelayMsHeader string faultDelayPercentageHeader string @@ -420,7 +420,7 @@ func TestFaultInjection(t *testing.T) { { name: "abort", - faultServerAddrHeader: "testService.testNamespace", + faultServerAddrMatch: true, faultServerMethodHeader: "testMethod", faultAbortCodeHeader: "500", @@ -431,7 +431,7 @@ func TestFaultInjection(t *testing.T) { { name: "service does not match", - faultServerAddrHeader: "fooService.testNamespace", + faultServerAddrMatch: false, faultServerMethodHeader: "testMethod", faultAbortCodeHeader: "500", @@ -442,7 +442,7 @@ func TestFaultInjection(t *testing.T) { { name: "method does not match", - faultServerAddrHeader: "testService.testNamespace", + faultServerAddrMatch: true, faultServerMethodHeader: "fooMethod", faultAbortCodeHeader: "500", @@ -453,7 +453,7 @@ func TestFaultInjection(t *testing.T) { { name: "less than min abort code", - faultServerAddrHeader: "testService.testNamespace", + faultServerAddrMatch: true, faultServerMethodHeader: "testMethod", faultAbortCodeHeader: "99", @@ -464,7 +464,7 @@ func TestFaultInjection(t *testing.T) { { name: "greater than max abort code", - faultServerAddrHeader: "testService.testNamespace", + faultServerAddrMatch: true, faultServerMethodHeader: "testMethod", faultAbortCodeHeader: "600", @@ -482,14 +482,26 @@ func TestFaultInjection(t *testing.T) { })) defer server.Close() - client := server.Client() - req, err := http.NewRequest("GET", server.URL, nil) + client, err := NewClient(ClientConfig{ + Slug: "test", + }) + if err != nil { + t.Fatalf("NewClient returned error: %v", err) + } + + req, err := http.NewRequest("GET", server.URL+"/testMethod", nil) if err != nil { t.Fatalf("unexpected error when creating request: %v", err) } - if tt.faultServerAddrHeader != "" { - req.Header.Set(faults.FaultServerAddressHeader, tt.faultServerAddrHeader) + if tt.faultServerAddrMatch { + // This is a hack to get the fault injection middleware address + // matching to work with an HTTP test server, as the fault injection + // middleware relies on the DNS address for request matching. HTTP + // test servers run on the local machine, and modifying the DNS + // logic to route an arbitrary test address to this local server is + // extremely non-trivial. + req.Header.Set(faults.FaultServerAddressHeader, "127.0") } if tt.faultServerMethodHeader != "" { req.Header.Set(faults.FaultServerMethodHeader, tt.faultServerMethodHeader) From 43a5c5743c45c2043599dfa7b1e3e3088c71718a Mon Sep 17 00:00:00 2001 From: Hiram Silvey <9485196+HiramSilvey@users.noreply.github.com> Date: Fri, 6 Dec 2024 14:37:15 -0500 Subject: [PATCH 09/33] Add logging and refine tests. --- faults/common.go | 38 ++++++++++++++----------------- httpbp/client_middlewares.go | 15 ++++++------ httpbp/client_middlewares_test.go | 12 ++++------ thriftbp/client_middlewares.go | 1 + thriftbp/client_pool.go | 1 + 5 files changed, 32 insertions(+), 35 deletions(-) diff --git a/faults/common.go b/faults/common.go index ba5fbbb90..326f4b05f 100644 --- a/faults/common.go +++ b/faults/common.go @@ -1,6 +1,8 @@ package faults import ( + "fmt" + "log/slog" "strconv" "strings" "time" @@ -24,32 +26,24 @@ func (r randSingleton) getRandInt() int { return *r.randInt } -func getShortenedAddress(address string) string { - parts := strings.Split(address, ".") - if len(parts) < 2 { - return "" - } - return strings.Join(parts[:2], ".") -} - -func isSelected(percentageHeader string, GetHeaderFn func(key string) string, singleRand randSingleton) bool { +func isSelected(percentageHeader string, GetHeaderFn func(key string) string, singleRand randSingleton) (bool, string) { percentageStr := GetHeaderFn(percentageHeader) if percentageStr == "" { - return true + return true, "" } percentage, err := strconv.Atoi(percentageStr) if err != nil { - // log "provided delay percentage is not a valid integer" - return false + return false, fmt.Sprintf("provided delay percentage %s is not a valid integer", percentageStr) } if percentage < 0 || percentage > 100 { - // log "provided delay percentage is outside the valid range of [0-100]" - return false + return false, fmt.Sprintf("provided delay percentage %d is outside the valid range of [0-100]", percentage) } - return singleRand.getRandInt() < percentage + return singleRand.getRandInt() < percentage, "" } type InjectFaultParams struct { + CallerName string + Address, Method string AbortCodeMin, AbortCodeMax int @@ -64,7 +58,7 @@ type InjectFaultParams struct { func InjectFault(params InjectFaultParams) (interface{}, error) { serverAddress := params.GetHeaderFn(FaultServerAddressHeader) - if serverAddress == "" || serverAddress != getShortenedAddress(params.Address) { + if serverAddress == "" || serverAddress != strings.TrimSuffix(params.Address, ".svc.cluster.local") { return params.ResumeFn() } @@ -79,13 +73,14 @@ func InjectFault(params InjectFaultParams) (interface{}, error) { delayMs := params.GetHeaderFn(FaultDelayMsHeader) if delayMs != "" { - if !isSelected(FaultDelayPercentageHeader, params.GetHeaderFn, singleRand) { + if selected, msg := isSelected(FaultDelayPercentageHeader, params.GetHeaderFn, singleRand); !selected { + slog.Warn(fmt.Sprintf("%s: %s", params.CallerName, msg)) return params.ResumeFn() } delay, err := strconv.Atoi(delayMs) if err != nil { - // log "provided delay is not a valid integer" + slog.Warn(fmt.Sprintf("%s: provided delay %s is not a valid integer", params.CallerName, delayMs)) return params.ResumeFn() } @@ -98,17 +93,18 @@ func InjectFault(params InjectFaultParams) (interface{}, error) { abortCode := params.GetHeaderFn(FaultAbortCodeHeader) if abortCode != "" { - if !isSelected(FaultAbortPercentageHeader, params.GetHeaderFn, singleRand) { + if selected, msg := isSelected(FaultAbortPercentageHeader, params.GetHeaderFn, singleRand); !selected { + slog.Warn(fmt.Sprintf("%s: %s", params.CallerName, msg)) return params.ResumeFn() } code, err := strconv.Atoi(abortCode) if err != nil { - // log "provided abort code is not a valid integer" + slog.Warn(fmt.Sprintf("%s: provided abort code %s is not a valid integer", params.CallerName, abortCode)) return params.ResumeFn() } if code < params.AbortCodeMin || code > params.AbortCodeMax { - // log "provided abort code is outside of the valid range" + slog.Warn(fmt.Sprintf("%s: provided abort code %d is outside of the valid range", params.CallerName, code)) return params.ResumeFn() } abortMessage := params.GetHeaderFn(FaultAbortMessageHeader) diff --git a/httpbp/client_middlewares.go b/httpbp/client_middlewares.go index d7a092c7b..aa6217500 100644 --- a/httpbp/client_middlewares.go +++ b/httpbp/client_middlewares.go @@ -79,7 +79,7 @@ func NewClient(config ClientConfig, middleware ...ClientMiddleware) (*http.Clien } defaults := []ClientMiddleware{ - FaultInjection(config.Slug), + FaultInjection(), MonitorClient(config.Slug + transport.WithRetrySlugSuffix), PrometheusClientMetrics(config.Slug + transport.WithRetrySlugSuffix), Retries(config.MaxErrorReadAhead, config.RetryOptions...), @@ -359,13 +359,13 @@ type ServiceAddressParts struct { Namespace string } -func FaultInjection(serverSlug string) ClientMiddleware { +func FaultInjection() ClientMiddleware { return func(next http.RoundTripper) http.RoundTripper { return roundTripperFunc(func(req *http.Request) (*http.Response, error) { - resumeFn := func() (interface{}, error) { + resumeFn := faults.ResumeFn(func() (interface{}, error) { return next.RoundTrip(req) - } - responseFn := func(code int, message string) (interface{}, error) { + }) + responseFn := faults.ResponseFn(func(code int, message string) (interface{}, error) { return &http.Response{ Status: http.StatusText(code), StatusCode: code, @@ -382,14 +382,15 @@ func FaultInjection(serverSlug string) ClientMiddleware { Request: req, TLS: req.TLS, }, nil - } + }) resp, err := faults.InjectFault(faults.InjectFaultParams{ + CallerName: "httpbp.FaultInjection", Address: req.URL.Host, Method: strings.TrimPrefix(req.URL.Path, "/"), AbortCodeMin: 100, AbortCodeMax: 599, - GetHeaderFn: req.Header.Get, + GetHeaderFn: faults.GetHeaderFn(req.Header.Get), ResumeFn: resumeFn, ResponseFn: responseFn}) return resp.(*http.Response), err diff --git a/httpbp/client_middlewares_test.go b/httpbp/client_middlewares_test.go index 45feabb7b..4d4312a17 100644 --- a/httpbp/client_middlewares_test.go +++ b/httpbp/client_middlewares_test.go @@ -8,6 +8,7 @@ import ( "net/http/httptest" "net/url" "runtime/debug" + "strings" "sync" "sync/atomic" "testing" @@ -495,13 +496,10 @@ func TestFaultInjection(t *testing.T) { } if tt.faultServerAddrMatch { - // This is a hack to get the fault injection middleware address - // matching to work with an HTTP test server, as the fault injection - // middleware relies on the DNS address for request matching. HTTP - // test servers run on the local machine, and modifying the DNS - // logic to route an arbitrary test address to this local server is - // extremely non-trivial. - req.Header.Set(faults.FaultServerAddressHeader, "127.0") + // We can't set a specific address here because the middleware + // relies on the DNS address, which is not customizeable when making + // real requests to a local HTTP test server. + req.Header.Set(faults.FaultServerAddressHeader, strings.TrimPrefix(server.URL, "http://")) } if tt.faultServerMethodHeader != "" { req.Header.Set(faults.FaultServerMethodHeader, tt.faultServerMethodHeader) diff --git a/thriftbp/client_middlewares.go b/thriftbp/client_middlewares.go index 02f82fc22..f03279a5b 100644 --- a/thriftbp/client_middlewares.go +++ b/thriftbp/client_middlewares.go @@ -420,6 +420,7 @@ func FaultInjectionClientMiddleware(address string) thrift.ClientMiddleware { }) resp, err := faults.InjectFault(faults.InjectFaultParams{ + CallerName: "thriftpb.FaultInjectionClientMiddleware", Address: address, Method: method, AbortCodeMin: thrift.UNKNOWN_TRANSPORT_EXCEPTION, diff --git a/thriftbp/client_pool.go b/thriftbp/client_pool.go index 6a84ab671..547b757a0 100644 --- a/thriftbp/client_pool.go +++ b/thriftbp/client_pool.go @@ -401,6 +401,7 @@ func NewBaseplateClientPoolWithContext(ctx context.Context, cfg ClientPoolConfig } defaults := BaseplateDefaultClientMiddlewares( DefaultClientMiddlewareArgs{ + Address: cfg.Addr, EdgeContextImpl: cfg.EdgeContextImpl, ServiceSlug: cfg.ServiceSlug, RetryOptions: cfg.DefaultRetryOptions, From ac4e91a6b5a139069a38da4244b4cab8af81a204 Mon Sep 17 00:00:00 2001 From: Hiram Silvey <9485196+HiramSilvey@users.noreply.github.com> Date: Fri, 6 Dec 2024 15:09:28 -0500 Subject: [PATCH 10/33] Use official rand library and update logging. --- faults/common.go | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/faults/common.go b/faults/common.go index 326f4b05f..0d6df019c 100644 --- a/faults/common.go +++ b/faults/common.go @@ -3,25 +3,26 @@ package faults import ( "fmt" "log/slog" + "math/rand/v2" "strconv" "strings" "time" - - "golang.org/x/exp/rand" ) +// Function signatures for use in protocol-specific dependency injection. type GetHeaderFn func(key string) string type ResumeFn func() (interface{}, error) type ResponseFn func(code int, message string) (interface{}, error) type SleepFn func(d time.Duration) +// Object to ensure a random number is only generated at most 1 time. type randSingleton struct { randInt *int } func (r randSingleton) getRandInt() int { if r.randInt == nil { - *r.randInt = rand.Intn(100) + *r.randInt = rand.IntN(100) } return *r.randInt } @@ -33,10 +34,10 @@ func isSelected(percentageHeader string, GetHeaderFn func(key string) string, si } percentage, err := strconv.Atoi(percentageStr) if err != nil { - return false, fmt.Sprintf("provided delay percentage %s is not a valid integer", percentageStr) + return false, fmt.Sprintf("provided percentage \"%s\" is not a valid integer", percentageStr) } if percentage < 0 || percentage > 100 { - return false, fmt.Sprintf("provided delay percentage %d is outside the valid range of [0-100]", percentage) + return false, fmt.Sprintf("provided percentage \"%d\" is outside the valid range of [0-100]", percentage) } return singleRand.getRandInt() < percentage, "" } @@ -51,7 +52,7 @@ type InjectFaultParams struct { ResumeFn ResumeFn ResponseFn ResponseFn - // Exposed for tests + // Exposed for tests. RandInt *int SleepFn *SleepFn } @@ -73,14 +74,16 @@ func InjectFault(params InjectFaultParams) (interface{}, error) { delayMs := params.GetHeaderFn(FaultDelayMsHeader) if delayMs != "" { - if selected, msg := isSelected(FaultDelayPercentageHeader, params.GetHeaderFn, singleRand); !selected { - slog.Warn(fmt.Sprintf("%s: %s", params.CallerName, msg)) + if selected, reason := isSelected(FaultDelayPercentageHeader, params.GetHeaderFn, singleRand); !selected { + if reason != "" { + slog.Warn(fmt.Sprintf("%s: %s", params.CallerName, reason)) + } return params.ResumeFn() } delay, err := strconv.Atoi(delayMs) if err != nil { - slog.Warn(fmt.Sprintf("%s: provided delay %s is not a valid integer", params.CallerName, delayMs)) + slog.Warn(fmt.Sprintf("%s: provided delay \"%s\" is not a valid integer", params.CallerName, delayMs)) return params.ResumeFn() } @@ -93,18 +96,20 @@ func InjectFault(params InjectFaultParams) (interface{}, error) { abortCode := params.GetHeaderFn(FaultAbortCodeHeader) if abortCode != "" { - if selected, msg := isSelected(FaultAbortPercentageHeader, params.GetHeaderFn, singleRand); !selected { - slog.Warn(fmt.Sprintf("%s: %s", params.CallerName, msg)) + if selected, reason := isSelected(FaultAbortPercentageHeader, params.GetHeaderFn, singleRand); !selected { + if reason != "" { + slog.Warn(fmt.Sprintf("%s: %s", params.CallerName, reason)) + } return params.ResumeFn() } code, err := strconv.Atoi(abortCode) if err != nil { - slog.Warn(fmt.Sprintf("%s: provided abort code %s is not a valid integer", params.CallerName, abortCode)) + slog.Warn(fmt.Sprintf("%s: provided abort code \"%s\" is not a valid integer", params.CallerName, abortCode)) return params.ResumeFn() } if code < params.AbortCodeMin || code > params.AbortCodeMax { - slog.Warn(fmt.Sprintf("%s: provided abort code %d is outside of the valid range", params.CallerName, code)) + slog.Warn(fmt.Sprintf("%s: provided abort code \"%d\" is outside of the valid range", params.CallerName, code)) return params.ResumeFn() } abortMessage := params.GetHeaderFn(FaultAbortMessageHeader) From b2d100346b27ebe111a9e7d49de631228309733f Mon Sep 17 00:00:00 2001 From: Hiram Silvey <9485196+HiramSilvey@users.noreply.github.com> Date: Fri, 6 Dec 2024 15:11:00 -0500 Subject: [PATCH 11/33] go mod tidy --- go.mod | 1 - go.sum | 1 - 2 files changed, 2 deletions(-) diff --git a/go.mod b/go.mod index 90c8c6738..9eb78ce30 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,6 @@ require ( github.com/sony/gobreaker v0.4.1 go.uber.org/automaxprocs v1.5.1 go.uber.org/zap v1.24.0 - golang.org/x/exp v0.0.0-20190121172915-509febef88a4 golang.org/x/sys v0.18.0 golang.org/x/time v0.0.0-20220609170525-579cf78fd858 google.golang.org/grpc v1.56.3 diff --git a/go.sum b/go.sum index e8c0106dc..2d48c7c67 100644 --- a/go.sum +++ b/go.sum @@ -337,7 +337,6 @@ golang.org/x/crypto v0.0.0-20201112155050-0c6587e931a9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= -golang.org/x/exp v0.0.0-20190121172915-509febef88a4 h1:c2HOrn5iMezYjSlGPncknSEr/8x5LELb/ilJbXi9DEA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= From 99dc665b01c879c9487cec34087204c952ffce34 Mon Sep 17 00:00:00 2001 From: Hiram Silvey <9485196+HiramSilvey@users.noreply.github.com> Date: Fri, 6 Dec 2024 15:35:17 -0500 Subject: [PATCH 12/33] Fixed shortened address to include port trimming. --- faults/common.go | 6 +++++- httpbp/client_middlewares.go | 2 ++ thriftbp/client_middlewares_test.go | 2 +- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/faults/common.go b/faults/common.go index 0d6df019c..0c456849d 100644 --- a/faults/common.go +++ b/faults/common.go @@ -59,7 +59,11 @@ type InjectFaultParams struct { func InjectFault(params InjectFaultParams) (interface{}, error) { serverAddress := params.GetHeaderFn(FaultServerAddressHeader) - if serverAddress == "" || serverAddress != strings.TrimSuffix(params.Address, ".svc.cluster.local") { + shortAddress := params.Address + if i := strings.Index(params.Address, ".svc.cluster.local"); i != -1 { + shortAddress = params.Address[:i] + } + if serverAddress == "" || serverAddress != shortAddress { return params.ResumeFn() } diff --git a/httpbp/client_middlewares.go b/httpbp/client_middlewares.go index aa6217500..9d524921c 100644 --- a/httpbp/client_middlewares.go +++ b/httpbp/client_middlewares.go @@ -46,6 +46,8 @@ func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { // plus any additional client middleware passed into this function. Default // middlewares are: // +// * FaultInjection +// // * MonitorClient with transport.WithRetrySlugSuffix // // * PrometheusClientMetrics with transport.WithRetrySlugSuffix diff --git a/thriftbp/client_middlewares_test.go b/thriftbp/client_middlewares_test.go index 256de4fe7..c5c4cc0e4 100644 --- a/thriftbp/client_middlewares_test.go +++ b/thriftbp/client_middlewares_test.go @@ -24,7 +24,7 @@ import ( const ( service = "testService" - address = "testService.testNamespace.svc.cluster.local" + address = "testService.testNamespace.svc.cluster.local:12345" method = "testMethod" ) From 3f63e058625e831758bdf60e737651ced9df502c Mon Sep 17 00:00:00 2001 From: Hiram Silvey <9485196+HiramSilvey@users.noreply.github.com> Date: Fri, 6 Dec 2024 15:48:50 -0500 Subject: [PATCH 13/33] Remove unused function in test code. --- thriftbp/client_middlewares_test.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/thriftbp/client_middlewares_test.go b/thriftbp/client_middlewares_test.go index c5c4cc0e4..d0ba2932e 100644 --- a/thriftbp/client_middlewares_test.go +++ b/thriftbp/client_middlewares_test.go @@ -398,10 +398,6 @@ func TestPrometheusClientMiddleware(t *testing.T) { } } -func intPtr(i int) *int { - return &i -} - func TestFaultInjectionClientMiddleware(t *testing.T) { testCases := []struct { name string From 4444f33230a1ee0fe1aaa9aae40a270b3defc9d9 Mon Sep 17 00:00:00 2001 From: Hiram Silvey <9485196+HiramSilvey@users.noreply.github.com> Date: Fri, 6 Dec 2024 15:53:22 -0500 Subject: [PATCH 14/33] Fix lint errors. --- faults/common.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/faults/common.go b/faults/common.go index 0c456849d..77f6a9435 100644 --- a/faults/common.go +++ b/faults/common.go @@ -1,3 +1,4 @@ +// Package faults provides common headers and client-side fault injection functionality. package faults import ( @@ -9,7 +10,6 @@ import ( "time" ) -// Function signatures for use in protocol-specific dependency injection. type GetHeaderFn func(key string) string type ResumeFn func() (interface{}, error) type ResponseFn func(code int, message string) (interface{}, error) From 1c7754886cc68f1b632b955bc03a811191d83d59 Mon Sep 17 00:00:00 2001 From: Hiram Silvey <9485196+HiramSilvey@users.noreply.github.com> Date: Tue, 10 Dec 2024 11:38:50 -0500 Subject: [PATCH 15/33] Update based on PR review comments. --- faults/common.go | 52 ++++++++++++++++++++----------- httpbp/client_middlewares.go | 17 +++++----- httpbp/client_middlewares_test.go | 2 -- thriftbp/client_middlewares.go | 23 ++++++++------ 4 files changed, 57 insertions(+), 37 deletions(-) diff --git a/faults/common.go b/faults/common.go index 77f6a9435..9c39eebf2 100644 --- a/faults/common.go +++ b/faults/common.go @@ -1,4 +1,5 @@ -// Package faults provides common headers and client-side fault injection functionality. +// Package faults provides common headers and client-side fault injection +// functionality. package faults import ( @@ -10,9 +11,20 @@ import ( "time" ) +// GetHeaderFn is the function type to return the value of a protocol-specific +// header with the given key. type GetHeaderFn func(key string) string -type ResumeFn func() (interface{}, error) -type ResponseFn func(code int, message string) (interface{}, error) + +// ResumeFn is the function type to continue processing the protocol-specific +// request without injecting a fault. +type ResumeFn[T any] func() (T, error) + +// ResponseFn is the function type to inject a protocol-specific fault with the +// given code and message. +type ResponseFn[T any] func(code int, message string) (T, error) + +// SleepFn is the function type to sleep for the given duration. It is only +// exposed for testing purposes. type SleepFn func(d time.Duration) // Object to ensure a random number is only generated at most 1 time. @@ -27,38 +39,42 @@ func (r randSingleton) getRandInt() int { return *r.randInt } -func isSelected(percentageHeader string, GetHeaderFn func(key string) string, singleRand randSingleton) (bool, string) { +func isSelected(percentageHeader string, GetHeaderFn func(key string) string, singleRand randSingleton) (bool, error) { percentageStr := GetHeaderFn(percentageHeader) if percentageStr == "" { - return true, "" + return true, nil } percentage, err := strconv.Atoi(percentageStr) if err != nil { - return false, fmt.Sprintf("provided percentage \"%s\" is not a valid integer", percentageStr) + return false, fmt.Errorf("provided percentage %q is not a valid integer: %w", percentageStr, err) } if percentage < 0 || percentage > 100 { - return false, fmt.Sprintf("provided percentage \"%d\" is outside the valid range of [0-100]", percentage) + return false, fmt.Errorf("provided percentage %q is outside the valid range of [0-100]", percentage) } - return singleRand.getRandInt() < percentage, "" + return singleRand.getRandInt() < percentage, nil } -type InjectFaultParams struct { +type InjectFaultParams[T any] struct { CallerName string Address, Method string AbortCodeMin, AbortCodeMax int GetHeaderFn GetHeaderFn - ResumeFn ResumeFn - ResponseFn ResponseFn + ResumeFn ResumeFn[T] + ResponseFn ResponseFn[T] // Exposed for tests. RandInt *int SleepFn *SleepFn } -func InjectFault(params InjectFaultParams) (interface{}, error) { +func InjectFault[T any](params InjectFaultParams[T]) (T, error) { serverAddress := params.GetHeaderFn(FaultServerAddressHeader) + + // The short address should just be ., without the + // local cluster suffix or port. Non-cluster-local addresses are not + // shortened. shortAddress := params.Address if i := strings.Index(params.Address, ".svc.cluster.local"); i != -1 { shortAddress = params.Address[:i] @@ -78,9 +94,9 @@ func InjectFault(params InjectFaultParams) (interface{}, error) { delayMs := params.GetHeaderFn(FaultDelayMsHeader) if delayMs != "" { - if selected, reason := isSelected(FaultDelayPercentageHeader, params.GetHeaderFn, singleRand); !selected { - if reason != "" { - slog.Warn(fmt.Sprintf("%s: %s", params.CallerName, reason)) + if selected, err := isSelected(FaultDelayPercentageHeader, params.GetHeaderFn, singleRand); !selected { + if err != nil { + slog.Warn(fmt.Sprintf("%s: %v", params.CallerName, err)) } return params.ResumeFn() } @@ -100,9 +116,9 @@ func InjectFault(params InjectFaultParams) (interface{}, error) { abortCode := params.GetHeaderFn(FaultAbortCodeHeader) if abortCode != "" { - if selected, reason := isSelected(FaultAbortPercentageHeader, params.GetHeaderFn, singleRand); !selected { - if reason != "" { - slog.Warn(fmt.Sprintf("%s: %s", params.CallerName, reason)) + if selected, err := isSelected(FaultAbortPercentageHeader, params.GetHeaderFn, singleRand); !selected { + if err != nil { + slog.Warn(fmt.Sprintf("%s: %v", params.CallerName, err)) } return params.ResumeFn() } diff --git a/httpbp/client_middlewares.go b/httpbp/client_middlewares.go index 9d524921c..0daf8cc67 100644 --- a/httpbp/client_middlewares.go +++ b/httpbp/client_middlewares.go @@ -364,10 +364,10 @@ type ServiceAddressParts struct { func FaultInjection() ClientMiddleware { return func(next http.RoundTripper) http.RoundTripper { return roundTripperFunc(func(req *http.Request) (*http.Response, error) { - resumeFn := faults.ResumeFn(func() (interface{}, error) { + var resumeFn faults.ResumeFn[*http.Response] = func() (*http.Response, error) { return next.RoundTrip(req) - }) - responseFn := faults.ResponseFn(func(code int, message string) (interface{}, error) { + } + var responseFn faults.ResponseFn[*http.Response] = func(code int, message string) (*http.Response, error) { return &http.Response{ Status: http.StatusText(code), StatusCode: code, @@ -384,18 +384,19 @@ func FaultInjection() ClientMiddleware { Request: req, TLS: req.TLS, }, nil - }) + } - resp, err := faults.InjectFault(faults.InjectFaultParams{ + resp, err := faults.InjectFault(faults.InjectFaultParams[*http.Response]{ CallerName: "httpbp.FaultInjection", Address: req.URL.Host, Method: strings.TrimPrefix(req.URL.Path, "/"), - AbortCodeMin: 100, + AbortCodeMin: 400, AbortCodeMax: 599, GetHeaderFn: faults.GetHeaderFn(req.Header.Get), ResumeFn: resumeFn, - ResponseFn: responseFn}) - return resp.(*http.Response), err + ResponseFn: responseFn, + }) + return resp, err }) } } diff --git a/httpbp/client_middlewares_test.go b/httpbp/client_middlewares_test.go index 4d4312a17..4b87ad916 100644 --- a/httpbp/client_middlewares_test.go +++ b/httpbp/client_middlewares_test.go @@ -7,7 +7,6 @@ import ( "net/http" "net/http/httptest" "net/url" - "runtime/debug" "strings" "sync" "sync/atomic" @@ -523,7 +522,6 @@ func TestFaultInjection(t *testing.T) { resp, err := client.Do(req) if err != nil { - t.Log(string(debug.Stack())) t.Fatalf("expected no error, got %v", err) } if tt.wantResp.StatusCode != resp.StatusCode { diff --git a/thriftbp/client_middlewares.go b/thriftbp/client_middlewares.go index f03279a5b..1f6817897 100644 --- a/thriftbp/client_middlewares.go +++ b/thriftbp/client_middlewares.go @@ -405,21 +405,25 @@ func FaultInjectionClientMiddleware(address string) thrift.ClientMiddleware { return func(next thrift.TClient) thrift.TClient { return thrift.WrappedTClient{ Wrapped: func(ctx context.Context, method string, args, result thrift.TStruct) (thrift.ResponseMeta, error) { - getHeaderFn := faults.GetHeaderFn(func(key string) string { + if address == "" { + return next.Call(ctx, method, args, result) + } + + var getHeaderFn faults.GetHeaderFn = func(key string) string { header, ok := thrift.GetHeader(ctx, key) if !ok { return "" } return header - }) - resumeFn := faults.ResumeFn(func() (interface{}, error) { + } + var resumeFn faults.ResumeFn[thrift.ResponseMeta] = func() (thrift.ResponseMeta, error) { return next.Call(ctx, method, args, result) - }) - responseFn := faults.ResponseFn(func(code int, message string) (interface{}, error) { + } + var responseFn faults.ResponseFn[thrift.ResponseMeta] = func(code int, message string) (thrift.ResponseMeta, error) { return thrift.ResponseMeta{}, thrift.NewTTransportException(code, message) - }) + } - resp, err := faults.InjectFault(faults.InjectFaultParams{ + resp, err := faults.InjectFault(faults.InjectFaultParams[thrift.ResponseMeta]{ CallerName: "thriftpb.FaultInjectionClientMiddleware", Address: address, Method: method, @@ -427,8 +431,9 @@ func FaultInjectionClientMiddleware(address string) thrift.ClientMiddleware { AbortCodeMax: thrift.END_OF_FILE, GetHeaderFn: getHeaderFn, ResumeFn: resumeFn, - ResponseFn: responseFn}) - return resp.(thrift.ResponseMeta), err + ResponseFn: responseFn, + }) + return resp, err }, } } From 71114b049d1a7507c9c07a6091fd0c567f1ff1be Mon Sep 17 00:00:00 2001 From: Hiram Silvey <9485196+HiramSilvey@users.noreply.github.com> Date: Tue, 10 Dec 2024 13:57:28 -0500 Subject: [PATCH 16/33] Simplified random int generation. --- faults/common.go | 49 ++++++++++++++++++++++-------------------------- 1 file changed, 22 insertions(+), 27 deletions(-) diff --git a/faults/common.go b/faults/common.go index 9c39eebf2..7cc974ef6 100644 --- a/faults/common.go +++ b/faults/common.go @@ -27,31 +27,19 @@ type ResponseFn[T any] func(code int, message string) (T, error) // exposed for testing purposes. type SleepFn func(d time.Duration) -// Object to ensure a random number is only generated at most 1 time. -type randSingleton struct { - randInt *int -} - -func (r randSingleton) getRandInt() int { - if r.randInt == nil { - *r.randInt = rand.IntN(100) - } - return *r.randInt -} - -func isSelected(percentageHeader string, GetHeaderFn func(key string) string, singleRand randSingleton) (bool, error) { +func getPercentage(percentageHeader string, GetHeaderFn func(key string) string) (int, error) { percentageStr := GetHeaderFn(percentageHeader) if percentageStr == "" { - return true, nil + return 100, nil } percentage, err := strconv.Atoi(percentageStr) if err != nil { - return false, fmt.Errorf("provided percentage %q is not a valid integer: %w", percentageStr, err) + return 0, fmt.Errorf("provided percentage %q is not a valid integer: %w", percentageStr, err) } if percentage < 0 || percentage > 100 { - return false, fmt.Errorf("provided percentage %q is outside the valid range of [0-100]", percentage) + return 0, fmt.Errorf("provided percentage %q is outside the valid range of [0-100]", percentage) } - return singleRand.getRandInt() < percentage, nil + return percentage, nil } type InjectFaultParams[T any] struct { @@ -88,16 +76,21 @@ func InjectFault[T any](params InjectFaultParams[T]) (T, error) { return params.ResumeFn() } - singleRand := randSingleton{ - randInt: params.RandInt, + var randInt int + if params.RandInt != nil { + randInt = *params.RandInt + } else { + randInt = rand.IntN(100) } delayMs := params.GetHeaderFn(FaultDelayMsHeader) if delayMs != "" { - if selected, err := isSelected(FaultDelayPercentageHeader, params.GetHeaderFn, singleRand); !selected { - if err != nil { - slog.Warn(fmt.Sprintf("%s: %v", params.CallerName, err)) - } + percentage, err := getPercentage(FaultDelayPercentageHeader, params.GetHeaderFn) + if err != nil { + slog.Warn(fmt.Sprintf("%s: %v", params.CallerName, err)) + return params.ResumeFn() + } + if randInt >= percentage { return params.ResumeFn() } @@ -116,10 +109,12 @@ func InjectFault[T any](params InjectFaultParams[T]) (T, error) { abortCode := params.GetHeaderFn(FaultAbortCodeHeader) if abortCode != "" { - if selected, err := isSelected(FaultAbortPercentageHeader, params.GetHeaderFn, singleRand); !selected { - if err != nil { - slog.Warn(fmt.Sprintf("%s: %v", params.CallerName, err)) - } + percentage, err := getPercentage(FaultAbortPercentageHeader, params.GetHeaderFn) + if err != nil { + slog.Warn(fmt.Sprintf("%s: %v", params.CallerName, err)) + return params.ResumeFn() + } + if randInt >= percentage { return params.ResumeFn() } From 4fd70181b6fbffdc54dfc874d95a5cf970a81bce Mon Sep 17 00:00:00 2001 From: Hiram Silvey <9485196+HiramSilvey@users.noreply.github.com> Date: Tue, 10 Dec 2024 15:17:12 -0500 Subject: [PATCH 17/33] Move faults into internal directory. --- httpbp/client_middlewares.go | 2 +- httpbp/client_middlewares_test.go | 2 +- {faults => internal/faults}/common.go | 0 internal/faults/common_test.go | 335 +++++++++++++++++++++++++ {faults => internal/faults}/headers.go | 0 thriftbp/client_middlewares.go | 2 +- thriftbp/client_middlewares_test.go | 2 +- 7 files changed, 339 insertions(+), 4 deletions(-) rename {faults => internal/faults}/common.go (100%) create mode 100644 internal/faults/common_test.go rename {faults => internal/faults}/headers.go (100%) diff --git a/httpbp/client_middlewares.go b/httpbp/client_middlewares.go index 0daf8cc67..eda55ed45 100644 --- a/httpbp/client_middlewares.go +++ b/httpbp/client_middlewares.go @@ -16,7 +16,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/reddit/baseplate.go/breakerbp" - "github.com/reddit/baseplate.go/faults" + "github.com/reddit/baseplate.go/internal/faults" //lint:ignore SA1019 This library is internal only, not actually deprecated "github.com/reddit/baseplate.go/internalv2compat" diff --git a/httpbp/client_middlewares_test.go b/httpbp/client_middlewares_test.go index 4b87ad916..929ad87ab 100644 --- a/httpbp/client_middlewares_test.go +++ b/httpbp/client_middlewares_test.go @@ -17,7 +17,7 @@ import ( "github.com/sony/gobreaker" "github.com/reddit/baseplate.go/breakerbp" - "github.com/reddit/baseplate.go/faults" + "github.com/reddit/baseplate.go/internal/faults" ) func TestNewClient(t *testing.T) { diff --git a/faults/common.go b/internal/faults/common.go similarity index 100% rename from faults/common.go rename to internal/faults/common.go diff --git a/internal/faults/common_test.go b/internal/faults/common_test.go new file mode 100644 index 000000000..9f0c06ba0 --- /dev/null +++ b/internal/faults/common_test.go @@ -0,0 +1,335 @@ +package faults_test + +import ( + "testing" + "time" + + "github.com/reddit/baseplate.go/internal/faults" +) + +const ( + defaultAddress = "testService.testNamespace.svc.cluster.local:12345" + method = "testMethod" + minAbortCode = 0 + maxAbortCode = 10 +) + +type Response struct { + code int + message string +} + +func intPtr(i int) *int { + return &i +} + +func TestInjectFault(t *testing.T) { + testCases := []struct { + name string + address string + randInt *int + + faultServerAddressHeader string + faultServerMethodHeader string + faultDelayMsHeader string + faultDelayPercentageHeader string + faultAbortCodeHeader string + faultAbortMessageHeader string + faultAbortPercentageHeader string + + wantDelayMs int + wantResponse *Response + }{ + { + name: "no fault specified", + wantResponse: nil, + }, + { + name: "delay", + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultDelayMsHeader: "1", + + wantDelayMs: 1, + }, + { + name: "abort", + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultAbortCodeHeader: "1", + faultAbortMessageHeader: "test fault", + + wantResponse: &Response{ + code: 1, + message: "test fault", + }, + }, + { + name: "invalid server address", + address: "foo", + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultAbortCodeHeader: "1", + faultAbortMessageHeader: "test fault", + + wantResponse: nil, + }, + { + name: "server address does not match", + + faultServerAddressHeader: "fooService.testNamespace", + faultServerMethodHeader: "testMethod", + faultAbortCodeHeader: "1", + faultAbortMessageHeader: "test fault", + + wantResponse: nil, + }, + { + name: "method does not match", + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "fooMethod", + faultAbortCodeHeader: "1", + faultAbortMessageHeader: "test fault", + + wantResponse: nil, + }, + { + name: "guaranteed percent", + randInt: intPtr(99), // Maximum possible integer returned by rand.Intn(100) + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultDelayMsHeader: "250", + faultDelayPercentageHeader: "100", // All requests delayed + faultAbortCodeHeader: "1", + faultAbortMessageHeader: "test fault", + faultAbortPercentageHeader: "100", // All requests aborted + + wantDelayMs: 250, + wantResponse: &Response{ + code: 1, + message: "test fault", + }, + }, + { + name: "fence post below percent", + randInt: intPtr(49), + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultDelayMsHeader: "250", + faultDelayPercentageHeader: "50", + faultAbortCodeHeader: "1", + faultAbortMessageHeader: "test fault", + faultAbortPercentageHeader: "50", + + wantDelayMs: 250, + wantResponse: &Response{ + code: 1, + message: "test fault", + }, + }, + { + name: "fence post at percent", + randInt: intPtr(50), + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultDelayMsHeader: "250", + faultDelayPercentageHeader: "50", + faultAbortCodeHeader: "1", + faultAbortMessageHeader: "test fault", + faultAbortPercentageHeader: "50", + + wantDelayMs: 0, + wantResponse: nil, + }, + { + name: "guaranteed skip percent", + randInt: intPtr(0), // Minimum possible integer returned by rand.Intn(100) + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultDelayMsHeader: "250", + faultDelayPercentageHeader: "0", // No requests delayed + faultAbortCodeHeader: "1", + faultAbortMessageHeader: "test fault", + faultAbortPercentageHeader: "0", // No requests aborted + + wantDelayMs: 0, + wantResponse: nil, + }, + { + name: "invalid delay percentage negative", + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultDelayMsHeader: "250", + faultDelayPercentageHeader: "-1", + + wantDelayMs: 0, + }, + { + name: "invalid delay percentage over 100", + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultDelayMsHeader: "250", + faultDelayPercentageHeader: "101", + + wantDelayMs: 0, + }, + { + name: "invalid delay ms", + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultDelayMsHeader: "NaN", + + wantDelayMs: 0, + }, + { + name: "invalid abort percentage negative", + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultAbortCodeHeader: "1", + faultAbortMessageHeader: "test fault", + faultAbortPercentageHeader: "-1", + + wantResponse: nil, + }, + { + name: "invalid abort percentage over 100", + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultAbortCodeHeader: "1", + faultAbortMessageHeader: "test fault", + faultAbortPercentageHeader: "101", + + wantResponse: nil, + }, + { + name: "invalid abort code", + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultAbortCodeHeader: "NaN", + faultAbortMessageHeader: "test fault", + + wantResponse: nil, + }, + { + name: "less than min abort code", + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultAbortCodeHeader: "-1", + faultAbortMessageHeader: "test fault", + + wantResponse: nil, + }, + { + name: "greater than max abort code", + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultAbortCodeHeader: "11", + faultAbortMessageHeader: "test fault", + + wantResponse: nil, + }, + { + name: "invalid abort percentage", + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultAbortCodeHeader: "1", + faultAbortMessageHeader: "test fault", + faultAbortPercentageHeader: "NaN", + + wantResponse: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + address := tc.address + if address == "" { + address = defaultAddress + } + + getHeaderFn := faults.GetHeaderFn(func(key string) string { + if key == faults.FaultServerAddressHeader { + return tc.faultServerAddressHeader + } + if key == faults.FaultServerMethodHeader { + return tc.faultServerMethodHeader + } + if key == faults.FaultDelayMsHeader { + return tc.faultDelayMsHeader + } + if key == faults.FaultDelayPercentageHeader { + return tc.faultDelayPercentageHeader + } + if key == faults.FaultAbortCodeHeader { + return tc.faultAbortCodeHeader + } + if key == faults.FaultAbortMessageHeader { + return tc.faultAbortMessageHeader + } + if key == faults.FaultAbortPercentageHeader { + return tc.faultAbortPercentageHeader + } + return "" + }) + var resumeFn faults.ResumeFn[*Response] = func() (*Response, error) { + return nil, nil + } + var responseFn faults.ResponseFn[*Response] = func(code int, message string) (*Response, error) { + return &Response{ + code: code, + message: message, + }, nil + } + delayMs := 0 + sleepFn := faults.SleepFn(func(d time.Duration) { + delayMs = int(d.Milliseconds()) + }) + + resp, err := faults.InjectFault(faults.InjectFaultParams[*Response]{ + CallerName: "faults_test.TestInjectFault", + Address: address, + Method: method, + AbortCodeMin: minAbortCode, + AbortCodeMax: maxAbortCode, + GetHeaderFn: getHeaderFn, + ResumeFn: resumeFn, + ResponseFn: responseFn, + SleepFn: &sleepFn, + RandInt: tc.randInt, + }) + + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if tc.wantDelayMs != delayMs { + t.Fatalf("expected delay of %v ms, got %v ms", tc.wantDelayMs, delayMs) + } + if tc.wantResponse == nil && resp != nil { + t.Fatalf("expected no response, got %v", resp) + } + if tc.wantResponse != nil && resp == nil { + t.Fatalf("expected response %v, got nil", tc.wantResponse) + } + if resp != nil && *tc.wantResponse != *resp { + t.Fatalf("expected response %v, got %v", tc.wantResponse, resp) + } + }) + } +} diff --git a/faults/headers.go b/internal/faults/headers.go similarity index 100% rename from faults/headers.go rename to internal/faults/headers.go diff --git a/thriftbp/client_middlewares.go b/thriftbp/client_middlewares.go index 1f6817897..27b4e3364 100644 --- a/thriftbp/client_middlewares.go +++ b/thriftbp/client_middlewares.go @@ -15,7 +15,7 @@ import ( "github.com/reddit/baseplate.go/breakerbp" "github.com/reddit/baseplate.go/ecinterface" "github.com/reddit/baseplate.go/errorsbp" - "github.com/reddit/baseplate.go/faults" + "github.com/reddit/baseplate.go/internal/faults" "github.com/reddit/baseplate.go/internal/gen-go/reddit/baseplate" "github.com/reddit/baseplate.go/internal/thriftint" diff --git a/thriftbp/client_middlewares_test.go b/thriftbp/client_middlewares_test.go index d0ba2932e..6a5933856 100644 --- a/thriftbp/client_middlewares_test.go +++ b/thriftbp/client_middlewares_test.go @@ -12,7 +12,7 @@ import ( "github.com/reddit/baseplate.go" "github.com/reddit/baseplate.go/ecinterface" - "github.com/reddit/baseplate.go/faults" + "github.com/reddit/baseplate.go/internal/faults" baseplatethrift "github.com/reddit/baseplate.go/internal/gen-go/reddit/baseplate" "github.com/reddit/baseplate.go/internal/prometheusbpint/spectest" "github.com/reddit/baseplate.go/prometheusbp" From 291b7bc040844f28bb9b4aed92ebad2b641bbc61 Mon Sep 17 00:00:00 2001 From: Hiram Silvey <9485196+HiramSilvey@users.noreply.github.com> Date: Tue, 10 Dec 2024 15:22:21 -0500 Subject: [PATCH 18/33] Fix style. --- httpbp/client_middlewares.go | 1 - 1 file changed, 1 deletion(-) diff --git a/httpbp/client_middlewares.go b/httpbp/client_middlewares.go index eda55ed45..8acdb0029 100644 --- a/httpbp/client_middlewares.go +++ b/httpbp/client_middlewares.go @@ -17,7 +17,6 @@ import ( "github.com/reddit/baseplate.go/breakerbp" "github.com/reddit/baseplate.go/internal/faults" - //lint:ignore SA1019 This library is internal only, not actually deprecated "github.com/reddit/baseplate.go/internalv2compat" "github.com/reddit/baseplate.go/retrybp" From b6d720ff13564fac278162d65853e41e2aa136d2 Mon Sep 17 00:00:00 2001 From: Hiram Silvey <9485196+HiramSilvey@users.noreply.github.com> Date: Tue, 10 Dec 2024 15:26:19 -0500 Subject: [PATCH 19/33] Remove unused type struct and fix imports. --- httpbp/client_middlewares.go | 5 ----- thriftbp/client_middlewares.go | 1 - 2 files changed, 6 deletions(-) diff --git a/httpbp/client_middlewares.go b/httpbp/client_middlewares.go index 8acdb0029..ec3d86ede 100644 --- a/httpbp/client_middlewares.go +++ b/httpbp/client_middlewares.go @@ -355,11 +355,6 @@ func PrometheusClientMetrics(serverSlug string) ClientMiddleware { } } -type ServiceAddressParts struct { - Name string - Namespace string -} - func FaultInjection() ClientMiddleware { return func(next http.RoundTripper) http.RoundTripper { return roundTripperFunc(func(req *http.Request) (*http.Response, error) { diff --git a/thriftbp/client_middlewares.go b/thriftbp/client_middlewares.go index 27b4e3364..a11baef1e 100644 --- a/thriftbp/client_middlewares.go +++ b/thriftbp/client_middlewares.go @@ -18,7 +18,6 @@ import ( "github.com/reddit/baseplate.go/internal/faults" "github.com/reddit/baseplate.go/internal/gen-go/reddit/baseplate" "github.com/reddit/baseplate.go/internal/thriftint" - //lint:ignore SA1019 This library is internal only, not actually deprecated "github.com/reddit/baseplate.go/internalv2compat" "github.com/reddit/baseplate.go/prometheusbp" From a892c0149e69fec8f8401f8c3b32d8fb1a0c1677 Mon Sep 17 00:00:00 2001 From: Hiram Silvey <9485196+HiramSilvey@users.noreply.github.com> Date: Wed, 11 Dec 2024 09:27:12 -0500 Subject: [PATCH 20/33] Expand testing and fix PR review nits. --- httpbp/client_middlewares.go | 5 +- internal/faults/common.go | 61 +++++++++--------- internal/faults/common_test.go | 114 ++++++++++++++++++++++++++++----- thriftbp/client_middlewares.go | 7 +- 4 files changed, 135 insertions(+), 52 deletions(-) diff --git a/httpbp/client_middlewares.go b/httpbp/client_middlewares.go index ec3d86ede..cbacea709 100644 --- a/httpbp/client_middlewares.go +++ b/httpbp/client_middlewares.go @@ -17,6 +17,7 @@ import ( "github.com/reddit/baseplate.go/breakerbp" "github.com/reddit/baseplate.go/internal/faults" + //lint:ignore SA1019 This library is internal only, not actually deprecated "github.com/reddit/baseplate.go/internalv2compat" "github.com/reddit/baseplate.go/retrybp" @@ -358,10 +359,10 @@ func PrometheusClientMetrics(serverSlug string) ClientMiddleware { func FaultInjection() ClientMiddleware { return func(next http.RoundTripper) http.RoundTripper { return roundTripperFunc(func(req *http.Request) (*http.Response, error) { - var resumeFn faults.ResumeFn[*http.Response] = func() (*http.Response, error) { + resumeFn := func() (*http.Response, error) { return next.RoundTrip(req) } - var responseFn faults.ResponseFn[*http.Response] = func(code int, message string) (*http.Response, error) { + responseFn := func(code int, message string) (*http.Response, error) { return &http.Response{ Status: http.StatusText(code), StatusCode: code, diff --git a/internal/faults/common.go b/internal/faults/common.go index 7cc974ef6..12e0a745e 100644 --- a/internal/faults/common.go +++ b/internal/faults/common.go @@ -23,23 +23,32 @@ type ResumeFn[T any] func() (T, error) // given code and message. type ResponseFn[T any] func(code int, message string) (T, error) -// SleepFn is the function type to sleep for the given duration. It is only -// exposed for testing purposes. -type SleepFn func(d time.Duration) +// sleepFn is the function type to sleep for the given duration. Only used in +// tests. +type sleepFn func(d time.Duration) + +// The canonical address for a cluster-local address is ., +// without the local cluster suffix or port. The canonical address for a +// non-cluster-local address is the full original address.n +func getCanonicalAddress(serverAddress string) string { + if i := strings.Index(serverAddress, ".svc.cluster.local"); i != -1 { + return serverAddress[:i] + } + return serverAddress +} -func getPercentage(percentageHeader string, GetHeaderFn func(key string) string) (int, error) { - percentageStr := GetHeaderFn(percentageHeader) - if percentageStr == "" { +func parsePercentage(percentage string) (int, error) { + if percentage == "" { return 100, nil } - percentage, err := strconv.Atoi(percentageStr) + intPercentage, err := strconv.Atoi(percentage) if err != nil { - return 0, fmt.Errorf("provided percentage %q is not a valid integer: %w", percentageStr, err) + return 0, fmt.Errorf("provided percentage %q is not a valid integer: %w", percentage, err) } - if percentage < 0 || percentage > 100 { - return 0, fmt.Errorf("provided percentage %q is outside the valid range of [0-100]", percentage) + if intPercentage < 0 || intPercentage > 100 { + return 0, fmt.Errorf("provided percentage \"%d\" is outside the valid range of [0-100]", intPercentage) } - return percentage, nil + return intPercentage, nil } type InjectFaultParams[T any] struct { @@ -52,22 +61,14 @@ type InjectFaultParams[T any] struct { ResumeFn ResumeFn[T] ResponseFn ResponseFn[T] - // Exposed for tests. - RandInt *int - SleepFn *SleepFn + randInt *int + sleepFn *sleepFn } func InjectFault[T any](params InjectFaultParams[T]) (T, error) { - serverAddress := params.GetHeaderFn(FaultServerAddressHeader) - - // The short address should just be ., without the - // local cluster suffix or port. Non-cluster-local addresses are not - // shortened. - shortAddress := params.Address - if i := strings.Index(params.Address, ".svc.cluster.local"); i != -1 { - shortAddress = params.Address[:i] - } - if serverAddress == "" || serverAddress != shortAddress { + faultHeaderAddress := params.GetHeaderFn(FaultServerAddressHeader) + requestAddress := getCanonicalAddress(params.Address) + if faultHeaderAddress == "" || faultHeaderAddress != requestAddress { return params.ResumeFn() } @@ -77,15 +78,15 @@ func InjectFault[T any](params InjectFaultParams[T]) (T, error) { } var randInt int - if params.RandInt != nil { - randInt = *params.RandInt + if params.randInt != nil { + randInt = *params.randInt } else { randInt = rand.IntN(100) } delayMs := params.GetHeaderFn(FaultDelayMsHeader) if delayMs != "" { - percentage, err := getPercentage(FaultDelayPercentageHeader, params.GetHeaderFn) + percentage, err := parsePercentage(params.GetHeaderFn(FaultDelayPercentageHeader)) if err != nil { slog.Warn(fmt.Sprintf("%s: %v", params.CallerName, err)) return params.ResumeFn() @@ -101,15 +102,15 @@ func InjectFault[T any](params InjectFaultParams[T]) (T, error) { } sleepFn := time.Sleep - if params.SleepFn != nil { - sleepFn = *params.SleepFn + if params.sleepFn != nil { + sleepFn = *params.sleepFn } sleepFn(time.Duration(delay) * time.Millisecond) } abortCode := params.GetHeaderFn(FaultAbortCodeHeader) if abortCode != "" { - percentage, err := getPercentage(FaultAbortPercentageHeader, params.GetHeaderFn) + percentage, err := parsePercentage(params.GetHeaderFn(FaultAbortPercentageHeader)) if err != nil { slog.Warn(fmt.Sprintf("%s: %v", params.CallerName, err)) return params.ResumeFn() diff --git a/internal/faults/common_test.go b/internal/faults/common_test.go index 9f0c06ba0..82274228b 100644 --- a/internal/faults/common_test.go +++ b/internal/faults/common_test.go @@ -1,10 +1,9 @@ -package faults_test +package faults import ( + "strings" "testing" "time" - - "github.com/reddit/baseplate.go/internal/faults" ) const ( @@ -14,6 +13,87 @@ const ( maxAbortCode = 10 ) +func TestGetCanonicalAddress(t *testing.T) { + testCases := []struct { + name string + address string + want string + }{ + { + name: "cluster local address", + address: "testService.testNamespace.svc.cluster.local:12345", + want: "testService.testNamespace", + }, + { + name: "external address", + address: "foo.bar:12345", + want: "foo.bar:12345", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got := getCanonicalAddress(tc.address) + if got != tc.want { + t.Fatalf("expected %q, got %q", tc.want, got) + } + }) + } +} + +func TestParsePercentage(t *testing.T) { + testCases := []struct { + name string + percentage string + want int + wantErr string + }{ + { + name: "empty", + percentage: "", + want: 100, + }, + { + name: "valid", + percentage: "50", + want: 50, + }, + { + name: "NaN", + percentage: "NaN", + want: 0, + wantErr: "not a valid integer", + }, + { + name: "under min", + percentage: "-1", + want: 0, + wantErr: "outside the valid range", + }, + { + name: "over max", + percentage: "101", + want: 0, + wantErr: "outside the valid range", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := parsePercentage(tc.percentage) + if got != tc.want { + t.Fatalf("expected %v, got %v", tc.want, got) + } + if tc.wantErr == "" && err != nil { + t.Fatalf("expected no error, got %v", err) + } + if tc.wantErr != "" && !strings.Contains(err.Error(), tc.wantErr) { + t.Fatalf("expected error to contain %q, got %v", tc.wantErr, err) + } + }) + } +} + type Response struct { code int message string @@ -264,45 +344,45 @@ func TestInjectFault(t *testing.T) { address = defaultAddress } - getHeaderFn := faults.GetHeaderFn(func(key string) string { - if key == faults.FaultServerAddressHeader { + getHeaderFn := GetHeaderFn(func(key string) string { + if key == FaultServerAddressHeader { return tc.faultServerAddressHeader } - if key == faults.FaultServerMethodHeader { + if key == FaultServerMethodHeader { return tc.faultServerMethodHeader } - if key == faults.FaultDelayMsHeader { + if key == FaultDelayMsHeader { return tc.faultDelayMsHeader } - if key == faults.FaultDelayPercentageHeader { + if key == FaultDelayPercentageHeader { return tc.faultDelayPercentageHeader } - if key == faults.FaultAbortCodeHeader { + if key == FaultAbortCodeHeader { return tc.faultAbortCodeHeader } - if key == faults.FaultAbortMessageHeader { + if key == FaultAbortMessageHeader { return tc.faultAbortMessageHeader } - if key == faults.FaultAbortPercentageHeader { + if key == FaultAbortPercentageHeader { return tc.faultAbortPercentageHeader } return "" }) - var resumeFn faults.ResumeFn[*Response] = func() (*Response, error) { + var resumeFn ResumeFn[*Response] = func() (*Response, error) { return nil, nil } - var responseFn faults.ResponseFn[*Response] = func(code int, message string) (*Response, error) { + var responseFn ResponseFn[*Response] = func(code int, message string) (*Response, error) { return &Response{ code: code, message: message, }, nil } delayMs := 0 - sleepFn := faults.SleepFn(func(d time.Duration) { + sleepFn := sleepFn(func(d time.Duration) { delayMs = int(d.Milliseconds()) }) - resp, err := faults.InjectFault(faults.InjectFaultParams[*Response]{ + resp, err := InjectFault(InjectFaultParams[*Response]{ CallerName: "faults_test.TestInjectFault", Address: address, Method: method, @@ -311,8 +391,8 @@ func TestInjectFault(t *testing.T) { GetHeaderFn: getHeaderFn, ResumeFn: resumeFn, ResponseFn: responseFn, - SleepFn: &sleepFn, - RandInt: tc.randInt, + sleepFn: &sleepFn, + randInt: tc.randInt, }) if err != nil { diff --git a/thriftbp/client_middlewares.go b/thriftbp/client_middlewares.go index a11baef1e..75383c50f 100644 --- a/thriftbp/client_middlewares.go +++ b/thriftbp/client_middlewares.go @@ -18,6 +18,7 @@ import ( "github.com/reddit/baseplate.go/internal/faults" "github.com/reddit/baseplate.go/internal/gen-go/reddit/baseplate" "github.com/reddit/baseplate.go/internal/thriftint" + //lint:ignore SA1019 This library is internal only, not actually deprecated "github.com/reddit/baseplate.go/internalv2compat" "github.com/reddit/baseplate.go/prometheusbp" @@ -408,17 +409,17 @@ func FaultInjectionClientMiddleware(address string) thrift.ClientMiddleware { return next.Call(ctx, method, args, result) } - var getHeaderFn faults.GetHeaderFn = func(key string) string { + getHeaderFn := func(key string) string { header, ok := thrift.GetHeader(ctx, key) if !ok { return "" } return header } - var resumeFn faults.ResumeFn[thrift.ResponseMeta] = func() (thrift.ResponseMeta, error) { + resumeFn := func() (thrift.ResponseMeta, error) { return next.Call(ctx, method, args, result) } - var responseFn faults.ResponseFn[thrift.ResponseMeta] = func(code int, message string) (thrift.ResponseMeta, error) { + responseFn := func(code int, message string) (thrift.ResponseMeta, error) { return thrift.ResponseMeta{}, thrift.NewTTransportException(code, message) } From 3bb8fd0a0cad5d94b8165aafdf8baa28d0ff673a Mon Sep 17 00:00:00 2001 From: Hiram Silvey <9485196+HiramSilvey@users.noreply.github.com> Date: Wed, 11 Dec 2024 09:29:46 -0500 Subject: [PATCH 21/33] Fix imports formatting. --- httpbp/client_middlewares.go | 1 - thriftbp/client_middlewares.go | 1 - 2 files changed, 2 deletions(-) diff --git a/httpbp/client_middlewares.go b/httpbp/client_middlewares.go index cbacea709..f2fcd66a7 100644 --- a/httpbp/client_middlewares.go +++ b/httpbp/client_middlewares.go @@ -17,7 +17,6 @@ import ( "github.com/reddit/baseplate.go/breakerbp" "github.com/reddit/baseplate.go/internal/faults" - //lint:ignore SA1019 This library is internal only, not actually deprecated "github.com/reddit/baseplate.go/internalv2compat" "github.com/reddit/baseplate.go/retrybp" diff --git a/thriftbp/client_middlewares.go b/thriftbp/client_middlewares.go index 75383c50f..3e0edb8dc 100644 --- a/thriftbp/client_middlewares.go +++ b/thriftbp/client_middlewares.go @@ -18,7 +18,6 @@ import ( "github.com/reddit/baseplate.go/internal/faults" "github.com/reddit/baseplate.go/internal/gen-go/reddit/baseplate" "github.com/reddit/baseplate.go/internal/thriftint" - //lint:ignore SA1019 This library is internal only, not actually deprecated "github.com/reddit/baseplate.go/internalv2compat" "github.com/reddit/baseplate.go/prometheusbp" From 66851ea123d6e037287ae4aeb261a225d78c69f3 Mon Sep 17 00:00:00 2001 From: Hiram Silvey <9485196+HiramSilvey@users.noreply.github.com> Date: Wed, 11 Dec 2024 09:30:31 -0500 Subject: [PATCH 22/33] Fix typo. --- internal/faults/common.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/faults/common.go b/internal/faults/common.go index 12e0a745e..230478c7c 100644 --- a/internal/faults/common.go +++ b/internal/faults/common.go @@ -29,7 +29,7 @@ type sleepFn func(d time.Duration) // The canonical address for a cluster-local address is ., // without the local cluster suffix or port. The canonical address for a -// non-cluster-local address is the full original address.n +// non-cluster-local address is the full original address. func getCanonicalAddress(serverAddress string) string { if i := strings.Index(serverAddress, ".svc.cluster.local"); i != -1 { return serverAddress[:i] From e2908ff1486d77ae7a5bf8fc21319800182dc9a8 Mon Sep 17 00:00:00 2001 From: Hiram Silvey <9485196+HiramSilvey@users.noreply.github.com> Date: Wed, 11 Dec 2024 12:25:08 -0500 Subject: [PATCH 23/33] Strip port and anything after it for non-cluster-local addresses. --- internal/faults/common.go | 12 +++++++++++- internal/faults/common_test.go | 14 ++++++++++++-- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/internal/faults/common.go b/internal/faults/common.go index 230478c7c..1825efd1f 100644 --- a/internal/faults/common.go +++ b/internal/faults/common.go @@ -29,11 +29,21 @@ type sleepFn func(d time.Duration) // The canonical address for a cluster-local address is ., // without the local cluster suffix or port. The canonical address for a -// non-cluster-local address is the full original address. +// non-cluster-local address is the full original address without the port. func getCanonicalAddress(serverAddress string) string { + // Cluster-local address. if i := strings.Index(serverAddress, ".svc.cluster.local"); i != -1 { return serverAddress[:i] } + // External host:port address. + if i := strings.LastIndex(serverAddress, ":"); i != -1 { + port := serverAddress[i+1:] + // Verify this is actually a port number. + if port != "" && port[0] >= '0' && port[0] <= '9' { + return serverAddress[:i] + } + } + // Other address, i.e. unix domain socket. return serverAddress } diff --git a/internal/faults/common_test.go b/internal/faults/common_test.go index 82274228b..bf4750fb9 100644 --- a/internal/faults/common_test.go +++ b/internal/faults/common_test.go @@ -25,9 +25,19 @@ func TestGetCanonicalAddress(t *testing.T) { want: "testService.testNamespace", }, { - name: "external address", + name: "external address port stripped", address: "foo.bar:12345", - want: "foo.bar:12345", + want: "foo.bar", + }, + { + name: "unexpected address path stripped", + address: "foo.bar:12345/path", + want: "foo.bar", + }, + { + name: "external address without port untouched", + address: "unix://foo", + want: "unix://foo", }, } From 2e2882b0cf5085f8d3a4f32201beb7549680e1ab Mon Sep 17 00:00:00 2001 From: Hiram Silvey <9485196+HiramSilvey@users.noreply.github.com> Date: Wed, 11 Dec 2024 12:43:40 -0500 Subject: [PATCH 24/33] Fix tests and update HTTP address to strip port automatically. --- httpbp/client_middlewares.go | 2 +- httpbp/client_middlewares_test.go | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/httpbp/client_middlewares.go b/httpbp/client_middlewares.go index f2fcd66a7..6e08caee6 100644 --- a/httpbp/client_middlewares.go +++ b/httpbp/client_middlewares.go @@ -382,7 +382,7 @@ func FaultInjection() ClientMiddleware { resp, err := faults.InjectFault(faults.InjectFaultParams[*http.Response]{ CallerName: "httpbp.FaultInjection", - Address: req.URL.Host, + Address: req.URL.Hostname(), Method: strings.TrimPrefix(req.URL.Path, "/"), AbortCodeMin: 400, AbortCodeMax: 599, diff --git a/httpbp/client_middlewares_test.go b/httpbp/client_middlewares_test.go index 929ad87ab..ea0c087d1 100644 --- a/httpbp/client_middlewares_test.go +++ b/httpbp/client_middlewares_test.go @@ -7,7 +7,6 @@ import ( "net/http" "net/http/httptest" "net/url" - "strings" "sync" "sync/atomic" "testing" @@ -498,7 +497,11 @@ func TestFaultInjection(t *testing.T) { // We can't set a specific address here because the middleware // relies on the DNS address, which is not customizeable when making // real requests to a local HTTP test server. - req.Header.Set(faults.FaultServerAddressHeader, strings.TrimPrefix(server.URL, "http://")) + parsed, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("unexpected error when parsing httptest server URL: %v", err) + } + req.Header.Set(faults.FaultServerAddressHeader, parsed.Hostname()) } if tt.faultServerMethodHeader != "" { req.Header.Set(faults.FaultServerMethodHeader, tt.faultServerMethodHeader) From 6f2024ea24a979489eb3a7d186c7d8dded3d213f Mon Sep 17 00:00:00 2001 From: Hiram Silvey <9485196+HiramSilvey@users.noreply.github.com> Date: Wed, 11 Dec 2024 12:48:21 -0500 Subject: [PATCH 25/33] Add edge case test. --- internal/faults/common_test.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/internal/faults/common_test.go b/internal/faults/common_test.go index bf4750fb9..17080826c 100644 --- a/internal/faults/common_test.go +++ b/internal/faults/common_test.go @@ -34,6 +34,11 @@ func TestGetCanonicalAddress(t *testing.T) { address: "foo.bar:12345/path", want: "foo.bar", }, + { + name: "unexpected trailing colon untouched", + address: "foo.bar:", + want: "foo.bar:", + }, { name: "external address without port untouched", address: "unix://foo", From d467f04b0c9163dd54cdeb9737bcd5321d57fe74 Mon Sep 17 00:00:00 2001 From: Hiram Silvey <9485196+HiramSilvey@users.noreply.github.com> Date: Wed, 11 Dec 2024 16:32:15 -0500 Subject: [PATCH 26/33] Update to use a different random integer per feature. This is described in https://github.com/grpc/proposal/blob/master/A33-Fault-Injection.md#evaluate-possibility-fraction. --- internal/faults/common.go | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/internal/faults/common.go b/internal/faults/common.go index 1825efd1f..5ca26430a 100644 --- a/internal/faults/common.go +++ b/internal/faults/common.go @@ -61,6 +61,15 @@ func parsePercentage(percentage string) (int, error) { return intPercentage, nil } +func selected(randInt *int, percentage int) bool { + if randInt != nil { + return *randInt < percentage + } + // Use a different random integer per feature as per + // https://github.com/grpc/proposal/blob/master/A33-Fault-Injection.md#evaluate-possibility-fraction. + return rand.IntN(100) < percentage +} + type InjectFaultParams[T any] struct { CallerName string @@ -87,13 +96,6 @@ func InjectFault[T any](params InjectFaultParams[T]) (T, error) { return params.ResumeFn() } - var randInt int - if params.randInt != nil { - randInt = *params.randInt - } else { - randInt = rand.IntN(100) - } - delayMs := params.GetHeaderFn(FaultDelayMsHeader) if delayMs != "" { percentage, err := parsePercentage(params.GetHeaderFn(FaultDelayPercentageHeader)) @@ -101,7 +103,7 @@ func InjectFault[T any](params InjectFaultParams[T]) (T, error) { slog.Warn(fmt.Sprintf("%s: %v", params.CallerName, err)) return params.ResumeFn() } - if randInt >= percentage { + if !selected(params.randInt, percentage) { return params.ResumeFn() } @@ -125,7 +127,7 @@ func InjectFault[T any](params InjectFaultParams[T]) (T, error) { slog.Warn(fmt.Sprintf("%s: %v", params.CallerName, err)) return params.ResumeFn() } - if randInt >= percentage { + if !selected(params.randInt, percentage) { return params.ResumeFn() } From d602696305fbbfa06c0c317f15e9a93bbb81e12f Mon Sep 17 00:00:00 2001 From: Hiram Silvey <9485196+HiramSilvey@users.noreply.github.com> Date: Thu, 12 Dec 2024 16:29:54 -0500 Subject: [PATCH 27/33] Fix typo. Co-authored-by: Andrew Boyle --- httpbp/client_middlewares_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/httpbp/client_middlewares_test.go b/httpbp/client_middlewares_test.go index ea0c087d1..f0c213806 100644 --- a/httpbp/client_middlewares_test.go +++ b/httpbp/client_middlewares_test.go @@ -495,7 +495,7 @@ func TestFaultInjection(t *testing.T) { if tt.faultServerAddrMatch { // We can't set a specific address here because the middleware - // relies on the DNS address, which is not customizeable when making + // relies on the DNS address, which is not customizable when making // real requests to a local HTTP test server. parsed, err := url.Parse(server.URL) if err != nil { From 788b3e4189c701fdab852416ee02f9589ecb9387 Mon Sep 17 00:00:00 2001 From: Hiram Silvey <9485196+HiramSilvey@users.noreply.github.com> Date: Thu, 12 Dec 2024 16:44:48 -0500 Subject: [PATCH 28/33] Abort delay if request context cancelled. --- httpbp/client_middlewares.go | 1 + internal/faults/common.go | 22 +++++++++++++++++++--- internal/faults/common_test.go | 25 +++++++++++++++++++++---- thriftbp/client_middlewares.go | 1 + 4 files changed, 42 insertions(+), 7 deletions(-) diff --git a/httpbp/client_middlewares.go b/httpbp/client_middlewares.go index 6e08caee6..b131f4b6e 100644 --- a/httpbp/client_middlewares.go +++ b/httpbp/client_middlewares.go @@ -381,6 +381,7 @@ func FaultInjection() ClientMiddleware { } resp, err := faults.InjectFault(faults.InjectFaultParams[*http.Response]{ + Context: req.Context(), CallerName: "httpbp.FaultInjection", Address: req.URL.Hostname(), Method: strings.TrimPrefix(req.URL.Path, "/"), diff --git a/internal/faults/common.go b/internal/faults/common.go index 5ca26430a..6cc8e41d1 100644 --- a/internal/faults/common.go +++ b/internal/faults/common.go @@ -3,6 +3,7 @@ package faults import ( + "context" "fmt" "log/slog" "math/rand/v2" @@ -25,7 +26,7 @@ type ResponseFn[T any] func(code int, message string) (T, error) // sleepFn is the function type to sleep for the given duration. Only used in // tests. -type sleepFn func(d time.Duration) +type sleepFn func(ctx context.Context, d time.Duration) error // The canonical address for a cluster-local address is ., // without the local cluster suffix or port. The canonical address for a @@ -71,6 +72,7 @@ func selected(randInt *int, percentage int) bool { } type InjectFaultParams[T any] struct { + Context context.Context CallerName string Address, Method string @@ -84,6 +86,17 @@ type InjectFaultParams[T any] struct { sleepFn *sleepFn } +func sleep(ctx context.Context, d time.Duration) error { + t := time.NewTimer(d) + select { + case <-t.C: + case <-ctx.Done(): + t.Stop() + return ctx.Err() + } + return nil +} + func InjectFault[T any](params InjectFaultParams[T]) (T, error) { faultHeaderAddress := params.GetHeaderFn(FaultServerAddressHeader) requestAddress := getCanonicalAddress(params.Address) @@ -113,11 +126,14 @@ func InjectFault[T any](params InjectFaultParams[T]) (T, error) { return params.ResumeFn() } - sleepFn := time.Sleep + sleepFn := sleep if params.sleepFn != nil { sleepFn = *params.sleepFn } - sleepFn(time.Duration(delay) * time.Millisecond) + if err := sleepFn(params.Context, time.Duration(delay)*time.Millisecond); err != nil { + slog.Warn(fmt.Sprintf("%s: error when delaying request: %v", params.CallerName, err)) + return params.ResumeFn() + } } abortCode := params.GetHeaderFn(FaultAbortCodeHeader) diff --git a/internal/faults/common_test.go b/internal/faults/common_test.go index 17080826c..a6b2cd08d 100644 --- a/internal/faults/common_test.go +++ b/internal/faults/common_test.go @@ -1,6 +1,8 @@ package faults import ( + "context" + "fmt" "strings" "testing" "time" @@ -120,9 +122,10 @@ func intPtr(i int) *int { func TestInjectFault(t *testing.T) { testCases := []struct { - name string - address string - randInt *int + name string + address string + randInt *int + sleepErr bool faultServerAddressHeader string faultServerMethodHeader string @@ -287,6 +290,16 @@ func TestInjectFault(t *testing.T) { wantDelayMs: 0, }, + { + name: "error while sleeping", + sleepErr: true, + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultDelayMsHeader: "1", + + wantDelayMs: 0, + }, { name: "invalid abort percentage negative", @@ -393,8 +406,12 @@ func TestInjectFault(t *testing.T) { }, nil } delayMs := 0 - sleepFn := sleepFn(func(d time.Duration) { + sleepFn := sleepFn(func(ctx context.Context, d time.Duration) error { + if tc.sleepErr { + return fmt.Errorf("context cancelled") + } delayMs = int(d.Milliseconds()) + return nil }) resp, err := InjectFault(InjectFaultParams[*Response]{ diff --git a/thriftbp/client_middlewares.go b/thriftbp/client_middlewares.go index 3e0edb8dc..4f5014edf 100644 --- a/thriftbp/client_middlewares.go +++ b/thriftbp/client_middlewares.go @@ -423,6 +423,7 @@ func FaultInjectionClientMiddleware(address string) thrift.ClientMiddleware { } resp, err := faults.InjectFault(faults.InjectFaultParams[thrift.ResponseMeta]{ + Context: ctx, CallerName: "thriftpb.FaultInjectionClientMiddleware", Address: address, Method: method, From 01cf3f0b05abaaf6a01068c3b6e29fb7bc7e2f56 Mon Sep 17 00:00:00 2001 From: Hiram Silvey <9485196+HiramSilvey@users.noreply.github.com> Date: Thu, 12 Dec 2024 16:52:08 -0500 Subject: [PATCH 29/33] Update test to be more clear what the effect is. --- internal/faults/common_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/internal/faults/common_test.go b/internal/faults/common_test.go index a6b2cd08d..a93f83a25 100644 --- a/internal/faults/common_test.go +++ b/internal/faults/common_test.go @@ -291,12 +291,14 @@ func TestInjectFault(t *testing.T) { wantDelayMs: 0, }, { - name: "error while sleeping", + name: "error while sleeping short circuits", sleepErr: true, faultServerAddressHeader: "testService.testNamespace", faultServerMethodHeader: "testMethod", faultDelayMsHeader: "1", + faultAbortCodeHeader: "1", + faultAbortMessageHeader: "test fault", wantDelayMs: 0, }, From e75e5ea6e298c68f54c569d61f8388f12e16d5fd Mon Sep 17 00:00:00 2001 From: Hiram Silvey <9485196+HiramSilvey@users.noreply.github.com> Date: Fri, 13 Dec 2024 12:12:13 -0500 Subject: [PATCH 30/33] Fix printf formatters. --- internal/faults/common.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/faults/common.go b/internal/faults/common.go index 6cc8e41d1..3fec02f03 100644 --- a/internal/faults/common.go +++ b/internal/faults/common.go @@ -122,7 +122,7 @@ func InjectFault[T any](params InjectFaultParams[T]) (T, error) { delay, err := strconv.Atoi(delayMs) if err != nil { - slog.Warn(fmt.Sprintf("%s: provided delay \"%s\" is not a valid integer", params.CallerName, delayMs)) + slog.Warn(fmt.Sprintf("%s: provided delay %q is not a valid integer", params.CallerName, delayMs)) return params.ResumeFn() } @@ -149,7 +149,7 @@ func InjectFault[T any](params InjectFaultParams[T]) (T, error) { code, err := strconv.Atoi(abortCode) if err != nil { - slog.Warn(fmt.Sprintf("%s: provided abort code \"%s\" is not a valid integer", params.CallerName, abortCode)) + slog.Warn(fmt.Sprintf("%s: provided abort code %q is not a valid integer", params.CallerName, abortCode)) return params.ResumeFn() } if code < params.AbortCodeMin || code > params.AbortCodeMax { From d03c802dbae862a2a4ac50f0c7b0b3e70ee525bd Mon Sep 17 00:00:00 2001 From: Hiram Silvey <9485196+HiramSilvey@users.noreply.github.com> Date: Fri, 13 Dec 2024 15:53:06 -0500 Subject: [PATCH 31/33] Fix oversight where skipping the delay would skip abort as well. --- internal/faults/common.go | 31 +++++++++++++++---------------- internal/faults/common_test.go | 18 ++++++++++++++++++ 2 files changed, 33 insertions(+), 16 deletions(-) diff --git a/internal/faults/common.go b/internal/faults/common.go index 3fec02f03..1b8b0b78c 100644 --- a/internal/faults/common.go +++ b/internal/faults/common.go @@ -116,23 +116,22 @@ func InjectFault[T any](params InjectFaultParams[T]) (T, error) { slog.Warn(fmt.Sprintf("%s: %v", params.CallerName, err)) return params.ResumeFn() } - if !selected(params.randInt, percentage) { - return params.ResumeFn() - } - delay, err := strconv.Atoi(delayMs) - if err != nil { - slog.Warn(fmt.Sprintf("%s: provided delay %q is not a valid integer", params.CallerName, delayMs)) - return params.ResumeFn() - } - - sleepFn := sleep - if params.sleepFn != nil { - sleepFn = *params.sleepFn - } - if err := sleepFn(params.Context, time.Duration(delay)*time.Millisecond); err != nil { - slog.Warn(fmt.Sprintf("%s: error when delaying request: %v", params.CallerName, err)) - return params.ResumeFn() + if selected(params.randInt, percentage) { + delay, err := strconv.Atoi(delayMs) + if err != nil { + slog.Warn(fmt.Sprintf("%s: provided delay %q is not a valid integer", params.CallerName, delayMs)) + return params.ResumeFn() + } + + sleepFn := sleep + if params.sleepFn != nil { + sleepFn = *params.sleepFn + } + if err := sleepFn(params.Context, time.Duration(delay)*time.Millisecond); err != nil { + slog.Warn(fmt.Sprintf("%s: error when delaying request: %v", params.CallerName, err)) + return params.ResumeFn() + } } } diff --git a/internal/faults/common_test.go b/internal/faults/common_test.go index a93f83a25..dbbe4d8a2 100644 --- a/internal/faults/common_test.go +++ b/internal/faults/common_test.go @@ -261,6 +261,24 @@ func TestInjectFault(t *testing.T) { wantDelayMs: 0, wantResponse: nil, }, + { + name: "only skip delay", + randInt: intPtr(50), + + faultServerAddressHeader: "testService.testNamespace", + faultServerMethodHeader: "testMethod", + faultDelayMsHeader: "250", + faultDelayPercentageHeader: "0", // No requests delayed + faultAbortCodeHeader: "1", + faultAbortMessageHeader: "test fault", + faultAbortPercentageHeader: "100", // All requests aborted + + wantDelayMs: 0, + wantResponse: &Response{ + code: 1, + message: "test fault", + }, + }, { name: "invalid delay percentage negative", From 065e336c026405bb52b1153d0965278f3f6ec3cf Mon Sep 17 00:00:00 2001 From: Hiram Silvey <9485196+HiramSilvey@users.noreply.github.com> Date: Fri, 13 Dec 2024 16:26:59 -0500 Subject: [PATCH 32/33] Update abort section to mirror delay section more closely in style. --- internal/faults/common.go | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/internal/faults/common.go b/internal/faults/common.go index 1b8b0b78c..062fbc925 100644 --- a/internal/faults/common.go +++ b/internal/faults/common.go @@ -142,21 +142,20 @@ func InjectFault[T any](params InjectFaultParams[T]) (T, error) { slog.Warn(fmt.Sprintf("%s: %v", params.CallerName, err)) return params.ResumeFn() } - if !selected(params.randInt, percentage) { - return params.ResumeFn() - } - code, err := strconv.Atoi(abortCode) - if err != nil { - slog.Warn(fmt.Sprintf("%s: provided abort code %q is not a valid integer", params.CallerName, abortCode)) - return params.ResumeFn() - } - if code < params.AbortCodeMin || code > params.AbortCodeMax { - slog.Warn(fmt.Sprintf("%s: provided abort code \"%d\" is outside of the valid range", params.CallerName, code)) - return params.ResumeFn() + if selected(params.randInt, percentage) { + code, err := strconv.Atoi(abortCode) + if err != nil { + slog.Warn(fmt.Sprintf("%s: provided abort code %q is not a valid integer", params.CallerName, abortCode)) + return params.ResumeFn() + } + if code < params.AbortCodeMin || code > params.AbortCodeMax { + slog.Warn(fmt.Sprintf("%s: provided abort code \"%d\" is outside of the valid range", params.CallerName, code)) + return params.ResumeFn() + } + abortMessage := params.GetHeaderFn(FaultAbortMessageHeader) + return params.ResponseFn(code, abortMessage) } - abortMessage := params.GetHeaderFn(FaultAbortMessageHeader) - return params.ResponseFn(code, abortMessage) } return params.ResumeFn() From c2d6cd23ab5390b2a3fc167aa0ff486dc533685f Mon Sep 17 00:00:00 2001 From: Hiram Silvey <9485196+HiramSilvey@users.noreply.github.com> Date: Fri, 13 Dec 2024 16:46:52 -0500 Subject: [PATCH 33/33] Move sleep fn up for better readability. --- internal/faults/common.go | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/internal/faults/common.go b/internal/faults/common.go index 062fbc925..08d0d3b08 100644 --- a/internal/faults/common.go +++ b/internal/faults/common.go @@ -71,6 +71,17 @@ func selected(randInt *int, percentage int) bool { return rand.IntN(100) < percentage } +func sleep(ctx context.Context, d time.Duration) error { + t := time.NewTimer(d) + select { + case <-t.C: + case <-ctx.Done(): + t.Stop() + return ctx.Err() + } + return nil +} + type InjectFaultParams[T any] struct { Context context.Context CallerName string @@ -86,17 +97,6 @@ type InjectFaultParams[T any] struct { sleepFn *sleepFn } -func sleep(ctx context.Context, d time.Duration) error { - t := time.NewTimer(d) - select { - case <-t.C: - case <-ctx.Done(): - t.Stop() - return ctx.Err() - } - return nil -} - func InjectFault[T any](params InjectFaultParams[T]) (T, error) { faultHeaderAddress := params.GetHeaderFn(FaultServerAddressHeader) requestAddress := getCanonicalAddress(params.Address)