"""Tests for the MCP SSH server.""" import unittest from unittest.mock import patch, MagicMock from mcpssh.server import SSHServerMCP, SSHSession, SSHConnectionParams, CommandParams class TestSSHSession(unittest.TestCase): """Test SSH session class.""" @patch('paramiko.SSHClient') def test_connect_success(self, mock_ssh_client): """Test successful SSH connection.""" # Setup mock_client = MagicMock() mock_ssh_client.return_value = mock_client # Execute session = SSHSession("example.com", 22, "username", "/path/to/key") result = session.connect() # Verify self.assertTrue(result) self.assertTrue(session.connected) mock_client.connect.assert_called_once_with( hostname="example.com", port=22, username="username", key_filename="/path/to/key" ) @patch('paramiko.SSHClient') def test_connect_failure(self, mock_ssh_client): """Test failed SSH connection.""" # Setup mock_client = MagicMock() mock_client.connect.side_effect = Exception("Connection failed") mock_ssh_client.return_value = mock_client # Execute session = SSHSession("example.com", 22, "username", "/path/to/key") result = session.connect() # Verify self.assertFalse(result) self.assertFalse(session.connected) @patch('paramiko.SSHClient') def test_execute_command_success(self, mock_ssh_client): """Test successful command execution.""" # Setup mock_client = MagicMock() mock_stdout = MagicMock() mock_stdout.read.return_value = b"command output" mock_stdout.channel.recv_exit_status.return_value = 0 mock_stderr = MagicMock() mock_stderr.read.return_value = b"" mock_client.exec_command.return_value = (None, mock_stdout, mock_stderr) mock_ssh_client.return_value = mock_client # Execute session = SSHSession("example.com", 22, "username", "/path/to/key") session.connected = True # Skip connection session.client = mock_client result = session.execute_command("ls -la") # Verify self.assertEqual(result["stdout"], "command output") self.assertEqual(result["stderr"], "") self.assertEqual(result["exit_code"], 0) @patch('paramiko.SSHClient') def test_close(self, mock_ssh_client): """Test closing SSH connection.""" # Setup mock_client = MagicMock() mock_ssh_client.return_value = mock_client # Execute session = SSHSession("example.com", 22, "username", "/path/to/key") session.connected = True session.client = mock_client session.close() # Verify self.assertFalse(session.connected) mock_client.close.assert_called_once() class TestSSHServerMCP(unittest.TestCase): """Test SSH server MCP implementation.""" def setUp(self): """Set up test environment.""" self.server = SSHServerMCP() @patch('mcpssh.server.SSHSession') def test_ssh_connect_success(self, mock_ssh_session): """Test successful SSH connection.""" # Setup mock_session = MagicMock() mock_session.connect.return_value = True mock_ssh_session.return_value = mock_session # Execute params = { "hostname": "example.com", "port": 22, "username": "username", "key_filename": "/path/to/key" } result = self.server.ssh_connect(SSHConnectionParams(**params)) # Verify self.assertTrue(result["success"]) self.assertEqual(result["message"], "Connected to example.com") @patch('mcpssh.server.SSHSession') def test_ssh_connect_failure(self, mock_ssh_session): """Test failed SSH connection.""" # Setup mock_session = MagicMock() mock_session.connect.return_value = False mock_ssh_session.return_value = mock_session # Execute params = { "hostname": "example.com", "port": 22, "username": "username", "key_filename": "/path/to/key" } result = self.server.ssh_connect(SSHConnectionParams(**params)) # Verify self.assertFalse(result["success"]) self.assertEqual(result["message"], "Failed to connect to SSH server") def test_ssh_execute_not_connected(self): """Test command execution when not connected.""" # Execute params = {"command": "ls -la"} result = self.server.ssh_execute(CommandParams(**params)) # Verify self.assertFalse(result["success"]) self.assertEqual(result["message"], "Not connected to SSH server") @patch('mcpssh.server.SSHSession') def test_ssh_execute_success(self, mock_ssh_session): """Test successful command execution.""" # Setup mock_session = MagicMock() mock_session.connected = True mock_session.execute_command.return_value = { "stdout": "command output", "stderr": "", "exit_code": 0 } self.server.ssh_session = mock_session # Execute params = {"command": "ls -la"} result = self.server.ssh_execute(CommandParams(**params)) # Verify self.assertEqual(result["stdout"], "command output") self.assertEqual(result["stderr"], "") self.assertEqual(result["exit_code"], 0) def test_ssh_disconnect_not_connected(self): """Test disconnection when not connected.""" # Execute result = self.server.ssh_disconnect() # Verify self.assertTrue(result["success"]) self.assertEqual(result["message"], "Not connected to SSH server") def test_ssh_disconnect_success(self): """Test successful disconnection.""" # Setup mock_session = MagicMock() self.server.ssh_session = mock_session # Execute result = self.server.ssh_disconnect() # Verify self.assertTrue(result["success"]) self.assertEqual(result["message"], "Disconnected from SSH server") mock_session.close.assert_called_once() if __name__ == "__main__": unittest.main()