Skip to content

Commit cb18e67

Browse files
committed
Support AWS_MSK_IAM authentication
Adds an AWS_MSK_IAM authentication mechanism which is described here: * https://github.com/aws/aws-msk-iam-auth#uriencode To use the mechanism pass the following keyword arguments when initializing a class: ``` security_protocol='SASL_SSL', sasl_mechanism='AWS_MSK_IAM', bootstrap_servers=[ 'b-1.cluster.x.y.kafka.region.amazonaws.com:9088', ... ], ``` The credentials and region will be pulled using `botocore.session.Session`. Using the mechanism requires the `botocore` library which can be installed with: ```sh pip install botocore ``` **TODO:** - [ ] Documentation - [ ] Tests - [ ] Refresh mechanism for temporary credentials?
1 parent f0a57a6 commit cb18e67

File tree

2 files changed

+234
-1
lines changed

2 files changed

+234
-1
lines changed

kafka/conn.py

+51-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import kafka.errors as Errors
2424
from kafka.future import Future
2525
from kafka.metrics.stats import Avg, Count, Max, Rate
26+
from kafka.msk import AwsMskIamClient
2627
from kafka.oauth.abstract import AbstractTokenProvider
2728
from kafka.protocol.admin import SaslHandShakeRequest, DescribeAclsRequest_v2, DescribeClientQuotasRequest
2829
from kafka.protocol.commit import OffsetFetchRequest
@@ -83,6 +84,12 @@ class SSLWantWriteError(Exception):
8384
gssapi = None
8485
GSSError = None
8586

87+
# needed for AWS_MSK_IAM authentication:
88+
try:
89+
from botocore.session import Session as BotoSession
90+
except ImportError:
91+
# no botocore available, will disable AWS_MSK_IAM mechanism
92+
BotoSession = None
8693

8794
AFI_NAMES = {
8895
socket.AF_UNSPEC: "unspecified",
@@ -227,7 +234,7 @@ class BrokerConnection(object):
227234
'sasl_oauth_token_provider': None
228235
}
229236
SECURITY_PROTOCOLS = ('PLAINTEXT', 'SSL', 'SASL_PLAINTEXT', 'SASL_SSL')
230-
SASL_MECHANISMS = ('PLAIN', 'GSSAPI', 'OAUTHBEARER', "SCRAM-SHA-256", "SCRAM-SHA-512")
237+
SASL_MECHANISMS = ('PLAIN', 'GSSAPI', 'OAUTHBEARER', "SCRAM-SHA-256", "SCRAM-SHA-512", 'AWS_MSK_IAM')
231238

232239
def __init__(self, host, port, afi, **configs):
233240
self.host = host
@@ -276,6 +283,9 @@ def __init__(self, host, port, afi, **configs):
276283
token_provider = self.config['sasl_oauth_token_provider']
277284
assert token_provider is not None, 'sasl_oauth_token_provider required for OAUTHBEARER sasl'
278285
assert callable(getattr(token_provider, "token", None)), 'sasl_oauth_token_provider must implement method #token()'
286+
if self.config['sasl_mechanism'] == 'AWS_MSK_IAM':
287+
assert BotoSession is not None, 'AWS_MSK_IAM requires the "botocore" package'
288+
assert self.config['security_protocol'] == 'SASL_SSL', 'AWS_MSK_IAM requires SASL_SSL'
279289
# This is not a general lock / this class is not generally thread-safe yet
280290
# However, to avoid pushing responsibility for maintaining
281291
# per-connection locks to the upstream client, we will use this lock to
@@ -561,6 +571,8 @@ def _handle_sasl_handshake_response(self, future, response):
561571
return self._try_authenticate_oauth(future)
562572
elif self.config['sasl_mechanism'].startswith("SCRAM-SHA-"):
563573
return self._try_authenticate_scram(future)
574+
elif self.config['sasl_mechanism'] == 'AWS_MSK_IAM':
575+
return self._try_authenticate_aws_msk_iam(future)
564576
else:
565577
return future.failure(
566578
Errors.UnsupportedSaslMechanismError(
@@ -661,6 +673,44 @@ def _try_authenticate_plain(self, future):
661673
log.info('%s: Authenticated as %s via PLAIN', self, self.config['sasl_plain_username'])
662674
return future.success(True)
663675

676+
def _try_authenticate_aws_msk_iam(self, future):
677+
session = BotoSession()
678+
credentials = session.get_credentials().get_frozen_credentials()
679+
client = AwsMskIamClient(
680+
host=self.host,
681+
access_key=credentials.access_key,
682+
secret_key=credentials.secret_key,
683+
region=session.get_config_variable('region'),
684+
token=credentials.token,
685+
)
686+
687+
msg = client.first_message()
688+
size = Int32.encode(len(msg))
689+
690+
err = None
691+
close = False
692+
with self._lock:
693+
if not self._can_send_recv():
694+
err = Errors.NodeNotReadyError(str(self))
695+
close = False
696+
else:
697+
try:
698+
self._send_bytes_blocking(size + msg)
699+
data = self._recv_bytes_blocking(4)
700+
data = self._recv_bytes_blocking(struct.unpack('4B', data)[-1])
701+
except (ConnectionError, TimeoutError) as e:
702+
log.exception("%s: Error receiving reply from server", self)
703+
err = Errors.KafkaConnectionError("%s: %s" % (self, e))
704+
close = True
705+
706+
if err is not None:
707+
if close:
708+
self.close(error=err)
709+
return future.failure(err)
710+
711+
log.info('%s: Authenticated via AWS_MSK_IAM %s', self, data.decode('utf-8'))
712+
return future.success(True)
713+
664714
def _try_authenticate_scram(self, future):
665715
if self.config['security_protocol'] == 'SASL_PLAINTEXT':
666716
log.warning('%s: Exchanging credentials in the clear', self)

kafka/msk.py

+183
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
import datetime
2+
import hashlib
3+
import hmac
4+
import json
5+
import string
6+
import urllib.parse
7+
8+
9+
class AwsMskIamClient:
10+
UNRESERVED_CHARS = string.ascii_letters + string.digits + '-._~'
11+
12+
def __init__(self, host, access_key, secret_key, region, token=None):
13+
"""
14+
Arguments:
15+
host (str): The hostname of the broker.
16+
access_key (str): An AWS_ACCESS_KEY_ID.
17+
secret_key (str): An AWS_SECRET_ACCESS_KEY.
18+
region (str): An AWS_REGION.
19+
token (Optional[str]): An AWS_SESSION_TOKEN if using temporary
20+
credentials.
21+
"""
22+
self.algorithm = 'AWS4-HMAC-SHA256'
23+
self.expires = '900'
24+
self.hashfunc = hashlib.sha256
25+
self.headers = [
26+
('host', host)
27+
]
28+
self.version = '2020_10_22'
29+
30+
self.service = 'kafka-cluster'
31+
self.action = '{}:Connect'.format(self.service)
32+
33+
now = datetime.datetime.utcnow()
34+
self.datestamp = now.strftime('%Y%m%d')
35+
self.timestamp = now.strftime('%Y%m%dT%H%M%SZ')
36+
37+
self.host = host
38+
self.access_key = access_key
39+
self.secret_key = secret_key
40+
self.region = region
41+
self.token = token
42+
43+
@property
44+
def _credential(self):
45+
return '{0.access_key}/{0._scope}'.format(self)
46+
47+
@property
48+
def _scope(self):
49+
return '{0.datestamp}/{0.region}/{0.service}/aws4_request'.format(self)
50+
51+
@property
52+
def _signed_headers(self):
53+
"""
54+
Returns (str):
55+
An alphabetically sorted, semicolon-delimited list of lowercase
56+
request header names.
57+
"""
58+
return ';'.join(sorted(k.lower() for k, _ in self.headers))
59+
60+
@property
61+
def _canonical_headers(self):
62+
"""
63+
Returns (str):
64+
A newline-delited list of header names and values.
65+
Header names are lowercased.
66+
"""
67+
return '\n'.join(map(':'.join, self.headers)) + '\n'
68+
69+
@property
70+
def _canonical_request(self):
71+
"""
72+
Returns (str):
73+
An AWS Signature Version 4 canonical request in the format:
74+
<Method>\n
75+
<Path>\n
76+
<CanonicalQueryString>\n
77+
<CanonicalHeaders>\n
78+
<SignedHeaders>\n
79+
<HashedPayload>
80+
"""
81+
# The hashed_payload is always an empty string for MSK.
82+
hashed_payload = self.hashfunc(b'').hexdigest()
83+
return '\n'.join((
84+
'GET',
85+
'/',
86+
self._canonical_querystring,
87+
self._canonical_headers,
88+
self._signed_headers,
89+
hashed_payload,
90+
))
91+
92+
@property
93+
def _canonical_querystring(self):
94+
"""
95+
Returns (str):
96+
A '&'-separated list of URI-encoded key/value pairs.
97+
"""
98+
params = []
99+
params.append(('Action', self.action))
100+
params.append(('X-Amz-Algorithm', self.algorithm))
101+
params.append(('X-Amz-Credential', self._credential))
102+
params.append(('X-Amz-Date', self.timestamp))
103+
params.append(('X-Amz-Expires', self.expires))
104+
if self.token:
105+
params.append(('X-Amz-Security-Token', self.token))
106+
params.append(('X-Amz-SignedHeaders', self._signed_headers))
107+
108+
return '&'.join(self._uriencode(k) + '=' + self._uriencode(v) for k, v in params)
109+
110+
@property
111+
def _signing_key(self):
112+
"""
113+
Returns (bytes):
114+
An AWS Signature V4 signing key generated from the secret_key, date,
115+
region, service, and request type.
116+
"""
117+
key = self._hmac(('AWS4' + self.secret_key).encode('utf-8'), self.datestamp)
118+
key = self._hmac(key, self.region)
119+
key = self._hmac(key, self.service)
120+
key = self._hmac(key, 'aws4_request')
121+
return key
122+
123+
@property
124+
def _signing_str(self):
125+
"""
126+
Returns (str):
127+
A string used to sign the AWS Signature V4 payload in the format:
128+
<Algorithm>\n
129+
<Timestamp>\n
130+
<Scope>\n
131+
<CanonicalRequestHash>
132+
"""
133+
canonical_request_hash = self.hashfunc(self._canonical_request.encode('utf-8')).hexdigest()
134+
return '\n'.join((self.algorithm, self.timestamp, self._scope, canonical_request_hash))
135+
136+
def _uriencode(self, msg):
137+
"""
138+
Arguments:
139+
msg (str): A string to URI-encode.
140+
141+
Returns (str):
142+
The URI-encoded version of the provided msg, following the encoding
143+
rules specified: https://github.com/aws/aws-msk-iam-auth#uriencode
144+
"""
145+
return urllib.parse.quote(msg, safe=self.UNRESERVED_CHARS)
146+
147+
def _hmac(self, key, msg):
148+
"""
149+
Arguments:
150+
key (bytes): A key to use for the HMAC digest.
151+
msg (str): A value to include in the HMAC digest.
152+
Returns (bytes):
153+
An HMAC digest of the given key and msg.
154+
"""
155+
return hmac.new(key, msg.encode('utf-8'), digestmod=self.hashfunc).digest()
156+
157+
def first_message(self):
158+
"""
159+
Returns (bytes):
160+
An encoded JSON authentication payload that can be sent to the
161+
broker.
162+
"""
163+
signature = hmac.new(
164+
self._signing_key,
165+
self._signing_str.encode('utf-8'),
166+
digestmod=self.hashfunc,
167+
).hexdigest()
168+
msg = {
169+
'version': self.version,
170+
'host': self.host,
171+
'user-agent': 'kafka-python',
172+
'action': self.action,
173+
'x-amz-algorithm': self.algorithm,
174+
'x-amz-credential': self._credential,
175+
'x-amz-date': self.timestamp,
176+
'x-amz-signedheaders': self._signed_headers,
177+
'x-amz-expires': self.expires,
178+
'x-amz-signature': signature,
179+
}
180+
if self.token:
181+
msg['x-amz-security-token'] = self.token
182+
183+
return json.dumps(msg, separators=(',', ':')).encode('utf-8')

0 commit comments

Comments
 (0)