From b0ce0752c16e51800cec88fd742c76d8fa1933ca Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 11 Jul 2025 12:51:54 +0000 Subject: [PATCH 1/2] Initial plan From 5c79c7125cba4325f37ac076f61ac641283a2ec0 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 11 Jul 2025 13:07:47 +0000 Subject: [PATCH 2/2] Add comprehensive test coverage for core utility modules Co-authored-by: francescortu <90980154+francescortu@users.noreply.github.com> --- test/test_hooks.py | 269 ++++++++++++++++++++++++++ test/test_inference.py | 327 ++++++++++++++++++++++++++++++++ test/test_logger.py | 274 +++++++++++++++++++++++++++ test/test_progress.py | 417 +++++++++++++++++++++++++++++++++++++++++ test/test_utils.py | 174 +++++++++++++++++ 5 files changed, 1461 insertions(+) create mode 100644 test/test_hooks.py create mode 100644 test/test_inference.py create mode 100644 test/test_logger.py create mode 100644 test/test_progress.py create mode 100644 test/test_utils.py diff --git a/test/test_hooks.py b/test/test_hooks.py new file mode 100644 index 0000000..88c2e00 --- /dev/null +++ b/test/test_hooks.py @@ -0,0 +1,269 @@ +import unittest +import torch +from unittest.mock import MagicMock, patch, call +from easyroutine.interpretability.hooks import ( + process_args_kwargs_output, + restore_same_args_kwargs_output, + multiply_pattern, + compute_statistics +) + + +class TestHooksUtilityFunctions(unittest.TestCase): + """Test suite for utility functions in hooks module.""" + + def test_process_args_kwargs_output_with_output(self): + """Test process_args_kwargs_output when output is provided.""" + args = (torch.tensor([1, 2, 3]),) + kwargs = {"hidden_states": torch.tensor([4, 5, 6])} + output = torch.tensor([7, 8, 9]) + + result = process_args_kwargs_output(args, kwargs, output) + self.assertTrue(torch.equal(result, output)) + + def test_process_args_kwargs_output_with_tuple_output(self): + """Test process_args_kwargs_output when output is a tuple.""" + args = (torch.tensor([1, 2, 3]),) + kwargs = {} + output = (torch.tensor([7, 8, 9]), torch.tensor([10, 11, 12])) + + result = process_args_kwargs_output(args, kwargs, output) + self.assertTrue(torch.equal(result, output[0])) + + def test_process_args_kwargs_output_with_args(self): + """Test process_args_kwargs_output when output is None but args exist.""" + args = (torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])) + kwargs = {} + output = None + + result = process_args_kwargs_output(args, kwargs, output) + self.assertTrue(torch.equal(result, args[0])) + + def test_process_args_kwargs_output_with_kwargs(self): + """Test process_args_kwargs_output when output is None, no args, but kwargs exist.""" + args = () + kwargs = {"hidden_states": torch.tensor([4, 5, 6])} + output = None + + result = process_args_kwargs_output(args, kwargs, output) + self.assertTrue(torch.equal(result, kwargs["hidden_states"])) + + def test_process_args_kwargs_output_no_hidden_states(self): + """Test process_args_kwargs_output when no hidden_states in kwargs.""" + args = () + kwargs = {"other_param": torch.tensor([4, 5, 6])} + output = None + + # Based on the actual implementation, this raises UnboundLocalError + # Let's test that this edge case is handled + with self.assertRaises(UnboundLocalError): + process_args_kwargs_output(args, kwargs, output) + + def test_restore_same_args_kwargs_output_with_tuple_output(self): + """Test restore_same_args_kwargs_output with tuple output.""" + b = torch.tensor([7, 8, 9]) + args = (torch.tensor([1, 2, 3]),) + kwargs = {} + output = (torch.tensor([10, 11, 12]), torch.tensor([13, 14, 15])) + + result = restore_same_args_kwargs_output(b, args, kwargs, output) + + # Should return tuple with b as first element + self.assertIsInstance(result, tuple) + self.assertTrue(torch.equal(result[0], b)) + self.assertTrue(torch.equal(result[1], output[1])) + + def test_restore_same_args_kwargs_output_with_args(self): + """Test restore_same_args_kwargs_output with args but no output.""" + b = torch.tensor([7, 8, 9]) + args = (torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])) + kwargs = {} + output = None + + new_args, new_kwargs = restore_same_args_kwargs_output(b, args, kwargs, output) + + # Should return modified args with b as first element + self.assertTrue(torch.equal(new_args[0], b)) + self.assertTrue(torch.equal(new_args[1], args[1])) + self.assertEqual(new_kwargs, kwargs) + + def test_restore_same_args_kwargs_output_with_kwargs(self): + """Test restore_same_args_kwargs_output with kwargs but no output or args.""" + b = torch.tensor([7, 8, 9]) + args = () + kwargs = {"hidden_states": torch.tensor([1, 2, 3])} + output = None + + new_args, new_kwargs = restore_same_args_kwargs_output(b, args, kwargs, output) + + # Should return modified kwargs with b as hidden_states + self.assertEqual(new_args, args) + self.assertTrue(torch.equal(new_kwargs["hidden_states"], b)) + + def test_multiply_pattern(self): + """Test multiply_pattern function.""" + tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + multiplication_value = 0.5 + + result = multiply_pattern(tensor, multiplication_value) + expected = torch.tensor([[0.5, 1.0], [1.5, 2.0]]) + + self.assertTrue(torch.allclose(result, expected)) + + def test_multiply_pattern_zero(self): + """Test multiply_pattern with zero multiplication (ablation).""" + tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + multiplication_value = 0.0 + + result = multiply_pattern(tensor, multiplication_value) + expected = torch.zeros_like(tensor) + + self.assertTrue(torch.equal(result, expected)) + + def test_compute_statistics_basic(self): + """Test compute_statistics with basic tensor.""" + tensor = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + + mean, variance, second_moment = compute_statistics(tensor, dim=-1) + + # Expected values for each row + expected_mean = torch.tensor([2.0, 5.0]) # (1+2+3)/3, (4+5+6)/3 + expected_second_moment = torch.tensor([14.0/3, 77.0/3]) # (1+4+9)/3, (16+25+36)/3 + expected_variance = expected_second_moment - expected_mean.pow(2) + + self.assertTrue(torch.allclose(mean, expected_mean, atol=1e-6)) + self.assertTrue(torch.allclose(variance, expected_variance, atol=1e-6)) + self.assertTrue(torch.allclose(second_moment, expected_second_moment, atol=1e-6)) + + def test_compute_statistics_different_dim(self): + """Test compute_statistics with different dimension.""" + tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + + mean, variance, second_moment = compute_statistics(tensor, dim=0) + + # Computing along dim=0 (across rows) + expected_mean = torch.tensor([3.0, 4.0]) # (1+3+5)/3, (2+4+6)/3 + expected_second_moment = torch.tensor([35.0/3, 56.0/3]) # (1+9+25)/3, (4+16+36)/3 + expected_variance = expected_second_moment - expected_mean.pow(2) + + self.assertTrue(torch.allclose(mean, expected_mean, atol=1e-6)) + self.assertTrue(torch.allclose(variance, expected_variance, atol=1e-6)) + self.assertTrue(torch.allclose(second_moment, expected_second_moment, atol=1e-6)) + + def test_compute_statistics_keepdim_false(self): + """Test compute_statistics with keepdim=False.""" + tensor = torch.tensor([[1.0, 2.0, 3.0]]) + + mean, variance, second_moment = compute_statistics(tensor, dim=-1, keepdim=False) + + # Should squeeze the dimension + self.assertEqual(mean.shape, torch.Size([])) + self.assertEqual(variance.shape, torch.Size([])) + self.assertEqual(second_moment.shape, torch.Size([])) + + def test_compute_statistics_single_value(self): + """Test compute_statistics with single value tensor.""" + tensor = torch.tensor([[5.0]]) + + mean, variance, second_moment = compute_statistics(tensor, dim=-1) + + expected_mean = torch.tensor([5.0]) + expected_variance = torch.tensor([0.0]) # Variance of single value is 0 + expected_second_moment = torch.tensor([25.0]) + + self.assertTrue(torch.allclose(mean, expected_mean)) + self.assertTrue(torch.allclose(variance, expected_variance)) + self.assertTrue(torch.allclose(second_moment, expected_second_moment)) + + def test_compute_statistics_zero_variance(self): + """Test compute_statistics with constant values (zero variance).""" + tensor = torch.tensor([[2.0, 2.0, 2.0]]) + + mean, variance, second_moment = compute_statistics(tensor, dim=-1) + + expected_mean = torch.tensor([2.0]) + expected_variance = torch.tensor([0.0]) + expected_second_moment = torch.tensor([4.0]) + + self.assertTrue(torch.allclose(mean, expected_mean)) + self.assertTrue(torch.allclose(variance, expected_variance, atol=1e-6)) + self.assertTrue(torch.allclose(second_moment, expected_second_moment)) + + +class TestHooksEdgeCases(unittest.TestCase): + """Test edge cases and error conditions in hooks.""" + + def test_process_args_kwargs_output_empty_inputs(self): + """Test process_args_kwargs_output with empty inputs.""" + # Based on the implementation, this raises UnboundLocalError + with self.assertRaises(UnboundLocalError): + process_args_kwargs_output((), {}, None) + + def test_restore_same_args_kwargs_output_edge_case(self): + """Test restore_same_args_kwargs_output with single output (not tuple).""" + b = torch.tensor([7, 8, 9]) + args = (torch.tensor([1, 2, 3]),) + kwargs = {} + output = torch.tensor([10, 11, 12]) # Single tensor, not tuple + + result = restore_same_args_kwargs_output(b, args, kwargs, output) + + # Should return b directly since output is not a tuple + self.assertTrue(torch.equal(result, b)) + + def test_multiply_pattern_preserves_shape(self): + """Test that multiply_pattern preserves tensor shape.""" + tensor = torch.randn(3, 4, 5) + multiplication_value = 0.7 + + result = multiply_pattern(tensor, multiplication_value) + + self.assertEqual(result.shape, tensor.shape) + self.assertTrue(torch.allclose(result, tensor * multiplication_value)) + + def test_multiply_pattern_with_negative_values(self): + """Test multiply_pattern with negative multiplication values.""" + tensor = torch.tensor([[1.0, -2.0], [-3.0, 4.0]]) + multiplication_value = -1.5 + + result = multiply_pattern(tensor, multiplication_value) + expected = tensor * multiplication_value + + self.assertTrue(torch.allclose(result, expected)) + + def test_compute_statistics_with_nan(self): + """Test compute_statistics behavior with NaN values.""" + tensor = torch.tensor([[1.0, float('nan'), 3.0]]) + + mean, variance, second_moment = compute_statistics(tensor, dim=-1) + + # Results should contain NaN + self.assertTrue(torch.isnan(mean).any()) + self.assertTrue(torch.isnan(variance).any()) + self.assertTrue(torch.isnan(second_moment).any()) + + def test_compute_statistics_with_inf(self): + """Test compute_statistics behavior with infinite values.""" + tensor = torch.tensor([[1.0, float('inf'), 3.0]]) + + mean, variance, second_moment = compute_statistics(tensor, dim=-1) + + # Results should contain inf + self.assertTrue(torch.isinf(mean).any()) + self.assertTrue(torch.isinf(second_moment).any()) + + def test_compute_statistics_empty_tensor(self): + """Test compute_statistics with empty tensor.""" + tensor = torch.empty(0, 3) + + # This should not crash but may produce NaN results + mean, variance, second_moment = compute_statistics(tensor, dim=0) + + # Results should be NaN for empty tensor + self.assertTrue(torch.isnan(mean).all()) + self.assertTrue(torch.isnan(variance).all()) + self.assertTrue(torch.isnan(second_moment).all()) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/test/test_inference.py b/test/test_inference.py new file mode 100644 index 0000000..28081c4 --- /dev/null +++ b/test/test_inference.py @@ -0,0 +1,327 @@ +import unittest +from abc import ABC +from unittest.mock import MagicMock, patch +import sys + +# Import directly to avoid the __init__.py that imports vllm +from easyroutine.inference.base_model_interface import ( + BaseInferenceModelConfig, + BaseInferenceModel +) + + +class TestBaseInferenceModelConfig(unittest.TestCase): + """Test suite for BaseInferenceModelConfig dataclass.""" + + def test_config_creation_with_defaults(self): + """Test creating config with only required parameters.""" + config = BaseInferenceModelConfig(model_name="test-model") + + # Check required parameter + self.assertEqual(config.model_name, "test-model") + + # Check default values + self.assertEqual(config.n_gpus, 1) + self.assertEqual(config.dtype, 'bfloat16') + self.assertEqual(config.temperature, 0) + self.assertEqual(config.top_p, 0.95) + self.assertEqual(config.max_new_tokens, 5000) + + def test_config_creation_with_custom_values(self): + """Test creating config with custom parameters.""" + config = BaseInferenceModelConfig( + model_name="custom-model", + n_gpus=4, + dtype="float16", + temperature=0.7, + top_p=0.9, + max_new_tokens=1000 + ) + + self.assertEqual(config.model_name, "custom-model") + self.assertEqual(config.n_gpus, 4) + self.assertEqual(config.dtype, "float16") + self.assertEqual(config.temperature, 0.7) + self.assertEqual(config.top_p, 0.9) + self.assertEqual(config.max_new_tokens, 1000) + + def test_config_is_dataclass(self): + """Test that config behaves as a dataclass.""" + config1 = BaseInferenceModelConfig(model_name="test") + config2 = BaseInferenceModelConfig(model_name="test") + config3 = BaseInferenceModelConfig(model_name="different") + + # Same parameters should be equal + self.assertEqual(config1, config2) + # Different parameters should not be equal + self.assertNotEqual(config1, config3) + + +class TestBaseInferenceModel(unittest.TestCase): + """Test suite for BaseInferenceModel abstract class.""" + + def setUp(self): + """Set up test environment.""" + # Create a concrete implementation for testing + class ConcreteInferenceModel(BaseInferenceModel): + def convert_chat_messages_to_custom_format(self, chat_messages): + return f"Converted: {chat_messages}" + + def generate(self, inputs): + return "Generated response" + + self.ConcreteModel = ConcreteInferenceModel + + def test_base_model_is_abstract(self): + """Test that BaseInferenceModel is abstract and cannot be instantiated.""" + config = BaseInferenceModelConfig(model_name="test") + + with self.assertRaises(TypeError): + BaseInferenceModel(config) + + def test_concrete_model_initialization(self): + """Test initialization of concrete model.""" + config = BaseInferenceModelConfig(model_name="test-model") + model = self.ConcreteModel(config) + + self.assertEqual(model.config, config) + self.assertEqual(model.config.model_name, "test-model") + + def test_init_model_class_method(self): + """Test the init_model class method.""" + model = self.ConcreteModel.init_model( + model_name="test-model", + n_gpus=2, + dtype="float32" + ) + + self.assertIsInstance(model, self.ConcreteModel) + self.assertEqual(model.config.model_name, "test-model") + self.assertEqual(model.config.n_gpus, 2) + self.assertEqual(model.config.dtype, "float32") + + def test_init_model_with_defaults(self): + """Test init_model with default parameters.""" + model = self.ConcreteModel.init_model(model_name="test-model") + + self.assertEqual(model.config.model_name, "test-model") + self.assertEqual(model.config.n_gpus, 1) # default + self.assertEqual(model.config.dtype, 'bfloat16') # default + + def test_append_with_chat_template_empty_history(self): + """Test appending message to empty chat history.""" + config = BaseInferenceModelConfig(model_name="test") + model = self.ConcreteModel(config) + + result = model.append_with_chat_template("Hello world") + + expected = [{'role': 'user', 'content': 'Hello world'}] + self.assertEqual(result, expected) + + def test_append_with_chat_template_with_role(self): + """Test appending message with specific role.""" + config = BaseInferenceModelConfig(model_name="test") + model = self.ConcreteModel(config) + + result = model.append_with_chat_template( + "System message", + role='system' + ) + + expected = [{'role': 'system', 'content': 'System message'}] + self.assertEqual(result, expected) + + def test_append_with_chat_template_existing_history(self): + """Test appending message to existing chat history.""" + config = BaseInferenceModelConfig(model_name="test") + model = self.ConcreteModel(config) + + existing_history = [ + {'role': 'system', 'content': 'System message'}, + {'role': 'user', 'content': 'First user message'} + ] + + result = model.append_with_chat_template( + "Second user message", + role='user', + chat_history=existing_history + ) + + expected = existing_history + [{'role': 'user', 'content': 'Second user message'}] + self.assertEqual(result, expected) + + def test_append_with_chat_template_assistant_role(self): + """Test appending assistant message.""" + config = BaseInferenceModelConfig(model_name="test") + model = self.ConcreteModel(config) + + chat_history = [{'role': 'user', 'content': 'Question?'}] + + result = model.append_with_chat_template( + "Assistant response", + role='assistant', + chat_history=chat_history + ) + + expected = [ + {'role': 'user', 'content': 'Question?'}, + {'role': 'assistant', 'content': 'Assistant response'} + ] + self.assertEqual(result, expected) + + def test_append_with_chat_template_validates_history(self): + """Test that invalid chat history raises assertion error.""" + config = BaseInferenceModelConfig(model_name="test") + model = self.ConcreteModel(config) + + # Missing 'role' key + invalid_history = [{'content': 'message without role'}] + + with self.assertRaises(AssertionError): + model.append_with_chat_template( + "New message", + chat_history=invalid_history + ) + + def test_append_with_chat_template_validates_history_missing_content(self): + """Test validation with missing 'content' key.""" + config = BaseInferenceModelConfig(model_name="test") + model = self.ConcreteModel(config) + + # Missing 'content' key + invalid_history = [{'role': 'user'}] + + with self.assertRaises(AssertionError): + model.append_with_chat_template( + "New message", + chat_history=invalid_history + ) + + def test_append_with_chat_template_empty_list_is_valid(self): + """Test that empty chat history list is valid.""" + config = BaseInferenceModelConfig(model_name="test") + model = self.ConcreteModel(config) + + # Should not raise assertion error + result = model.append_with_chat_template( + "First message", + chat_history=[] + ) + + expected = [{'role': 'user', 'content': 'First message'}] + self.assertEqual(result, expected) + + def test_convert_chat_messages_to_custom_format_is_abstract(self): + """Test that convert_chat_messages_to_custom_format must be implemented.""" + class IncompleteModel(BaseInferenceModel): + pass # Missing implementation + + config = BaseInferenceModelConfig(model_name="test") + + with self.assertRaises(TypeError): + IncompleteModel(config) + + def test_concrete_implementation_of_convert_method(self): + """Test the concrete implementation of convert_chat_messages_to_custom_format.""" + config = BaseInferenceModelConfig(model_name="test") + model = self.ConcreteModel(config) + + messages = [{'role': 'user', 'content': 'Hello'}] + result = model.convert_chat_messages_to_custom_format(messages) + + self.assertEqual(result, f"Converted: {messages}") + + def test_role_validation_types(self): + """Test that role parameter accepts valid literal types.""" + config = BaseInferenceModelConfig(model_name="test") + model = self.ConcreteModel(config) + + # These should all work without error + user_result = model.append_with_chat_template("Message", role='user') + assistant_result = model.append_with_chat_template("Message", role='assistant') + system_result = model.append_with_chat_template("Message", role='system') + + self.assertEqual(user_result[0]['role'], 'user') + self.assertEqual(assistant_result[0]['role'], 'assistant') + self.assertEqual(system_result[0]['role'], 'system') + + def test_chat_history_immutability(self): + """Test that original chat history is not modified.""" + config = BaseInferenceModelConfig(model_name="test") + model = self.ConcreteModel(config) + + original_history = [{'role': 'user', 'content': 'Original message'}] + original_copy = original_history.copy() + + result = model.append_with_chat_template( + "New message", + chat_history=original_history + ) + + # Original history should be unchanged + self.assertEqual(original_history, original_copy) + # Result should have the new message + self.assertEqual(len(result), 2) + self.assertEqual(result[1]['content'], 'New message') + + def test_multiple_append_operations(self): + """Test multiple sequential append operations.""" + config = BaseInferenceModelConfig(model_name="test") + model = self.ConcreteModel(config) + + # Build conversation step by step + history = [] + history = model.append_with_chat_template( + "Hello", role='user', chat_history=history + ) + history = model.append_with_chat_template( + "Hi there!", role='assistant', chat_history=history + ) + history = model.append_with_chat_template( + "How are you?", role='user', chat_history=history + ) + + expected = [ + {'role': 'user', 'content': 'Hello'}, + {'role': 'assistant', 'content': 'Hi there!'}, + {'role': 'user', 'content': 'How are you?'} + ] + + self.assertEqual(history, expected) + + +class TestAbstractMethodEnforcement(unittest.TestCase): + """Test that abstract methods are properly enforced.""" + + def test_missing_convert_method_prevents_instantiation(self): + """Test that missing convert method prevents instantiation.""" + class PartialModel(BaseInferenceModel): + # Missing convert_chat_messages_to_custom_format implementation + def some_other_method(self): + pass + + config = BaseInferenceModelConfig(model_name="test") + + with self.assertRaises(TypeError) as context: + PartialModel(config) + + # Check that the error mentions the missing abstract method + error_msg = str(context.exception) + self.assertIn("abstract", error_msg.lower()) + + def test_all_methods_implemented_allows_instantiation(self): + """Test that implementing all abstract methods allows instantiation.""" + class CompleteModel(BaseInferenceModel): + def convert_chat_messages_to_custom_format(self, chat_messages): + return "implemented" + + config = BaseInferenceModelConfig(model_name="test") + + # Should not raise any errors + model = CompleteModel(config) + self.assertIsInstance(model, CompleteModel) + self.assertIsInstance(model, BaseInferenceModel) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/test/test_logger.py b/test/test_logger.py new file mode 100644 index 0000000..b2a6be2 --- /dev/null +++ b/test/test_logger.py @@ -0,0 +1,274 @@ +import unittest +import logging +import tempfile +import os +from io import StringIO +from unittest.mock import patch, MagicMock +from easyroutine.logger import ( + logger, + warning_once, + setup_logging, + enable_debug_logging, + enable_info_logging, + enable_warning_logging, + disable_logging, + setup_default_logging +) + + +class TestLogger(unittest.TestCase): + """Test suite for easyroutine.logger module.""" + + def setUp(self): + """Set up test environment.""" + # Store original logger state + self.original_level = logger.level + self.original_handlers = logger.handlers[:] + self.original_propagate = logger.propagate + + # Clear any existing warning_once messages + from easyroutine.logger import _logged_once_messages + _logged_once_messages.clear() + + def tearDown(self): + """Restore original logger state.""" + logger.setLevel(self.original_level) + logger.handlers.clear() + logger.handlers.extend(self.original_handlers) + logger.propagate = self.original_propagate + + # Clear warning_once messages + from easyroutine.logger import _logged_once_messages + _logged_once_messages.clear() + + def test_logger_exists(self): + """Test that the logger object exists and has correct name.""" + self.assertEqual(logger.name, "easyroutine") + self.assertIsInstance(logger, logging.Logger) + + def test_warning_once_function_exists(self): + """Test that warning_once is available as a function.""" + # Test that we can call warning_once + with patch.object(logger, 'warning') as mock_warning: + warning_once("Test message") + mock_warning.assert_called_once_with("Test message") + + def test_warning_once_method_exists(self): + """Test that warning_once is attached to logger as a method.""" + self.assertTrue(hasattr(logger, 'warning_once')) + self.assertTrue(callable(logger.warning_once)) + + def test_warning_once_logs_first_time(self): + """Test that warning_once logs the first occurrence of a message.""" + with patch.object(logger, 'warning') as mock_warning: + logger.warning_once("First time message") + mock_warning.assert_called_once_with("First time message") + + def test_warning_once_ignores_duplicate(self): + """Test that warning_once ignores subsequent identical messages.""" + with patch.object(logger, 'warning') as mock_warning: + # Log the same message twice + logger.warning_once("Duplicate message") + logger.warning_once("Duplicate message") + + # Should only be called once + mock_warning.assert_called_once_with("Duplicate message") + + def test_warning_once_different_messages(self): + """Test that warning_once logs different messages separately.""" + with patch.object(logger, 'warning') as mock_warning: + logger.warning_once("Message A") + logger.warning_once("Message B") + logger.warning_once("Message A") # Should not log again + + # Should be called twice (once for each unique message) + self.assertEqual(mock_warning.call_count, 2) + mock_warning.assert_any_call("Message A") + mock_warning.assert_any_call("Message B") + + def test_setup_default_logging(self): + """Test that setup_default_logging configures the logger correctly.""" + # Clear existing handlers + logger.handlers.clear() + + setup_default_logging() + + # Check that logger has handlers + self.assertTrue(len(logger.handlers) > 0) + self.assertEqual(logger.level, logging.INFO) + self.assertFalse(logger.propagate) + + def test_setup_default_logging_no_duplicate_handlers(self): + """Test that setup_default_logging doesn't add duplicate handlers.""" + # Call setup_default_logging multiple times + setup_default_logging() + handler_count_after_first = len(logger.handlers) + + setup_default_logging() + handler_count_after_second = len(logger.handlers) + + # Should not add duplicate handlers + self.assertEqual(handler_count_after_first, handler_count_after_second) + + def test_setup_logging_file_only(self): + """Test setup_logging with file output only.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False) as tmp_file: + tmp_filename = tmp_file.name + + try: + setup_logging(level="DEBUG", file=tmp_filename, console=False) + + # Test logging to file + logger.debug("Debug message") + logger.info("Info message") + + # Check file contents + with open(tmp_filename, 'r') as f: + content = f.read() + self.assertIn("Debug message", content) + self.assertIn("Info message", content) + + # Check logger configuration + self.assertEqual(logger.level, logging.DEBUG) + + finally: + os.unlink(tmp_filename) + + def test_setup_logging_console_only(self): + """Test setup_logging with console output only.""" + setup_logging(level="WARNING", file=None, console=True) + + # Should have at least one handler (console) + self.assertTrue(len(logger.handlers) > 0) + self.assertEqual(logger.level, logging.WARNING) + + def test_setup_logging_both_file_and_console(self): + """Test setup_logging with both file and console output.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False) as tmp_file: + tmp_filename = tmp_file.name + + try: + setup_logging(level="INFO", file=tmp_filename, console=True) + + # Should have at least two handlers + self.assertTrue(len(logger.handlers) >= 2) + self.assertEqual(logger.level, logging.INFO) + + finally: + os.unlink(tmp_filename) + + def test_enable_debug_logging(self): + """Test enable_debug_logging function.""" + enable_debug_logging() + + self.assertEqual(logger.level, logging.DEBUG) + # Check that all handlers have debug level + for handler in logger.handlers: + self.assertEqual(handler.level, logging.DEBUG) + + def test_enable_info_logging(self): + """Test enable_info_logging function.""" + enable_info_logging() + + self.assertEqual(logger.level, logging.INFO) + for handler in logger.handlers: + self.assertEqual(handler.level, logging.INFO) + + def test_enable_warning_logging(self): + """Test enable_warning_logging function.""" + enable_warning_logging() + + self.assertEqual(logger.level, logging.WARNING) + for handler in logger.handlers: + self.assertEqual(handler.level, logging.WARNING) + + def test_disable_logging(self): + """Test disable_logging function.""" + disable_logging() + + # Logger level should be set higher than CRITICAL + self.assertGreater(logger.level, logging.CRITICAL) + for handler in logger.handlers: + self.assertGreater(handler.level, logging.CRITICAL) + + def test_setup_logging_custom_format(self): + """Test setup_logging with custom format.""" + custom_format = "%(levelname)s: %(message)s" + + with tempfile.NamedTemporaryFile(mode='w', delete=False) as tmp_file: + tmp_filename = tmp_file.name + + try: + setup_logging( + level="INFO", + file=tmp_filename, + console=False, + fmt=custom_format + ) + + logger.info("Test message") + + # Check that custom format is used + with open(tmp_filename, 'r') as f: + content = f.read() + self.assertIn("INFO: Test message", content) + + finally: + os.unlink(tmp_filename) + + def test_setup_logging_invalid_level(self): + """Test setup_logging with invalid level defaults to INFO.""" + setup_logging(level="INVALID_LEVEL", console=True, file=None) + + # Should default to INFO level + self.assertEqual(logger.level, logging.INFO) + + def test_logging_level_hierarchy(self): + """Test that logging level hierarchy works correctly.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False) as tmp_file: + tmp_filename = tmp_file.name + + try: + # Set to WARNING level + setup_logging(level="WARNING", file=tmp_filename, console=False) + + # Log messages at different levels + logger.debug("Debug message") # Should not appear + logger.info("Info message") # Should not appear + logger.warning("Warning message") # Should appear + logger.error("Error message") # Should appear + + # Check file contents + with open(tmp_filename, 'r') as f: + content = f.read() + self.assertNotIn("Debug message", content) + self.assertNotIn("Info message", content) + self.assertIn("Warning message", content) + self.assertIn("Error message", content) + + finally: + os.unlink(tmp_filename) + + def test_logger_propagation(self): + """Test that logger propagation is set correctly.""" + setup_default_logging() + self.assertFalse(logger.propagate) + + def test_warning_once_persistence(self): + """Test that warning_once messages persist across logger reconfigurations.""" + # Log a message with warning_once + with patch.object(logger, 'warning') as mock_warning: + logger.warning_once("Persistent message") + mock_warning.assert_called_once() + mock_warning.reset_mock() + + # Reconfigure logger + setup_logging(level="INFO", console=True, file=None) + + # Try to log the same message again + logger.warning_once("Persistent message") + mock_warning.assert_not_called() + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/test/test_progress.py b/test/test_progress.py new file mode 100644 index 0000000..d2fa89f --- /dev/null +++ b/test/test_progress.py @@ -0,0 +1,417 @@ +import unittest +import os +import sys +import time +from io import StringIO +from unittest.mock import patch, MagicMock +from easyroutine.console.progress import ( + LoggingProgress, + format_time, + is_non_interactive_batch, + get_progress_bar, + progress, + _NoOpProgress +) + + +class TestProgressModule(unittest.TestCase): + """Test suite for easyroutine.console.progress module.""" + + def test_format_time_seconds(self): + """Test format_time with seconds.""" + self.assertEqual(format_time(30.5), "30.5s") + self.assertEqual(format_time(45.0), "45.0s") + self.assertEqual(format_time(59.9), "59.9s") + + def test_format_time_minutes(self): + """Test format_time with minutes.""" + self.assertEqual(format_time(60), "1.0m") + self.assertEqual(format_time(90), "1.5m") + self.assertEqual(format_time(3540), "59.0m") # Just under an hour + + def test_format_time_hours(self): + """Test format_time with hours.""" + self.assertEqual(format_time(3600), "1.0h") + self.assertEqual(format_time(5400), "1.5h") + self.assertEqual(format_time(7200), "2.0h") + + def test_format_time_edge_cases(self): + """Test format_time with edge cases.""" + self.assertEqual(format_time(0), "0.0s") + self.assertEqual(format_time(0.1), "0.1s") + + +class TestLoggingProgress(unittest.TestCase): + """Test suite for LoggingProgress class.""" + + def setUp(self): + """Set up test environment.""" + self.progress = LoggingProgress(log_interval=0.1, update_frequency=2) + + def test_init(self): + """Test LoggingProgress initialization.""" + progress = LoggingProgress(log_interval=5, update_frequency=10) + self.assertEqual(progress.log_interval, 5) + self.assertEqual(progress.update_frequency, 10) + self.assertEqual(progress.tasks, {}) + + def test_context_manager(self): + """Test LoggingProgress as context manager.""" + with LoggingProgress() as progress: + self.assertIsInstance(progress, LoggingProgress) + + @patch('builtins.print') + def test_add_task(self, mock_print): + """Test adding a task.""" + task_id = self.progress.add_task("Test task", total=100) + + self.assertEqual(task_id, 0) + self.assertIn(task_id, self.progress.tasks) + + task = self.progress.tasks[task_id] + self.assertEqual(task["description"], "Test task") + self.assertEqual(task["total"], 100) + self.assertEqual(task["completed"], 0) + + # Check print was called + mock_print.assert_called_once() + args = mock_print.call_args[0][0] + self.assertIn("Starting: Test task", args) + self.assertIn("Total: 100", args) + + @patch('builtins.print') + def test_add_task_no_total(self, mock_print): + """Test adding a task without total.""" + task_id = self.progress.add_task("Test task without total") + + task = self.progress.tasks[task_id] + self.assertIsNone(task["total"]) + + args = mock_print.call_args[0][0] + self.assertIn("Total: unknown", args) + + @patch('builtins.print') + @patch('time.time') + def test_update_by_time_interval(self, mock_time, mock_print): + """Test update triggering by time interval.""" + # Mock time progression + mock_time.side_effect = [0, 0, 0.2] # start, add_task, update + + task_id = self.progress.add_task("Test task", total=100) + mock_print.reset_mock() + + self.progress.update(task_id, advance=5) + + # Should print because time interval exceeded + mock_print.assert_called_once() + args = mock_print.call_args[0][0] + self.assertIn("Test task: 5/100", args) + + @patch('builtins.print') + def test_update_by_frequency(self, mock_print): + """Test update triggering by item frequency.""" + task_id = self.progress.add_task("Test task", total=100) + mock_print.reset_mock() + + # Update with enough items to trigger frequency-based logging + self.progress.update(task_id, advance=2) + + # Should print because we reached update_frequency=2 + mock_print.assert_called_once() + + @patch('builtins.print') + def test_update_nonexistent_task(self, mock_print): + """Test updating a nonexistent task.""" + self.progress.update(999, advance=1) + + # Should not crash or print anything + mock_print.assert_not_called() + + @patch('builtins.print') + @patch('time.time') + def test_update_with_total_calculates_percentage(self, mock_time, mock_print): + """Test that update calculates percentage and remaining time correctly.""" + mock_time.side_effect = [0, 0, 1.0] # start, add_task, update + + task_id = self.progress.add_task("Test task", total=100) + mock_print.reset_mock() + + self.progress.update(task_id, advance=25) + + args = mock_print.call_args[0][0] + self.assertIn("25/100", args) + self.assertIn("(25.0%)", args) + self.assertIn("Elapsed:", args) + self.assertIn("Remaining:", args) + + @patch('builtins.print') + @patch('time.time') + def test_update_without_total(self, mock_time, mock_print): + """Test update for task without total.""" + mock_time.side_effect = [0, 0, 1.0] + + task_id = self.progress.add_task("Test task") # No total + mock_print.reset_mock() + + self.progress.update(task_id, advance=25) + + args = mock_print.call_args[0][0] + self.assertIn("25 items", args) + self.assertIn("Elapsed:", args) + self.assertNotIn("Remaining:", args) + + @patch('builtins.print') + def test_track_with_known_length(self, mock_print): + """Test track method with iterable of known length.""" + items = [1, 2, 3, 4, 5] + + result = list(self.progress.track(items, description="Test tracking")) + + self.assertEqual(result, items) + # Should have printed start and completion messages + self.assertTrue(mock_print.called) + + @patch('builtins.print') + def test_track_with_unknown_length(self, mock_print): + """Test track method with generator (unknown length).""" + def generate_items(): + for i in range(3): + yield i + + result = list(self.progress.track(generate_items(), description="Test generator")) + + self.assertEqual(result, [0, 1, 2]) + self.assertTrue(mock_print.called) + + @patch('builtins.print') + def test_track_with_explicit_total(self, mock_print): + """Test track method with explicitly provided total.""" + items = [1, 2, 3] + + result = list(self.progress.track(items, total=10, description="Test explicit total")) + + self.assertEqual(result, items) + # Check that the explicit total was used + start_call = mock_print.call_args_list[0][0][0] + self.assertIn("Total: 10", start_call) + + +class TestIsNonInteractiveBatch(unittest.TestCase): + """Test suite for is_non_interactive_batch function.""" + + def test_with_slurm_job_id_and_no_tty(self): + """Test detection with SLURM_JOB_ID and no TTY.""" + with patch.dict(os.environ, {'SLURM_JOB_ID': '12345'}): + with patch('sys.stdout.isatty', return_value=False): + self.assertTrue(is_non_interactive_batch()) + + def test_with_slurm_job_id_and_tty(self): + """Test with SLURM_JOB_ID but with TTY (interactive session).""" + with patch.dict(os.environ, {'SLURM_JOB_ID': '12345'}): + with patch('sys.stdout.isatty', return_value=True): + self.assertFalse(is_non_interactive_batch()) + + def test_with_pbs_job_id(self): + """Test detection with PBS_JOBID.""" + with patch.dict(os.environ, {'PBS_JOBID': '12345.server'}): + with patch('sys.stdout.isatty', return_value=False): + self.assertTrue(is_non_interactive_batch()) + + def test_with_dumb_terminal(self): + """Test detection with TERM=dumb.""" + with patch.dict(os.environ, {'TERM': 'dumb'}, clear=True): + self.assertTrue(is_non_interactive_batch()) + + def test_with_output_redirection(self): + """Test detection with output redirection (no TTY).""" + with patch('sys.stdout.isatty', return_value=False): + # Remove batch environment variables + env_vars = ['SLURM_JOB_ID', 'PBS_JOBID', 'LSB_JOBID', 'SGE_TASK_ID'] + env_patches = {var: patch.dict(os.environ, {}, clear=False) for var in env_vars} + + # Remove each variable if it exists + for var in env_vars: + if var in os.environ: + del os.environ[var] + + try: + self.assertTrue(is_non_interactive_batch()) + finally: + # Restore any variables that might have been removed + pass + + def test_slurm_pty_exception(self): + """Test SLURM PTY exception case.""" + with patch.dict(os.environ, {'SLURM_PTY_PORT': '12345'}): + with patch('sys.stdout.isatty', return_value=False): + # Should return False due to SLURM_PTY_PORT exception + self.assertFalse(is_non_interactive_batch()) + + def test_interactive_session(self): + """Test normal interactive session.""" + with patch('sys.stdout.isatty', return_value=True): + # Remove batch environment variables + env_vars = ['SLURM_JOB_ID', 'PBS_JOBID', 'LSB_JOBID', 'SGE_TASK_ID'] + original_values = {} + + # Store original values and remove variables + for var in env_vars: + if var in os.environ: + original_values[var] = os.environ[var] + del os.environ[var] + + try: + with patch.dict(os.environ, {'TERM': 'xterm'}, clear=False): + self.assertFalse(is_non_interactive_batch()) + finally: + # Restore original values + for var, value in original_values.items(): + os.environ[var] = value + + +class TestGetProgressBar(unittest.TestCase): + """Test suite for get_progress_bar function.""" + + def test_disabled_progress_bar(self): + """Test that disabled progress bar returns NoOpProgress.""" + progress_bar = get_progress_bar(disable=True) + self.assertIsInstance(progress_bar, _NoOpProgress) + + @patch('easyroutine.console.progress.is_non_interactive_batch') + def test_batch_mode_progress_bar(self, mock_batch_check): + """Test that batch mode returns LoggingProgress.""" + mock_batch_check.return_value = True + progress_bar = get_progress_bar() + self.assertIsInstance(progress_bar, LoggingProgress) + + @patch('easyroutine.console.progress.is_non_interactive_batch') + def test_interactive_mode_progress_bar(self, mock_batch_check): + """Test that interactive mode returns Rich Progress.""" + mock_batch_check.return_value = False + progress_bar = get_progress_bar() + # Rich Progress class name varies, so check it's not our custom classes + self.assertNotIsInstance(progress_bar, LoggingProgress) + self.assertNotIsInstance(progress_bar, _NoOpProgress) + + def test_force_batch_mode(self): + """Test forcing batch mode regardless of environment.""" + progress_bar = get_progress_bar(force_batch_mode=True) + self.assertIsInstance(progress_bar, LoggingProgress) + + @patch.dict(os.environ, {'SLURM_JOB_ID': '12345'}) + def test_slurm_environment_settings(self): + """Test that SLURM environment gets special default settings.""" + with patch('easyroutine.console.progress.is_non_interactive_batch', return_value=True): + progress_bar = get_progress_bar() + self.assertIsInstance(progress_bar, LoggingProgress) + # In SLURM environment, update_frequency should be set + self.assertEqual(progress_bar.update_frequency, 1) + + +class TestNoOpProgress(unittest.TestCase): + """Test suite for _NoOpProgress class.""" + + def setUp(self): + self.progress = _NoOpProgress() + + def test_track_yields_items(self): + """Test that track method yields items without modification.""" + items = [1, 2, 3, 4, 5] + result = list(self.progress.track(items)) + self.assertEqual(result, items) + + def test_context_manager(self): + """Test _NoOpProgress as context manager.""" + with _NoOpProgress() as progress: + self.assertIsInstance(progress, _NoOpProgress) + + def test_add_task_returns_dummy_id(self): + """Test that add_task returns a dummy task ID.""" + task_id = self.progress.add_task("Test task", total=100) + self.assertEqual(task_id, 0) + + def test_update_does_nothing(self): + """Test that update method does nothing.""" + # Should not raise any exceptions + self.progress.update(0, advance=1) + self.progress.update(999, advance=100) # Invalid task ID should also work + + +class TestProgressFunction(unittest.TestCase): + """Test suite for the progress function.""" + + @patch('easyroutine.console.progress.get_progress_bar') + def test_progress_function_with_list(self, mock_get_progress_bar): + """Test progress function with a list.""" + mock_progress_bar = MagicMock() + mock_get_progress_bar.return_value.__enter__.return_value = mock_progress_bar + + items = [1, 2, 3] + list(progress(items, description="Test")) + + # Check that get_progress_bar was called with correct parameters + mock_get_progress_bar.assert_called_once() + + # Check that track was called on the progress bar + mock_progress_bar.track.assert_called_once_with(items, total=3, description="Test") + + @patch('easyroutine.console.progress.get_progress_bar') + def test_progress_function_with_generator(self, mock_get_progress_bar): + """Test progress function with a generator.""" + mock_progress_bar = MagicMock() + mock_get_progress_bar.return_value.__enter__.return_value = mock_progress_bar + + def gen(): + yield from [1, 2, 3] + + list(progress(gen(), description="Test generator")) + + # Generator length can't be determined, so total should be None + mock_progress_bar.track.assert_called_once() + args, kwargs = mock_progress_bar.track.call_args + self.assertIsNone(kwargs['total']) + + @patch('easyroutine.console.progress.get_progress_bar') + def test_progress_function_with_explicit_total(self, mock_get_progress_bar): + """Test progress function with explicit total.""" + mock_progress_bar = MagicMock() + mock_get_progress_bar.return_value.__enter__.return_value = mock_progress_bar + + items = [1, 2, 3] + list(progress(items, total=10, description="Test explicit")) + + # Should use the explicit total + mock_progress_bar.track.assert_called_once_with(items, total=10, description="Test explicit") + + @patch('easyroutine.console.progress.get_progress_bar') + def test_progress_function_desc_parameter(self, mock_get_progress_bar): + """Test progress function with desc parameter (alternative to description).""" + mock_progress_bar = MagicMock() + mock_get_progress_bar.return_value.__enter__.return_value = mock_progress_bar + + items = [1, 2, 3] + list(progress(items, desc="Test desc")) + + # desc should override description + mock_progress_bar.track.assert_called_once_with(items, total=3, description="Test desc") + + @patch('easyroutine.console.progress.get_progress_bar') + def test_progress_function_forwards_parameters(self, mock_get_progress_bar): + """Test that progress function forwards parameters to get_progress_bar.""" + mock_progress_bar = MagicMock() + mock_get_progress_bar.return_value.__enter__.return_value = mock_progress_bar + + items = [1, 2, 3] + list(progress(items, disable=True, force_batch_mode=True, log_interval=5)) + + # Check that parameters were forwarded + mock_get_progress_bar.assert_called_once_with( + disable=True, + force_batch_mode=True, + log_interval=5, + update_frequency=0 + ) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/test/test_utils.py b/test/test_utils.py new file mode 100644 index 0000000..4b3aa90 --- /dev/null +++ b/test/test_utils.py @@ -0,0 +1,174 @@ +import unittest +import tempfile +import os +from unittest.mock import patch +from easyroutine.utils import path_to_parents, path_to_relative + + +class TestUtils(unittest.TestCase): + """Test suite for easyroutine.utils module.""" + + def setUp(self): + """Set up test environment with a temporary directory structure.""" + self.temp_dir = tempfile.mkdtemp() + self.original_cwd = os.getcwd() + + # Create a nested directory structure for testing + self.test_dir = os.path.join(self.temp_dir, "level1", "level2", "level3") + os.makedirs(self.test_dir, exist_ok=True) + + # Create a parallel directory for relative path testing + self.parallel_dir = os.path.join(self.temp_dir, "level1", "parallel") + os.makedirs(self.parallel_dir, exist_ok=True) + + def tearDown(self): + """Clean up by returning to original directory.""" + os.chdir(self.original_cwd) + + def test_path_to_parents_single_level(self): + """Test going up one directory level.""" + # Start in the deepest directory + os.chdir(self.test_dir) + initial_dir = os.getcwd() + + with patch('builtins.print') as mock_print: + path_to_parents(1) + + # Should be one level up + expected_dir = os.path.dirname(initial_dir) + self.assertEqual(os.getcwd(), expected_dir) + + # Check that print was called with the correct message + mock_print.assert_called_once() + args = mock_print.call_args[0][0] + self.assertIn("Changed working directory to:", args) + self.assertIn(expected_dir, args) + + def test_path_to_parents_multiple_levels(self): + """Test going up multiple directory levels.""" + # Start in the deepest directory + os.chdir(self.test_dir) + initial_dir = os.getcwd() + + with patch('builtins.print') as mock_print: + path_to_parents(2) + + # Should be two levels up + expected_dir = os.path.dirname(os.path.dirname(initial_dir)) + self.assertEqual(os.getcwd(), expected_dir) + + # Check that print was called + mock_print.assert_called_once() + args = mock_print.call_args[0][0] + self.assertIn("Changed working directory to:", args) + self.assertIn(expected_dir, args) + + def test_path_to_parents_default_level(self): + """Test default behavior (going up one level).""" + os.chdir(self.test_dir) + initial_dir = os.getcwd() + + with patch('builtins.print') as mock_print: + path_to_parents() # Default to 1 level + + expected_dir = os.path.dirname(initial_dir) + self.assertEqual(os.getcwd(), expected_dir) + mock_print.assert_called_once() + + def test_path_to_parents_zero_levels(self): + """Test edge case of going up zero levels.""" + os.chdir(self.test_dir) + initial_dir = os.getcwd() + + with patch('builtins.print') as mock_print: + path_to_parents(0) + + # Based on the implementation, even with 0 levels, it goes up one level first + # Then the loop for additional levels doesn't run + expected_dir = os.path.dirname(initial_dir) + self.assertEqual(os.getcwd(), expected_dir) + + def test_path_to_relative_valid_path(self): + """Test changing to a valid relative path.""" + # Start in level2 directory + level2_dir = os.path.dirname(self.test_dir) + os.chdir(level2_dir) + + with patch('builtins.print') as mock_print: + path_to_relative("level3") + + # Should now be in level3 + self.assertEqual(os.getcwd(), self.test_dir) + + # Check print output + mock_print.assert_called_once() + args = mock_print.call_args[0][0] + self.assertIn("Changed working directory to:", args) + self.assertIn(self.test_dir, args) + + def test_path_to_relative_nested_path(self): + """Test changing to a nested relative path.""" + # Start in temp_dir + os.chdir(self.temp_dir) + + with patch('builtins.print') as mock_print: + path_to_relative(os.path.join("level1", "level2")) + + # Should now be in level2 + expected_dir = os.path.join(self.temp_dir, "level1", "level2") + self.assertEqual(os.getcwd(), expected_dir) + mock_print.assert_called_once() + + def test_path_to_relative_current_directory(self): + """Test changing to current directory (edge case).""" + os.chdir(self.test_dir) + initial_dir = os.getcwd() + + with patch('builtins.print') as mock_print: + path_to_relative(".") + + # Should stay in the same directory + self.assertEqual(os.getcwd(), initial_dir) + mock_print.assert_called_once() + + def test_path_to_relative_parent_directory(self): + """Test using relative path to go to parent.""" + os.chdir(self.test_dir) + + with patch('builtins.print') as mock_print: + path_to_relative("..") + + # Should be in parent directory + expected_dir = os.path.dirname(self.test_dir) + self.assertEqual(os.getcwd(), expected_dir) + mock_print.assert_called_once() + + def test_path_to_relative_invalid_path(self): + """Test behavior with non-existent relative path.""" + os.chdir(self.temp_dir) + + # This should raise an exception since the path doesn't exist + with self.assertRaises(FileNotFoundError): + path_to_relative("nonexistent_directory") + + def test_path_operations_integration(self): + """Test combining path_to_relative and path_to_parents operations.""" + # Start in temp_dir, go to nested directory, then back up + os.chdir(self.temp_dir) + + # Go to nested path + with patch('builtins.print'): + path_to_relative(os.path.join("level1", "level2", "level3")) + + self.assertEqual(os.getcwd(), self.test_dir) + + # Go back up two levels + with patch('builtins.print'): + path_to_parents(2) + + expected_dir = os.path.join(self.temp_dir, "level1") + self.assertEqual(os.getcwd(), expected_dir) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file