Skip to content

Commit 9660132

Browse files
committed
Implement format autodetection for decompression
1 parent 0262f89 commit 9660132

File tree

5 files changed

+277
-33
lines changed

5 files changed

+277
-33
lines changed

snappy/__main__.py

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,8 @@
22
import io
33
import sys
44

5-
from .snappy import stream_compress, stream_decompress
6-
from .hadoop_snappy import (
7-
stream_compress as hadoop_stream_compress,
8-
stream_decompress as hadoop_stream_decompress)
9-
10-
11-
FRAMING_FORMAT = 'framing'
12-
13-
HADOOP_FORMAT = 'hadoop_snappy'
14-
15-
DEFAULT_FORMAT = FRAMING_FORMAT
16-
17-
COMPRESS_METHODS = {
18-
FRAMING_FORMAT: stream_compress,
19-
HADOOP_FORMAT: hadoop_stream_compress,
20-
}
21-
22-
DECOMPRESS_METHODS = {
23-
FRAMING_FORMAT: stream_decompress,
24-
HADOOP_FORMAT: hadoop_stream_decompress,
25-
}
5+
from . import snappy_formats as formats
6+
from .snappy import UncompressError
267

278

289
def cmdline_main():
@@ -58,9 +39,11 @@ def cmdline_main():
5839
parser.add_argument(
5940
'-t',
6041
dest='target_format',
61-
default=DEFAULT_FORMAT,
62-
choices=[FRAMING_FORMAT, HADOOP_FORMAT],
63-
help='Target format, default is {}'.format(DEFAULT_FORMAT)
42+
default=formats.DEFAULT_FORMAT,
43+
choices=formats.ALL_SUPPORTED_FORMATS,
44+
help=(
45+
'Target format, default is "{}"'.format(formats.DEFAULT_FORMAT)
46+
)
6447
)
6548

6649
parser.add_argument(
@@ -79,18 +62,27 @@ def cmdline_main():
7962
)
8063

8164
args = parser.parse_args()
82-
if args.compress:
83-
method = COMPRESS_METHODS[args.target_format]
84-
else:
85-
method = DECOMPRESS_METHODS[args.target_format]
8665

8766
# workaround for https://bugs.python.org/issue14156
8867
if isinstance(args.infile, io.TextIOWrapper):
8968
args.infile = stdin
9069
if isinstance(args.outfile, io.TextIOWrapper):
9170
args.outfile = stdout
9271

93-
method(args.infile, args.outfile)
72+
additional_args = {}
73+
if args.compress:
74+
method = formats.get_compress_function(args.target_format)
75+
else:
76+
try:
77+
method, read_chunk = formats.get_decompress_function(
78+
args.target_format,
79+
args.infile
80+
)
81+
except UncompressError as err:
82+
sys.exit("Failed to get decompress function: {}".format(err))
83+
additional_args['start_chunk'] = read_chunk
84+
85+
method(args.infile, args.outfile, **additional_args)
9486

9587

9688
if __name__ == "__main__":

snappy/hadoop_snappy.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
_compress, _uncompress,
2727
stream_compress as _stream_compress,
2828
stream_decompress as _stream_decompress,
29+
check_format as _check_format,
2930
UncompressError,
3031
_CHUNK_MAX)
3132

@@ -112,6 +113,26 @@ def __init__(self):
112113
# total uncompressed data length of the current block
113114
self._uncompressed_length = 0
114115

116+
@staticmethod
117+
def check_format(data):
118+
"""Just checks that first two integers (big endian four-bytes int)
119+
in the given data block comply to: first int >= second int.
120+
This is a simple assumption that we have in the data a start of a
121+
block for hadoop snappy format. It should contain uncompressed block
122+
length as the first integer, and compressed subblock length as the
123+
second integer.
124+
Raises UncompressError if the condition is not fulfilled.
125+
:return: None
126+
"""
127+
int_size = _INT_SIZE
128+
if len(data) < int_size * 2:
129+
raise UncompressError("Too short data length")
130+
# We cant actually be sure abot the format here.
131+
# Assumption that compressed data length is less than uncompressed
132+
# is not true in general.
133+
# So, just don't check anything
134+
return
135+
115136
def decompress(self, data):
116137
"""Decompress 'data', returning a string containing the uncompressed
117138
data corresponding to at least part of the data in string. This data
@@ -178,8 +199,17 @@ def stream_compress(src, dst, blocksize=SNAPPY_BUFFER_SIZE_DEFAULT):
178199
)
179200

180201

181-
def stream_decompress(src, dst, blocksize=_STREAM_TO_STREAM_BLOCK_SIZE):
202+
def stream_decompress(src, dst, blocksize=_STREAM_TO_STREAM_BLOCK_SIZE,
203+
start_chunk=None):
182204
return _stream_decompress(
183205
src, dst, blocksize=blocksize,
206+
decompressor_cls=StreamDecompressor,
207+
start_chunk=start_chunk
208+
)
209+
210+
211+
def check_format(fin=None, chunk=None, blocksize=_STREAM_TO_STREAM_BLOCK_SIZE):
212+
return _check_format(
213+
fin=fin, chunk=chunk, blocksize=blocksize,
184214
decompressor_cls=StreamDecompressor
185215
)

snappy/snappy.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,25 @@ def __init__(self):
195195
self._buf = b""
196196
self._header_found = False
197197

198+
@staticmethod
199+
def check_format(data):
200+
"""Checks that the given data starts with snappy framing format
201+
stream identifier.
202+
Raises UncompressError if it doesn't start with the identifier.
203+
:return: None
204+
"""
205+
if len(data) < 6:
206+
raise UncompressError("Too short data length")
207+
chunk_type = struct.unpack("<L", data[:4])[0]
208+
size = (chunk_type >> 8)
209+
chunk_type &= 0xff
210+
if (chunk_type != _IDENTIFIER_CHUNK or
211+
size != len(_STREAM_IDENTIFIER)):
212+
raise UncompressError("stream missing snappy identifier")
213+
chunk = data[4:4 + size]
214+
if chunk != _STREAM_IDENTIFIER:
215+
raise UncompressError("stream has invalid snappy identifier")
216+
198217
def decompress(self, data):
199218
"""Decompress 'data', returning a string containing the uncompressed
200219
data corresponding to at least part of the data in string. This data
@@ -279,17 +298,41 @@ def stream_compress(src,
279298
def stream_decompress(src,
280299
dst,
281300
blocksize=_STREAM_TO_STREAM_BLOCK_SIZE,
282-
decompressor_cls=StreamDecompressor):
301+
decompressor_cls=StreamDecompressor,
302+
start_chunk=None):
283303
"""Takes an incoming file-like object and an outgoing file-like object,
284304
reads data from src, decompresses it, and writes it to dst. 'src' should
285305
support the read method, and 'dst' should support the write method.
286306
287307
The default blocksize is good for almost every scenario.
308+
:param decompressor_cls: class that implements `decompress` method like
309+
StreamDecompressor in the module
310+
:param start_chunk: start block of data that have already been read from
311+
the input stream (to detect the format, for example)
288312
"""
289313
decompressor = decompressor_cls()
290314
while True:
291-
buf = src.read(blocksize)
292-
if not buf: break
315+
if start_chunk:
316+
buf = start_chunk
317+
start_chunk = None
318+
else:
319+
buf = src.read(blocksize)
320+
if not buf: break
293321
buf = decompressor.decompress(buf)
294322
if buf: dst.write(buf)
295323
decompressor.flush() # makes sure the stream ended well
324+
325+
326+
def check_format(fin=None, chunk=None,
327+
blocksize=_STREAM_TO_STREAM_BLOCK_SIZE,
328+
decompressor_cls=StreamDecompressor):
329+
ok = True
330+
if chunk is None:
331+
chunk = fin.read(blocksize)
332+
if not chunk:
333+
raise UncompressError("Empty input stream")
334+
try:
335+
decompressor_cls.check_format(chunk)
336+
except UncompressError as err:
337+
ok = False
338+
return ok, chunk

snappy/snappy_formats.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
"""Consts and function to handle target format.
2+
ALL_SUPPORTED_FORMATS - list of supported formats
3+
get_decompress_function - returns stream decompress function for a current
4+
format (specified or autodetected)
5+
get_compress_function - returns compress function for a current format
6+
(specifed or default)
7+
"""
8+
from .snappy import (
9+
stream_compress, stream_decompress, check_format, UncompressError)
10+
from .hadoop_snappy import (
11+
stream_compress as hadoop_stream_compress,
12+
stream_decompress as hadoop_stream_decompress,
13+
check_format as hadoop_check_format)
14+
15+
16+
FRAMING_FORMAT = 'framing'
17+
18+
HADOOP_FORMAT = 'hadoop_snappy'
19+
20+
# Means format auto detection.
21+
# For compression will be used framing format.
22+
# In case of decompression will try to detect a format from the input stream
23+
# header.
24+
FORMAT_AUTO = 'auto'
25+
26+
DEFAULT_FORMAT = FORMAT_AUTO
27+
28+
ALL_SUPPORTED_FORMATS = [FRAMING_FORMAT, HADOOP_FORMAT, FORMAT_AUTO]
29+
30+
_COMPRESS_METHODS = {
31+
FRAMING_FORMAT: stream_compress,
32+
HADOOP_FORMAT: hadoop_stream_compress,
33+
}
34+
35+
_DECOMPRESS_METHODS = {
36+
FRAMING_FORMAT: stream_decompress,
37+
HADOOP_FORMAT: hadoop_stream_decompress,
38+
}
39+
40+
# We will use framing format as the default to compression.
41+
# And for decompression, if it's not defined explicitly, we will try to
42+
# guess the format from the file header.
43+
_DEFAULT_COMPRESS_FORMAT = FRAMING_FORMAT
44+
45+
# The tuple contains an ordered sequence of a format checking function and
46+
# a format-specific decompression function.
47+
# Framing format has it's header, that may be recognized.
48+
# Hadoop snappy format hasn't any special headers, it contains only
49+
# uncompressed block length integer and length of compressed subblock.
50+
# So we first check framing format and if it is not the case, then
51+
# check for snappy format.
52+
_DECOMPRESS_FORMAT_FUNCS = (
53+
(check_format, stream_decompress),
54+
(hadoop_check_format, hadoop_stream_decompress),
55+
)
56+
57+
58+
def guess_format_by_header(fin):
59+
"""Tries to guess a compression format for the given input file by it's
60+
header.
61+
:return: tuple of decompression method and a chunk that was taken from the
62+
input for format detection.
63+
"""
64+
chunk = None
65+
for check_method, decompress_func in _DECOMPRESS_FORMAT_FUNCS:
66+
ok, chunk = check_method(fin=fin, chunk=chunk)
67+
if not ok:
68+
continue
69+
return decompress_func, chunk
70+
raise UncompressError("Can't detect archive format")
71+
72+
73+
def get_decompress_function(specified_format, fin):
74+
if specified_format == FORMAT_AUTO:
75+
decompress_func, read_chunk = guess_format_by_header(fin)
76+
return decompress_func, read_chunk
77+
return _DECOMPRESS_METHODS[specified_format], None
78+
79+
80+
def get_compress_function(specified_format):
81+
if specified_format == FORMAT_AUTO:
82+
return _COMPRESS_METHODS[_DEFAULT_COMPRESS_FORMAT]
83+
return _COMPRESS_METHODS[specified_format]

test_formats.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import io
2+
import os
3+
from unittest import TestCase
4+
5+
from snappy import snappy_formats as formats
6+
from snappy.snappy import _CHUNK_MAX, UncompressError
7+
8+
9+
class TestFormatBase(TestCase):
10+
compress_format = formats.FORMAT_AUTO
11+
decompress_format = formats.FORMAT_AUTO
12+
success = True
13+
14+
def runTest(self):
15+
data = os.urandom(1024 * 256 * 2) + os.urandom(13245 * 2)
16+
compress_func = formats.get_compress_function(self.compress_format)
17+
instream = io.BytesIO(data)
18+
compressed_stream = io.BytesIO()
19+
compress_func(instream, compressed_stream)
20+
compressed_stream.seek(0)
21+
if not self.success:
22+
with self.assertRaises(UncompressError) as err:
23+
decompress_func, read_chunk = formats.get_decompress_function(
24+
self.decompress_format, compressed_stream
25+
)
26+
decompressed_stream = io.BytesIO()
27+
decompress_func(
28+
compressed_stream,
29+
decompressed_stream,
30+
start_chunk=read_chunk
31+
)
32+
return
33+
decompress_func, read_chunk = formats.get_decompress_function(
34+
self.decompress_format, compressed_stream
35+
)
36+
decompressed_stream = io.BytesIO()
37+
decompress_func(
38+
compressed_stream,
39+
decompressed_stream,
40+
start_chunk=read_chunk
41+
)
42+
decompressed_stream.seek(0)
43+
self.assertEqual(data, decompressed_stream.read())
44+
45+
46+
class TestFormatFramingFraming(TestFormatBase):
47+
compress_format = formats.FRAMING_FORMAT
48+
decompress_format = formats.FRAMING_FORMAT
49+
success = True
50+
51+
52+
class TestFormatFramingHadoop(TestFormatBase):
53+
compress_format = formats.FRAMING_FORMAT
54+
decompress_format = formats.HADOOP_FORMAT
55+
success = False
56+
57+
58+
class TestFormatFramingAuto(TestFormatBase):
59+
compress_format = formats.FRAMING_FORMAT
60+
decompress_format = formats.FORMAT_AUTO
61+
success = True
62+
63+
64+
class TestFormatHadoopHadoop(TestFormatBase):
65+
compress_format = formats.HADOOP_FORMAT
66+
decompress_format = formats.HADOOP_FORMAT
67+
success = True
68+
69+
70+
class TestFormatHadoopFraming(TestFormatBase):
71+
compress_format = formats.HADOOP_FORMAT
72+
decompress_format = formats.FRAMING_FORMAT
73+
success = False
74+
75+
76+
class TestFormatHadoopAuto(TestFormatBase):
77+
compress_format = formats.HADOOP_FORMAT
78+
decompress_format = formats.FORMAT_AUTO
79+
success = True
80+
81+
82+
class TestFormatAutoFraming(TestFormatBase):
83+
compress_format = formats.FORMAT_AUTO
84+
decompress_format = formats.FRAMING_FORMAT
85+
success = True
86+
87+
88+
class TestFormatAutoHadoop(TestFormatBase):
89+
compress_format = formats.FORMAT_AUTO
90+
decompress_format = formats.HADOOP_FORMAT
91+
success = False
92+
93+
94+
if __name__ == "__main__":
95+
import unittest
96+
unittest.main()

0 commit comments

Comments
 (0)