Skip to content

Commit

Permalink
added in changes from PR: dpkp#2255 to support AWS_MSK_IAM authentica…
Browse files Browse the repository at this point in the history
…tion
  • Loading branch information
kenna-bmcdevitt committed Jun 15, 2023
1 parent 7ac6c6e commit fdb1f4b
Show file tree
Hide file tree
Showing 3 changed files with 301 additions and 1 deletion.
52 changes: 51 additions & 1 deletion kafka/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import kafka.errors as Errors
from kafka.future import Future
from kafka.metrics.stats import Avg, Count, Max, Rate
from kafka.msk import AwsMskIamClient
from kafka.oauth.abstract import AbstractTokenProvider
from kafka.protocol.admin import SaslHandShakeRequest, DescribeAclsRequest_v2, DescribeClientQuotasRequest
from kafka.protocol.commit import OffsetFetchRequest
Expand Down Expand Up @@ -83,6 +84,12 @@ class SSLWantWriteError(Exception):
gssapi = None
GSSError = None

# needed for AWS_MSK_IAM authentication:
try:
from botocore.session import Session as BotoSession
except ImportError:
# no botocore available, will disable AWS_MSK_IAM mechanism
BotoSession = None

AFI_NAMES = {
socket.AF_UNSPEC: "unspecified",
Expand Down Expand Up @@ -227,7 +234,7 @@ class BrokerConnection(object):
'sasl_oauth_token_provider': None
}
SECURITY_PROTOCOLS = ('PLAINTEXT', 'SSL', 'SASL_PLAINTEXT', 'SASL_SSL')
SASL_MECHANISMS = ('PLAIN', 'GSSAPI', 'OAUTHBEARER', "SCRAM-SHA-256", "SCRAM-SHA-512")
SASL_MECHANISMS = ('PLAIN', 'GSSAPI', 'OAUTHBEARER', "SCRAM-SHA-256", "SCRAM-SHA-512", 'AWS_MSK_IAM')

def __init__(self, host, port, afi, **configs):
self.host = host
Expand Down Expand Up @@ -276,6 +283,9 @@ def __init__(self, host, port, afi, **configs):
token_provider = self.config['sasl_oauth_token_provider']
assert token_provider is not None, 'sasl_oauth_token_provider required for OAUTHBEARER sasl'
assert callable(getattr(token_provider, "token", None)), 'sasl_oauth_token_provider must implement method #token()'
if self.config['sasl_mechanism'] == 'AWS_MSK_IAM':
assert BotoSession is not None, 'AWS_MSK_IAM requires the "botocore" package'
assert self.config['security_protocol'] == 'SASL_SSL', 'AWS_MSK_IAM requires SASL_SSL'
# This is not a general lock / this class is not generally thread-safe yet
# However, to avoid pushing responsibility for maintaining
# per-connection locks to the upstream client, we will use this lock to
Expand Down Expand Up @@ -561,6 +571,8 @@ def _handle_sasl_handshake_response(self, future, response):
return self._try_authenticate_oauth(future)
elif self.config['sasl_mechanism'].startswith("SCRAM-SHA-"):
return self._try_authenticate_scram(future)
elif self.config['sasl_mechanism'] == 'AWS_MSK_IAM':
return self._try_authenticate_aws_msk_iam(future)
else:
return future.failure(
Errors.UnsupportedSaslMechanismError(
Expand Down Expand Up @@ -660,6 +672,44 @@ def _try_authenticate_plain(self, future):

log.info('%s: Authenticated as %s via PLAIN', self, self.config['sasl_plain_username'])
return future.success(True)

def _try_authenticate_aws_msk_iam(self, future):
session = BotoSession()
credentials = session.get_credentials().get_frozen_credentials()
client = AwsMskIamClient(
host=self.host,
access_key=credentials.access_key,
secret_key=credentials.secret_key,
region=session.get_config_variable('region'),
token=credentials.token,
)

msg = client.first_message()
size = Int32.encode(len(msg))

err = None
close = False
with self._lock:
if not self._can_send_recv():
err = Errors.NodeNotReadyError(str(self))
close = False
else:
try:
self._send_bytes_blocking(size + msg)
data = self._recv_bytes_blocking(4)
data = self._recv_bytes_blocking(struct.unpack('4B', data)[-1])
except (ConnectionError, TimeoutError) as e:
log.exception("%s: Error receiving reply from server", self)
err = Errors.KafkaConnectionError("%s: %s" % (self, e))
close = True

if err is not None:
if close:
self.close(error=err)
return future.failure(err)

log.info('%s: Authenticated via AWS_MSK_IAM %s', self, data.decode('utf-8'))
return future.success(True)

def _try_authenticate_scram(self, future):
if self.config['security_protocol'] == 'SASL_PLAINTEXT':
Expand Down
183 changes: 183 additions & 0 deletions kafka/msk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import datetime
import hashlib
import hmac
import json
import string

from kafka.vendor.six.moves import urllib


class AwsMskIamClient:
UNRESERVED_CHARS = string.ascii_letters + string.digits + '-._~'

def __init__(self, host, access_key, secret_key, region, token=None):
"""
Arguments:
host (str): The hostname of the broker.
access_key (str): An AWS_ACCESS_KEY_ID.
secret_key (str): An AWS_SECRET_ACCESS_KEY.
region (str): An AWS_REGION.
token (Optional[str]): An AWS_SESSION_TOKEN if using temporary
credentials.
"""
self.algorithm = 'AWS4-HMAC-SHA256'
self.expires = '900'
self.hashfunc = hashlib.sha256
self.headers = [
('host', host)
]
self.version = '2020_10_22'

self.service = 'kafka-cluster'
self.action = '{}:Connect'.format(self.service)

now = datetime.datetime.utcnow()
self.datestamp = now.strftime('%Y%m%d')
self.timestamp = now.strftime('%Y%m%dT%H%M%SZ')

self.host = host
self.access_key = access_key
self.secret_key = secret_key
self.region = region
self.token = token

@property
def _credential(self):
return '{0.access_key}/{0._scope}'.format(self)

@property
def _scope(self):
return '{0.datestamp}/{0.region}/{0.service}/aws4_request'.format(self)

@property
def _signed_headers(self):
"""
Returns (str):
An alphabetically sorted, semicolon-delimited list of lowercase
request header names.
"""
return ';'.join(sorted(k.lower() for k, _ in self.headers))

@property
def _canonical_headers(self):
"""
Returns (str):
A newline-delited list of header names and values.
Header names are lowercased.
"""
return '\n'.join(map(':'.join, self.headers)) + '\n'

@property
def _canonical_request(self):
"""
Returns (str):
An AWS Signature Version 4 canonical request in the format:
<Method>\n
<Path>\n
<CanonicalQueryString>\n
<CanonicalHeaders>\n
<SignedHeaders>\n
<HashedPayload>
"""
# The hashed_payload is always an empty string for MSK.
hashed_payload = self.hashfunc(b'').hexdigest()
return '\n'.join((
'GET',
'/',
self._canonical_querystring,
self._canonical_headers,
self._signed_headers,
hashed_payload,
))

@property
def _canonical_querystring(self):
"""
Returns (str):
A '&'-separated list of URI-encoded key/value pairs.
"""
params = []
params.append(('Action', self.action))
params.append(('X-Amz-Algorithm', self.algorithm))
params.append(('X-Amz-Credential', self._credential))
params.append(('X-Amz-Date', self.timestamp))
params.append(('X-Amz-Expires', self.expires))
if self.token:
params.append(('X-Amz-Security-Token', self.token))
params.append(('X-Amz-SignedHeaders', self._signed_headers))

return '&'.join(self._uriencode(k) + '=' + self._uriencode(v) for k, v in params)

@property
def _signing_key(self):
"""
Returns (bytes):
An AWS Signature V4 signing key generated from the secret_key, date,
region, service, and request type.
"""
key = self._hmac(('AWS4' + self.secret_key).encode('utf-8'), self.datestamp)
key = self._hmac(key, self.region)
key = self._hmac(key, self.service)
key = self._hmac(key, 'aws4_request')
return key

@property
def _signing_str(self):
"""
Returns (str):
A string used to sign the AWS Signature V4 payload in the format:
<Algorithm>\n
<Timestamp>\n
<Scope>\n
<CanonicalRequestHash>
"""
canonical_request_hash = self.hashfunc(self._canonical_request.encode('utf-8')).hexdigest()
return '\n'.join((self.algorithm, self.timestamp, self._scope, canonical_request_hash))

def _uriencode(self, msg):
"""
Arguments:
msg (str): A string to URI-encode.
Returns (str):
The URI-encoded version of the provided msg, following the encoding
rules specified: https://github.com/aws/aws-msk-iam-auth#uriencode
"""
return urllib.parse.quote(msg, safe=self.UNRESERVED_CHARS)

def _hmac(self, key, msg):
"""
Arguments:
key (bytes): A key to use for the HMAC digest.
msg (str): A value to include in the HMAC digest.
Returns (bytes):
An HMAC digest of the given key and msg.
"""
return hmac.new(key, msg.encode('utf-8'), digestmod=self.hashfunc).digest()

def first_message(self):
"""
Returns (bytes):
An encoded JSON authentication payload that can be sent to the
broker.
"""
signature = hmac.new(
self._signing_key,
self._signing_str.encode('utf-8'),
digestmod=self.hashfunc,
).hexdigest()
msg = {
'version': self.version,
'host': self.host,
'user-agent': 'kafka-python',
'action': self.action,
'x-amz-algorithm': self.algorithm,
'x-amz-credential': self._credential,
'x-amz-date': self.timestamp,
'x-amz-signedheaders': self._signed_headers,
'x-amz-expires': self.expires,
'x-amz-signature': signature,
}
if self.token:
msg['x-amz-security-token'] = self.token

return json.dumps(msg, separators=(',', ':')).encode('utf-8')
67 changes: 67 additions & 0 deletions test/test_msk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
mport datetime
import json

from kafka.msk import AwsMskIamClient

try:
from unittest import mock
except ImportError:
import mock


def client_factory(token=None):
now = datetime.datetime.utcfromtimestamp(1629321911)
with mock.patch('kafka.msk.datetime') as mock_dt:
mock_dt.datetime.utcnow = mock.Mock(return_value=now)
return AwsMskIamClient(
host='localhost',
access_key='XXXXXXXXXXXXXXXXXXXX',
secret_key='XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX',
region='us-east-1',
token=token,
)


def test_aws_msk_iam_client_permanent_credentials():
client = client_factory(token=None)
msg = client.first_message()
assert msg
assert isinstance(msg, bytes)
actual = json.loads(msg)

expected = {
'version': '2020_10_22',
'host': 'localhost',
'user-agent': 'kafka-python',
'action': 'kafka-cluster:Connect',
'x-amz-algorithm': 'AWS4-HMAC-SHA256',
'x-amz-credential': 'XXXXXXXXXXXXXXXXXXXX/20210818/us-east-1/kafka-cluster/aws4_request',
'x-amz-date': '20210818T212511Z',
'x-amz-signedheaders': 'host',
'x-amz-expires': '900',
'x-amz-signature': '0fa42ae3d5693777942a7a4028b564f0b372bafa2f71c1a19ad60680e6cb994b',
}
assert actual == expected


def test_aws_msk_iam_client_temporary_credentials():
client = client_factory(token='XXXXX')
msg = client.first_message()
assert msg
assert isinstance(msg, bytes)
actual = json.loads(msg)

expected = {
'version': '2020_10_22',
'host': 'localhost',
'user-agent': 'kafka-python',
'action': 'kafka-cluster:Connect',
'x-amz-algorithm': 'AWS4-HMAC-SHA256',
'x-amz-credential': 'XXXXXXXXXXXXXXXXXXXX/20210818/us-east-1/kafka-cluster/aws4_request',
'x-amz-date': '20210818T212511Z',
'x-amz-signedheaders': 'host',
'x-amz-expires': '900',
'x-amz-signature': 'b0619c50b7ecb4a7f6f92bd5f733770df5710e97b25146f97015c0b1db783b05',
'x-amz-security-token': 'XXXXX',
}
assert actual == expected

0 comments on commit fdb1f4b

Please sign in to comment.