diff --git a/lud4ik/command_client.py b/lud4ik/command_client.py index e3b320f..9029917 100644 --- a/lud4ik/command_client.py +++ b/lud4ik/command_client.py @@ -1,9 +1,11 @@ import os import signal import socket -import threading +import logging -from work.utils import format_reply +from work.protocol import Feeder, Packet +from work.models import cmd +from work.utils import configure_logging from work.cmdargs import get_cmd_args from work.exceptions import ClientFinishException @@ -15,9 +17,9 @@ def shutdown_handler(signum, frame): class CommandClient: session_id = None - TIMEOUT = 10.0 - reply_commands = ['connected', 'pong', 'pongd', 'ackquit', 'ackfinish'] - print_reply_commands = ['pong', 'pongd'] + TIMEOUT = 1.0 + CHUNK_SIZE = 1024 + commands = [cmd.CONNECTED, cmd.PONG, cmd.PONGD, cmd.ACKQUIT, cmd.ACKFINISH] def __init__(self, host, port): self.socket = socket.socket(socket.AF_INET, @@ -29,64 +31,61 @@ def __init__(self, host, port): def run_client(cls, host, port): client = cls(host, port) try: + handler = signal.signal(signal.SIGINT, shutdown_handler) client.run() - signal.signal(signal.SIGUSR1, shutdown_handler) - except ClientFinishException: + except (OSError, socket.timeout, ClientFinishException): client.shutdown() finally: - pass + signal.signal(signal.SIGINT, handler) def run(self): - self.thread = threading.Thread(target=self.recv_response) - self.thread.start() + self.feeder = Feeder(self.commands) while True: - command = input() - command_name = command.split()[0] - command = command.replace(' ', '\n') - self.socket.sendall(format_reply(command)) + command = input().split() + kwargs = {} + cmd_input = getattr(cmd, command[0].upper()) + if cmd_input == cmd.PINGD: + kwargs['data'] = command[1] + packet = eval('{}(**kwargs).pack()'.format(command[0])) + self.socket.sendall(packet) + self.recv_response() def recv_response(self): + tail = bytes() while True: - msg = self.get_reply() - parts = msg.split('\n') - command_name = parts[0] - if command_name in self.print_reply_commands: - print(msg) - elif command_name == 'connected': - if parts[-1].startswith('session'): - self.session_id = parts[-1][7:] - print(msg) - elif command_name == 'ackquit': - if parts[-1] == self.session_id: - self.close() - else: - print(msg) - elif command_name == 'ackfinish': - self.close() - - def get_reply(self): - msg = bytes() - msg_len = int(self.socket.recv(4)) - while len(msg) < msg_len: - try: - chunk = self.socket.recv(msg_len - len(msg)) - except socket.timeout: - self.close() - msg += chunk - msg = msg.decode('utf-8') - return msg - - def close(self): - os.kill(os.getpid(), signal.SIGUSR1) + chunk = tail + self.socket.recv(self.CHUNK_SIZE) + packet, tail = self.feeder.feed(chunk) + if not packet: + continue + else: + getattr(self, packet.__class__.__name__.lower())(packet) + break + + def connected(self, packet): + self.session = packet.session + print('{} {}'.format(packet.cmd, packet.session)) + + def pong(self, packet): + print(packet.cmd) + + def pongd(self, packet): + print('{} {}'.format(packet.cmd, packet.data)) + + def ackquit(self, packet): + print('{} {}'.format(packet.cmd, packet.session)) + self.shutdown() + + def ackfinish(self, packet): + print(packet.cmd) + self.shutdown() def shutdown(self): self.socket.close() - print('socket closed') - self.thread.join() - print('thread closed') + logging.info('socket closed') raise SystemExit() if __name__ == '__main__': + configure_logging('Client') args = get_cmd_args() CommandClient.run_client(args.host, args.port) \ No newline at end of file diff --git a/lud4ik/command_server.py b/lud4ik/command_server.py index 98b58c0..e46e4fe 100644 --- a/lud4ik/command_server.py +++ b/lud4ik/command_server.py @@ -1,13 +1,20 @@ import os +import os.path import socket import signal +import logging import threading from operator import attrgetter from collections import namedtuple +from work.protocol import Feeder +from work.models import cmd from work.cmdargs import get_cmd_args from work.exceptions import ServerFinishException -from work.utils import format_reply, get_random_hash +from work.utils import (get_random_hash, + handle_timeout, + get_keyword_args, + configure_logging) def shutdown_handler(signum, frame): @@ -17,14 +24,16 @@ def shutdown_handler(signum, frame): class CommandServer: MAX_CONN = 5 - TIMEOUT = 100.0 + TIMEOUT = 1.0 + CHUNK_SIZE = 1024 + PID_FILE = 'server.pid' clients = {} - commands = ['connect', 'ping', 'pingd', 'quit', 'finish'] - single_reply_commands = ['ping', 'pingd'] + commands = [cmd.CONNECT, cmd.PING, cmd.PINGD, cmd.QUIT, cmd.FINISH] templ = namedtuple('templ', 'addr, thread, session') def __init__(self, host, port): self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.socket.settimeout(self.TIMEOUT) self.socket.bind((host, port)) @@ -32,90 +41,91 @@ def __init__(self, host, port): def run_server(cls, host, port): server = cls(host, port) try: + handler = signal.signal(signal.SIGINT, shutdown_handler) + with open(cls.PID_FILE, 'w') as f: + f.write(str(os.getpid())) server.run() - signal.signal(signal.SIGUSR1, shutdown_handler) - except ServerFinishException: + except (ServerFinishException, OSError): server.shutdown() finally: - pass + signal.signal(signal.SIGINT, handler) def run(self): + self.socket.listen(self.MAX_CONN) while True: - self.socket.listen(self.MAX_CONN) - conn, addr = self.socket.accept() - th = threading.Thread(target=self.run_client, args=(conn, )) - self.clients[conn] = self.templ(addr=addr, thread=th, - session=get_random_hash()) - th.start() + with handle_timeout(): + conn, addr = self.socket.accept() + th = threading.Thread(target=self.run_client, args=(conn, )) + self.clients[conn] = self.templ(addr=addr, + thread=th, + session=get_random_hash()) + th.start() def run_client(self, conn): + feeder = Feeder(self.commands) + tail = bytes() while True: - msg = bytes() - msg_len = int(conn.recv(4)) - while len(msg) < msg_len: - try: - chunk = conn.recv(msg_len - len(msg)) - except socket.timeout: - conn.close() - del self.clients[conn] - return - msg += chunk - - msg = msg.decode('utf-8').split('\n') - command_name = msg[0] - command = getattr(self, command_name) - args = [conn] - if len(msg) > 1: - args.append(msg[1]) - command(*args) - - def connect(self, conn): - self.condition_reply(conn, "connected", reply_templ="{}\nsession{}") - - def ping(self, conn): - reply = format_reply('pong') + try: + chunk = tail + conn.recv(self.CHUNK_SIZE) + packet, tail = feeder.feed(chunk) + if not packet: + continue + process = getattr(self, packet.__class__.__name__.lower()) + kwargs = {} + kw_only = get_keyword_args(process) + if 'conn' in kw_only: + kwargs['conn'] = conn + process(packet, **kwargs) + except (socket.timeout, OSError): + conn.close() + self.clients.pop(conn, None) + return + + def connect(self, packet, *, conn): + session = self.clients[conn].session + reply = packet.reply(session) + for client in list(self.clients.keys()): + conn.sendall(reply) + + def ping(self, packet, *, conn): + reply = packet.reply() conn.sendall(reply) - def pingd(self, conn, data): - reply = format_reply('{}\n{}'.format('pongd', data)) + def pingd(self, packet, *, conn): + reply = packet.reply() conn.sendall(reply) - def quit(self, conn): - self.condition_reply(conn, "ackquit", shared_templ="{}\n{} disconnected.") + def quit(self, packet, *, conn): + session = self.clients[conn].session + reply = packet.reply(session) + for client in list(self.clients.keys()): + conn.sendall(reply) conn.close() - del self.clients[conn] + self.clients.pop(conn, None) raise SystemExit() - def condition_reply(self, conn, reply_command, shared_templ="{}\n{}", reply_templ="{}\n{}"): - addr = self.clients[conn].addr - shared_reply = format_reply(shared_templ.format(reply_command, addr)) - session_id = self.clients[conn].session - reply = format_reply(reply_templ.format(reply_command, session_id)) - for client in self.clients.keys(): - if client == conn: - conn.sendall(reply) - else: - conn.sendall(shared_reply) - - def finish(self, conn): - addr = self.clients[conn].addr - reply = format_reply("{}\n{} finished server.".format('ackfinish', addr)) - for client in self.clients.keys(): + def finish(self, packet): + reply = packet.reply() + for client in list(self.clients.keys()): client.sendall(reply) - os.kill(os.getpid(), signal.SIGUSR1) + os.kill(os.getpid(), signal.SIGINT) + raise SystemExit() def shutdown(self): self.socket.close() - print('socket closed') - for conn in self.clients.keys(): + logging.info('socket closed') + for conn in list(self.clients.keys()): conn.close() - print('connections closed') - for th in map(attrgetter('thread'), self.clients.values()): + logging.info('connections closed') + for th in map(attrgetter('thread'), list(self.clients.values())): th.join() - print('threads closed') + logging.info('threads closed') + if os.path.exists(self.PID_FILE): + os.remove(self.PID_FILE) raise SystemExit() if __name__ == '__main__': + configure_logging('Server') args = get_cmd_args() CommandServer.run_server(args.host, args.port) \ No newline at end of file diff --git a/lud4ik/tests/__init__.py b/lud4ik/tests/__init__.py index 16599b1..ecdebe7 100644 --- a/lud4ik/tests/__init__.py +++ b/lud4ik/tests/__init__.py @@ -1,13 +1,2 @@ -import unittest - - -class BaseTestCase(unittest.TestCase): - - def setUp(self): - pass - - def tearDown(self): - pass - - def test_connect(self): - self.assertRaises(FileNotFoundError, open, '/doesnotexist.py') +from .test_server import ServerTestCase +from .test_command import CommandTestCase diff --git a/lud4ik/tests/test_command.py b/lud4ik/tests/test_command.py new file mode 100644 index 0000000..544aa21 --- /dev/null +++ b/lud4ik/tests/test_command.py @@ -0,0 +1,79 @@ +import unittest + +from work.protocol import Packet +from work.models import (cmd, Connected, Pong, PongD, AckQuit, AckFinish, + Connect, Ping, PingD, Quit, Finish) +from work.fields import Cmd, Str +from work.exceptions import FieldDeclarationError + + +class CommandTestCase(unittest.TestCase): + + LENGTH = 4 + + def test_connect(self): + packet = Connect() + self.assertIsInstance(Packet.unpack(packet.pack()[self.LENGTH:]), + Connect) + + def test_ping(self): + packet = Ping() + self.assertIsInstance(Packet.unpack(packet.pack()[self.LENGTH:]), Ping) + + def test_pingd(self): + packet = PingD(data='test_data') + unpacked = Packet.unpack(packet.pack()[self.LENGTH:]) + self.assertEqual(packet.data, unpacked.data) + self.assertIsInstance(unpacked, PingD) + + def test_quit(self): + packet = Quit() + self.assertIsInstance(Packet.unpack(packet.pack()[self.LENGTH:]), Quit) + + def test_finish(self): + packet = Finish() + self.assertIsInstance(Packet.unpack(packet.pack()[self.LENGTH:]), + Finish) + + def test_connected(self): + packet = Connected(session='test_session') + unpacked = Packet.unpack(packet.pack()[self.LENGTH:]) + self.assertEqual(packet.session, unpacked.session) + self.assertIsInstance(unpacked, Connected) + + def test_pong(self): + packet = Pong() + self.assertIsInstance(Packet.unpack(packet.pack()[self.LENGTH:]), Pong) + + def test_pongd(self): + packet = PongD(data='test_data') + unpacked = Packet.unpack(packet.pack()[self.LENGTH:]) + self.assertEqual(packet.data, unpacked.data) + self.assertIsInstance(unpacked, PongD) + + def test_ackquit(self): + packet = AckQuit(session='test_session') + unpacked = Packet.unpack(packet.pack()[self.LENGTH:]) + self.assertEqual(packet.session, unpacked.session) + self.assertIsInstance(unpacked, AckQuit) + + def test_ackfinish(self): + packet = AckFinish() + self.assertIsInstance(Packet.unpack(packet.pack()[self.LENGTH:]), + AckFinish) + + def test_without_fields(self): + with self.assertRaises(FieldDeclarationError): + class ErrorClass(Packet): + pass + + def test_without_cmd(self): + with self.assertRaises(FieldDeclarationError): + class ErrorClass(Packet): + data = Str(maxsize=256) + + def test_dublicate(self): + with self.assertRaises(FieldDeclarationError): + class ErrorClass(Packet): + cmd = Cmd(cmd.CONNECTED) + data = Str(maxsize=256) diff --git a/lud4ik/tests/test_server.py b/lud4ik/tests/test_server.py new file mode 100644 index 0000000..22d713f --- /dev/null +++ b/lud4ik/tests/test_server.py @@ -0,0 +1,73 @@ +import os +import os.path +import time +import socket +import signal +import unittest +import subprocess + +from work.utils import get_msg +from work.protocol import Packet +from command_server import CommandServer +from work.models import (cmd, Connected, Pong, PongD, AckQuit, AckFinish, + Connect, Ping, PingD, Quit, Finish) + + +class ServerTestCase(unittest.TestCase): + + HOST = '' + PORT = 50007 + PID_FILE = 'server.pid' + + def setUp(self): + self.server = subprocess.Popen(['python3.3', 'command_server.py']) + self.addCleanup(self.stop_server) + while True: + if os.path.exists(self.PID_FILE): + time.sleep(0.5) + break + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.socket.connect((self.HOST, self.PORT)) + + def stop_server(self): + if self.server.poll() is None: + os.kill(self.server.pid, signal.SIGINT) + + def test_connect(self): + packet = Connect().pack() + self.socket.sendall(packet) + reply = Packet.unpack(get_msg(self.socket)) + self.assertIsInstance(reply, Connected) + self.assertTrue(hasattr(reply, 'session')) + + def test_ping(self): + packet = Ping().pack() + self.socket.sendall(packet) + reply = Packet.unpack(get_msg(self.socket)) + self.assertIsInstance(reply, Pong) + + def test_pingd(self): + packet = PingD(data='test_data') + serialized_packet = packet.pack() + self.socket.sendall(serialized_packet) + reply = Packet.unpack(get_msg(self.socket)) + self.assertIsInstance(reply, PongD) + self.assertEqual(packet.data, reply.data) + + def test_quit(self): + packet = Quit().pack() + self.socket.sendall(packet) + reply = Packet.unpack(get_msg(self.socket)) + self.assertIsInstance(reply, AckQuit) + self.assertTrue(hasattr(reply, 'session')) + + def test_finish(self): + packet = Finish().pack() + self.socket.sendall(packet) + reply = Packet.unpack(get_msg(self.socket)) + self.assertIsInstance(reply, AckFinish) + while True: + if not os.path.exists(self.PID_FILE): + time.sleep(0.5) + break + self.assertTrue(self.server.poll() is not None) \ No newline at end of file diff --git a/lud4ik/work/exceptions.py b/lud4ik/work/exceptions.py index ddf9e05..cf87061 100644 --- a/lud4ik/work/exceptions.py +++ b/lud4ik/work/exceptions.py @@ -5,4 +5,12 @@ class ServerFinishException(Exception): class ClientFinishException(Exception): + pass + + +class FieldDeclarationError(Exception): + pass + + +class ValidationError(Exception): pass \ No newline at end of file diff --git a/lud4ik/work/fields.py b/lud4ik/work/fields.py new file mode 100644 index 0000000..5753ec5 --- /dev/null +++ b/lud4ik/work/fields.py @@ -0,0 +1,54 @@ +from .exceptions import ValidationError + + +class Field: + + def __get__(self, instance, owner): + return instance.__dict__[self.name] + + def __set__(self, instance, value): + if value.__class__ != self._type: + raise ValidationError() + if hasattr(self, 'validate'): + self.validate(value) + instance.__dict__[self.name] = value + + +class Cmd(Field): + _type = int + serialize = staticmethod(lambda x: x.to_bytes(1, 'little')) + deserialize = staticmethod(lambda data: (data[0], data[1:])) + + def __init__(self, _id): + self.id = _id + + +class Str(Field): + _type = str + + def __init__(self, maxsize): + self.maxsize = maxsize + + def validate(self, value): + if len(value) > self.maxsize: + raise ValidationError() + + @staticmethod + def serialize(value): + return bytes(value, 'utf-8') + + @staticmethod + def deserialize(value): + return (value.decode('utf-8'), None) + + +class Int(Field): + _type = int + + @staticmethod + def serialize(value): + return value.to_bytes(4, 'little') + + @staticmethod + def deserialize(value): + return int.from_bytes(value, 'little') \ No newline at end of file diff --git a/lud4ik/work/models.py b/lud4ik/work/models.py new file mode 100644 index 0000000..513bb9e --- /dev/null +++ b/lud4ik/work/models.py @@ -0,0 +1,74 @@ +from .protocol import Packet +from .fields import Cmd, Str + + +class cmd: + CONNECT = 1 + PING = 2 + PINGD = 3 + QUIT = 4 + FINISH = 5 + CONNECTED = 6 + PONG = 7 + PONGD = 8 + ACKQUIT = 9 + ACKFINISH = 10 + + +class Connected(Packet): + cmd = Cmd(cmd.CONNECTED) + session = Str(maxsize=256) + + +class Pong(Packet): + cmd = Cmd(cmd.PONG) + + +class PongD(Packet): + cmd = Cmd(cmd.PONGD) + data = Str(maxsize=256) + + +class AckQuit(Packet): + cmd = Cmd(cmd.ACKQUIT) + session = Str(maxsize=256) + + +class AckFinish(Packet): + cmd = Cmd(cmd.ACKFINISH) + + +class Connect(Packet): + cmd = Cmd(cmd.CONNECT) + + def reply(self, session): + return Connected(session=session).pack() + + +class Ping(Packet): + cmd = Cmd(cmd.PING) + + def reply(self): + return Pong().pack() + + +class PingD(Packet): + cmd = Cmd(cmd.PINGD) + data = Str(maxsize=256) + + def reply(self): + return PongD(data=self.data).pack() + + +class Quit(Packet): + cmd = Cmd(cmd.QUIT) + + def reply(self, session): + return AckQuit(session=session).pack() + + +class Finish(Packet): + cmd = Cmd(cmd.FINISH) + + def reply(self): + return AckFinish().pack() \ No newline at end of file diff --git a/lud4ik/work/protocol.py b/lud4ik/work/protocol.py new file mode 100644 index 0000000..671d76f --- /dev/null +++ b/lud4ik/work/protocol.py @@ -0,0 +1,95 @@ +from collections import OrderedDict + +from .fields import Field, Int, Cmd +from .exceptions import FieldDeclarationError, ValidationError + + +class MetaPacket(type): + + packets = {} + + def __prepare__(name, bases): + return OrderedDict() + + def __init__(self, name, bases, dct): + if name == 'Packet': + return + + self.fields = OrderedDict() + for attr, value in dct.items(): + if isinstance(value, Cmd): + cmd = value + if cmd.id in self.__class__.packets: + raise FieldDeclarationError('Dublicate registered command.') + if isinstance(value, Field): + value.name = attr + self.fields[attr] = value + + if not (self.fields and isinstance(next(iter(self.fields.values())), Cmd)): + raise FieldDeclarationError('Command shoud be first field.') + + self.__class__.packets[cmd.id] = self + + +class Packet(metaclass=MetaPacket): + + def __init__(self, **kwargs): + names = list(self.fields.keys()) + cmd = self.fields[names[0]].id + setattr(self, names[0], cmd) + for attr in names[1:]: + value = kwargs.get(attr) + if value is None: + raise ValidationError() + setattr(self, attr, value) + + def pack(self): + result = bytes() + for attr, _type in self.fields.items(): + result += _type.serialize(getattr(self, attr)) + + return Int.serialize(len(result)) + result + + @classmethod + def unpack(cls, data: bytes): + kwargs = {} + pack_cls = cls.__class__.packets.get(data[0]) + if pack_cls is None: + raise ValidationError() + + tail = data + for attr, _type in pack_cls.fields.items(): + value, tail = _type.deserialize(tail) + kwargs[attr] = value + + return pack_cls(**kwargs) + + +class Feeder: + + LENGTH = 4 + + def __init__(self, commands): + self._len = None + self.commands = commands + + def feed(self, buffer): + if self._len is None: + if len(buffer) < self.LENGTH: + return None, buffer + self._len = Int.deserialize(buffer[:self.LENGTH]) + buffer = buffer[self.LENGTH:] + + if len(buffer) < self._len: + return None, buffer + + try: + if buffer[0] not in self.commands: + raise ValidationError() + packet = Packet.unpack(buffer[:self._len]) + except ValidationError: + packet = None + finally: + buffer = buffer[self._len:] + self._len = None + return packet, buffer \ No newline at end of file diff --git a/lud4ik/work/utils.py b/lud4ik/work/utils.py index b0d8c56..4a16d54 100644 --- a/lud4ik/work/utils.py +++ b/lud4ik/work/utils.py @@ -1,12 +1,57 @@ +import socket import random import hashlib +import logging +from inspect import signature +from contextlib import contextmanager def format_reply(reply): - return bytes("{:4}{}".format(len(reply), reply), 'utf-8') + byte_reply = bytes(reply, 'utf-8') + return len(byte_reply).to_bytes(4, 'little') + byte_reply def get_random_hash(n=10): _str = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' random_name = ''.join([random.choice(_str) for i in range(n)]) return hashlib.sha224(random_name.encode('utf-8')).hexdigest() + + +@contextmanager +def handle_timeout(): + try: + yield + except socket.timeout: + pass + + +def get_conn_data(conn, length): + msg = bytes() + while len(msg) < length: + chunk = conn.recv(length - len(msg)) + msg += chunk + return msg + + +def get_msg(conn): + MSG_LEN = 4 + msg_len = int.from_bytes(get_conn_data(conn, MSG_LEN), 'little') + msg = get_conn_data(conn, msg_len) + return msg + + +def get_keyword_args(function): + kwargs = [] + params = signature(function).parameters.values() + for i in params: + if i.kind == i.KEYWORD_ONLY: + kwargs.append(i.name) + return kwargs + + +def configure_logging(who): + logging.basicConfig( + filename = './tmp.log', + level=logging.INFO, + format= who +' [%(levelname)s] (%(threadName)s) %(message)s', + ) \ No newline at end of file