198 lines
6.4 KiB
Python
198 lines
6.4 KiB
Python
"""Tests for the MCP SSH server."""
|
|
|
|
import unittest
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
from mcpssh.server import SSHServerMCP, SSHSession, SSHConnectionParams, CommandParams
|
|
|
|
|
|
class TestSSHSession(unittest.TestCase):
|
|
"""Test SSH session class."""
|
|
|
|
@patch('paramiko.SSHClient')
|
|
def test_connect_success(self, mock_ssh_client):
|
|
"""Test successful SSH connection."""
|
|
# Setup
|
|
mock_client = MagicMock()
|
|
mock_ssh_client.return_value = mock_client
|
|
|
|
# Execute
|
|
session = SSHSession("example.com", 22, "username", "/path/to/key")
|
|
result = session.connect()
|
|
|
|
# Verify
|
|
self.assertTrue(result)
|
|
self.assertTrue(session.connected)
|
|
mock_client.connect.assert_called_once_with(
|
|
hostname="example.com",
|
|
port=22,
|
|
username="username",
|
|
key_filename="/path/to/key"
|
|
)
|
|
|
|
@patch('paramiko.SSHClient')
|
|
def test_connect_failure(self, mock_ssh_client):
|
|
"""Test failed SSH connection."""
|
|
# Setup
|
|
mock_client = MagicMock()
|
|
mock_client.connect.side_effect = Exception("Connection failed")
|
|
mock_ssh_client.return_value = mock_client
|
|
|
|
# Execute
|
|
session = SSHSession("example.com", 22, "username", "/path/to/key")
|
|
result = session.connect()
|
|
|
|
# Verify
|
|
self.assertFalse(result)
|
|
self.assertFalse(session.connected)
|
|
|
|
@patch('paramiko.SSHClient')
|
|
def test_execute_command_success(self, mock_ssh_client):
|
|
"""Test successful command execution."""
|
|
# Setup
|
|
mock_client = MagicMock()
|
|
mock_stdout = MagicMock()
|
|
mock_stdout.read.return_value = b"command output"
|
|
mock_stdout.channel.recv_exit_status.return_value = 0
|
|
mock_stderr = MagicMock()
|
|
mock_stderr.read.return_value = b""
|
|
mock_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
|
|
mock_ssh_client.return_value = mock_client
|
|
|
|
# Execute
|
|
session = SSHSession("example.com", 22, "username", "/path/to/key")
|
|
session.connected = True # Skip connection
|
|
session.client = mock_client
|
|
result = session.execute_command("ls -la")
|
|
|
|
# Verify
|
|
self.assertEqual(result["stdout"], "command output")
|
|
self.assertEqual(result["stderr"], "")
|
|
self.assertEqual(result["exit_code"], 0)
|
|
|
|
@patch('paramiko.SSHClient')
|
|
def test_close(self, mock_ssh_client):
|
|
"""Test closing SSH connection."""
|
|
# Setup
|
|
mock_client = MagicMock()
|
|
mock_ssh_client.return_value = mock_client
|
|
|
|
# Execute
|
|
session = SSHSession("example.com", 22, "username", "/path/to/key")
|
|
session.connected = True
|
|
session.client = mock_client
|
|
session.close()
|
|
|
|
# Verify
|
|
self.assertFalse(session.connected)
|
|
mock_client.close.assert_called_once()
|
|
|
|
|
|
class TestSSHServerMCP(unittest.TestCase):
|
|
"""Test SSH server MCP implementation."""
|
|
|
|
def setUp(self):
|
|
"""Set up test environment."""
|
|
self.server = SSHServerMCP()
|
|
|
|
@patch('mcpssh.server.SSHSession')
|
|
def test_ssh_connect_success(self, mock_ssh_session):
|
|
"""Test successful SSH connection."""
|
|
# Setup
|
|
mock_session = MagicMock()
|
|
mock_session.connect.return_value = True
|
|
mock_ssh_session.return_value = mock_session
|
|
|
|
# Execute
|
|
params = {
|
|
"hostname": "example.com",
|
|
"port": 22,
|
|
"username": "username",
|
|
"key_filename": "/path/to/key"
|
|
}
|
|
result = self.server.ssh_connect(SSHConnectionParams(**params))
|
|
|
|
# Verify
|
|
self.assertTrue(result["success"])
|
|
self.assertEqual(result["message"], "Connected to example.com")
|
|
|
|
@patch('mcpssh.server.SSHSession')
|
|
def test_ssh_connect_failure(self, mock_ssh_session):
|
|
"""Test failed SSH connection."""
|
|
# Setup
|
|
mock_session = MagicMock()
|
|
mock_session.connect.return_value = False
|
|
mock_ssh_session.return_value = mock_session
|
|
|
|
# Execute
|
|
params = {
|
|
"hostname": "example.com",
|
|
"port": 22,
|
|
"username": "username",
|
|
"key_filename": "/path/to/key"
|
|
}
|
|
result = self.server.ssh_connect(SSHConnectionParams(**params))
|
|
|
|
# Verify
|
|
self.assertFalse(result["success"])
|
|
self.assertEqual(result["message"], "Failed to connect to SSH server")
|
|
|
|
def test_ssh_execute_not_connected(self):
|
|
"""Test command execution when not connected."""
|
|
# Execute
|
|
params = {"command": "ls -la"}
|
|
result = self.server.ssh_execute(CommandParams(**params))
|
|
|
|
# Verify
|
|
self.assertFalse(result["success"])
|
|
self.assertEqual(result["message"], "Not connected to SSH server")
|
|
|
|
@patch('mcpssh.server.SSHSession')
|
|
def test_ssh_execute_success(self, mock_ssh_session):
|
|
"""Test successful command execution."""
|
|
# Setup
|
|
mock_session = MagicMock()
|
|
mock_session.connected = True
|
|
mock_session.execute_command.return_value = {
|
|
"stdout": "command output",
|
|
"stderr": "",
|
|
"exit_code": 0
|
|
}
|
|
|
|
self.server.ssh_session = mock_session
|
|
|
|
# Execute
|
|
params = {"command": "ls -la"}
|
|
result = self.server.ssh_execute(CommandParams(**params))
|
|
|
|
# Verify
|
|
self.assertEqual(result["stdout"], "command output")
|
|
self.assertEqual(result["stderr"], "")
|
|
self.assertEqual(result["exit_code"], 0)
|
|
|
|
def test_ssh_disconnect_not_connected(self):
|
|
"""Test disconnection when not connected."""
|
|
# Execute
|
|
result = self.server.ssh_disconnect()
|
|
|
|
# Verify
|
|
self.assertTrue(result["success"])
|
|
self.assertEqual(result["message"], "Not connected to SSH server")
|
|
|
|
def test_ssh_disconnect_success(self):
|
|
"""Test successful disconnection."""
|
|
# Setup
|
|
mock_session = MagicMock()
|
|
self.server.ssh_session = mock_session
|
|
|
|
# Execute
|
|
result = self.server.ssh_disconnect()
|
|
|
|
# Verify
|
|
self.assertTrue(result["success"])
|
|
self.assertEqual(result["message"], "Disconnected from SSH server")
|
|
mock_session.close.assert_called_once()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main() |