""" Unit tests for cancel/interrupt functionality. Tests the integration between cancel_stream() and agent.interrupt(). """ import pytest import queue import threading from unittest.mock import Mock from api.streaming import cancel_stream from api.config import AGENT_INSTANCES, STREAMS, CANCEL_FLAGS class TestCancelInterrupt: """Test suite for cancel/interrupt functionality""" def setup_method(self): """Clean up before each test""" AGENT_INSTANCES.clear() STREAMS.clear() CANCEL_FLAGS.clear() def teardown_method(self): """Clean up after each test""" AGENT_INSTANCES.clear() STREAMS.clear() CANCEL_FLAGS.clear() def test_cancel_calls_agent_interrupt(self): """Verify that cancel_stream() calls agent.interrupt() when agent exists""" # Setup stream_id = "test_stream_123" mock_agent = Mock() mock_agent.interrupt = Mock() STREAMS[stream_id] = queue.Queue() CANCEL_FLAGS[stream_id] = threading.Event() AGENT_INSTANCES[stream_id] = mock_agent # Execute result = cancel_stream(stream_id) # Assert assert result is True mock_agent.interrupt.assert_called_once_with("Cancelled by user") assert CANCEL_FLAGS[stream_id].is_set() def test_cancel_handles_interrupt_exception(self): """Verify that cancel_stream() handles interrupt() exceptions gracefully""" stream_id = "test_stream_456" mock_agent = Mock() mock_agent.interrupt = Mock(side_effect=RuntimeError("Agent error")) STREAMS[stream_id] = queue.Queue() CANCEL_FLAGS[stream_id] = threading.Event() AGENT_INSTANCES[stream_id] = mock_agent # Should not raise exception result = cancel_stream(stream_id) # Assert assert result is True mock_agent.interrupt.assert_called_once() assert CANCEL_FLAGS[stream_id].is_set() def test_cancel_before_agent_ready(self): """Test cancel when agent not yet stored in AGENT_INSTANCES (race condition)""" stream_id = "test_stream_789" STREAMS[stream_id] = queue.Queue() CANCEL_FLAGS[stream_id] = threading.Event() # Note: AGENT_INSTANCES[stream_id] not set (simulating race condition) # Should succeed even without agent result = cancel_stream(stream_id) # Assert assert result is True assert CANCEL_FLAGS[stream_id].is_set() # Agent will check this flag when it starts def test_cancel_nonexistent_stream(self): """Test cancel for a stream that doesn't exist""" result = cancel_stream("nonexistent_stream") assert result is False def test_cancel_sets_cancel_event(self): """Verify that cancel_stream() sets the cancel_event flag""" stream_id = "test_stream_event" STREAMS[stream_id] = queue.Queue() cancel_event = threading.Event() CANCEL_FLAGS[stream_id] = cancel_event result = cancel_stream(stream_id) assert result is True assert cancel_event.is_set() def test_cancel_puts_sentinel_in_queue(self): """Verify that cancel_stream() puts cancel sentinel in queue""" stream_id = "test_stream_queue" q = queue.Queue() STREAMS[stream_id] = q CANCEL_FLAGS[stream_id] = threading.Event() result = cancel_stream(stream_id) assert result is True # Check that cancel message was queued assert not q.empty() event_type, data = q.get_nowait() assert event_type == 'cancel' assert data['message'] == 'Cancelled by user'