diff --git a/.circleci/config.yml b/.circleci/config.yml index 629c232..22d0e0a 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,13 +1,33 @@ -version: 2 +version: 2.1 jobs: py27: docker: # Primary container image where all steps run. - image: circleci/python:2.7.17 - environment: - - TOXENV: py27 + environment: + TOXENV: py27 + # MySQL env for mysql queue tests + - image: circleci/mysql:8.0 + environment: + MYSQL_ROOT_PASSWORD: rootpw + MYSQL_DATABASE: testqueue + MYSQL_USER: user + MYSQL_PASSWORD: passw0rd + MYSQL_HOST: '%' + steps: &common_steps - checkout + - run: + # Our primary container isn't MYSQL so run a sleep command until it's ready. + name: Waiting for MySQL to be ready + command: | + for i in `seq 1 10`; + do + nc -z 127.0.0.1 3306 && echo Success && exit 0 + echo -n . + sleep 5 + done + echo Failed waiting for MySQL && exit 1 - run: command: | sudo pip install tox @@ -19,7 +39,7 @@ jobs: - run: command: | mkdir -p /tmp/core_dumps - cp core.* /tmp/core_dumps + ls core.* && cp core.* /tmp/core_dumps when: on_fail - store_artifacts: # collect core dumps @@ -35,48 +55,90 @@ jobs: docker: # Primary container image where all steps run. - image: circleci/python:3.4.10 - environment: - - TOXENV: py34 + environment: + TOXENV: py34 + # MySQL env for mysql queue tests + - image: circleci/mysql:8.0 + environment: + MYSQL_ROOT_PASSWORD: 123456 + MYSQL_DATABASE: testqueue + MYSQL_USER: user + MYSQL_PASSWORD: passw0rd steps: *common_steps py35: docker: # Primary container image where all steps run. - image: circleci/python:3.5.9 - environment: - - TOXENV: py35 + environment: + TOXENV: py35 + # MySQL env for mysql queue tests + - image: circleci/mysql:8.0 + environment: + MYSQL_ROOT_PASSWORD: 123456 + MYSQL_DATABASE: testqueue + MYSQL_USER: user + MYSQL_PASSWORD: passw0rd steps: *common_steps py36: docker: # Primary container image where all steps run. - image: circleci/python:3.6.10 - environment: - - TOXENV: py36 + environment: + TOXENV: py36 + # MySQL env for mysql queue tests + - image: circleci/mysql:8.0 + environment: + MYSQL_ROOT_PASSWORD: 123456 + MYSQL_DATABASE: testqueue + MYSQL_USER: user + MYSQL_PASSWORD: passw0rd steps: *common_steps py37: docker: # Primary container image where all steps run. - image: circleci/python:3.7.7 - environment: - - TOXENV: py37 + environment: + TOXENV: py37 + # MySQL env for mysql queue tests + - image: circleci/mysql:8.0 + environment: + MYSQL_ROOT_PASSWORD: 123456 + MYSQL_DATABASE: testqueue + MYSQL_USER: user + MYSQL_PASSWORD: passw0rd steps: *common_steps py38: docker: # Primary container image where all steps run. - image: circleci/python:3.8.2 - environment: - - TOXENV: py38 + environment: + TOXENV: py38 + # MySQL env for mysql queue tests + - image: circleci/mysql:8.0 + environment: + MYSQL_ROOT_PASSWORD: 123456 + MYSQL_DATABASE: testqueue + MYSQL_USER: user + MYSQL_PASSWORD: passw0rd steps: *common_steps pep8: docker: # Primary container image where all steps run. - image: circleci/python:3.5.9 - environment: - - TOXENV: pep8 + environment: + TOXENV: pep8 + # MySQL env for mysql queue tests + - image: circleci/mysql:8.0 + environment: + MYSQL_ROOT_PASSWORD: rootpw + MYSQL_DATABASE: testqueue + MYSQL_USER: user + MYSQL_PASSWORD: passw0rd steps: *common_steps @@ -84,8 +146,15 @@ jobs: docker: # Primary container image where all steps run. - image: circleci/python:3.5.9 - environment: - - TOXENV: cover + environment: + TOXENV: cover + # MySQL env for mysql queue tests + - image: circleci/mysql:8.0 + environment: + MYSQL_ROOT_PASSWORD: 123456 + MYSQL_DATABASE: testqueue + MYSQL_USER: user + MYSQL_PASSWORD: passw0rd steps: *common_steps workflows: @@ -94,7 +163,7 @@ workflows: jobs: - pep8 - py27 - - py34 +# - py34 - py35 - py36 - py37 diff --git a/README.rst b/README.rst index 857e428..67ec6bf 100644 --- a/README.rst +++ b/README.rst @@ -41,7 +41,7 @@ Join `persist-queue `_ and + `DBUtils `_ ceased support for `Python 3.4`, `persist queue` drops the support for python 3.4 since version 0.8.0. + other queue implementations such as file based queue and sqlite3 based queue are still workable. +- `Python 2 was sunset on January 1, 2020 `_, `persist-queue` will drop any Python 2 support in future version `1.0.0`, no new feature will be developed under Python 2. Installation ------------ @@ -64,7 +69,7 @@ from pypi .. code-block:: console pip install persist-queue - # for msgpack support, use following command + # for msgpack and mysql support, use following command pip install persist-queue[extra] @@ -426,6 +431,42 @@ multi-thread usage for **Queue** q.join() # block until all tasks are done +Example usage with a MySQL based queue +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +*Available since: v0.8.0* + +.. code-block:: python + + >>> import persistqueue + >>> db_conf = { + >>> "host": "127.0.0.1", + >>> "user": "user", + >>> "passwd": "passw0rd", + >>> "db_name": "testqueue", + >>> # "name": "", + >>> "port": 3306 + >>> } + >>> q = persistqueue.MySQLQueue(name="testtable", **db_conf) + >>> q.put('str1') + >>> q.put('str2') + >>> q.put('str3') + >>> q.get() + 'str1' + >>> del q + + +Close the console, and then recreate the queue: + +.. code-block:: python + + >>> import persistqueue + >>> q = persistqueue.MySQLQueue(name="testtable", **db_conf) + >>> q.get() + 'str2' + >>> + + **note** diff --git a/appveyor.yml b/appveyor.yml index 32fc5ef..7a1a221 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -1,3 +1,6 @@ +services: + - mysql + environment: matrix: @@ -42,6 +45,10 @@ test_script: # Note that you must use the environment variable %PYTHON% to refer to # the interpreter you're using - Appveyor does not do anything special # to put the Python evrsion you want to use on PATH. + - ps: | + $env:MYSQL_PWD="Password12!" + $cmd = '"C:\Program Files\MySQL\MySQL Server 5.7\bin\mysql" -e "create database testqueue;" --user=root' + iex "& $cmd" - "%PYTHON%\\Scripts\\tox.exe" #on_success: diff --git a/extra-requirements.txt b/extra-requirements.txt index 275288c..2f225b1 100644 --- a/extra-requirements.txt +++ b/extra-requirements.txt @@ -1 +1,3 @@ -msgpack>=0.5.6 \ No newline at end of file +msgpack>=0.5.6 +PyMySQL +DBUtils<3.0.0 # since 3.0.0 no longer supports Python2.x \ No newline at end of file diff --git a/persistqueue/__init__.py b/persistqueue/__init__.py index 44efe78..5810b3b 100644 --- a/persistqueue/__init__.py +++ b/persistqueue/__init__.py @@ -1,7 +1,7 @@ # coding=utf-8 __author__ = 'Peter Wang' __license__ = 'BSD' -__version__ = '0.7.0' +__version__ = '0.8.0-alpha0' from .exceptions import Empty, Full # noqa from .queue import Queue # noqa @@ -11,6 +11,7 @@ from .sqlqueue import SQLiteQueue, FIFOSQLiteQueue, FILOSQLiteQueue, \ UniqueQ # noqa from .sqlackqueue import SQLiteAckQueue, UniqueAckQ + from .mysqlqueue import MySQLQueue except ImportError: import logging @@ -18,5 +19,5 @@ log.info("No sqlite3 module found, sqlite3 based queues are not available") __all__ = ["Queue", "SQLiteQueue", "FIFOSQLiteQueue", "FILOSQLiteQueue", - "UniqueQ", "PDict", "SQLiteAckQueue", "UniqueAckQ", "Empty", "Full", - "__author__", "__license__", "__version__"] + "UniqueQ", "PDict", "SQLiteAckQueue", "UniqueAckQ", "MySQLQueue", + "Empty", "Full", "__author__", "__license__", "__version__"] diff --git a/persistqueue/common.py b/persistqueue/common.py index 4d660fc..ad1ef1a 100644 --- a/persistqueue/common.py +++ b/persistqueue/common.py @@ -1,4 +1,4 @@ -#! coding = utf-8 +# coding=utf-8 import logging import pickle diff --git a/persistqueue/exceptions.py b/persistqueue/exceptions.py index ed675e7..dae5866 100644 --- a/persistqueue/exceptions.py +++ b/persistqueue/exceptions.py @@ -1,4 +1,4 @@ -#! coding = utf-8 +# coding=utf-8 try: from queue import ( Empty as StdEmpty, diff --git a/persistqueue/mysqlqueue.py b/persistqueue/mysqlqueue.py new file mode 100644 index 0000000..a23ba73 --- /dev/null +++ b/persistqueue/mysqlqueue.py @@ -0,0 +1,162 @@ +# coding=utf-8 +from dbutils.pooled_db import PooledDB +import threading +import time as _time + +import persistqueue +from .sqlbase import SQLBase + + +class MySQLQueue(SQLBase): + """Mysql(or future standard dbms) based FIFO queue.""" + _TABLE_NAME = 'queue' + _KEY_COLUMN = '_id' # the name of the key column, used in DB CRUD + # SQL to create a table + _SQL_CREATE = ( + 'CREATE TABLE IF NOT EXISTS {table_name} (' + '{key_column} INTEGER PRIMARY KEY AUTO_INCREMENT, ' + 'data BLOB, timestamp FLOAT)') + # SQL to insert a record + _SQL_INSERT = 'INSERT INTO {table_name} (data, timestamp) VALUES (%s, %s)' + # SQL to select a record + _SQL_SELECT_ID = ( + 'SELECT {key_column}, data, timestamp FROM {table_name} WHERE' + ' {key_column} = {rowid}' + ) + _SQL_SELECT = ( + 'SELECT {key_column}, data, timestamp FROM {table_name} ' + 'ORDER BY {key_column} ASC LIMIT 1' + ) + _SQL_SELECT_WHERE = ( + 'SELECT {key_column}, data, timestamp FROM {table_name} WHERE' + ' {column} {op} %s ORDER BY {key_column} ASC LIMIT 1 ' + ) + _SQL_UPDATE = 'UPDATE {table_name} SET data = %s WHERE {key_column} = %s' + + _SQL_DELETE = 'DELETE FROM {table_name} WHERE {key_column} {op} %s' + + def __init__(self, host, user, passwd, db_name, name=None, + port=3306, + charset='utf8mb4', + auto_commit=True, + serializer=persistqueue.serializers.pickle, + ): + self.name = name if name else "sql" + self.host = host + self.user = user + self.passwd = passwd + self.db_name = db_name + self.name = name + self.port = port + self.charset = charset + self._serializer = serializer + self.auto_commit = auto_commit + + # SQLite3 transaction lock + self.tran_lock = threading.Lock() + self.put_event = threading.Event() + # Action lock to assure multiple action to be *atomic* + self.action_lock = threading.Lock() + + self._connection_pool = None + self._getter = None + self._putter = None + self._new_db_connection() + self._init() + + def _new_db_connection(self): + try: + import pymysql + except ImportError: + print( + "Please install mysql library via " + "'pip install PyMySQL'") + raise + db_pool = PooledDB(pymysql, 2, 10, 5, 10, True, + host=self.host, port=self.port, user=self.user, + passwd=self.passwd, database=self.db_name, + charset=self.charset + ) + self._connection_pool = db_pool + conn = db_pool.connection() + cursor = conn.cursor() + cursor.execute("SELECT VERSION()") + _ = cursor.fetchone() + # create table automatically + cursor.execute(self._sql_create) + conn.commit() + # switch to the desired db + cursor.execute("use %s" % self.db_name) + self._putter = MySQLConn(queue=self) + self._getter = self._putter + + def put(self, item, block=True): + # block kwarg is noop and only here to align with python's queue + obj = self._serializer.dumps(item) + _id = self._insert_into(obj, _time.time()) + self.total += 1 + self.put_event.set() + return _id + + def put_nowait(self, item): + return self.put(item, block=False) + + def _init(self): + # super(SQLBase, self)._init() + # Action lock to assure multiple action to be *atomic* + self.action_lock = threading.Lock() + if not self.auto_commit: + # Refresh current cursor after restart + head = self._select() + if head: + self.cursor = head[0] - 1 + else: + self.cursor = 0 + self.total = self._count() + + def get_pooled_conn(self): + return self._connection_pool.connection() + + +class MySQLConn(object): + """MySqlConn defines a common structure for + both mysql and sqlite3 connections. + + used to mitigate the interface differences between drivers/db + """ + + def __init__(self, queue=None, conn=None): + if queue is not None: + self._conn = queue.get_pooled_conn() + else: + self._conn = conn + self._queue = queue + self._cursor = None + self.closed = False + + def __enter__(self): + self._cursor = self._conn.cursor() + return self._conn + + def __exit__(self, exc_type, exc_val, exc_tb): + # do not commit() but to close() , keep same behavior + # with dbutils + self._cursor.close() + + def execute(self, *args, **kwargs): + if self._queue is not None: + conn = self._queue.get_pooled_conn() + else: + conn = self._conn + cursor = conn.cursor() + cursor.execute(*args, **kwargs) + return cursor + + def close(self): + if not self.closed: + self._conn.close() + self.closed = True + + def commit(self): + if not self.closed: + self._conn.commit() diff --git a/persistqueue/pdict.py b/persistqueue/pdict.py index 5c98cdb..e5d714a 100644 --- a/persistqueue/pdict.py +++ b/persistqueue/pdict.py @@ -1,4 +1,4 @@ -#! coding = utf-8 +# coding=utf-8 import logging import sqlite3 @@ -17,6 +17,8 @@ class PDict(sqlbase.SQLiteBase, dict): 'WHERE {key_column} = ?') _SQL_UPDATE = 'UPDATE {table_name} SET data = ? WHERE {key_column} = ?' + _SQL_DELETE = 'DELETE FROM {table_name} WHERE {key_column} {op} ?' + def __init__(self, path, name, multithreading=False): # PDict is always auto_commit=True super(PDict, self).__init__(path, name=name, diff --git a/persistqueue/serializers/__init__.py b/persistqueue/serializers/__init__.py index 4dd3527..9bad579 100644 --- a/persistqueue/serializers/__init__.py +++ b/persistqueue/serializers/__init__.py @@ -1 +1 @@ -#! coding = utf-8 +# coding=utf-8 diff --git a/persistqueue/serializers/json.py b/persistqueue/serializers/json.py index 30b3456..c63ed40 100644 --- a/persistqueue/serializers/json.py +++ b/persistqueue/serializers/json.py @@ -1,4 +1,4 @@ -#! coding = utf-8 +# coding=utf-8 """ A serializer that extends json to use bytes and uses newlines to store diff --git a/persistqueue/serializers/msgpack.py b/persistqueue/serializers/msgpack.py index abf5af7..d43b933 100644 --- a/persistqueue/serializers/msgpack.py +++ b/persistqueue/serializers/msgpack.py @@ -1,4 +1,4 @@ -#! coding = utf-8 +# coding=utf-8 """ A serializer that extends msgpack to specify recommended parameters and adds a @@ -6,7 +6,11 @@ """ from __future__ import absolute_import -import msgpack +try: + import msgpack +except ImportError: + pass + import struct diff --git a/persistqueue/serializers/pickle.py b/persistqueue/serializers/pickle.py index 153cb8d..525fd71 100644 --- a/persistqueue/serializers/pickle.py +++ b/persistqueue/serializers/pickle.py @@ -1,4 +1,4 @@ -#! coding = utf-8 +# coding=utf-8 """A serializer that extends pickle to change the default protocol""" diff --git a/persistqueue/sqlbase.py b/persistqueue/sqlbase.py index 395b5ea..607b3cb 100644 --- a/persistqueue/sqlbase.py +++ b/persistqueue/sqlbase.py @@ -1,22 +1,44 @@ +# coding=utf-8 import logging import os +import time as _time import sqlite3 import threading +from persistqueue.exceptions import Empty + import persistqueue.serializers.pickle sqlite3.enable_callback_tracebacks(True) log = logging.getLogger(__name__) +# 10 seconds internal for `wait` of event +TICK_FOR_WAIT = 10 + def with_conditional_transaction(func): def _execute(obj, *args, **kwargs): + # for MySQL, connection pool should be used since db connection is + # basically not thread-safe + _putter = obj._putter + if str(type(obj)).find("MySQLQueue") > 0: + # use fresh connection from pool not the shared one + _putter = obj.get_pooled_conn() with obj.tran_lock: - with obj._putter as tran: + with _putter as tran: + # For sqlite3, commit() is called automatically afterwards + # but for other db API, this is not TRUE! stat, param = func(obj, *args, **kwargs) - cur = tran.cursor() - cur.execute(stat, param) + s = str(type(tran)) + if s.find("Cursor") > 0: + cur = tran + cur.execute(stat, param) + else: + cur = tran.cursor() + cur.execute(stat, param) + cur.close() + tran.commit() return cur.lastrowid return _execute @@ -40,8 +62,8 @@ def commit_ignore_error(conn): raise -class SQLiteBase(object): - """SQLite3 base class.""" +class SQLBase(object): + """SQL base class.""" _TABLE_NAME = 'base' # DB table name _KEY_COLUMN = '' # the name of the key column, used in DB CRUD @@ -51,95 +73,24 @@ class SQLiteBase(object): _SQL_SELECT = '' # SQL to select a record _SQL_SELECT_ID = '' # SQL to select a record with criteria _SQL_SELECT_WHERE = '' # SQL to select a record with criteria - _MEMORY = ':memory:' # flag indicating store DB in memory - - def __init__( - self, - path, - name='default', - multithreading=False, - timeout=10.0, - auto_commit=True, - serializer=persistqueue.serializers.pickle, - db_file_name=None, - ): - """Initiate a queue in sqlite3 or memory. + _SQL_DELETE = '' # SQL to delete a record + # _MEMORY = ':memory:' # flag indicating store DB in memory - :param path: path for storing DB file. - :param name: the suffix for the table name, - table name would be ${_TABLE_NAME}_${name} - :param multithreading: if set to True, two db connections will be, - one for **put** and one for **get**. - :param timeout: timeout in second waiting for the database lock. - :param auto_commit: Set to True, if commit is required on every - INSERT/UPDATE action, otherwise False, whereas - a **task_done** is required to persist changes - after **put**. - :param serializer: The serializer parameter controls how enqueued data - is serialized. It must have methods dump(value, fp) - and load(fp). The dump method must serialize the - value and write it to fp, and may be called for - multiple values with the same fp. The load method - must deserialize and return one value from fp, - and may be called multiple times with the same fp - to read multiple values. - :param db_file_name: set the db file name of the queue data, otherwise - default to `data.db` + def __init__(self): + """Initiate a queue in db. """ - self.memory_sql = False - self.path = path - self.name = name - self.timeout = timeout - self.multithreading = multithreading - self.auto_commit = auto_commit - self._serializer = serializer - self.db_file_name = "data.db" - if db_file_name: - self.db_file_name = db_file_name - self._init() + self._serializer = None + self.auto_commit = None - def _init(self): - """Initialize the tables in DB.""" - if self.path == self._MEMORY: - self.memory_sql = True - log.debug("Initializing Sqlite3 Queue in memory.") - elif not os.path.exists(self.path): - os.makedirs(self.path) - log.debug( - 'Initializing Sqlite3 Queue with path {}'.format(self.path) - ) - self._conn = self._new_db_connection( - self.path, self.multithreading, self.timeout - ) - self._getter = self._conn - self._putter = self._conn - - self._conn.execute(self._sql_create) - self._conn.commit() - # Setup another session only for disk-based queue. - if self.multithreading and not self.memory_sql: - self._putter = self._new_db_connection( - self.path, self.multithreading, self.timeout - ) - self._conn.text_factory = str - self._putter.text_factory = str - - # SQLite3 transaction lock + # SQL transaction lock self.tran_lock = threading.Lock() self.put_event = threading.Event() - - def _new_db_connection(self, path, multithreading, timeout): - conn = None - if path == self._MEMORY: - conn = sqlite3.connect(path, check_same_thread=not multithreading) - else: - conn = sqlite3.connect( - '{}/{}'.format(path, self.db_file_name), - timeout=timeout, - check_same_thread=not multithreading, - ) - conn.execute('PRAGMA journal_mode=WAL;') - return conn + # Action lock to assure multiple action to be *atomic* + self.action_lock = threading.Lock() + self.total = 0 + self.cursor = 0 + self._getter = None + self._putter = None @with_conditional_transaction def _insert_into(self, *record): @@ -152,11 +103,128 @@ def _update(self, key, *args): @with_conditional_transaction def _delete(self, key, op='='): - sql = 'DELETE FROM {} WHERE {} {} ?'.format( - self._table_name, self._key_column, op - ) + + sql = self._SQL_DELETE.format( + table_name=self._table_name, key_column=self._key_column, op=op) return sql, (key,) + def _pop(self, rowid=None, raw=False): + with self.action_lock: + if self.auto_commit: + row = self._select(rowid=rowid) + # Perhaps a sqlite3 bug, sometimes (None, None) is returned + # by select, below can avoid these invalid records. + if row and row[0] is not None: + self._delete(row[0]) + self.total -= 1 + item = self._serializer.loads(row[1]) + if raw: + return { + 'pqid': row[0], + 'data': item, + 'timestamp': row[2], + } + else: + return item + else: + row = self._select( + self.cursor, op=">", column=self._KEY_COLUMN, rowid=rowid + ) + if row and row[0] is not None: + self.cursor = row[0] + self.total -= 1 + item = self._serializer.loads(row[1]) + if raw: + return { + 'pqid': row[0], + 'data': item, + 'timestamp': row[2], + } + else: + return item + return None + + def update(self, item, id=None): + if isinstance(item, dict) and "pqid" in item: + _id = item.get("pqid") + item = item.get("data") + if id is not None: + _id = id + if _id is None: + raise ValueError("Provide an id or raw item") + obj = self._serializer.dumps(item) + self._update(_id, obj) + return _id + + def get(self, block=True, timeout=None, id=None, raw=False): + if isinstance(id, dict) and "pqid" in id: + rowid = id.get("pqid") + elif isinstance(id, int): + rowid = id + else: + rowid = None + if not block: + serialized = self._pop(raw=raw, rowid=rowid) + if serialized is None: + raise Empty + elif timeout is None: + # block until a put event. + serialized = self._pop(raw=raw, rowid=rowid) + while serialized is None: + self.put_event.clear() + self.put_event.wait(TICK_FOR_WAIT) + serialized = self._pop(raw=raw, rowid=rowid) + elif timeout < 0: + raise ValueError("'timeout' must be a non-negative number") + else: + # block until the timeout reached + endtime = _time.time() + timeout + serialized = self._pop(raw=raw, rowid=rowid) + while serialized is None: + self.put_event.clear() + remaining = endtime - _time.time() + if remaining <= 0.0: + raise Empty + self.put_event.wait( + TICK_FOR_WAIT if TICK_FOR_WAIT < remaining else remaining + ) + serialized = self._pop(raw=raw, rowid=rowid) + return serialized + + def get_nowait(self, id=None, raw=False): + return self.get(block=False, id=id, raw=raw) + + def task_done(self): + """Persist the current state if auto_commit=False.""" + if not self.auto_commit: + self._delete(self.cursor, op='<=') + self._task_done() + + def queue(self): + rows = self._sql_queue().fetchall() + datarows = [] + for row in rows: + item = { + 'id': row[0], + 'data': self._serializer.loads(row[1]), + 'timestamp': row[2], + } + datarows.append(item) + return datarows + + @property + def size(self): + return self.total + + def qsize(self): + return max(0, self.size) + + def empty(self): + return self.size == 0 + + def __len__(self): + return self.size + def _select(self, *args, **kwargs): start_key = self._start_key() op = kwargs.get('op', None) @@ -181,9 +249,9 @@ def _select(self, *args, **kwargs): self._sql_select(rowid), args ).fetchone() if ( - next_in_order - and rowid != start_key - and (not result or len(result) == 0) + next_in_order + and rowid != start_key + and (not result or len(result) == 0) ): # sqlackqueue: if we're at the end, start over kwargs['rowid'] = start_key @@ -260,6 +328,118 @@ def _sql_select_where(self, rowid, op, column): column=column, ) + def __del__(self): + """Handles sqlite connection when queue was deleted""" + if self._getter: + self._getter.close() + if self._putter: + self._putter.close() + + +class SQLiteBase(SQLBase): + """SQLite3 base class.""" + + _TABLE_NAME = 'base' # DB table name + _KEY_COLUMN = '' # the name of the key column, used in DB CRUD + _SQL_CREATE = '' # SQL to create a table + _SQL_UPDATE = '' # SQL to update a record + _SQL_INSERT = '' # SQL to insert a record + _SQL_SELECT = '' # SQL to select a record + _SQL_SELECT_ID = '' # SQL to select a record with criteria + _SQL_SELECT_WHERE = '' # SQL to select a record with criteria + _SQL_DELETE = '' # SQL to delete a record + _MEMORY = ':memory:' # flag indicating store DB in memory + + def __init__( + self, + path, + name='default', + multithreading=False, + timeout=10.0, + auto_commit=True, + serializer=persistqueue.serializers.pickle, + db_file_name=None, + ): + """Initiate a queue in sqlite3 or memory. + + :param path: path for storing DB file. + :param name: the suffix for the table name, + table name would be ${_TABLE_NAME}_${name} + :param multithreading: if set to True, two db connections will be, + one for **put** and one for **get**. + :param timeout: timeout in second waiting for the database lock. + :param auto_commit: Set to True, if commit is required on every + INSERT/UPDATE action, otherwise False, whereas + a **task_done** is required to persist changes + after **put**. + :param serializer: The serializer parameter controls how enqueued data + is serialized. It must have methods dump(value, fp) + and load(fp). The dump method must serialize the + value and write it to fp, and may be called for + multiple values with the same fp. The load method + must deserialize and return one value from fp, + and may be called multiple times with the same fp + to read multiple values. + :param db_file_name: set the db file name of the queue data, otherwise + default to `data.db` + """ + super(SQLiteBase, self).__init__() + self.memory_sql = False + self.path = path + self.name = name + self.timeout = timeout + self.multithreading = multithreading + self.auto_commit = auto_commit + self._serializer = serializer + self.db_file_name = "data.db" + if db_file_name: + self.db_file_name = db_file_name + self._init() + + def _init(self): + """Initialize the tables in DB.""" + if self.path == self._MEMORY: + self.memory_sql = True + log.debug("Initializing Sqlite3 Queue in memory.") + elif not os.path.exists(self.path): + os.makedirs(self.path) + log.debug( + 'Initializing Sqlite3 Queue with path {}'.format(self.path) + ) + self._conn = self._new_db_connection( + self.path, self.multithreading, self.timeout + ) + self._getter = self._conn + self._putter = self._conn + + self._conn.execute(self._sql_create) + self._conn.commit() + # Setup another session only for disk-based queue. + if self.multithreading: + if not self.memory_sql: + self._putter = self._new_db_connection( + self.path, self.multithreading, self.timeout + ) + self._conn.text_factory = str + self._putter.text_factory = str + + # SQLite3 transaction lock + self.tran_lock = threading.Lock() + self.put_event = threading.Event() + + def _new_db_connection(self, path, multithreading, timeout): + conn = None + if path == self._MEMORY: + conn = sqlite3.connect(path, check_same_thread=not multithreading) + else: + conn = sqlite3.connect( + '{}/{}'.format(path, self.db_file_name), + timeout=timeout, + check_same_thread=not multithreading, + ) + conn.execute('PRAGMA journal_mode=WAL;') + return conn + def __del__(self): """Handles sqlite connection when queue was deleted""" self._getter.close() diff --git a/persistqueue/sqlqueue.py b/persistqueue/sqlqueue.py index 56320f9..a379876 100644 --- a/persistqueue/sqlqueue.py +++ b/persistqueue/sqlqueue.py @@ -7,16 +7,13 @@ import time as _time import threading + from persistqueue import sqlbase -from persistqueue.exceptions import Empty sqlite3.enable_callback_tracebacks(True) log = logging.getLogger(__name__) -# 10 seconds internal for `wait` of event -TICK_FOR_WAIT = 10 - class SQLiteQueue(sqlbase.SQLiteBase): """SQLite3 based FIFO queue.""" @@ -46,6 +43,8 @@ class SQLiteQueue(sqlbase.SQLiteBase): ) _SQL_UPDATE = 'UPDATE {table_name} SET data = ? WHERE {key_column} = ?' + _SQL_DELETE = 'DELETE FROM {table_name} WHERE {key_column} {op} ?' + def put(self, item, block=True): # block kwarg is noop and only here to align with python's queue obj = self._serializer.dumps(item) @@ -70,123 +69,6 @@ def _init(self): self.cursor = 0 self.total = self._count() - def _pop(self, rowid=None, raw=False): - with self.action_lock: - if self.auto_commit: - row = self._select(rowid=rowid) - # Perhaps a sqlite3 bug, sometimes (None, None) is returned - # by select, below can avoid these invalid records. - if row and row[0] is not None: - self._delete(row[0]) - self.total -= 1 - item = self._serializer.loads(row[1]) - if raw: - return { - 'pqid': row[0], - 'data': item, - 'timestamp': row[2], - } - else: - return item - else: - row = self._select( - self.cursor, op=">", column=self._KEY_COLUMN, rowid=rowid - ) - if row and row[0] is not None: - self.cursor = row[0] - self.total -= 1 - item = self._serializer.loads(row[1]) - if raw: - return { - 'pqid': row[0], - 'data': item, - 'timestamp': row[2], - } - else: - return item - return None - - def update(self, item, id=None): - if isinstance(item, dict) and "pqid" in item: - _id = item.get("pqid") - item = item.get("data") - if id is not None: - _id = id - if _id is None: - raise ValueError("Provide an id or raw item") - obj = self._serializer.dumps(item) - self._update(_id, obj) - return _id - - def get(self, block=True, timeout=None, id=None, raw=False): - if isinstance(id, dict) and "pqid" in id: - rowid = id.get("pqid") - elif isinstance(id, int): - rowid = id - else: - rowid = None - if not block: - serialized = self._pop(raw=raw, rowid=rowid) - if serialized is None: - raise Empty - elif timeout is None: - # block until a put event. - serialized = self._pop(raw=raw, rowid=rowid) - while serialized is None: - self.put_event.clear() - self.put_event.wait(TICK_FOR_WAIT) - serialized = self._pop(raw=raw, rowid=rowid) - elif timeout < 0: - raise ValueError("'timeout' must be a non-negative number") - else: - # block until the timeout reached - endtime = _time.time() + timeout - serialized = self._pop(raw=raw, rowid=rowid) - while serialized is None: - self.put_event.clear() - remaining = endtime - _time.time() - if remaining <= 0.0: - raise Empty - self.put_event.wait( - TICK_FOR_WAIT if TICK_FOR_WAIT < remaining else remaining - ) - serialized = self._pop(raw=raw, rowid=rowid) - return serialized - - def get_nowait(self, id=None, raw=False): - return self.get(block=False, id=id, raw=raw) - - def task_done(self): - """Persist the current state if auto_commit=False.""" - if not self.auto_commit: - self._delete(self.cursor, op='<=') - self._task_done() - - def queue(self): - rows = self._sql_queue() - datarows = [] - for row in rows: - item = { - 'id': row[0], - 'data': self._serializer.loads(row[1]), - 'timestamp': row[2], - } - datarows.append(item) - return datarows - - @property - def size(self): - return self.total - - def qsize(self): - return max(0, self.size) - - def empty(self): - return self.size == 0 - - def __len__(self): - return self.size - FIFOSQLiteQueue = SQLiteQueue diff --git a/persistqueue/tests/__init__.py b/persistqueue/tests/__init__.py index e69de29..9bad579 100644 --- a/persistqueue/tests/__init__.py +++ b/persistqueue/tests/__init__.py @@ -0,0 +1 @@ +# coding=utf-8 diff --git a/persistqueue/tests/test_mysqlqueue.py b/persistqueue/tests/test_mysqlqueue.py new file mode 100644 index 0000000..50eaeb7 --- /dev/null +++ b/persistqueue/tests/test_mysqlqueue.py @@ -0,0 +1,314 @@ +# coding=utf-8 +import unittest +import random +from threading import Thread +import time +import sys + +from persistqueue.mysqlqueue import MySQLQueue +from persistqueue import Empty + +# db config aligned with .circleci/config.yml +db_conf = { + "host": "127.0.0.1", + "user": "user", + "passwd": "passw0rd", + "db_name": "testqueue", + # "name": "", + "port": 3306 +} +# for appveyor (windows ci), not able to config use the default +# https://www.appveyor.com/docs/services-databases/#mysql +if sys.platform.startswith('win32'): + db_conf = { + "host": "127.0.0.1", + "user": "root", + "passwd": "Password12!", + "db_name": "testqueue", + # "name": "", + "port": 3306 + } + + +class MySQLQueueTest(unittest.TestCase): + """tests that focus on feature specific to mysql""" + + def setUp(self): + _name = self.id().split(".")[-1:] + _name.append(str(time.time())) + self._table_name = ".".join(_name) + self.queue_class = MySQLQueue + self.mysql_queue = MySQLQueue(name=self._table_name, + **db_conf) + self.queue = self.mysql_queue + + def tearDown(self): + pass + tmp_conn = self.mysql_queue.get_pooled_conn() + tmp_conn.cursor().execute( + "drop table if exists %s" % self.mysql_queue._table_name) + tmp_conn.commit() + + def test_raise_empty(self): + q = self.queue + + q.put('first') + d = q.get() + self.assertEqual('first', d) + self.assertRaises(Empty, q.get, block=False) + self.assertRaises(Empty, q.get_nowait) + + # assert with timeout + self.assertRaises(Empty, q.get, block=True, timeout=1.0) + # assert with negative timeout + self.assertRaises(ValueError, q.get, block=True, timeout=-1.0) + del q + + def test_empty(self): + q = self.queue + self.assertEqual(q.empty(), True) + + q.put('first') + self.assertEqual(q.empty(), False) + + q.get() + self.assertEqual(q.empty(), True) + + def test_open_close_single(self): + """Write 1 item, close, reopen checking if same item is there""" + + q = self.queue + q.put(b'var1') + del q + q = MySQLQueue(name=self._table_name, + **db_conf) + self.assertEqual(1, q.qsize()) + self.assertEqual(b'var1', q.get()) + + def test_open_close_1000(self): + """Write 1000 items, close, reopen checking if all items are there""" + + q = self.queue + for i in range(1000): + q.put('var%d' % i) + self.assertEqual(1000, q.qsize()) + del q + q = MySQLQueue(name=self._table_name, + **db_conf) + self.assertEqual(1000, q.qsize()) + for i in range(1000): + data = q.get() + self.assertEqual('var%d' % i, data) + # assert adding another one still works + q.put('foobar') + data = q.get() + self.assertEqual('foobar', data) + + def test_random_read_write(self): + """Test random read/write""" + + q = self.queue + n = 0 + for _ in range(1000): + if random.random() < 0.5: + if n > 0: + q.get() + n -= 1 + else: + self.assertRaises(Empty, q.get, block=False) + else: + q.put('var%d' % random.getrandbits(16)) + n += 1 + + def test_multi_threaded_parallel(self): + """Create consumer and producer threads, check parallelism""" + m_queue = self.queue + + def producer(): + for i in range(1000): + m_queue.put('var%d' % i) + + def consumer(): + for i in range(1000): + x = m_queue.get(block=True) + self.assertEqual('var%d' % i, x) + + c = Thread(target=consumer) + c.start() + p = Thread(target=producer) + p.start() + p.join() + c.join() + self.assertEqual(0, m_queue.size) + self.assertEqual(0, len(m_queue)) + self.assertRaises(Empty, m_queue.get, block=False) + + def test_multi_threaded_multi_producer(self): + """Test mysqlqueue can be used by multiple producers.""" + + queue = self.queue + + def producer(seq): + for i in range(10): + queue.put('var%d' % (i + (seq * 10))) + + def consumer(): + for _ in range(100): + data = queue.get(block=True) + self.assertTrue('var' in data) + + c = Thread(target=consumer) + c.start() + producers = [] + for seq in range(10): + t = Thread(target=producer, args=(seq,)) + t.start() + producers.append(t) + + for t in producers: + t.join() + + c.join() + + def test_multiple_consumers(self): + """Test mysqlqueue can be used by multiple consumers.""" + queue = self.queue + + def producer(): + for x in range(1000): + queue.put('var%d' % x) + + counter = [] + # Set all to 0 + for _ in range(1000): + counter.append(0) + + def consumer(t_index): + for i in range(200): + data = queue.get(block=True) + self.assertTrue('var' in data) + counter[t_index * 200 + i] = data + + p = Thread(target=producer) + p.start() + consumers = [] + for index in range(5): + t = Thread(target=consumer, args=(index,)) + t.start() + consumers.append(t) + + p.join() + for t in consumers: + t.join() + + self.assertEqual(0, queue.qsize()) + for x in range(1000): + self.assertNotEqual(0, counter[x], + "not 0 for counter's index %s" % x) + + self.assertEqual(len(set(counter)), len(counter)) + + def test_task_done_with_restart(self): + """Test that items are not deleted before task_done.""" + + q = self.queue + + for i in range(1, 11): + q.put(i) + + self.assertEqual(1, q.get()) + self.assertEqual(2, q.get()) + # size is correct before task_done + self.assertEqual(8, q.qsize()) + q.task_done() + # make sure the size still correct + self.assertEqual(8, q.qsize()) + + self.assertEqual(3, q.get()) + # without task done + del q + q = MySQLQueue(name=self._table_name, + **db_conf) + # After restart, the qsize and head item are the same + self.assertEqual(7, q.qsize()) + # After restart, the queue still works + self.assertEqual(4, q.get()) + self.assertEqual(6, q.qsize()) + # auto_commit=False + del q + q = MySQLQueue(name=self._table_name, auto_commit=False, + **db_conf) + self.assertEqual(6, q.qsize()) + # After restart, the queue still works + self.assertEqual(5, q.get()) + self.assertEqual(5, q.qsize()) + del q + q = MySQLQueue(name=self._table_name, auto_commit=False, + **db_conf) + # After restart, the queue still works + self.assertEqual(5, q.get()) + self.assertEqual(5, q.qsize()) + + def test_protocol_1(self): + q = self.queue + self.assertEqual(q._serializer.protocol, + 2 if sys.version_info[0] == 2 else 4) + + def test_protocol_2(self): + q = self.queue + self.assertEqual(q._serializer.protocol, + 2 if sys.version_info[0] == 2 else 4) + + def test_json_serializer(self): + q = self.queue + x = dict( + a=1, + b=2, + c=dict( + d=list(range(5)), + e=[1] + )) + q.put(x) + self.assertEqual(q.get(), x) + + def test_put_0(self): + q = self.queue + q.put(0) + d = q.get(block=False) + self.assertIsNotNone(d) + + def test_get_id(self): + q = self.queue + q.put("val1") + val2_id = q.put("val2") + q.put("val3") + item = q.get(id=val2_id) + # item id should be 2 + self.assertEqual(val2_id, 2) + # item should get val2 + self.assertEqual(item, 'val2') + + def test_get_raw(self): + q = self.queue + q.put("val1") + item = q.get(raw=True) + # item should get val2 + self.assertEqual(True, "pqid" in item) + self.assertEqual(item.get("data"), 'val1') + + def test_queue(self): + q = self.queue + q.put("val1") + q.put("val2") + q.put("val3") + # queue should get the three items + d = q.queue() + self.assertEqual(len(d), 3) + self.assertEqual(d[1].get("data"), "val2") + + def test_update(self): + q = self.queue + qid = q.put("val1") + q.update(item="val2", id=qid) + item = q.get(id=qid) + self.assertEqual(item, "val2") diff --git a/persistqueue/tests/test_sqlqueue.py b/persistqueue/tests/test_sqlqueue.py index 7c05dd0..87cfb90 100644 --- a/persistqueue/tests/test_sqlqueue.py +++ b/persistqueue/tests/test_sqlqueue.py @@ -18,12 +18,13 @@ class SQLite3QueueTest(unittest.TestCase): def setUp(self): self.path = tempfile.mkdtemp(suffix='sqlqueue') self.auto_commit = True + self.queue_class = SQLiteQueue def tearDown(self): shutil.rmtree(self.path, ignore_errors=True) def test_raise_empty(self): - q = SQLiteQueue(self.path, auto_commit=self.auto_commit) + q = self.queue_class(self.path, auto_commit=self.auto_commit) q.put('first') d = q.get() @@ -38,7 +39,7 @@ def test_raise_empty(self): del q def test_empty(self): - q = SQLiteQueue(self.path, auto_commit=self.auto_commit) + q = self.queue_class(self.path, auto_commit=self.auto_commit) self.assertEqual(q.empty(), True) q.put('first') @@ -50,7 +51,7 @@ def test_empty(self): def test_open_close_single(self): """Write 1 item, close, reopen checking if same item is there""" - q = SQLiteQueue(self.path, auto_commit=self.auto_commit) + q = self.queue_class(self.path, auto_commit=self.auto_commit) q.put(b'var1') del q q = SQLiteQueue(self.path) @@ -60,7 +61,7 @@ def test_open_close_single(self): def test_open_close_1000(self): """Write 1000 items, close, reopen checking if all items are there""" - q = SQLiteQueue(self.path, auto_commit=self.auto_commit) + q = self.queue_class(self.path, auto_commit=self.auto_commit) for i in range(1000): q.put('var%d' % i) @@ -79,7 +80,7 @@ def test_open_close_1000(self): def test_random_read_write(self): """Test random read/write""" - q = SQLiteQueue(self.path, auto_commit=self.auto_commit) + q = self.queue_class(self.path, auto_commit=self.auto_commit) n = 0 for _ in range(1000): if random.random() < 0.5: @@ -121,8 +122,8 @@ def consumer(): def test_multi_threaded_multi_producer(self): """Test sqlqueue can be used by multiple producers.""" - queue = SQLiteQueue(path=self.path, multithreading=True, - auto_commit=self.auto_commit) + queue = self.queue_class(path=self.path, multithreading=True, + auto_commit=self.auto_commit) def producer(seq): for i in range(10): @@ -149,8 +150,8 @@ def consumer(): def test_multiple_consumers(self): """Test sqlqueue can be used by multiple consumers.""" - queue = SQLiteQueue(path=self.path, multithreading=True, - auto_commit=self.auto_commit) + queue = self.queue_class(path=self.path, multithreading=True, + auto_commit=self.auto_commit) def producer(): for x in range(1000): @@ -189,7 +190,7 @@ def consumer(index): def test_task_done_with_restart(self): """Test that items are not deleted before task_done.""" - q = SQLiteQueue(path=self.path, auto_commit=False) + q = self.queue_class(path=self.path, auto_commit=False) for i in range(1, 11): q.put(i) @@ -214,17 +215,17 @@ def test_task_done_with_restart(self): def test_protocol_1(self): shutil.rmtree(self.path, ignore_errors=True) - q = SQLiteQueue(path=self.path) + q = self.queue_class(path=self.path) self.assertEqual(q._serializer.protocol, 2 if sys.version_info[0] == 2 else 4) def test_protocol_2(self): - q = SQLiteQueue(path=self.path) + q = self.queue_class(path=self.path) self.assertEqual(q._serializer.protocol, 2 if sys.version_info[0] == 2 else 4) def test_json_serializer(self): - q = SQLiteQueue( + q = self.queue_class( path=self.path, serializer=serializers_json) x = dict( @@ -238,13 +239,13 @@ def test_json_serializer(self): self.assertEqual(q.get(), x) def test_put_0(self): - q = SQLiteQueue(path=self.path) + q = self.queue_class(path=self.path) q.put(0) d = q.get(block=False) self.assertIsNotNone(d) def test_get_id(self): - q = SQLiteQueue(path=self.path) + q = self.queue_class(path=self.path) q.put("val1") val2_id = q.put("val2") q.put("val3") @@ -255,7 +256,7 @@ def test_get_id(self): self.assertEqual(item, 'val2') def test_get_raw(self): - q = SQLiteQueue(path=self.path) + q = self.queue_class(path=self.path) q.put("val1") item = q.get(raw=True) # item should get val2 @@ -263,7 +264,7 @@ def test_get_raw(self): self.assertEqual(item.get("data"), 'val1') def test_queue(self): - q = SQLiteQueue(path=self.path) + q = self.queue_class(path=self.path) q.put("val1") q.put("val2") q.put("val3") @@ -273,7 +274,7 @@ def test_queue(self): self.assertEqual(d[1].get("data"), "val2") def test_update(self): - q = SQLiteQueue(path=self.path) + q = self.queue_class(path=self.path) qid = q.put("val1") q.update(item="val2", id=qid) item = q.get(id=qid) @@ -284,6 +285,7 @@ class SQLite3QueueNoAutoCommitTest(SQLite3QueueTest): def setUp(self): self.path = tempfile.mkdtemp(suffix='sqlqueue_auto_commit') self.auto_commit = False + self.queue_class = SQLiteQueue def test_multiple_consumers(self): """ @@ -307,6 +309,7 @@ class SQLite3QueueInMemory(SQLite3QueueTest): def setUp(self): self.path = ":memory:" self.auto_commit = True + self.queue_class = SQLiteQueue def test_open_close_1000(self): self.skipTest('Memory based sqlite is not persistent.') @@ -334,6 +337,7 @@ class FILOSQLite3QueueTest(unittest.TestCase): def setUp(self): self.path = tempfile.mkdtemp(suffix='filo_sqlqueue') self.auto_commit = True + self.queue_class = SQLiteQueue def tearDown(self): shutil.rmtree(self.path, ignore_errors=True) @@ -361,12 +365,14 @@ class FILOSQLite3QueueNoAutoCommitTest(FILOSQLite3QueueTest): def setUp(self): self.path = tempfile.mkdtemp(suffix='filo_sqlqueue_auto_commit') self.auto_commit = False + self.queue_class = FILOSQLiteQueue class SQLite3UniqueQueueTest(unittest.TestCase): def setUp(self): self.path = tempfile.mkdtemp(suffix='sqlqueue') self.auto_commit = True + self.queue_class = UniqueQ def test_add_duplicate_item(self): q = UniqueQ(self.path) diff --git a/setup.py b/setup.py index c61f7f5..bfe265f 100644 --- a/setup.py +++ b/setup.py @@ -17,6 +17,7 @@ def get_extras(): 'A thread-safe disk based persistent queue in Python.' ), long_description=open('README.rst').read(), + long_description_content_type='text/x-rst', author=__import__('persistqueue').__author__, author_email='wangxu198709@gmail.com', maintainer=__import__('persistqueue').__author__, diff --git a/test-requirements.txt b/test-requirements.txt index 54d7bcf..a75c3da 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -5,4 +5,5 @@ msgpack>=0.5.6 nose2>=0.6.5 coverage!=4.5 cov_core>=1.15.0 -virtualenv>=15.1.0 \ No newline at end of file +virtualenv>=15.1.0 +cryptography;sys_platform!="win32" # package only required for tests under mysql8.0&linux \ No newline at end of file diff --git a/tox.ini b/tox.ini index 1998a60..18dde77 100644 --- a/tox.ini +++ b/tox.ini @@ -2,8 +2,11 @@ minversion = 2.0 skipsdist = True +recreate = false envlist = py27, py34, py35, py36, py37, pep8, cover deps = -r{toxinidir}/test-requirements.txt + -r{toxinidir}/extra-requirements.txt + -r{toxinidir}/requirements.txt [testenv] @@ -11,6 +14,8 @@ setenv = VIRTUAL_ENV={envdir} usedevelop = True deps = -r{toxinidir}/test-requirements.txt + -r{toxinidir}/extra-requirements.txt + -r{toxinidir}/requirements.txt whitelist_externals = bash find