diff --git a/lib/bindings/python/src/dynamo/nixl_connect/__init__.py b/lib/bindings/python/src/dynamo/nixl_connect/__init__.py index bf523daa2a..c4478367a0 100644 --- a/lib/bindings/python/src/dynamo/nixl_connect/__init__.py +++ b/lib/bindings/python/src/dynamo/nixl_connect/__init__.py @@ -16,6 +16,7 @@ from __future__ import annotations import asyncio +import base64 import logging import socket import uuid @@ -1185,7 +1186,7 @@ async def _wait_for_completion_(self) -> None: case _: return - def metadata(self) -> RdmaMetadata: + def metadata(self, hex_encode: bool = False) -> RdmaMetadata: """ Gets the request descriptor for the operation. """ @@ -1209,9 +1210,14 @@ def metadata(self) -> RdmaMetadata: f"dynamo.nixl_connect.{self.__class__.__name__}: Compressed NIXL metadata is larger than original ({compressed_len} > {original_len})." ) + if not hex_encode: + encoded_metadata = base64.b64encode(nixl_metadata).decode("utf-8") + encoded_metadata = "b64:" + encoded_metadata + else: + encoded_metadata = nixl_metadata.hex() self._serialized_request = RdmaMetadata( descriptors=descriptors, - nixl_metadata=nixl_metadata.hex(), + nixl_metadata=encoded_metadata, notification_key=self._notification_key, operation_kind=int(self._operation_kind), ) @@ -1471,11 +1477,15 @@ def __init__( self._connector = connector # When `nixl_metadata` is a string, it is assumed to have come from a remote worker - # via a `RdmaMetadata` object and therefore can assumed be a hex-encoded, compressed + # via a `RdmaMetadata` object and therefore can assumed be a b64-encoded, compressed # representation of the NIXL metadata. if isinstance(nixl_metadata, str): - # Decode the hex-encoded string into bytes. - nixl_metadata = bytes.fromhex(nixl_metadata) + if nixl_metadata.startswith("b64:"): + # Decode the b64-encoded string into bytes. + nixl_metadata = base64.b64decode(nixl_metadata.lstrip("b64:")) + else: + # fallback for earlier versions of nixl connect + nixl_metadata = bytes.fromhex(nixl_metadata) # Decompress the NIXL metadata. nixl_metadata = zlib.decompress(nixl_metadata)