diff --git a/internal/examples/sqs/consumer/consumer.go b/internal/examples/sqs/consumer/consumer.go index f63bacf..02e2fe0 100644 --- a/internal/examples/sqs/consumer/consumer.go +++ b/internal/examples/sqs/consumer/consumer.go @@ -46,6 +46,7 @@ func main() { d, err := sqs.New( sqs.WithUrl(os.Getenv("AWS_QUEUE_URL")), sqs.WithRegion(os.Getenv("AWS_REGION")), + sqs.WithSharedCredentials(os.Getenv("AWS_SHARED_CREDENTIALS_FILE"), "default"), sqs.AutoTestConnection(), ) if err != nil { @@ -70,14 +71,19 @@ func main() { // Ping the sqs to check connectivity with AWS go func() { - for { - <-time.After(time.Second * 10) + ping := func() { if err := sub.Ping(); err != nil { log.Println("Ping failed: " + err.Error()) } else { log.Println("Ping OK") } } + + ping() + for { + <-time.After(time.Second * 10) + ping() + } }() log.Print("Waiting for consuming") diff --git a/internal/examples/sqs/producer/producer.go b/internal/examples/sqs/producer/producer.go index 06ebf73..825009f 100644 --- a/internal/examples/sqs/producer/producer.go +++ b/internal/examples/sqs/producer/producer.go @@ -33,6 +33,7 @@ func main() { d, err := sqs.New( sqs.WithUrl(os.Getenv("AWS_QUEUE_URL")), sqs.WithRegion(os.Getenv("AWS_REGION")), + sqs.WithSharedCredentials(os.Getenv("AWS_SHARED_CREDENTIALS_FILE"), "default"), ) if err != nil { diff --git a/sqs/options.go b/sqs/options.go index 5001153..52949f5 100644 --- a/sqs/options.go +++ b/sqs/options.go @@ -25,3 +25,9 @@ func WithRegion(region string) Option { d.region = region } } + +func WithSharedCredentials(filename, profile string) Option { + return func(d *Driver) { + d.SetSharedCredentials(filename, profile) + } +} diff --git a/sqs/sqs.go b/sqs/sqs.go index de2afdd..6d1b6a9 100644 --- a/sqs/sqs.go +++ b/sqs/sqs.go @@ -25,6 +25,18 @@ type Driver struct { url string sqsClient sqsClient testConnectionOnStartup bool + + sharedCredentials *struct { + filename string + profile string + } +} + +func (d *Driver) SetSharedCredentials(filename, profile string) { + d.sharedCredentials = &struct { + filename string + profile string + }{filename: filename, profile: profile} } func New(options ...Option) (*Driver, error) { @@ -37,7 +49,7 @@ func New(options ...Option) (*Driver, error) { } if driver.sqsClient == nil { - clientCredentials, err := getCredentials() + clientCredentials, err := driver.getCredentials() if err != nil { return nil, err } @@ -59,15 +71,20 @@ func New(options ...Option) (*Driver, error) { return driver, nil } -func getCredentials() (*credentials.Credentials, error) { - if os.Getenv("AWS_SHARED_CREDENTIALS_FILE") != "" { - return credentials.NewSharedCredentials("", ""), nil - } else if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" { +func (d *Driver) getCredentials() (*credentials.Credentials, error) { + if d.sharedCredentials != nil { + return credentials.NewSharedCredentials( + d.sharedCredentials.filename, + d.sharedCredentials.profile, + ), nil + } + + if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" { return credentials.NewEnvCredentials(), nil } return nil, errors.New( - "missing AWS_SHARED_CREDENTIALS_FILE and AWS_ACCESS_KEY_ID/AWS_SECRET_ACCESS_KEY env vars", + "missing shared credentials and AWS_ACCESS_KEY_ID/AWS_SECRET_ACCESS_KEY env vars", ) } @@ -85,5 +102,10 @@ func createClient(queueUrl string, region string, clientCredentials *credentials }, } - return sqs.New(session.Must(session.NewSessionWithOptions(options))), nil + sqsSession, err := session.NewSessionWithOptions(options) + if err != nil { + return nil, fmt.Errorf("error creating sqs session: %w", err) + } + + return sqs.New(sqsSession), nil } diff --git a/sqs/sqs_test.go b/sqs/sqs_test.go index 7adf42f..ab77c30 100644 --- a/sqs/sqs_test.go +++ b/sqs/sqs_test.go @@ -41,11 +41,10 @@ func (suite *SQSTestSuite) TearDownTest() { } func (suite *SQSTestSuite) TestNewWIthUrlAndRegionOption() { - os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "test") - _, err := sqs.New( sqs.WithUrl("https://sqs.eu-central-1.amazonaws.com"), sqs.WithRegion("us-east-1"), + sqs.WithSharedCredentials("test", "default"), ) suite.Nil(err) @@ -59,8 +58,8 @@ func (suite *SQSTestSuite) TestNewWithDefaultOptions() { } func (suite *SQSTestSuite) TestNew_InvalidQueueURL() { - os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "/a/file") _, err := sqs.New( + sqs.WithSharedCredentials("test", "default"), sqs.WithUrl("-"), )