Skip to content

Support AWS_MSK_IAM authentication #2519

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions kafka/sasl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import absolute_import

from kafka.sasl.gssapi import SaslMechanismGSSAPI
from kafka.sasl.msk import SaslMechanismAwsMskIam
from kafka.sasl.oauth import SaslMechanismOAuth
from kafka.sasl.plain import SaslMechanismPlain
from kafka.sasl.scram import SaslMechanismScram
Expand All @@ -24,3 +25,4 @@ def get_sasl_mechanism(name):
register_sasl_mechanism('PLAIN', SaslMechanismPlain)
register_sasl_mechanism('SCRAM-SHA-256', SaslMechanismScram)
register_sasl_mechanism('SCRAM-SHA-512', SaslMechanismScram)
register_sasl_mechanism('AWS_MSK_IAM', SaslMechanismAwsMskIam)
233 changes: 233 additions & 0 deletions kafka/sasl/msk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
from __future__ import absolute_import

import datetime
import hashlib
import hmac
import json
import string

# needed for AWS_MSK_IAM authentication:
try:
from botocore.session import Session as BotoSession
except ImportError:
# no botocore available, will disable AWS_MSK_IAM mechanism
BotoSession = None

from kafka.sasl.abc import SaslMechanism
from kafka.vendor.six.moves import urllib


class SaslMechanismAwsMskIam(SaslMechanism):
def __init__(self, **config):
assert BotoSession is not None, 'AWS_MSK_IAM requires the "botocore" package'
assert config.get('security_protocol', '') == 'SASL_SSL', 'AWS_MSK_IAM requires SASL_SSL'
assert 'host' in config, 'AWS_MSK_IAM requires host configuration'
self.host = config['host']
self._auth = None
self._is_done = False
self._is_authenticated = False

def auth_bytes(self):
session = BotoSession()
credentials = session.get_credentials().get_frozen_credentials()
client = AwsMskIamClient(
host=self.host,
access_key=credentials.access_key,
secret_key=credentials.secret_key,
region=session.get_config_variable('region'),
token=credentials.token,
)
return client.first_message()

def receive(self, auth_bytes):
self._is_done = True
self._is_authenticated = auth_bytes != b''
self._auth = auth_bytes.deode('utf-8')

def is_done(self):
return self._is_done

def is_authenticated(self):
return self._is_authenticated

def auth_details(self):
if not self.is_authenticated:
raise RuntimeError('Not authenticated yet!')
return 'Authenticated via SASL / AWS_MSK_IAM %s' % (self._auth,)


class AwsMskIamClient:
UNRESERVED_CHARS = string.ascii_letters + string.digits + '-._~'

def __init__(self, host, access_key, secret_key, region, token=None):
"""
Arguments:
host (str): The hostname of the broker.
access_key (str): An AWS_ACCESS_KEY_ID.
secret_key (str): An AWS_SECRET_ACCESS_KEY.
region (str): An AWS_REGION.
token (Optional[str]): An AWS_SESSION_TOKEN if using temporary
credentials.
"""
self.algorithm = 'AWS4-HMAC-SHA256'
self.expires = '900'
self.hashfunc = hashlib.sha256
self.headers = [
('host', host)
]
self.version = '2020_10_22'

self.service = 'kafka-cluster'
self.action = '{}:Connect'.format(self.service)

now = datetime.datetime.utcnow()
self.datestamp = now.strftime('%Y%m%d')
self.timestamp = now.strftime('%Y%m%dT%H%M%SZ')

self.host = host
self.access_key = access_key
self.secret_key = secret_key
self.region = region
self.token = token

@property
def _credential(self):
return '{0.access_key}/{0._scope}'.format(self)

@property
def _scope(self):
return '{0.datestamp}/{0.region}/{0.service}/aws4_request'.format(self)

@property
def _signed_headers(self):
"""
Returns (str):
An alphabetically sorted, semicolon-delimited list of lowercase
request header names.
"""
return ';'.join(sorted(k.lower() for k, _ in self.headers))

@property
def _canonical_headers(self):
"""
Returns (str):
A newline-delited list of header names and values.
Header names are lowercased.
"""
return '\n'.join(map(':'.join, self.headers)) + '\n'

@property
def _canonical_request(self):
"""
Returns (str):
An AWS Signature Version 4 canonical request in the format:
<Method>\n
<Path>\n
<CanonicalQueryString>\n
<CanonicalHeaders>\n
<SignedHeaders>\n
<HashedPayload>
"""
# The hashed_payload is always an empty string for MSK.
hashed_payload = self.hashfunc(b'').hexdigest()
return '\n'.join((
'GET',
'/',
self._canonical_querystring,
self._canonical_headers,
self._signed_headers,
hashed_payload,
))

@property
def _canonical_querystring(self):
"""
Returns (str):
A '&'-separated list of URI-encoded key/value pairs.
"""
params = []
params.append(('Action', self.action))
params.append(('X-Amz-Algorithm', self.algorithm))
params.append(('X-Amz-Credential', self._credential))
params.append(('X-Amz-Date', self.timestamp))
params.append(('X-Amz-Expires', self.expires))
if self.token:
params.append(('X-Amz-Security-Token', self.token))
params.append(('X-Amz-SignedHeaders', self._signed_headers))

return '&'.join(self._uriencode(k) + '=' + self._uriencode(v) for k, v in params)

@property
def _signing_key(self):
"""
Returns (bytes):
An AWS Signature V4 signing key generated from the secret_key, date,
region, service, and request type.
"""
key = self._hmac(('AWS4' + self.secret_key).encode('utf-8'), self.datestamp)
key = self._hmac(key, self.region)
key = self._hmac(key, self.service)
key = self._hmac(key, 'aws4_request')
return key

@property
def _signing_str(self):
"""
Returns (str):
A string used to sign the AWS Signature V4 payload in the format:
<Algorithm>\n
<Timestamp>\n
<Scope>\n
<CanonicalRequestHash>
"""
canonical_request_hash = self.hashfunc(self._canonical_request.encode('utf-8')).hexdigest()
return '\n'.join((self.algorithm, self.timestamp, self._scope, canonical_request_hash))

def _uriencode(self, msg):
"""
Arguments:
msg (str): A string to URI-encode.

Returns (str):
The URI-encoded version of the provided msg, following the encoding
rules specified: https://github.com/aws/aws-msk-iam-auth#uriencode
"""
return urllib.parse.quote(msg, safe=self.UNRESERVED_CHARS)

def _hmac(self, key, msg):
"""
Arguments:
key (bytes): A key to use for the HMAC digest.
msg (str): A value to include in the HMAC digest.
Returns (bytes):
An HMAC digest of the given key and msg.
"""
return hmac.new(key, msg.encode('utf-8'), digestmod=self.hashfunc).digest()

def first_message(self):
"""
Returns (bytes):
An encoded JSON authentication payload that can be sent to the
broker.
"""
signature = hmac.new(
self._signing_key,
self._signing_str.encode('utf-8'),
digestmod=self.hashfunc,
).hexdigest()
msg = {
'version': self.version,
'host': self.host,
'user-agent': 'kafka-python',
'action': self.action,
'x-amz-algorithm': self.algorithm,
'x-amz-credential': self._credential,
'x-amz-date': self.timestamp,
'x-amz-signedheaders': self._signed_headers,
'x-amz-expires': self.expires,
'x-amz-signature': signature,
}
if self.token:
msg['x-amz-security-token'] = self.token

return json.dumps(msg, separators=(',', ':')).encode('utf-8')
67 changes: 67 additions & 0 deletions test/sasl/test_msk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import datetime
import json

from kafka.sasl.msk import AwsMskIamClient

try:
from unittest import mock
except ImportError:
import mock


def client_factory(token=None):
now = datetime.datetime.utcfromtimestamp(1629321911)
with mock.patch('kafka.sasl.msk.datetime') as mock_dt:
mock_dt.datetime.utcnow = mock.Mock(return_value=now)
return AwsMskIamClient(
host='localhost',
access_key='XXXXXXXXXXXXXXXXXXXX',
secret_key='XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX',
region='us-east-1',
token=token,
)


def test_aws_msk_iam_client_permanent_credentials():
client = client_factory(token=None)
msg = client.first_message()
assert msg
assert isinstance(msg, bytes)
actual = json.loads(msg)

expected = {
'version': '2020_10_22',
'host': 'localhost',
'user-agent': 'kafka-python',
'action': 'kafka-cluster:Connect',
'x-amz-algorithm': 'AWS4-HMAC-SHA256',
'x-amz-credential': 'XXXXXXXXXXXXXXXXXXXX/20210818/us-east-1/kafka-cluster/aws4_request',
'x-amz-date': '20210818T212511Z',
'x-amz-signedheaders': 'host',
'x-amz-expires': '900',
'x-amz-signature': '0fa42ae3d5693777942a7a4028b564f0b372bafa2f71c1a19ad60680e6cb994b',
}
assert actual == expected


def test_aws_msk_iam_client_temporary_credentials():
client = client_factory(token='XXXXX')
msg = client.first_message()
assert msg
assert isinstance(msg, bytes)
actual = json.loads(msg)

expected = {
'version': '2020_10_22',
'host': 'localhost',
'user-agent': 'kafka-python',
'action': 'kafka-cluster:Connect',
'x-amz-algorithm': 'AWS4-HMAC-SHA256',
'x-amz-credential': 'XXXXXXXXXXXXXXXXXXXX/20210818/us-east-1/kafka-cluster/aws4_request',
'x-amz-date': '20210818T212511Z',
'x-amz-signedheaders': 'host',
'x-amz-expires': '900',
'x-amz-signature': 'b0619c50b7ecb4a7f6f92bd5f733770df5710e97b25146f97015c0b1db783b05',
'x-amz-security-token': 'XXXXX',
}
assert actual == expected