diff --git a/setup.py b/setup.py index 348c635e..cde27f61 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ import re import sys -from setuptools import setup, find_packages +from setuptools import find_packages, setup from setuptools.command.test import test as TestCommand long_description = open("README.rst", "r").read() @@ -47,6 +47,7 @@ def run_tests(self): "wrapt", "six>=1.5", "yarl", + "requests_toolbelt", ] setup( diff --git a/tests/unit/test_matchers.py b/tests/unit/test_matchers.py index 5e45ab66..3f169428 100644 --- a/tests/unit/test_matchers.py +++ b/tests/unit/test_matchers.py @@ -3,8 +3,7 @@ import pytest -from vcr import matchers -from vcr import request +from vcr import matchers, request # the dict contains requests with corresponding to its key difference # with 'base' request. @@ -66,6 +65,22 @@ def test_uri_matcher(): } +def make_multipart_data(boundary, name): + # Simulates multipart request containing a "name" text field and "file" binary data (1x1 white PNG pixel). + return ( + b"--" + + boundary + + b'\r\nContent-Disposition: form-data; name="name"\r\n\r\n' + + name.encode("utf-8") + + b"\r\n--" + + boundary + + b'\r\nContent-Disposition: form-data; name="file"; filename="file"\r\n\r\n\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc\xf8\xff\xff?\x00\x05\xfe\x02\xfe\r\xefF\xb8\x00\x00\x00\x00IEND\xaeB`\x82\r\n--' + + boundary + + b"--\r\n", + {"Content-Type": f"multipart/form-data; boundary={boundary.decode('utf-8')}"}, + ) + + @pytest.mark.parametrize( "r1, r2", [ @@ -125,6 +140,18 @@ def test_uri_matcher(): request.Request("POST", "http://aws.custom.com/", b"123", boto3_bytes_headers), request.Request("POST", "http://aws.custom.com/", b"123", boto3_bytes_headers), ), + ( + request.Request( + "POST", + "http://host.com/", + *make_multipart_data(b"1ef6f9a15c8b2da0dce8f3d2804bf01a", "a.png"), + ), + request.Request( + "POST", + "http://host.com/", + *make_multipart_data(b"e60f58697d0dcf44d44c50a64de09dbb", "a.png"), + ), + ), ], ) def test_body_matcher_does_match(r1, r2): @@ -150,6 +177,28 @@ def test_body_matcher_does_match(r1, r2): request.Request("POST", "http://host.com/", req1_body, {"Content-Type": "text/xml"}), request.Request("POST", "http://host.com/", req2_body, {"content-type": "text/xml"}), ), + ( + request.Request( + "POST", + "http://host.com/", + *make_multipart_data(b"1ef6f9a15c8b2da0dce8f3d2804bf01a", "a.png"), + ), + request.Request( + "POST", + "http://host.com/", + *make_multipart_data(b"e60f58697d0dcf44d44c50a64de09dbb", "b.png"), + ), + ), + ( + request.Request( + "POST", + "http://host.com/", + *make_multipart_data(b"1ef6f9a15c8b2da0dce8f3d2804bf01a", "a.png"), + ), + request.Request( + "POST", "http://host.com/", '{"b": 2, "a": 1}', {"content-type": "application/json"} + ), + ), ], ) def test_body_match_does_not_match(r1, r2): diff --git a/tox.ini b/tox.ini index 0a5519a9..96fa3680 100644 --- a/tox.ini +++ b/tox.ini @@ -79,6 +79,7 @@ deps = PyYAML ipaddress requests: requests>=2.22.0 + requests_toolbelt httplib2: httplib2 urllib3: urllib3 boto3: boto3 diff --git a/vcr/matchers.py b/vcr/matchers.py index 3dd48726..86787889 100644 --- a/vcr/matchers.py +++ b/vcr/matchers.py @@ -1,9 +1,12 @@ import json +import logging +import re import urllib import xmlrpc.client -from .util import read_body -import logging +from requests_toolbelt.multipart import decoder + +from .util import read_body log = logging.getLogger(__name__) @@ -45,7 +48,7 @@ def body(r1, r2): r2_transformer = _get_transformer(r2) if transformer != r2_transformer: transformer = _identity - assert transformer(read_body(r1)) == transformer(read_body(r2)) + assert transformer(r1.headers, read_body(r1)) == transformer(r2.headers, read_body(r2)) def headers(r1, r2): @@ -62,7 +65,7 @@ def checker(headers): return checker -def _transform_json(body): +def _transform_json(headers, body): # Request body is always a byte string, but json.loads() wants a text # string. RFC 7159 says the default encoding is UTF-8 (although UTF-16 # and UTF-32 are also allowed: hmmmmm). @@ -70,20 +73,32 @@ def _transform_json(body): return json.loads(body.decode("utf-8")) +def _transform_multipart_form_data(headers, body): + decoded_data = decoder.MultipartDecoder(content=body, content_type=headers["content-type"]) + return ( + decoded_data.encoding, + [(part.headers, part.content) for part in decoded_data.parts], + ) + + _xml_header_checker = _header_checker("text/xml") _xmlrpc_header_checker = _header_checker("xmlrpc", header="User-Agent") _checker_transformer_pairs = ( ( _header_checker("application/x-www-form-urlencoded"), - lambda body: urllib.parse.parse_qs(body.decode("ascii")), + lambda headers, body: urllib.parse.parse_qs(body.decode("ascii")), ), (_header_checker("application/json"), _transform_json), - (lambda request: _xml_header_checker(request) and _xmlrpc_header_checker(request), xmlrpc.client.loads), + ( + lambda request: _xml_header_checker(request) and _xmlrpc_header_checker(request), + lambda headers, body: xmlrpc.client.loads(body), + ), + (_header_checker("multipart/form-data"), _transform_multipart_form_data), ) -def _identity(x): - return x +def _identity(headers, body): + return body def _get_transformer(request):