Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python3 port for protocol.py & test_protocol.py #1786

Closed
wants to merge 12 commits into from
14 changes: 10 additions & 4 deletions src/highlevelcrypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,16 @@

from binascii import hexlify

import pyelliptic
from bmconfigparser import BMConfigParser
from pyelliptic import OpenSSL
from pyelliptic import arithmetic as a
try:
import pyelliptic
from bmconfigparser import BMConfigParser
from pyelliptic import OpenSSL
from pyelliptic import arithmetic as a
except ImportError:
from . import pyelliptic
from .bmconfigparser import BMConfigParser
from .pyelliptic import OpenSSL
from .pyelliptic import arithmetic as a
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that pyelliptic may eventually become a separate package



def makeCryptor(privkey):
Expand Down
50 changes: 27 additions & 23 deletions src/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,16 @@ def isBitSetWithinBitfield(fourByteString, n):
return x & 2**n != 0


# ip addresses
# IP addresses


def encodeHost(host):
"""Encode a given host to be used in low-level socket operations"""
if host.find('.onion') > -1:
return '\xfd\x87\xd8\x7e\xeb\x43' + base64.b32decode(
return b'\xfd\x87\xd8\x7e\xeb\x43' + base64.b32decode(
host.split(".")[0], True)
elif host.find(':') == -1:
return '\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF' + \
return b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF' + \
socket.inet_aton(host)
return socket.inet_pton(socket.AF_INET6, host)

Expand Down Expand Up @@ -147,10 +147,10 @@ def checkIPAddress(host, private=False):
Returns hostStandardFormat if it is a valid IP address,
otherwise returns False
"""
if host[0:12] == '\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF':
if host[0:12] == b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF':
hostStandardFormat = socket.inet_ntop(socket.AF_INET, host[12:])
return checkIPv4Address(host[12:], hostStandardFormat, private)
elif host[0:6] == '\xfd\x87\xd8\x7e\xeb\x43':
elif host[0:6] == b'\xfd\x87\xd8\x7e\xeb\x43':
# Onion, based on BMD/bitcoind
hostStandardFormat = base64.b32encode(host[6:]).lower() + ".onion"
if private:
Expand All @@ -161,7 +161,7 @@ def checkIPAddress(host, private=False):
hostStandardFormat = socket.inet_ntop(socket.AF_INET6, host)
except ValueError:
return False
if hostStandardFormat == "":
if len(hostStandardFormat) == 0:
# This can happen on Windows systems which are
# not 64-bit compatible so let us drop the IPv6 address.
return False
Expand All @@ -173,23 +173,23 @@ def checkIPv4Address(host, hostStandardFormat, private=False):
Returns hostStandardFormat if it is an IPv4 address,
otherwise returns False
"""
if host[0] == '\x7F': # 127/8
if host[0:1] == b'\x7F': # 127/8
if not private:
logger.debug(
'Ignoring IP address in loopback range: %s',
hostStandardFormat)
return hostStandardFormat if private else False
if host[0] == '\x0A': # 10/8
if host[0:1] == b'\x0A': # 10/8
if not private:
logger.debug(
'Ignoring IP address in private range: %s', hostStandardFormat)
return hostStandardFormat if private else False
if host[0:2] == '\xC0\xA8': # 192.168/16
if host[0:2] == b'\xC0\xA8': # 192.168/16
if not private:
logger.debug(
'Ignoring IP address in private range: %s', hostStandardFormat)
return hostStandardFormat if private else False
if host[0:2] >= '\xAC\x10' and host[0:2] < '\xAC\x20': # 172.16/12
if host[0:2] >= b'\xAC\x10' and host[0:2] < b'\xAC\x20': # 172.16/12
if not private:
logger.debug(
'Ignoring IP address in private range: %s', hostStandardFormat)
Expand All @@ -202,15 +202,19 @@ def checkIPv6Address(host, hostStandardFormat, private=False):
Returns hostStandardFormat if it is an IPv6 address,
otherwise returns False
"""
if host == ('\x00' * 15) + '\x01':
if host == b'\x00' * 15 + b'\x01':
if not private:
logger.debug('Ignoring loopback address: %s', hostStandardFormat)
return False
if host[0] == '\xFE' and (ord(host[1]) & 0xc0) == 0x80:
try:
host = [ord(c) for c in host[:2]]
except TypeError: # python3 has ints already
pass
if host[0:1] == b'\xFE' and host[1] & 0xc0 == 0x80:
if not private:
logger.debug('Ignoring local address: %s', hostStandardFormat)
return hostStandardFormat if private else False
if (ord(host[0]) & 0xfe) == 0xfc:
if host[0] & 0xfe == 0xfc:
if not private:
logger.debug(
'Ignoring unique local address: %s', hostStandardFormat)
Expand Down Expand Up @@ -280,7 +284,7 @@ def isProofOfWorkSufficient(
# Packet creation


def CreatePacket(command, payload=''):
def CreatePacket(command, payload=b''):
"""Construct and return a packet"""
payload_length = len(payload)
checksum = hashlib.sha512(payload).digest()[0:4]
Expand All @@ -298,14 +302,14 @@ def assembleVersionMessage(
Construct the payload of a version message,
return the resulting bytes of running `CreatePacket` on it
"""
payload = ''
payload = b''
payload += pack('>L', 3) # protocol version.
# bitflags of the services I offer.
payload += pack(
'>q',
NODE_NETWORK |
(NODE_SSL if haveSSL(server) else 0) |
(NODE_DANDELION if state.dandelion else 0)
NODE_NETWORK
| (NODE_SSL if haveSSL(server) else 0)
| (NODE_DANDELION if state.dandelion else 0)
)
payload += pack('>q', int(time.time()))

Expand All @@ -327,13 +331,13 @@ def assembleVersionMessage(
# bitflags of the services I offer.
payload += pack(
'>q',
NODE_NETWORK |
(NODE_SSL if haveSSL(server) else 0) |
(NODE_DANDELION if state.dandelion else 0)
NODE_NETWORK
| (NODE_SSL if haveSSL(server) else 0)
| (NODE_DANDELION if state.dandelion else 0)
)
# = 127.0.0.1. This will be ignored by the remote host.
# The actual remote connected IP will be used.
payload += '\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF' + pack(
payload += b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF' + pack(
'>L', 2130706433)
# we have a separate extPort and incoming over clearnet
# or outgoing through clearnet
Expand All @@ -355,7 +359,7 @@ def assembleVersionMessage(
payload += nodeid[0:8]
else:
payload += eightBytesOfRandomDataUsedToDetectConnectionsToSelf
userAgent = '/PyBitmessage:' + softwareVersion + '/'
userAgent = ('/PyBitmessage:%s/' % softwareVersion).encode('utf-8')
payload += encodeVarint(len(userAgent))
payload += userAgent

Expand Down
5 changes: 4 additions & 1 deletion src/tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
def cleanup(home=None, files=_files):
"""Cleanup application files"""
if not home:
import state
try:
import state
except ImportError:
from pybitmessage import state
home = state.appdata
for pfile in files:
try:
Expand Down
82 changes: 77 additions & 5 deletions src/tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,97 @@
Tests for common protocol functions
"""

import sys
import unittest

from .common import skip_python3

skip_python3()
from pybitmessage import protocol, state


class TestProtocol(unittest.TestCase):
"""Main protocol test case"""

def test_checkIPv4Address(self):
"""Check the results of protocol.checkIPv4Address()"""
token = 'HELLO'
# checking protocol.encodeHost()[12:]
self.assertEqual( # 127.0.0.1
token, protocol.checkIPv4Address(b'\x7f\x00\x00\x01', token, True))
self.assertEqual( # 10.42.43.1
token, protocol.checkIPv4Address(b'\n*+\x01', token, True))
self.assertEqual( # 192.168.0.254
token, protocol.checkIPv4Address(b'\xc0\xa8\x00\xfe', token, True))
self.assertEqual( # 172.31.255.254
token, protocol.checkIPv4Address(b'\xac\x1f\xff\xfe', token, True))
self.assertFalse( # 8.8.8.8
protocol.checkIPv4Address(b'\x08\x08\x08\x08', token, True))

def test_checkIPv6Address(self):
"""Check the results of protocol.checkIPv6Address()"""
test_ip = '2001:db8::ff00:42:8329'
self.assertEqual(
'test', protocol.checkIPv6Address(
protocol.encodeHost(test_ip), 'test'))
self.assertFalse(
protocol.checkIPv6Address(
protocol.encodeHost(test_ip), 'test', True))

def test_check_local(self):
"""Check the logic of TCPConnection.local"""
from pybitmessage import protocol, state

self.assertTrue(
protocol.checkIPAddress(protocol.encodeHost('127.0.0.1'), True))
self.assertTrue(
protocol.checkIPAddress(protocol.encodeHost('192.168.0.1'), True))
self.assertTrue(
protocol.checkIPAddress(protocol.encodeHost('10.42.43.1'), True))
self.assertTrue(
protocol.checkIPAddress(protocol.encodeHost('172.31.255.2'), True))
self.assertFalse(protocol.checkIPAddress(
protocol.encodeHost('2001:db8::ff00:42:8329'), True))

globalhost = protocol.encodeHost('8.8.8.8')
self.assertFalse(protocol.checkIPAddress(globalhost, True))
self.assertEqual(protocol.checkIPAddress(globalhost), '8.8.8.8')

@unittest.skipIf(
sys.hexversion >= 0x3000000, 'this is still not working with python3')
def test_check_local_socks(self):
"""The SOCKS part of the local check"""
self.assertTrue(
not protocol.checkSocksIP('127.0.0.1')
or state.socksIP)

def test_network_group(self):
"""Test various types of network groups"""

test_ip = '1.2.3.4'
self.assertEqual(b'\x01\x02', protocol.network_group(test_ip))

test_ip = '127.0.0.1'
self.assertEqual('IPv4', protocol.network_group(test_ip))

self.assertEqual(
protocol.network_group('8.8.8.8'),
protocol.network_group('8.8.4.4'))
self.assertNotEqual(
protocol.network_group('1.1.1.1'),
protocol.network_group('8.8.8.8'))

test_ip = '0102:0304:0506:0708:090A:0B0C:0D0E:0F10'
self.assertEqual(
b'\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x0C',
protocol.network_group(test_ip))

test_ip = 'bootstrap8444.bitmessage.org'
self.assertEqual(
'bootstrap8444.bitmessage.org',
protocol.network_group(test_ip))

test_ip = 'quzwelsuziwqgpt2.onion'
self.assertEqual(
test_ip,
protocol.network_group(test_ip))

test_ip = None
self.assertEqual(
None,
protocol.network_group(test_ip))