Skip to content
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
Masahiro Nakagawa <repeatedly _at_ gmail.com>
INADA Naoki <songofacandy _at_ gmail.com>
Harish Vishwanath <harish dot shastry at gmail dot com>
33 changes: 33 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,39 @@
<!--
[![Build Status](https://travis-ci.org/msgpack/msgpack-rpc-python.png)](https://travis-ci.org/msgpack/msgpack-rpc-python)
-->
# Unix Domain Socket support
Unix domain socket support is now available for msgpack-rpc. Sample examples below.

## UDS examples

### Server

```python
import msgpackrpc.udsaddress
from msgpackrpc.transport import uds
class SumServer(object):
def sum(self, x, y):
return x + y

# Use builder as uds. default builder is tcp which creates tcp sockets
server = msgpackrpc.Server(SumServer(), builder=uds)
# Use UDSAddress instead of msgpackrpc.Address
server.listen(msgpackrpc.udsaddress.UDSAddress('/tmp/exrpc'))
server.start()
```

### Client
```python
import msgpackrpc.udsaddress
from msgpackrpc.transport import uds

#Use UDSAddress instead of default Address object
client = msgpackrpc.Client(msgpackrpc.udsaddress.UDSAddress("/tmp/exrpc"), builder=uds)
result = client.call('sum', 1, 2) # = >
print "Sum of 1 and 2 : %d" % result
```

Go through the below sections for general usage of Message Pack RPC Library

# MessagePack RPC for Python

Expand Down
11 changes: 11 additions & 0 deletions example/uds_simpleclient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
'''
@author: hvishwanath | [email protected]
'''

import msgpackrpc.udsaddress
from msgpackrpc.transport import uds

#Use UDSAddress instead of default Address object
client = msgpackrpc.Client(msgpackrpc.udsaddress.UDSAddress("/tmp/exrpc"), builder=uds)
result = client.call('sum', 1, 2) # = >
print "Sum of 1 and 2 : %d" % result
15 changes: 15 additions & 0 deletions example/uds_simpleserver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
'''
@author: hvishwanath | [email protected]
'''

import msgpackrpc.udsaddress
from msgpackrpc.transport import uds
class SumServer(object):
def sum(self, x, y):
return x + y

# Use builder as uds. default builder is tcp which creates tcp sockets
server = msgpackrpc.Server(SumServer(), builder=uds)
# Use UDSAddress instead of msgpackrpc.Address
server.listen(msgpackrpc.udsaddress.UDSAddress('/tmp/exrpc'))
server.start()
1 change: 1 addition & 0 deletions msgpackrpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from msgpackrpc.client import Client
from msgpackrpc.server import Server
from msgpackrpc.address import Address
from msgpackrpc.udsaddress import UDSAddress
53 changes: 53 additions & 0 deletions msgpackrpc/transport/uds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
'''
@author: hvishwanath | [email protected]
'''

import msgpackrpc.transport
from tornado.netutil import bind_unix_socket
from tornado import tcpserver
from tornado.iostream import IOStream

# Much of the implementation will be same as that of tcp module
# Changes required for unix domain socket support are done in this module
# Rest will be automatically used from tcp

# Create namespace equals
BaseSocket = msgpackrpc.transport.tcp.BaseSocket
ClientSocket = msgpackrpc.transport.tcp.ClientSocket
ClientTransport = msgpackrpc.transport.tcp.ClientTransport

ServerSocket = msgpackrpc.transport.tcp.ServerSocket
ServerTransport = msgpackrpc.transport.tcp.ServerTransport


class UDSServer(tcpserver.TCPServer):
"""Define a Unix domain socket server.
Instead of binding to TCP/IP socket, bind to UDS socket and listen"""

def __init__(self, io_loop=None, ssl_options=None):
tcpserver.TCPServer.__init__(self, io_loop=io_loop, ssl_options=ssl_options)

def listen(self, port, address=""):
"""Bind to a unix domain socket and add to self.
Note that port in our case actually contains the uds file name"""

# Create a Unix domain socket and bind
socket = bind_unix_socket(port)

# Add to self
self.add_socket(socket)

class MessagePackServer(UDSServer):
"""The MessagePackServer inherits from UDSServer
instead of tornado's TCP Server"""

def __init__(self, transport, io_loop=None, encodings=None):
self._transport = transport
self._encodings = encodings
UDSServer.__init__(self, io_loop=io_loop)

def handle_stream(self, stream, address):
ServerSocket(stream, self._transport, self._encodings)

#Monkey patch the MessagePackServer
msgpackrpc.transport.tcp.MessagePackServer = MessagePackServer
40 changes: 40 additions & 0 deletions msgpackrpc/udsaddress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
'''
@author: hvishwanath | [email protected]
'''

import socket
from tornado.platform.auto import set_close_exec

class UDSAddress(object):
"""This class abstracts Unix domain socket address.
For compatibility with other code in the library, port is always equal to host"""

def __init__(self, host, port=None):
self._host = host

# Passed value for port is ignored.
# Port is also made equal to host.
# This is because some of the code in transport.tcp uses address._port to connect.
# For a unix socket, there is no port. Hence if port = host, that code should work.
self._port = host

@property
def host(self):
return self._host

@property
def port(self):
return self._port

def unpack(self):
# Return only the host
return self._host

def socket(self, family=socket.AF_UNSPEC):
"""Return a Unix domain socket instead of tcp socket"""

sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
set_close_exec(sock.fileno())
sock.setblocking(0)

return sock
202 changes: 202 additions & 0 deletions test/test_uds_msgpackrpc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
'''
@author: hvishwanath | [email protected]
'''

from msgpackrpc.transport import uds
from time import sleep
import threading
try:
import unittest2 as unittest
except ImportError:
import unittest

import helper
import msgpackrpc
from msgpackrpc import error

class TestMessagePackRPC(unittest.TestCase):
ENABLE_TIMEOUT_TEST = False

class TestArg:
''' this class must know completely how to deserialize '''
def __init__(self, a, b, c):
self.a = a
self.b = b
self.c = c

def to_msgpack(self):
return (self.a, self.b, self.c)

def add(self, rhs):
self.a += rhs.a
self.b -= rhs.b
self.c *= rhs.c
return self

def __eq__(self, rhs):
return (self.a == rhs.a and self.b == rhs.b and self.c == rhs.c)

@staticmethod
def from_msgpack(arg):
return TestMessagePackRPC.TestArg(arg[0], arg[1], arg[2])

class TestServer(object):
def hello(self):
return "world"

def sum(self, x, y):
return x + y

def nil(self):
return None

def add_arg(self, arg0, arg1):
lhs = TestMessagePackRPC.TestArg.from_msgpack(arg0)
rhs = TestMessagePackRPC.TestArg.from_msgpack(arg1)
return lhs.add(rhs)

def raise_error(self):
raise Exception('error')

def long_exec(self):
sleep(3)
return 'finish!'

def async_result(self):
ar = msgpackrpc.server.AsyncResult()
def do_async():
sleep(2)
ar.set_result("You are async!")
threading.Thread(target=do_async).start()
return ar

def setUp(self):
# Create UDSAddress
self._address = msgpackrpc.UDSAddress('/tmp/unusedsocket')

def setup_env(self):
def _on_started():
self._server._loop.dettach_periodic_callback()
lock.release()
def _start_server(server):
server._loop.attach_periodic_callback(_on_started, 1)
server.start()
server.close()

# Use builder=uds
self._server = msgpackrpc.Server(TestMessagePackRPC.TestServer(), builder=uds)
self._server.listen(self._address)
self._thread = threading.Thread(target=_start_server, args=(self._server,))

lock = threading.Lock()
self._thread.start()
lock.acquire()
lock.acquire() # wait for the server to start

self._client = msgpackrpc.Client(self._address, unpack_encoding='utf-8')
return self._client;

def tearDown(self):
self._client.close();
self._server.stop();
self._thread.join();

def test_call(self):
client = self.setup_env();

result1 = client.call('hello')
result2 = client.call('sum', 1, 2)
result3 = client.call('nil')

self.assertEqual(result1, "world", "'hello' result is incorrect")
self.assertEqual(result2, 3, "'sum' result is incorrect")
self.assertIsNone(result3, "'nil' result is incorrect")

def test_call_userdefined_arg(self):
client = self.setup_env();

arg = TestMessagePackRPC.TestArg(0, 1, 2)
arg2 = TestMessagePackRPC.TestArg(23, 3, -23)

result1 = TestMessagePackRPC.TestArg.from_msgpack(client.call('add_arg', arg, arg2))
self.assertEqual(result1, arg.add(arg2))

result2 = TestMessagePackRPC.TestArg.from_msgpack(client.call('add_arg', arg2, arg))
self.assertEqual(result2, arg2.add(arg))

result3 = TestMessagePackRPC.TestArg.from_msgpack(client.call('add_arg', result1, result2))
self.assertEqual(result3, result1.add(result2))

def test_call_async(self):
client = self.setup_env();

future1 = client.call_async('hello')
future2 = client.call_async('sum', 1, 2)
future3 = client.call_async('nil')
future1.join()
future2.join()
future3.join()

self.assertEqual(future1.result, "world", "'hello' result is incorrect in call_async")
self.assertEqual(future2.result, 3, "'sum' result is incorrect in call_async")
self.assertIsNone(future3.result, "'nil' result is incorrect in call_async")

def test_notify(self):
client = self.setup_env();

result = True
try:
client.notify('hello')
client.notify('sum', 1, 2)
client.notify('nil')
except:
result = False

self.assertTrue(result)

def test_raise_error(self):
client = self.setup_env();
self.assertRaises(error.RPCError, lambda: client.call('raise_error'))

def test_unknown_method(self):
client = self.setup_env();
self.assertRaises(error.RPCError, lambda: client.call('unknown', True))
try:
client.call('unknown', True)
self.assertTrue(False)
except error.RPCError as e:
message = e.args[0]
self.assertEqual(message, "'unknown' method not found", "Error message mismatched")

def test_async_result(self):
client = self.setup_env();
self.assertEqual(client.call('async_result'), "You are async!")

def test_connect_failed(self):
client = self.setup_env();
port = helper.unused_port()
client = msgpackrpc.Client(msgpackrpc.Address('localhost', port), unpack_encoding='utf-8')
self.assertRaises(error.TransportError, lambda: client.call('hello'))

def test_timeout(self):
client = self.setup_env();

if self.__class__.ENABLE_TIMEOUT_TEST:
self.assertEqual(client.call('long_exec'), 'finish!', "'long_exec' result is incorrect")

client = msgpackrpc.Client(self._address, timeout=1, unpack_encoding='utf-8')
self.assertRaises(error.TimeoutError, lambda: client.call('long_exec'))
else:
print("Skip test_timeout")


if __name__ == '__main__':
import sys

try:
sys.argv.remove('--timeout-test')
TestMessagePackRPC.ENABLE_TIMEOUT_TEST = True
except:
pass

unittest.main()
Loading