diff --git a/codewars_test/test_framework.py b/codewars_test/test_framework.py index 50b9f85..cfad8ed 100644 --- a/codewars_test/test_framework.py +++ b/codewars_test/test_framework.py @@ -1,4 +1,5 @@ from __future__ import print_function +import inspect class AssertException(Exception): @@ -10,18 +11,50 @@ def format_message(message): def display(type, message, label="", mode=""): - print("\n<{0}:{1}:{2}>{3}".format( - type.upper(), mode.upper(), label, format_message(message))) + print( + "\n<{0}:{1}:{2}>{3}".format( + type.upper(), mode.upper(), label, format_message(message) + ) + ) + + +# TODO Currently this only works if assertion functions are written directly in the test case. +def _is_in_test_case(): + frame = inspect.currentframe() + caller_frame = frame.f_back + test_case_frame = caller_frame.f_back + decorator_frame = test_case_frame.f_back + if not decorator_frame: + return False + if not "func" in decorator_frame.f_locals: + return False + func = decorator_frame.f_locals["func"] + code = test_case_frame.f_code + if func and func.__code__ == code and func.test_case_func: + return True + return False + + +def _handle_test_result(passed, message=None, allow_raise=False, in_test_case=False): + if passed: + if not in_test_case: + display("PASSED", "Test Passed") + else: + if not message: + message = "Value is not what was expected" + if in_test_case: + raise AssertionError(message) + else: + display("FAILED", message) + if allow_raise: + # TODO Use AssertionError? + raise AssertException(message) def expect(passed=None, message=None, allow_raise=False): - if passed: - display('PASSED', 'Test Passed') - else: - message = message or "Value is not what was expected" - display('FAILED', message) - if allow_raise: - raise AssertException(message) + _handle_test_result( + passed, message, allow_raise, _is_in_test_case(), + ) def assert_equals(actual, expected, message=None, allow_raise=False): @@ -31,7 +64,9 @@ def assert_equals(actual, expected, message=None, allow_raise=False): else: message += ": " + equals_msg - expect(actual == expected, message, allow_raise) + _handle_test_result( + actual == expected, message, allow_raise, _is_in_test_case(), + ) def assert_not_equals(actual, expected, message=None, allow_raise=False): @@ -42,7 +77,9 @@ def assert_not_equals(actual, expected, message=None, allow_raise=False): else: message += ": " + equals_msg - expect(not (actual == expected), message, allow_raise) + _handle_test_result( + not (actual == expected), message, allow_raise, _is_in_test_case(), + ) def expect_error(message, function, exception=Exception): @@ -53,28 +90,40 @@ def expect_error(message, function, exception=Exception): passed = True except Exception: pass - expect(passed, message) + _handle_test_result( + passed, message, False, _is_in_test_case(), + ) def expect_no_error(message, function, exception=BaseException): + passed = True try: function() except exception as e: - fail("{}: {}".format(message or "Unexpected exception", repr(e))) - return + passed = False + message = "{}: {}".format(message or "Unexpected exception", repr(e)) except Exception: pass - pass_() + _handle_test_result( + passed, message, False, _is_in_test_case(), + ) -def pass_(): expect(True) +def pass_(): + if not _is_in_test_case(): + display("PASSED", "Test Passed") -def fail(message): expect(False, message) +def fail(message): + if _is_in_test_case(): + raise AssertionError(message) + else: + display("FAILED", message) def assert_approx_equals( - actual, expected, margin=1e-9, message=None, allow_raise=False): + actual, expected, margin=1e-9, message=None, allow_raise=False +): msg = "{0} should be close to {1} with absolute or relative margin of {2}" equals_msg = msg.format(repr(actual), repr(expected), repr(margin)) if message is None: @@ -82,17 +131,22 @@ def assert_approx_equals( else: message += ": " + equals_msg div = max(abs(actual), abs(expected), 1) - expect(abs((actual - expected) / div) < margin, message, allow_raise) + _handle_test_result( + abs((actual - expected) / div) < margin, + message, + allow_raise, + _is_in_test_case(), + ) -''' +""" Usage: @describe('describe text') def describe1(): @it('it text') def it1(): # some test cases... -''' +""" def _timed_block_factory(opening_text): @@ -102,52 +156,62 @@ def _timed_block_factory(opening_text): def _timed_block_decorator(s, before=None, after=None): display(opening_text, s) + is_test_case = opening_text == "IT" def wrapper(func): if callable(before): before() time = timer() + if is_test_case: + func.test_case_func = True try: func() + if is_test_case: + display("PASSED", "Test Passed") except AssertionError as e: - display('FAILED', str(e)) + display("FAILED", str(e)) except Exception: - fail('Unexpected exception raised') - tb_str = ''.join(format_exception(*exc_info())) - display('ERROR', tb_str) - display('COMPLETEDIN', '{:.2f}'.format((timer() - time) * 1000)) + fail("Unexpected exception raised") + tb_str = "".join(format_exception(*exc_info())) + display("ERROR", tb_str) + display("COMPLETEDIN", "{:.2f}".format((timer() - time) * 1000)) if callable(after): after() + return wrapper + return _timed_block_decorator -describe = _timed_block_factory('DESCRIBE') -it = _timed_block_factory('IT') +describe = _timed_block_factory("DESCRIBE") +it = _timed_block_factory("IT") -''' +""" Timeout utility Usage: @timeout(sec) def some_tests(): any code block... Note: Timeout value can be a float. -''' +""" def timeout(sec): def wrapper(func): from multiprocessing import Process - msg = 'Should not throw any exceptions inside timeout' + + msg = "Should not throw any exceptions inside timeout" def wrapped(): expect_no_error(msg, func) + process = Process(target=wrapped) process.start() process.join(sec) if process.is_alive(): - fail('Exceeded time limit of {:.3f} seconds'.format(sec)) + fail("Exceeded time limit of {:.3f} seconds".format(sec)) process.terminate() process.join() + return wrapper diff --git a/tests/fixtures/custom_assertion.expected.txt b/tests/fixtures/custom_assertion.expected.txt new file mode 100644 index 0000000..8c0f14b --- /dev/null +++ b/tests/fixtures/custom_assertion.expected.txt @@ -0,0 +1,22 @@ + +group 1 + +test 1 + +Test Passed + +0.00 + +test 2 + +Expected 1 to equal 2 + +0.01 + +test 3 + +using assert + +0.00 + +0.03 diff --git a/tests/fixtures/custom_assertion.py b/tests/fixtures/custom_assertion.py new file mode 100644 index 0000000..5db6634 --- /dev/null +++ b/tests/fixtures/custom_assertion.py @@ -0,0 +1,21 @@ +import codewars_test as test + + +def custom_assert_equal(a, b): + if a != b: + raise AssertionError("Expected {} to equal {}".format(a, b)) + + +@test.describe("group 1") +def group_1(): + @test.it("test 1") + def test_1(): + custom_assert_equal(1, 1) + + @test.it("test 2") + def test_2(): + custom_assert_equal(1, 2) + + @test.it("test 3") + def test_3(): + assert 1 == 2, "using assert" diff --git a/tests/fixtures/expect_error_sample.expected.txt b/tests/fixtures/expect_error_sample.expected.txt index b2c7f56..5968378 100644 --- a/tests/fixtures/expect_error_sample.expected.txt +++ b/tests/fixtures/expect_error_sample.expected.txt @@ -5,72 +5,24 @@ f0 did not raise any exception -f0 did not raise Exception - -f0 did not raise ArithmeticError - -f0 did not raise ZeroDivisionError - -f0 did not raise LookupError - -f0 did not raise KeyError - -f0 did not raise OSError - -0.03 +0.02 f1 raises Exception -Test Passed - -Test Passed - f1 did not raise ArithmeticError -f1 did not raise ZeroDivisionError - -f1 did not raise LookupError - -f1 did not raise KeyError - -f1 did not raise OSError - 0.02 f2 raises Exception >> ArithmeticError >> ZeroDivisionError -Test Passed - -Test Passed - -Test Passed - -Test Passed - f2 did not raise LookupError -f2 did not raise KeyError - -f2 did not raise OSError - 0.02 f3 raises Exception >> LookupError >> KeyError -Test Passed - -Test Passed - f3 did not raise ArithmeticError -f3 did not raise ZeroDivisionError - -Test Passed - -Test Passed - -f3 did not raise OSError - 0.02 -0.11 +0.10 diff --git a/tests/test_outputs.py b/tests/test_outputs.py index bfe8396..43c40e9 100644 --- a/tests/test_outputs.py +++ b/tests/test_outputs.py @@ -23,7 +23,12 @@ def test(self): expected = re.sub( r"(?<=)\d+(?:\.\d+)?", r"\\d+(?:\\.\\d+)?", r.read() ) - self.assertRegex(result.stdout.decode("utf-8"), expected) + actual = result.stdout.decode("utf-8") + self.assertRegex( + actual, + expected, + "Expected Pattern:\n{}\n\nGot:\n{}\n".format(expected, actual), + ) return test