From e696d384c2e2c296c5884d4c93faae05e25d055f Mon Sep 17 00:00:00 2001 From: WerWolv Date: Tue, 16 Dec 2025 20:25:46 +0100 Subject: [PATCH] feat: Add initial MCP Server support --- lib/external/libwolv | 2 +- lib/libimhex/CMakeLists.txt | 3 + .../communication_interface.hpp | 17 ++ lib/libimhex/include/hex/mcp/client.hpp | 15 ++ lib/libimhex/include/hex/mcp/server.hpp | 38 +++ lib/libimhex/source/api/content_registry.cpp | 34 +++ lib/libimhex/source/mcp/client.cpp | 38 +++ lib/libimhex/source/mcp/server.cpp | 226 ++++++++++++++++++ .../content/command_line_interface.hpp | 1 + plugins/builtin/romfs/lang/en_US.json | 1 + .../source/content/background_services.cpp | 17 ++ .../source/content/command_line_interface.cpp | 10 + .../source/content/settings_entries.cpp | 2 + plugins/builtin/source/content/ui_items.cpp | 15 ++ plugins/builtin/source/plugin_builtin.cpp | 5 +- 15 files changed, 421 insertions(+), 3 deletions(-) create mode 100644 lib/libimhex/include/hex/mcp/client.hpp create mode 100644 lib/libimhex/include/hex/mcp/server.hpp create mode 100644 lib/libimhex/source/mcp/client.cpp create mode 100644 lib/libimhex/source/mcp/server.cpp diff --git a/lib/external/libwolv b/lib/external/libwolv index c9108712b..b0c941656 160000 --- a/lib/external/libwolv +++ b/lib/external/libwolv @@ -1 +1 @@ -Subproject commit c9108712ba16a024dc7aee75f3fc9882629b2703 +Subproject commit b0c9416568475a784838e3af8b0021a542e47cd7 diff --git a/lib/libimhex/CMakeLists.txt b/lib/libimhex/CMakeLists.txt index 1977cf830..19803d993 100644 --- a/lib/libimhex/CMakeLists.txt +++ b/lib/libimhex/CMakeLists.txt @@ -57,6 +57,9 @@ set(LIBIMHEX_SOURCES source/ui/toast.cpp source/ui/banner.cpp + source/mcp/client.cpp + source/mcp/server.cpp + source/subcommands/subcommands.cpp ) diff --git a/lib/libimhex/include/hex/api/content_registry/communication_interface.hpp b/lib/libimhex/include/hex/api/content_registry/communication_interface.hpp index 5ab9eb421..dbd709249 100644 --- a/lib/libimhex/include/hex/api/content_registry/communication_interface.hpp +++ b/lib/libimhex/include/hex/api/content_registry/communication_interface.hpp @@ -7,6 +7,8 @@ #include #include +#include + EXPORT_MODULE namespace hex { /* Network Communication Interface Registry. Allows adding new communication interface endpoints */ @@ -22,4 +24,19 @@ EXPORT_MODULE namespace hex { } + namespace ContentRegistry::MCP { + + namespace impl { + mcp::Server& getMcpServerInstance(); + + void setEnabled(bool enabled); + } + + bool isEnabled(); + bool isConnected(); + + void registerTool(std::string_view capabilities, std::function function); + + } + } \ No newline at end of file diff --git a/lib/libimhex/include/hex/mcp/client.hpp b/lib/libimhex/include/hex/mcp/client.hpp new file mode 100644 index 000000000..141da108c --- /dev/null +++ b/lib/libimhex/include/hex/mcp/client.hpp @@ -0,0 +1,15 @@ +#pragma once + +#include + +namespace hex::mcp { + + class Client { + public: + Client() = default; + ~Client() = default; + + int run(std::istream &input, std::ostream &output); + }; + +} diff --git a/lib/libimhex/include/hex/mcp/server.hpp b/lib/libimhex/include/hex/mcp/server.hpp new file mode 100644 index 000000000..2b62ee82c --- /dev/null +++ b/lib/libimhex/include/hex/mcp/server.hpp @@ -0,0 +1,38 @@ +#pragma once + +#include +#include +#include + +namespace hex::mcp { + + class Server { + public: + constexpr static auto McpInternalPort = 19743; + + Server(); + ~Server(); + + void listen(); + void shutdown(); + void disconnect(); + bool isConnected(); + + void addPrimitive(std::string type, std::string_view capabilities, std::function function); + + private: + nlohmann::json handleInitialize(); + void handleNotifications(const std::string &method, const nlohmann::json ¶ms); + + struct Primitive { + nlohmann::json capabilities; + std::function function; + }; + + std::map> m_primitives; + + wolv::net::SocketServer m_server; + bool m_connected = false; + }; + +} diff --git a/lib/libimhex/source/api/content_registry.cpp b/lib/libimhex/source/api/content_registry.cpp index f0f3de479..59fe1a63c 100644 --- a/lib/libimhex/source/api/content_registry.cpp +++ b/lib/libimhex/source/api/content_registry.cpp @@ -1413,6 +1413,40 @@ namespace hex { } + namespace ContentRegistry::MCP { + + namespace impl { + + mcp::Server& getMcpServerInstance() { + static AutoReset> server; + + if (*server == nullptr) + server = std::make_unique(); + + return **server; + } + + static bool s_mcpEnabled = false; + void setEnabled(bool enabled) { + s_mcpEnabled = enabled; + } + + } + + bool isEnabled() { + return impl::s_mcpEnabled; + } + + bool isConnected() { + return impl::getMcpServerInstance().isConnected(); + } + + void registerTool(std::string_view capabilities, std::function function) { + impl::getMcpServerInstance().addPrimitive("tools", capabilities, function); + } + + } + namespace ContentRegistry::Experiments { namespace impl { diff --git a/lib/libimhex/source/mcp/client.cpp b/lib/libimhex/source/mcp/client.cpp new file mode 100644 index 000000000..7df8d839e --- /dev/null +++ b/lib/libimhex/source/mcp/client.cpp @@ -0,0 +1,38 @@ +#include +#include + +#include + +#include +#include + +#include +#include +#include +#include + +namespace hex::mcp { + + int Client::run(std::istream &input, std::ostream &output) { + wolv::net::SocketClient client(wolv::net::SocketClient::Type::TCP, true); + client.connect("127.0.0.1", Server::McpInternalPort); + + if (!client.isConnected()) { + log::resumeLogging(); + log::error("Cannot connect to ImHex. Do you have an instance running and is the MCP server enabled?"); + return EXIT_FAILURE; + } + + while (true) { + std::string request; + std::getline(input, request); + + client.writeString(request); + auto response = client.readString(); + if (!response.empty() && response.front() != 0x00) + output << response << std::endl; + } + + return EXIT_SUCCESS; + } +} diff --git a/lib/libimhex/source/mcp/server.cpp b/lib/libimhex/source/mcp/server.cpp new file mode 100644 index 000000000..86aaa2088 --- /dev/null +++ b/lib/libimhex/source/mcp/server.cpp @@ -0,0 +1,226 @@ +#include + +#include + +#include +#include +#include +#include +#include + + +namespace hex::mcp { + + class JsonRpc { + public: + explicit JsonRpc(const std::string &request) : m_request(request) { } + + struct MethodNotFoundException : std::exception {}; + struct InvalidParametersException : std::exception {}; + + std::optional execute(auto callback) { + try { + auto requestJson = nlohmann::json::parse(m_request); + + if (requestJson.is_array()) { + return handleBatchedMessages(requestJson, callback).transform([](const auto &response) { return response.dump(); }); + } else { + return handleMessage(requestJson, callback).transform([](const auto &response) { return response.dump(); }); + } + } catch (const MethodNotFoundException &) { + return createErrorMessage(ErrorCode::MethodNotFound, "Method not found").dump(); + } catch (const InvalidParametersException &) { + return createErrorMessage(ErrorCode::InvalidParams, "Invalid params").dump(); + } catch (const nlohmann::json::parse_error &) { + return createErrorMessage(ErrorCode::ParseError, "Parse error").dump(); + } catch (const std::exception &e) { + return createErrorMessage(ErrorCode::InternalError, e.what()).dump(); + } + } + + private: + std::optional handleMessage(const nlohmann::json &request, auto callback) { + // Validate JSON-RPC request + if (!request.contains("jsonrpc") || request["jsonrpc"] != "2.0" || + !request.contains("method") || !request["method"].is_string()) { + m_id = request.contains("id") ? std::optional(request["id"].get()) : std::nullopt; + + return createErrorMessage(ErrorCode::InvalidRequest, "Invalid Request").dump(); + } + + m_id = request.contains("id") ? std::optional(request["id"].get()) : std::nullopt; + + // Execute the method + auto result = callback(request["method"].get(), request.value("params", nlohmann::json::object())); + + if (!m_id.has_value()) + return std::nullopt; + + return createResponseMessage(result); + } + + std::optional handleBatchedMessages(const nlohmann::json &request, auto callback) { + if (!request.is_array()) { + return createErrorMessage(ErrorCode::InvalidRequest, "Invalid Request").dump(); + } + + nlohmann::json responses = nlohmann::json::array(); + for (const auto &message : request) { + auto response = handleMessage(message, callback); + if (response.has_value()) + responses.push_back(*response); + } + + if (responses.empty()) + return std::nullopt; + + return responses.dump(); + } + + enum class ErrorCode { + ParseError = -32700, + InvalidRequest = -32600, + MethodNotFound = -32601, + InvalidParams = -32602, + InternalError = -32603, + }; + + nlohmann::json createDefaultMessage() { + nlohmann::json message; + message["jsonrpc"] = "2.0"; + if (m_id.has_value()) + message["id"] = m_id.value(); + else + message["id"] = nullptr; + + return message; + } + + nlohmann::json createErrorMessage(ErrorCode code, const std::string &message) { + auto json = createDefaultMessage(); + json["error"] = { + { "code", int(code) }, + { "message", message } + }; + return json; + } + + nlohmann::json createResponseMessage(const nlohmann::json &result) { + auto json = createDefaultMessage(); + json["result"] = result; + return json; + } + + private: + std::string m_request; + std::optional m_id; + }; + + Server::Server() : m_server(McpInternalPort, 1024, 1, true) { + + } + + Server::~Server() { + this->shutdown(); + } + + void Server::listen() { + m_server.accept([this](auto, const std::vector &data) -> std::vector { + std::string request(data.begin(), data.end()); + + log::debug("MCP ----> {}", request); + + JsonRpc rpc(request); + auto response = rpc.execute([this](const std::string &method, const nlohmann::json ¶ms) -> nlohmann::json { + if (method == "initialize") { + return handleInitialize(); + } else if (method.starts_with("notifications/")) { + handleNotifications(method.substr(14), params); + return {}; + } else if (method.ends_with("/list")) { + auto primitiveName = method.substr(0, method.size() - 5); + if (m_primitives.contains(primitiveName)) { + nlohmann::json capabilitiesList = nlohmann::json::array(); + for (const auto &[name, primitive] : m_primitives[primitiveName]) { + capabilitiesList.push_back(primitive.capabilities); + } + + nlohmann::json result; + result[primitiveName] = capabilitiesList; + return result; + } + } else if (method.ends_with("/call")) { + auto primitive = method.substr(0, method.size() - 5); + if (auto primitiveIt = m_primitives.find(primitive); primitiveIt != m_primitives.end()) { + auto name = params.value("name", ""); + if (auto functionIt = primitiveIt->second.find(name); functionIt != primitiveIt->second.end()) { + return functionIt->second.function(params.value("arguments", nlohmann::json::object())); + } + } + } + + throw JsonRpc::MethodNotFoundException(); + }); + + log::debug("MCP <---- {}", response.value_or("")); + + if (response.has_value()) + return { response->begin(), response->end() }; + else + return std::vector{ 0x00 }; + }, [this](auto) { + log::info("MCP client disconnected"); + m_connected = false; + }, true); + } + + void Server::shutdown() { + m_server.shutdown(); + } + + void Server::disconnect() { + m_server.disconnectClients(); + } + + void Server::addPrimitive(std::string type, std::string_view capabilities, std::function function) { + auto json = nlohmann::json::parse(capabilities); + auto name = json["name"].get(); + + m_primitives[type][name] = { + json, + function + }; + } + + + nlohmann::json Server::handleInitialize() { + constexpr static auto ServerName = "ImHex"; + constexpr static auto ProtocolVersion = "2025-06-18"; + + return { + { "protocolVersion", ProtocolVersion }, + { + "capabilities", + { + { "tools", nlohmann::json::object() }, + }, + }, + { + "serverInfo", { + { "name", ServerName }, + { "version", ImHexApi::System::getImHexVersion().get() } + } + } + }; + } + + void Server::handleNotifications(const std::string &method, [[maybe_unused]] const nlohmann::json ¶ms) { + if (method == "initialized") { + m_connected = true; + } + } + + bool Server::isConnected() { + return m_connected; + } +} diff --git a/plugins/builtin/include/content/command_line_interface.hpp b/plugins/builtin/include/content/command_line_interface.hpp index 9718f519a..afca28951 100644 --- a/plugins/builtin/include/content/command_line_interface.hpp +++ b/plugins/builtin/include/content/command_line_interface.hpp @@ -30,6 +30,7 @@ namespace hex::plugin::builtin { void handleValidatePluginCommand(const std::vector &args); void handleSaveEditorCommand(const std::vector &args); void handleFileInfoCommand(const std::vector &args); + void handleMCPCommand(const std::vector &args); void registerCommandForwarders(); diff --git a/plugins/builtin/romfs/lang/en_US.json b/plugins/builtin/romfs/lang/en_US.json index f7fe56ca0..72b87d32f 100644 --- a/plugins/builtin/romfs/lang/en_US.json +++ b/plugins/builtin/romfs/lang/en_US.json @@ -54,6 +54,7 @@ "hex.builtin.achievement.misc.download_from_store.name": "There's an app for that", "hex.builtin.achievement.misc.download_from_store.desc": "Download any item from the Content Store", "hex.builtin.background_service.network_interface": "Network Interface", + "hex.builtin.setting.general.mcp_server": "MCP Server support", "hex.builtin.background_service.auto_backup": "Auto Backup", "hex.builtin.command.calc.desc": "Calculator", "hex.builtin.command.convert.desc": "Unit conversion", diff --git a/plugins/builtin/source/content/background_services.cpp b/plugins/builtin/source/content/background_services.cpp index 4be50c547..44c2d31c4 100644 --- a/plugins/builtin/source/content/background_services.cpp +++ b/plugins/builtin/source/content/background_services.cpp @@ -17,6 +17,8 @@ #include #include +#include +#include namespace hex::plugin::builtin { @@ -103,6 +105,16 @@ namespace hex::plugin::builtin { std::this_thread::sleep_for(std::chrono::seconds(1)); } + void handleMCPServer() { + if (!ContentRegistry::MCP::isEnabled()) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + ContentRegistry::MCP::impl::getMcpServerInstance().disconnect(); + return; + } + + ContentRegistry::MCP::impl::getMcpServerInstance().listen(); + } + } void registerBackgroundServices() { @@ -110,12 +122,17 @@ namespace hex::plugin::builtin { s_networkInterfaceServiceEnabled = value.get(false); }); + ContentRegistry::Settings::onChange("hex.builtin.setting.general", "hex.builtin.setting.general.mcp_server", [](const ContentRegistry::Settings::SettingsValue &value) { + ContentRegistry::MCP::impl::setEnabled(value.get(false)); + }); + ContentRegistry::Settings::onChange("hex.builtin.setting.general", "hex.builtin.setting.general.backups.auto_backup_time", [](const ContentRegistry::Settings::SettingsValue &value) { s_autoBackupTime = value.get(0) * 30; }); ContentRegistry::BackgroundServices::registerService("hex.builtin.background_service.network_interface", handleNetworkInterfaceService); ContentRegistry::BackgroundServices::registerService("hex.builtin.background_service.auto_backup", handleAutoBackup); + ContentRegistry::BackgroundServices::registerService("hex.builtin.background_service.mcp", handleMCPServer); EventProviderDirtied::subscribe([](prv::Provider *) { s_dataDirty = true; diff --git a/plugins/builtin/source/content/command_line_interface.cpp b/plugins/builtin/source/content/command_line_interface.cpp index e34bda651..208ffd47c 100644 --- a/plugins/builtin/source/content/command_line_interface.cpp +++ b/plugins/builtin/source/content/command_line_interface.cpp @@ -1,4 +1,6 @@ +#include #include +#include #include #include @@ -530,6 +532,14 @@ namespace hex::plugin::builtin { ContentRegistry::Views::setFullScreenView(path); } + void handleMCPCommand(const std::vector &) { + mcp::Client client; + + auto result = client.run(std::cin, std::cout); + std::fprintf(stderr, "MCP Client disconnected!\n"); + std::exit(result); + } + void registerCommandForwarders() { hex::subcommands::registerSubCommand("open", [](const std::vector &args){ diff --git a/plugins/builtin/source/content/settings_entries.cpp b/plugins/builtin/source/content/settings_entries.cpp index 36c2b8306..356b991d8 100644 --- a/plugins/builtin/source/content/settings_entries.cpp +++ b/plugins/builtin/source/content/settings_entries.cpp @@ -768,6 +768,8 @@ namespace hex::plugin::builtin { ContentRegistry::Settings::add("hex.builtin.setting.general", "hex.builtin.setting.general.network", "hex.builtin.setting.general.network_interface", false); + ContentRegistry::Settings::add("hex.builtin.setting.general", "hex.builtin.setting.general.network", "hex.builtin.setting.general.mcp_server", false); + #if !defined(OS_WEB) ContentRegistry::Settings::add("hex.builtin.setting.general", "hex.builtin.setting.general.network", "hex.builtin.setting.general.server_contact"); ContentRegistry::Settings::add("hex.builtin.setting.general", "hex.builtin.setting.general.network", "hex.builtin.setting.general.upload_crash_logs", true); diff --git a/plugins/builtin/source/content/ui_items.cpp b/plugins/builtin/source/content/ui_items.cpp index 1ccce8d72..8f00e4c9e 100644 --- a/plugins/builtin/source/content/ui_items.cpp +++ b/plugins/builtin/source/content/ui_items.cpp @@ -29,6 +29,7 @@ #include #include +#include namespace hex::plugin::builtin { @@ -239,6 +240,20 @@ namespace hex::plugin::builtin { }); } + ContentRegistry::UserInterface::addFooterItem([] { + if (ContentRegistry::MCP::isConnected()) { + ImGui::PushStyleColor(ImGuiCol_Text, ImGuiExt::GetCustomColorU32(ImGuiCustomCol_Highlight)); + } else { + ImGui::PushStyleColor(ImGuiCol_Text, ImGui::GetColorU32(ImGuiCol_TextDisabled)); + } + + if (ContentRegistry::MCP::isEnabled()) { + ImGui::TextUnformatted(ICON_VS_MCP); + } + + ImGui::PopStyleColor(); + }); + if (dbg::debugModeEnabled()) { ContentRegistry::UserInterface::addFooterItem([] { static float framerate = 0; diff --git a/plugins/builtin/source/plugin_builtin.cpp b/plugins/builtin/source/plugin_builtin.cpp index 17145dae6..99e3a245e 100644 --- a/plugins/builtin/source/plugin_builtin.cpp +++ b/plugins/builtin/source/plugin_builtin.cpp @@ -73,8 +73,8 @@ IMHEX_PLUGIN_SUBCOMMANDS() { { "open", "o", "Open files passed as argument. [default]", hex::plugin::builtin::handleOpenCommand }, { "new", "n", "Create a new empty file", hex::plugin::builtin::handleNewCommand }, - { "select", "s", "Select a range of bytes in the Hex Editor", hex::plugin::builtin::handleSelectCommand }, - { "pattern", "p", "Sets the loaded pattern", hex::plugin::builtin::handlePatternCommand }, + { "select", "s", "Select a range of bytes in the Hex Editor", hex::plugin::builtin::handleSelectCommand }, + { "pattern", "p", "Sets the loaded pattern", hex::plugin::builtin::handlePatternCommand }, { "calc", "", "Evaluate a mathematical expression", hex::plugin::builtin::handleCalcCommand }, { "hash", "", "Calculate the hash of a file", hex::plugin::builtin::handleHashCommand }, { "encode", "", "Encode a string", hex::plugin::builtin::handleEncodeCommand }, @@ -88,6 +88,7 @@ IMHEX_PLUGIN_SUBCOMMANDS() { { "validate-plugin", "", "Validates that a plugin can be loaded", hex::plugin::builtin::handleValidatePluginCommand }, { "save-editor", "", "Opens a pattern file for save file editing", hex::plugin::builtin::handleSaveEditorCommand }, { "file-info", "i", "Displays information about a file", hex::plugin::builtin::handleFileInfoCommand }, + { "mcp", "", "Starts a MCP Server for AI to interact with", hex::plugin::builtin::handleMCPCommand }, }; IMHEX_PLUGIN_SETUP_BUILTIN("Built-in", "WerWolv", "Default ImHex functionality") {