Skip to content

Commit

Permalink
Merge pull request #9 from arnaud-dfns/feat/ssm-profile
Browse files Browse the repository at this point in the history
feat: add aws ssm profile setting
  • Loading branch information
arnaud-dfns authored Jan 7, 2025
2 parents cce9dad + b15a509 commit 8ec4082
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 32 deletions.
6 changes: 5 additions & 1 deletion docs/data-sources/ssm.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,14 @@ provider "postgresql" {
### Required

- `ssm_instance` (String) Specify the exact Instance ID of the managed node to connect to for the session
- `ssm_region` (String) AWS Region where the instance is located
- `target_host` (String) The DNS name or IP address of the remote host
- `target_port` (Number) The port number of the remote host

### Optional

- `ssm_profile` (String) AWS profile name as set in credentials files. Can also be set using either the environment variables `AWS_PROFILE` or `AWS_DEFAULT_PROFILE`.
- `ssm_region` (String) AWS Region where the instance is located. The Region must be set. Can also be set using either the environment variables `AWS_REGION` or `AWS_DEFAULT_REGION`.

### Read-Only

- `local_host` (String) The DNS name or IP address of the local host
Expand Down
6 changes: 5 additions & 1 deletion docs/ephemeral-resources/ssm.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,14 @@ provider "postgresql" {
### Required

- `ssm_instance` (String) Specify the exact Instance ID of the managed node to connect to for the session
- `ssm_region` (String) AWS Region where the instance is located
- `target_host` (String) The DNS name or IP address of the remote host
- `target_port` (Number) The port number of the remote host

### Optional

- `ssm_profile` (String) AWS profile name as set in credentials files. Can also be set using either the environment variables `AWS_PROFILE` or `AWS_DEFAULT_PROFILE`.
- `ssm_region` (String) AWS Region where the instance is located. The Region must be set. Can also be set using either the environment variables `AWS_REGION` or `AWS_DEFAULT_REGION`.

### Read-Only

- `local_host` (String) The DNS name or IP address of the local host
Expand Down
38 changes: 30 additions & 8 deletions internal/provider/data_source_ssm.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ type SSMDataSource struct{}

// SSMDataSourceModel describes the data source data model.
type SSMDataSourceModel struct {
TargetHost types.String `tfsdk:"target_host"`
TargetPort types.Int64 `tfsdk:"target_port"`
LocalHost types.String `tfsdk:"local_host"`
LocalPort types.Int64 `tfsdk:"local_port"`
SSMInstance types.String `tfsdk:"ssm_instance"`
SSMProfile types.String `tfsdk:"ssm_profile"`
SSMRegion types.String `tfsdk:"ssm_region"`
TargetHost types.String `tfsdk:"target_host"`
TargetPort types.Int64 `tfsdk:"target_port"`
}

func (d *SSMDataSource) Metadata(ctx context.Context, req datasource.MetadataRequest, resp *datasource.MetadataResponse) {
Expand All @@ -54,9 +55,15 @@ func (d *SSMDataSource) Schema(ctx context.Context, req datasource.SchemaRequest
MarkdownDescription: "Specify the exact Instance ID of the managed node to connect to for the session",
Required: true,
},
"ssm_profile": schema.StringAttribute{
MarkdownDescription: "AWS profile name as set in credentials files. Can also be set using either the environment variables `AWS_PROFILE` or `AWS_DEFAULT_PROFILE`.",
Optional: true,
Computed: true,
},
"ssm_region": schema.StringAttribute{
MarkdownDescription: "AWS Region where the instance is located",
Required: true,
MarkdownDescription: "AWS Region where the instance is located. The Region must be set. Can also be set using either the environment variables `AWS_REGION` or `AWS_DEFAULT_REGION`.",
Optional: true,
Computed: true,
},

// Computed attributes
Expand Down Expand Up @@ -94,13 +101,28 @@ func (d *SSMDataSource) Read(ctx context.Context, req datasource.ReadRequest, re
data.LocalHost = types.StringValue("localhost")
data.LocalPort = types.Int64Value(int64(localPort))

_, err = ssm.ForkRemoteTunnel(ctx, ssm.TunnelConfig{
SSMRegion: data.SSMRegion.ValueString(),
tunnelCfg := ssm.TunnelConfig{
LocalPort: strconv.Itoa(localPort),
SSMInstance: data.SSMInstance.ValueString(),
SSMProfile: data.SSMProfile.ValueString(),
SSMRegion: data.SSMRegion.ValueString(),
TargetHost: data.TargetHost.ValueString(),
TargetPort: strconv.Itoa(int(data.TargetPort.ValueInt64())),
LocalPort: strconv.Itoa(localPort),
})
}

awsCfg, err := ssm.GetNewSDKConfig(ctx, tunnelCfg)
if err != nil {
resp.Diagnostics.AddError("Failed to initialize AWS SDK", fmt.Sprintf("Error: %s", err))
return
}

tunnelCfg.SSMRegion = awsCfg.Region
tunnelCfg.SSMProfile = ssm.GetSDKConfigProfile(awsCfg)

data.SSMRegion = types.StringValue(tunnelCfg.SSMRegion)
data.SSMProfile = types.StringValue(tunnelCfg.SSMProfile)

_, err = ssm.ForkRemoteTunnel(ctx, awsCfg, tunnelCfg)
if err != nil {
resp.Diagnostics.AddError("Failed to fork tunnel process", fmt.Sprintf("Error: %s", err))
return
Expand Down
38 changes: 30 additions & 8 deletions internal/provider/ephemeral_ssm.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ type SSMEphemeral struct{}

// SSMEphemeralModel describes the data source data model.
type SSMEphemeralModel struct {
TargetHost types.String `tfsdk:"target_host"`
TargetPort types.Int64 `tfsdk:"target_port"`
LocalHost types.String `tfsdk:"local_host"`
LocalPort types.Int64 `tfsdk:"local_port"`
SSMInstance types.String `tfsdk:"ssm_instance"`
SSMProfile types.String `tfsdk:"ssm_profile"`
SSMRegion types.String `tfsdk:"ssm_region"`
TargetHost types.String `tfsdk:"target_host"`
TargetPort types.Int64 `tfsdk:"target_port"`
}

func (d *SSMEphemeral) Metadata(ctx context.Context, req ephemeral.MetadataRequest, resp *ephemeral.MetadataResponse) {
Expand All @@ -54,9 +55,15 @@ func (d *SSMEphemeral) Schema(ctx context.Context, req ephemeral.SchemaRequest,
MarkdownDescription: "Specify the exact Instance ID of the managed node to connect to for the session",
Required: true,
},
"ssm_profile": schema.StringAttribute{
MarkdownDescription: "AWS profile name as set in credentials files. Can also be set using either the environment variables `AWS_PROFILE` or `AWS_DEFAULT_PROFILE`.",
Optional: true,
Computed: true,
},
"ssm_region": schema.StringAttribute{
MarkdownDescription: "AWS Region where the instance is located",
Required: true,
MarkdownDescription: "AWS Region where the instance is located. The Region must be set. Can also be set using either the environment variables `AWS_REGION` or `AWS_DEFAULT_REGION`.",
Optional: true,
Computed: true,
},

// Computed attributes
Expand Down Expand Up @@ -94,13 +101,28 @@ func (d *SSMEphemeral) Open(ctx context.Context, req ephemeral.OpenRequest, resp
data.LocalHost = types.StringValue("localhost")
data.LocalPort = types.Int64Value(int64(localPort))

cmd, err := ssm.ForkRemoteTunnel(ctx, ssm.TunnelConfig{
SSMRegion: data.SSMRegion.ValueString(),
tunnelCfg := ssm.TunnelConfig{
LocalPort: strconv.Itoa(localPort),
SSMInstance: data.SSMInstance.ValueString(),
SSMProfile: data.SSMProfile.ValueString(),
SSMRegion: data.SSMRegion.ValueString(),
TargetHost: data.TargetHost.ValueString(),
TargetPort: strconv.Itoa(int(data.TargetPort.ValueInt64())),
LocalPort: strconv.Itoa(localPort),
})
}

awsCfg, err := ssm.GetNewSDKConfig(ctx, tunnelCfg)
if err != nil {
resp.Diagnostics.AddError("Failed to initialize AWS SDK", fmt.Sprintf("Error: %s", err))
return
}

tunnelCfg.SSMRegion = awsCfg.Region
tunnelCfg.SSMProfile = ssm.GetSDKConfigProfile(awsCfg)

data.SSMRegion = types.StringValue(tunnelCfg.SSMRegion)
data.SSMProfile = types.StringValue(tunnelCfg.SSMProfile)

cmd, err := ssm.ForkRemoteTunnel(ctx, awsCfg, tunnelCfg)
if err != nil {
resp.Diagnostics.AddError("Failed to fork tunnel process", fmt.Sprintf("Error: %s", err))
return
Expand Down
35 changes: 25 additions & 10 deletions internal/ssm/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ import (
const DEFAULT_SSM_ENV_NAME = "AWS_SSM_START_SESSION_RESPONSE"

type TunnelConfig struct {
SSMRegion string
LocalPort string
SSMInstance string
SSMProfile string
SSMRegion string
TargetHost string
TargetPort string
LocalPort string
}

type SessionParams struct {
Expand All @@ -24,6 +25,27 @@ type SessionParams struct {
StreamUrl string
}

func GetNewSDKConfig(ctx context.Context, cfg TunnelConfig) (aws.Config, error) {
loadOptions := []func(*config.LoadOptions) error{}
if cfg.SSMRegion != "" {
loadOptions = append(loadOptions, config.WithRegion(cfg.SSMRegion))
}
if cfg.SSMProfile != "" {
loadOptions = append(loadOptions, config.WithSharedConfigProfile(cfg.SSMProfile))
}

return config.LoadDefaultConfig(ctx, loadOptions...)
}

func GetSDKConfigProfile(awsCfg aws.Config) string {
for _, cfg := range awsCfg.ConfigSources {
if p, ok := cfg.(config.SharedConfig); ok {
return p.Profile
}
}
return ""
}

func CreateSessionInput(cfg TunnelConfig) ssm.StartSessionInput {
reqParams := make(map[string][]string)
reqParams["portNumber"] = []string{cfg.TargetPort}
Expand All @@ -37,14 +59,7 @@ func CreateSessionInput(cfg TunnelConfig) ssm.StartSessionInput {
}
}

func StartTunnelSession(ctx context.Context, cfg TunnelConfig) (SessionParams, error) {
// Load AWS SDK config
awsCfg, err := config.LoadDefaultConfig(ctx)
if err != nil {
return SessionParams{}, err
}
awsCfg.Region = cfg.SSMRegion

func StartTunnelSession(ctx context.Context, awsCfg aws.Config, cfg TunnelConfig) (SessionParams, error) {
// Create SSM client
ssmClient := ssm.NewFromConfig(awsCfg)

Expand Down
8 changes: 4 additions & 4 deletions internal/ssm/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"strconv"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/ssm"
pluginSession "github.com/aws/session-manager-plugin/src/sessionmanagerplugin/session"
_ "github.com/aws/session-manager-plugin/src/sessionmanagerplugin/session/portsession"
Expand All @@ -30,10 +31,10 @@ func GetEndpoint(ctx context.Context, region string) (string, error) {
return endpoint.URI.String(), nil
}

func ForkRemoteTunnel(ctx context.Context, cfg TunnelConfig) (*exec.Cmd, error) {
func ForkRemoteTunnel(ctx context.Context, awsCfg aws.Config, cfg TunnelConfig) (*exec.Cmd, error) {
// First we start a session using AWS SDK
// see https://github.com/aws/aws-cli/blob/master/awscli/customizations/sessionmanager.py#L104
sessionParams, err := StartTunnelSession(ctx, cfg)
sessionParams, err := StartTunnelSession(ctx, awsCfg, cfg)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -102,7 +103,6 @@ func StartRemoteTunnel(ctx context.Context, cfgJson string, parentPid int) (err
return err
}

profileName := ""
endpointUrl, err := GetEndpoint(ctx, cfg.SSMRegion)
if err != nil {
return err
Expand All @@ -113,7 +113,7 @@ func StartRemoteTunnel(ctx context.Context, cfgJson string, parentPid int) (err
DEFAULT_SSM_ENV_NAME,
cfg.SSMRegion,
"StartSession",
profileName,
cfg.SSMProfile,
string(sessionInputJson),
endpointUrl,
}
Expand Down

0 comments on commit 8ec4082

Please sign in to comment.