From fe36687366dfdf20b1e148ec49120b8a099b5fc7 Mon Sep 17 00:00:00 2001 From: Tyler Wilding Date: Tue, 26 Apr 2022 21:32:46 -0400 Subject: [PATCH] common: make a common interface for creating a server socket --- common/CMakeLists.txt | 6 +- common/cross_os_debug/xdbg.cpp | 1 + common/cross_os_debug/xdbg.h | 1 + common/cross_sockets/XSocketServer.cpp | 134 +++++++++++++++ common/cross_sockets/XSocketServer.h | 50 ++++++ common/cross_sockets/xsocket.cpp | 33 +++- common/cross_sockets/xsocket.h | 7 +- common/global_profiler/GlobalProfiler.cpp | 1 + common/log/log.cpp | 2 + common/util/FileUtil.cpp | 2 + common/util/FrameLimiter.cpp | 3 +- common/util/Timer.cpp | 2 + game/kernel/kboot.cpp | 2 + game/runtime.cpp | 10 +- game/system/Deci2Server.cpp | 196 ++++------------------ game/system/Deci2Server.h | 52 ++---- goalc/compiler/nrepl/ReplServer.cpp | 130 +++++++++----- goalc/compiler/nrepl/ReplServer.h | 39 ++--- test/test_listener_deci2.cpp | 70 ++++---- 19 files changed, 416 insertions(+), 325 deletions(-) create mode 100644 common/cross_sockets/XSocketServer.cpp create mode 100644 common/cross_sockets/XSocketServer.h diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index c11ce8d6f2..c867ac74eb 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -1,7 +1,8 @@ add_library(common audio/audio_formats.cpp cross_os_debug/xdbg.cpp - cross_sockets/xsocket.cpp + cross_sockets/XSocket.cpp + cross_sockets/XSocketServer.cpp custom_data/TFrag3Data.cpp dma/dma.cpp dma/dma_copy.cpp @@ -44,8 +45,7 @@ add_library(common util/FrameLimiter.cpp util/image_loading.cpp goos/Printer.cpp - goos/PrettyPrinter2.cpp - ) + goos/PrettyPrinter2.cpp) target_link_libraries(common fmt lzokay replxx libzstd_static) diff --git a/common/cross_os_debug/xdbg.cpp b/common/cross_os_debug/xdbg.cpp index 0a336115be..ba591240ed 100644 --- a/common/cross_os_debug/xdbg.cpp +++ b/common/cross_os_debug/xdbg.cpp @@ -22,6 +22,7 @@ #include #elif _WIN32 #define NOMINMAX +#define WIN32_LEAN_AND_MEAN #include #include #include diff --git a/common/cross_os_debug/xdbg.h b/common/cross_os_debug/xdbg.h index 479c4b0133..4629fd54f4 100644 --- a/common/cross_os_debug/xdbg.h +++ b/common/cross_os_debug/xdbg.h @@ -14,6 +14,7 @@ #include #elif _WIN32 #define NOMINMAX +#define WIN32_LEAN_AND_MEAN #include #endif diff --git a/common/cross_sockets/XSocketServer.cpp b/common/cross_sockets/XSocketServer.cpp new file mode 100644 index 0000000000..5a70e85af0 --- /dev/null +++ b/common/cross_sockets/XSocketServer.cpp @@ -0,0 +1,134 @@ +#include "XSocketServer.h" + +#include "third-party/fmt/core.h" + +#include "common/cross_sockets/XSocket.h" + +#ifdef _WIN32 +#define NOMINMAX +#define WIN32_LEAN_AND_MEAN +#include +#include +#include +#endif + +XSocketServer::XSocketServer(std::function shutdown_callback, + int _tcp_port, + int _buffer_size) + : want_exit_callback(std::move(shutdown_callback)) { + tcp_port = _tcp_port; + buffer_size = _buffer_size; + buffer = new char[_buffer_size]; +} + +XSocketServer::~XSocketServer() { + shutdown_server(); +} + +void XSocketServer::shutdown_server() { + // Close the listening and accepted socket socket + close_server_socket(); + close_socket(accepted_socket); + + // If the accept thread is still running (nothing ever connected) + // kill it and clean it up + if (accept_thread_running) { + kill_accept_thread = true; + accept_thread.join(); + accept_thread_running = false; + } + + // Cleanup our buffer + delete[] buffer; +} + +bool XSocketServer::init_server() { + listening_socket = open_socket(AF_INET, SOCK_STREAM, 0); + if (listening_socket < 0) { + listening_socket = -1; + return false; + } + +#ifdef __linux + int server_socket_opt = SO_REUSEADDR | SO_REUSEPORT; +#elif _WIN32 + int server_socket_opt = SO_EXCLUSIVEADDRUSE; +#endif + + int opt = 1; + if (set_socket_option(listening_socket, SOL_SOCKET, server_socket_opt, &opt, sizeof(opt)) < 0) { + close_server_socket(); + return false; + }; + + if (set_socket_option(listening_socket, TCP_SOCKET_LEVEL, TCP_NODELAY, &opt, sizeof(opt)) < 0) { + close_server_socket(); + return false; + } + + if (set_socket_timeout(listening_socket, 100000) < 0) { + close_server_socket(); + return false; + } + + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = INADDR_ANY; + addr.sin_port = htons(tcp_port); + + if (bind(listening_socket, (sockaddr*)&addr, sizeof(addr)) < 0) { + fmt::print("[XSocketServer:{}] failed to bind\n", tcp_port); + close_server_socket(); + return false; + } + + if (listen(listening_socket, 0) < 0) { + fmt::print("[XSocketServer:{}] failed to listen\n", tcp_port); + close_server_socket(); + return false; + } + + server_initialized = true; + accept_thread_running = true; + kill_accept_thread = false; + accept_thread = std::thread(&XSocketServer::accept_thread_func, this); + fmt::print("[XSocketServer:{}] awaiting connections\n", tcp_port); + return true; +} + +void XSocketServer::close_server_socket() { + close_socket(listening_socket); + listening_socket = -1; +} + +void XSocketServer::accept_thread_func() { + socklen_t l = sizeof(addr); + while (!kill_accept_thread) { + accepted_socket = accept_socket(listening_socket, (sockaddr*)&addr, &l); + if (accepted_socket >= 0) { + set_socket_timeout(accepted_socket, 100000); + write_on_accept(); + client_connected = true; + return; + } + } +} + +bool XSocketServer::wait_for_connection() { + if (client_connected) { + if (accept_thread_running) { + accept_thread.join(); + accept_thread_running = false; + } + return true; + } else { + return false; + } +} + +void XSocketServer::lock() { + server_mutex.lock(); +} + +void XSocketServer::unlock() { + server_mutex.unlock(); +} diff --git a/common/cross_sockets/XSocketServer.h b/common/cross_sockets/XSocketServer.h new file mode 100644 index 0000000000..3240e93dca --- /dev/null +++ b/common/cross_sockets/XSocketServer.h @@ -0,0 +1,50 @@ +#pragma once + +#include "common/cross_sockets/XSocket.h" + +#include +#include "common/common_types.h" +#include +#include + +/// @brief A cross platform generic socket server implementation +class XSocketServer { + public: + static constexpr int DEF_BUFFER_SIZE = 32 * 1024 * 1024; + XSocketServer(std::function shutdown_callback, + int _tcp_port, + int _buffer_size = DEF_BUFFER_SIZE); + ~XSocketServer(); + bool init_server(); + void shutdown_server(); + void close_server_socket(); + + bool wait_for_connection(); + void lock(); + void unlock(); + + // Abstract methods -- use-case dependent + virtual void write_on_accept() = 0; + virtual void read_data() = 0; + virtual void send_data(void* buf, u16 len) = 0; + + protected: + int buffer_size; + int tcp_port; + struct sockaddr_in addr = {}; + int listening_socket = -1; + int accepted_socket = -1; + char* buffer = nullptr; + + bool kill_accept_thread = false; + bool server_initialized = false; + bool accept_thread_running = false; + bool client_connected = false; + + std::function want_exit_callback; + std::thread accept_thread; + + std::mutex server_mutex; + + void accept_thread_func(); +}; diff --git a/common/cross_sockets/xsocket.cpp b/common/cross_sockets/xsocket.cpp index 14979020da..be31cdb5b1 100644 --- a/common/cross_sockets/xsocket.cpp +++ b/common/cross_sockets/xsocket.cpp @@ -16,6 +16,8 @@ #include #include +#include "third-party/fmt/core.h" + int open_socket(int af, int type, int protocol) { #ifdef __linux return socket(af, type, protocol); @@ -32,6 +34,20 @@ int open_socket(int af, int type, int protocol) { #endif } +int accept_socket(int socket, sockaddr* addr, int* addrLen) { +#ifdef _WIN32 + WSADATA wsaData = {0}; + int iResult = 0; + // Initialize Winsock + iResult = WSAStartup(MAKEWORD(2, 2), &wsaData); + if (iResult != 0) { + printf("WSAStartup failed: %d\n", iResult); + return 1; + } +#endif + return accept(socket, addr, addrLen); +} + void close_socket(int sock) { if (sock < 0) { return; @@ -77,19 +93,26 @@ int set_socket_timeout(int socket, long microSeconds) { } int write_to_socket(int socket, const char* buf, int len) { + int bytes_wrote; #ifdef __linux - return write(socket, buf, len); + bytes_wrote = write(socket, buf, len); #elif _WIN32 - return send(socket, buf, len, 0); + bytes_wrote = send(socket, buf, len, 0); #endif + if (bytes_wrote < 0) { + fmt::print(stderr, "[XSocket:{}] Error writing to socket\n", socket); + } + return bytes_wrote; } int read_from_socket(int socket, char* buf, int len) { + int bytes_read; #ifdef __linux - return read(socket, buf, len); + bytes_read = read(socket, buf, len); #elif _WIN32 - return recv(socket, buf, len, 0); + bytes_read = recv(socket, buf, len, 0); #endif + return bytes_read; } bool socket_timed_out() { @@ -99,4 +122,4 @@ bool socket_timed_out() { auto err = WSAGetLastError(); return err == WSAETIMEDOUT; #endif -} \ No newline at end of file +} diff --git a/common/cross_sockets/xsocket.h b/common/cross_sockets/xsocket.h index b6d99b980e..bd18776be7 100644 --- a/common/cross_sockets/xsocket.h +++ b/common/cross_sockets/xsocket.h @@ -9,7 +9,11 @@ #include #include #include +#include #elif _WIN32 +#define NOMINMAX +#define WIN32_LEAN_AND_MEAN +#include #include #endif @@ -20,9 +24,10 @@ const int TCP_SOCKET_LEVEL = IPPROTO_TCP; #endif int open_socket(int af, int type, int protocol); +int accept_socket(int socket, sockaddr* addr, int* addrLen); void close_socket(int sock); int set_socket_option(int socket, int level, int optname, const void* optval, int optlen); int set_socket_timeout(int socket, long microSeconds); int write_to_socket(int socket, const char* buf, int len); int read_from_socket(int socket, char* buf, int len); -bool socket_timed_out(); \ No newline at end of file +bool socket_timed_out(); diff --git a/common/global_profiler/GlobalProfiler.cpp b/common/global_profiler/GlobalProfiler.cpp index c2c509d43d..3e713a7f5e 100644 --- a/common/global_profiler/GlobalProfiler.cpp +++ b/common/global_profiler/GlobalProfiler.cpp @@ -14,6 +14,7 @@ u32 get_current_tid() { } #else #define NOMINMAX +#define WIN32_LEAN_AND_MEAN #include #include "Processthreadsapi.h" u32 get_current_tid() { diff --git a/common/log/log.cpp b/common/log/log.cpp index 3d6b95b0f2..4539bb91da 100644 --- a/common/log/log.cpp +++ b/common/log/log.cpp @@ -4,6 +4,8 @@ #include "third-party/fmt/color.h" #include "log.h" #ifdef _WIN32 // see lg::initialize +#define NOMINMAX +#define WIN32_LEAN_AND_MEAN #include #endif #include "common/util/Assert.h" diff --git a/common/util/FileUtil.cpp b/common/util/FileUtil.cpp index 5f93c3d18b..dc4c73c86b 100644 --- a/common/util/FileUtil.cpp +++ b/common/util/FileUtil.cpp @@ -24,6 +24,8 @@ #include "third-party/lzokay/lzokay.hpp" #ifdef _WIN32 +#define NOMINMAX +#define WIN32_LEAN_AND_MEAN #include #else #include diff --git a/common/util/FrameLimiter.cpp b/common/util/FrameLimiter.cpp index 136ca6aae1..a4f6ed5732 100644 --- a/common/util/FrameLimiter.cpp +++ b/common/util/FrameLimiter.cpp @@ -41,6 +41,7 @@ void FrameLimiter::run(double target_fps, #else +#define NOMINMAX #include FrameLimiter::FrameLimiter() { @@ -74,4 +75,4 @@ void FrameLimiter::run(double target_fps, m_timer.start(); } -#endif \ No newline at end of file +#endif diff --git a/common/util/Timer.cpp b/common/util/Timer.cpp index 1195f29906..6050c990a0 100644 --- a/common/util/Timer.cpp +++ b/common/util/Timer.cpp @@ -1,6 +1,8 @@ #include "Timer.h" #ifdef _WIN32 +#define NOMINMAX +#define WIN32_LEAN_AND_MEAN #include #define MS_PER_SEC 1000ULL // MS = milliseconds #define US_PER_MS 1000ULL // US = microseconds diff --git a/game/kernel/kboot.cpp b/game/kernel/kboot.cpp index ef65bef02e..94b3de9482 100644 --- a/game/kernel/kboot.cpp +++ b/game/kernel/kboot.cpp @@ -22,6 +22,8 @@ #include "kprint.h" #ifdef _WIN32 +#define NOMINMAX +#define WIN32_LEAN_AND_MEAN #include "Windows.h" #include #elif __linux__ diff --git a/game/runtime.cpp b/game/runtime.cpp index bcd433eccd..27514391d9 100644 --- a/game/runtime.cpp +++ b/game/runtime.cpp @@ -9,6 +9,8 @@ #elif _WIN32 #include #include "third-party/mman/mman.h" +#define NOMINMAX +#define WIN32_LEAN_AND_MEAN #include #endif @@ -74,7 +76,7 @@ void deci2_runner(SystemThreadInterface& iface) { std::function shutdown_callback = [&]() { return iface.get_want_exit(); }; // create and register server - Deci2Server server(shutdown_callback); + Deci2Server server(shutdown_callback, DECI2_PORT); ee::LIBRARY_sceDeci2_register(&server); // now its ok to continue with initialization @@ -84,20 +86,20 @@ void deci2_runner(SystemThreadInterface& iface) { lg::debug("[DECI2] Waiting for EE to register protos"); server.wait_for_protos_ready(); // then allow the server to accept connections - if (!server.init()) { + if (!server.init_server()) { ASSERT(false); } lg::debug("[DECI2] Waiting for listener..."); bool saw_listener = false; while (!iface.get_want_exit()) { - if (server.check_for_listener()) { + if (server.wait_for_connection()) { if (!saw_listener) { lg::debug("[DECI2] Connected!"); } saw_listener = true; // we have a listener, run! - server.run(); + server.read_data(); } else { // no connection yet. Do a sleep so we don't spam checking the listener. std::this_thread::sleep_for(std::chrono::microseconds(50000)); diff --git a/game/system/Deci2Server.cpp b/game/system/Deci2Server.cpp index a67d5cbe56..82143d7426 100644 --- a/game/system/Deci2Server.cpp +++ b/game/system/Deci2Server.cpp @@ -4,154 +4,19 @@ * Works with deci2.cpp (sceDeci2) to implement the networking on target */ -#include -#include - -// TODO - i think im not including the dependency right..? -#include "common/cross_sockets/xsocket.h" - -#ifdef __linux -#include -#include -#include -#elif _WIN32 -#include -#include -#include -#endif - -#include "common/listener_common.h" -#include "common/versions.h" #include "Deci2Server.h" -#include "common/util/Assert.h" -Deci2Server::Deci2Server(std::function shutdown_callback) - : want_exit(std::move(shutdown_callback)) { - buffer = new char[BUFFER_SIZE]; -} +#include "common/cross_sockets/XSocket.h" -Deci2Server::~Deci2Server() { - close_server_socket(); - close_socket(new_sock); +#include "common/versions.h" +#include +#include - // if accept thread is running, kill it - if (accept_thread_running) { - kill_accept_thread = true; - accept_thread.join(); - accept_thread_running = false; - } +#include "third-party/fmt/core.h" - delete[] buffer; -} - -/*! - * Start waiting for the Listener to connect - */ -bool Deci2Server::init() { - server_socket = open_socket(AF_INET, SOCK_STREAM, 0); - if (server_socket < 0) { - server_socket = -1; - return false; - } - -#ifdef __linux - int server_socket_opt = SO_REUSEADDR | SO_REUSEPORT; -#elif _WIN32 - int server_socket_opt = SO_EXCLUSIVEADDRUSE; -#endif - - int opt = 1; - if (set_socket_option(server_socket, SOL_SOCKET, server_socket_opt, &opt, sizeof(opt)) < 0) { - close_server_socket(); - return false; - }; - - if (set_socket_option(server_socket, TCP_SOCKET_LEVEL, TCP_NODELAY, &opt, sizeof(opt)) < 0) { - close_server_socket(); - return false; - } - - if (set_socket_timeout(server_socket, 100000) < 0) { - close_server_socket(); - return false; - } - - addr.sin_family = AF_INET; - addr.sin_addr.s_addr = INADDR_ANY; - addr.sin_port = htons(DECI2_PORT); - - if (bind(server_socket, (sockaddr*)&addr, sizeof(addr)) < 0) { - printf("[Deci2Server] Failed to bind\n"); - close_server_socket(); - return false; - } - - if (listen(server_socket, 0) < 0) { - printf("[Deci2Server] Failed to listen\n"); - close_server_socket(); - return false; - } - - server_initialized = true; - accept_thread_running = true; - kill_accept_thread = false; - accept_thread = std::thread(&Deci2Server::accept_thread_func, this); - return true; -} - -void Deci2Server::close_server_socket() { - close_socket(server_socket); - server_socket = -1; -} - -/*! - * Return true if the listener is connected. - */ -bool Deci2Server::check_for_listener() { - if (server_connected) { - if (accept_thread_running) { - accept_thread.join(); - accept_thread_running = false; - } - return true; - } else { - return false; - } -} - -/*! - * Send data from buffer. User must provide appropriate headers. - */ -void Deci2Server::send_data(void* buf, u16 len) { - lock(); - if (!server_connected) { - printf("[DECI2] send while not connected, not sending!\n"); - } else { - uint16_t prog = 0; - while (prog < len) { - int wrote = write_to_socket(new_sock, (char*)(buf) + prog, len - prog); - prog += wrote; - if (!server_connected || want_exit()) { - unlock(); - return; - } - } - } - unlock(); -} - -/*! - * Lock the DECI mutex. Should be done before modifying protocols. - */ -void Deci2Server::lock() { - deci_mutex.lock(); -} - -/*! - * Unlock the DECI mutex. Should be done after modifying protocols. - */ -void Deci2Server::unlock() { - deci_mutex.unlock(); +void Deci2Server::write_on_accept() { + u32 versions[2] = {versions::GOAL_VERSION_MAJOR, versions::GOAL_VERSION_MINOR}; + write_to_socket(accepted_socket, (char*)&versions, 8); } /*! @@ -161,7 +26,7 @@ void Deci2Server::unlock() { void Deci2Server::wait_for_protos_ready() { if (protocols_ready) return; - std::unique_lock lk(deci_mutex); + std::unique_lock lk(server_mutex); cv.wait(lk, [&] { return protocols_ready; }); } @@ -180,14 +45,14 @@ void Deci2Server::send_proto_ready(Deci2Driver* drivers, int* driver_count) { cv.notify_all(); } -void Deci2Server::run() { +void Deci2Server::read_data() { int desired_size = (int)sizeof(Deci2Header); int got = 0; while (got < desired_size) { - ASSERT(got + desired_size < BUFFER_SIZE); - auto x = read_from_socket(new_sock, buffer + got, desired_size - got); - if (want_exit()) { + ASSERT(got + desired_size < buffer_size); + auto x = read_from_socket(accepted_socket, buffer + got, desired_size - got); + if (want_exit_callback()) { return; } got += x > 0 ? x : 0; @@ -222,7 +87,7 @@ void Deci2Server::run() { auto& driver = d2_drivers[handler]; u32 sent_to_program = 0; - while (!want_exit() && (hdr->rsvd < hdr->len || sent_to_program < hdr->rsvd)) { + while (!want_exit_callback() && (hdr->rsvd < hdr->len || sent_to_program < hdr->rsvd)) { // send what we have to the program if (sent_to_program < hdr->rsvd) { // driver.next_recv_size = 0; @@ -236,8 +101,8 @@ void Deci2Server::run() { // receive from network if (hdr->rsvd < hdr->len) { - auto x = read_from_socket(new_sock, buffer + hdr->rsvd, hdr->len - hdr->rsvd); - if (want_exit()) { + auto x = read_from_socket(accepted_socket, buffer + hdr->rsvd, hdr->len - hdr->rsvd); + if (want_exit_callback()) { return; } got += x > 0 ? x : 0; @@ -249,21 +114,20 @@ void Deci2Server::run() { unlock(); } -/*! - * Background thread for waiting for the listener. - */ -void Deci2Server::accept_thread_func() { - socklen_t l = sizeof(addr); - while (!kill_accept_thread) { - // TODO - might want to do a WSAStartUp call here as well, else it won't be balanced on the - // close - new_sock = accept(server_socket, (sockaddr*)&addr, &l); - if (new_sock >= 0) { - set_socket_timeout(new_sock, 100000); - u32 versions[2] = {versions::GOAL_VERSION_MAJOR, versions::GOAL_VERSION_MINOR}; - write_to_socket(new_sock, (char*)&versions, 8); // todo, check result? - server_connected = true; - return; +void Deci2Server::send_data(void* buf, u16 len) { + lock(); + if (!client_connected) { + printf("[DECI2] send while not connected, not sending!\n"); + } else { + uint16_t prog = 0; + while (prog < len) { + int wrote = write_to_socket(accepted_socket, (char*)(buf) + prog, len - prog); + prog += wrote; + if (!client_connected || want_exit_callback()) { + unlock(); + return; + } } } + unlock(); } diff --git a/game/system/Deci2Server.h b/game/system/Deci2Server.h index 5f11188057..a5d54fedf1 100644 --- a/game/system/Deci2Server.h +++ b/game/system/Deci2Server.h @@ -1,55 +1,25 @@ #pragma once -/*! - * @file Deci2Server.h - * Basic implementation of a DECI2 server. - * Works with deci2.cpp (sceDeci2) to implement the networking on target - */ +#include "common/cross_sockets/XSocketServer.h" -#ifdef __linux -#include -#elif _WIN32 -#include -#endif -#include -#include -#include -#include -#include "game/system/deci_common.h" +#include "deci_common.h" -class Deci2Server { +/// @brief Basic implementation of a DECI2 server. +/// Works with deci2.cpp(sceDeci2) to implement the networking on target +class Deci2Server : public XSocketServer { public: - static constexpr int BUFFER_SIZE = 32 * 1024 * 1024; - Deci2Server(std::function shutdown_callback); - ~Deci2Server(); - bool init(); - bool check_for_listener(); - void send_data(void* buf, u16 len); + using XSocketServer::XSocketServer; + + void write_on_accept() override; + void read_data() override; + void send_data(void* buf, u16 len) override; - void lock(); - void unlock(); void wait_for_protos_ready(); void send_proto_ready(Deci2Driver* drivers, int* driver_count); - void run(); - private: - void close_server_socket(); - void accept_thread_func(); - bool kill_accept_thread = false; - char* buffer = nullptr; - int server_socket = -1; - struct sockaddr_in addr = {}; - int new_sock = -1; - bool server_initialized = false; - bool accept_thread_running = false; - bool server_connected = false; - std::function want_exit; - std::thread accept_thread; - - std::condition_variable cv; bool protocols_ready = false; - std::mutex deci_mutex; + std::condition_variable cv; Deci2Driver* d2_drivers = nullptr; int* d2_driver_count = nullptr; }; diff --git a/goalc/compiler/nrepl/ReplServer.cpp b/goalc/compiler/nrepl/ReplServer.cpp index 338b906fb5..b5e80cc90f 100644 --- a/goalc/compiler/nrepl/ReplServer.cpp +++ b/goalc/compiler/nrepl/ReplServer.cpp @@ -1,6 +1,9 @@ #include "ReplServer.h" +#include "common/cross_sockets/XSocket.h" + #include "third-party/fmt/core.h" +#include // TODO - basically REPL to listen and inject commands into a running REPL // - we will need a C++ side client as well which will let us communicate with the repl via for @@ -8,53 +11,90 @@ // // TODO - The server also needs to eventually return the result of the evaluation -ReplSession::ReplSession(tcp::socket socket, Compiler* repl) : socket_(std::move(socket)) { - m_repl = repl; +// Known Issues: +// - doesn't handle disconnects/reconnects + +void ReplServer::write_on_accept() { + ping_response(); } -void ReplSession::start() { - fmt::print("[nREPL]: Client Connected!\n\r"); - do_read(); -} +void ReplServer::read_data() { + int desired_size = 1; + int got = 0; -void ReplSession::do_read() { - auto self(shared_from_this()); - socket_.async_read_some(asio::buffer(data_, max_length), - [this, self](std::error_code ec, std::size_t length) { - if (!ec) { - auto input = std::string(data_, length); - if (!input.empty()) { - m_repl->read_eval_print(input); - } - // TODO - i think this is kinda a hack, but its to keep the server - // cycling - do_write(0); - } - }); -} - -void ReplSession::do_write(std::size_t length) { - auto self(shared_from_this()); - asio::async_write(socket_, asio::buffer(data_, length), - [this, self](std::error_code ec, std::size_t /*length*/) { - if (!ec) { - do_read(); - } - }); -} - -ReplServer::ReplServer(asio::io_context& io_context, Compiler* repl) - : acceptor_(io_context, tcp::endpoint(tcp::v4(), repl->m_nrepl_port)), socket_(io_context) { - m_repl = repl; - m_port = repl->m_nrepl_port; - do_accept(); -} - -void ReplServer::do_accept() { - acceptor_.async_accept(socket_, [this](std::error_code ec) { - if (!ec) { - std::make_shared(std::move(socket_), this->m_repl)->start(); + while (got < desired_size) { + ASSERT(got + desired_size < buffer_size); + int sock = accepted_socket; + auto x = read_from_socket(sock, buffer + got, desired_size - got); + if (want_exit_callback()) { + return; } - do_accept(); - }); + got += x > 0 ? x : 0; + } + + auto* header = (ReplServerHeader*)(buffer); + + lock(); + + // get the body of the message + desired_size = header->length; + got = 0; + while (got < desired_size) { + ASSERT(got + desired_size < buffer_size); + auto x = read_from_socket(accepted_socket, buffer + got, desired_size - got); + if (want_exit_callback()) { + return; + } + got += x > 0 ? x : 0; + } + + auto* body = (char*)(buffer); + + switch (header->type) { + case ReplServerMessageType::PING: + ping_response(); + break; + case ReplServerMessageType::EVAL: + compile_msg("(repl-help)"); + break; + } + + unlock(); +} + +void ReplServer::send_data(void* buf, u16 len) { + lock(); + if (client_connected) { + int bytes_sent = 0; + while (bytes_sent < len) { + int wrote = write_to_socket(accepted_socket, (char*)(buf) + bytes_sent, len - bytes_sent); + bytes_sent += wrote; + if (!client_connected || want_exit_callback()) { + unlock(); + return; + } + } + } + unlock(); +} + +void ReplServer::set_compiler(std::shared_ptr _compiler) { + compiler = std::move(_compiler); +} + +void ReplServer::ping_response() { + u32 versions[2] = {versions::GOAL_VERSION_MAJOR, versions::GOAL_VERSION_MINOR}; + char* ye = "sanity"; + lock(); + write_to_socket(accepted_socket, (char*)&ye, 6); + unlock(); +} + +void ReplServer::compile_msg(const std::string_view& msg) { + if (compiler == nullptr) { + return; + } + compiler->lock(); + compiler->read_eval_print(msg.data()); + compiler->unlock(); } diff --git a/goalc/compiler/nrepl/ReplServer.h b/goalc/compiler/nrepl/ReplServer.h index 5190552fd5..2aae8cc2fb 100644 --- a/goalc/compiler/nrepl/ReplServer.h +++ b/goalc/compiler/nrepl/ReplServer.h @@ -1,38 +1,29 @@ #pragma once -#include -#include +#include "common/cross_sockets/XSocketServer.h" #include "goalc/compiler/Compiler.h" -using asio::ip::tcp; +enum ReplServerMessageType { PING = 0, EVAL = 10, SHUTDOWN = 20 }; -class ReplSession : public std::enable_shared_from_this { - public: - ReplSession(tcp::socket socket, Compiler* repl); - - void start(); - - private: - Compiler* m_repl; - tcp::socket socket_; - enum { max_length = 1024 * 20 }; - char data_[max_length]; - - void do_read(); - void do_write(std::size_t length); +struct ReplServerHeader { + u32 length; + u32 type; }; -class ReplServer { +class ReplServer : public XSocketServer { public: - ReplServer(asio::io_context& io_context, Compiler* repl); + using XSocketServer::XSocketServer; - int m_port; + void write_on_accept() override; + void read_data() override; + void send_data(void* buf, u16 len) override; + + void set_compiler(std::shared_ptr _compiler); private: - Compiler* m_repl; - tcp::acceptor acceptor_; - tcp::socket socket_; + std::shared_ptr compiler = nullptr; - void do_accept(); + void ping_response(); + void compile_msg(const std::string_view& msg); }; diff --git a/test/test_listener_deci2.cpp b/test/test_listener_deci2.cpp index 6037806977..45cd828496 100644 --- a/test/test_listener_deci2.cpp +++ b/test/test_listener_deci2.cpp @@ -15,12 +15,12 @@ TEST(Listener, ListenerCreation) { } TEST(Listener, DeciCreation) { - Deci2Server s(always_false); + Deci2Server s(always_false, DECI2_PORT); } TEST(Listener, DeciInit) { - Deci2Server s(always_false); - EXPECT_TRUE(s.init()); + Deci2Server s(always_false, DECI2_PORT); + EXPECT_TRUE(s.init_server()); } /*! @@ -38,61 +38,61 @@ TEST(Listener, ListenToNothing) { } TEST(Listener, DeciCheckNoListener) { - Deci2Server s(always_false); - EXPECT_TRUE(s.init()); - EXPECT_FALSE(s.check_for_listener()); - EXPECT_FALSE(s.check_for_listener()); - EXPECT_FALSE(s.check_for_listener()); + Deci2Server s(always_false, DECI2_PORT); + EXPECT_TRUE(s.init_server()); + EXPECT_FALSE(s.wait_for_connection()); + EXPECT_FALSE(s.wait_for_connection()); + EXPECT_FALSE(s.wait_for_connection()); } TEST(Listener, CheckConnectionStaysAlive) { - Deci2Server s(always_false); - EXPECT_TRUE(s.init()); - EXPECT_FALSE(s.check_for_listener()); + Deci2Server s(always_false, DECI2_PORT); + EXPECT_TRUE(s.init_server()); + EXPECT_FALSE(s.wait_for_connection()); Listener l; - EXPECT_FALSE(s.check_for_listener()); + EXPECT_FALSE(s.wait_for_connection()); bool connected = l.connect_to_target(); EXPECT_TRUE(connected); // TODO - some sort of backoff and retry would be better - while (connected && !s.check_for_listener()) { + while (connected && !s.wait_for_connection()) { } - EXPECT_TRUE(s.check_for_listener()); + EXPECT_TRUE(s.wait_for_connection()); std::this_thread::sleep_for(std::chrono::milliseconds(500)); - EXPECT_TRUE(s.check_for_listener()); + EXPECT_TRUE(s.wait_for_connection()); EXPECT_TRUE(l.is_connected()); } TEST(Listener, DeciThenListener) { for (int i = 0; i < 3; i++) { - Deci2Server s(always_false); - EXPECT_TRUE(s.init()); - EXPECT_FALSE(s.check_for_listener()); - EXPECT_FALSE(s.check_for_listener()); - EXPECT_FALSE(s.check_for_listener()); + Deci2Server s(always_false, DECI2_PORT); + EXPECT_TRUE(s.init_server()); + EXPECT_FALSE(s.wait_for_connection()); + EXPECT_FALSE(s.wait_for_connection()); + EXPECT_FALSE(s.wait_for_connection()); Listener l; - EXPECT_FALSE(s.check_for_listener()); - EXPECT_FALSE(s.check_for_listener()); + EXPECT_FALSE(s.wait_for_connection()); + EXPECT_FALSE(s.wait_for_connection()); bool connected = l.connect_to_target(); EXPECT_TRUE(connected); // TODO - some sort of backoff and retry would be better - while (connected && !s.check_for_listener()) { + while (connected && !s.wait_for_connection()) { } - EXPECT_TRUE(s.check_for_listener()); + EXPECT_TRUE(s.wait_for_connection()); } } TEST(Listener, DeciThenListener2) { for (int i = 0; i < 3; i++) { - Deci2Server s(always_false); - EXPECT_TRUE(s.init()); - EXPECT_FALSE(s.check_for_listener()); - EXPECT_FALSE(s.check_for_listener()); - EXPECT_FALSE(s.check_for_listener()); + Deci2Server s(always_false, DECI2_PORT); + EXPECT_TRUE(s.init_server()); + EXPECT_FALSE(s.wait_for_connection()); + EXPECT_FALSE(s.wait_for_connection()); + EXPECT_FALSE(s.wait_for_connection()); Listener l; - EXPECT_FALSE(s.check_for_listener()); - EXPECT_FALSE(s.check_for_listener()); + EXPECT_FALSE(s.wait_for_connection()); + EXPECT_FALSE(s.wait_for_connection()); EXPECT_TRUE(l.connect_to_target()); } } @@ -101,13 +101,13 @@ TEST(Listener, ListenerThenDeci) { for (int i = 0; i < 3; i++) { Listener l; EXPECT_FALSE(l.connect_to_target()); - Deci2Server s(always_false); - EXPECT_TRUE(s.init()); - EXPECT_FALSE(s.check_for_listener()); + Deci2Server s(always_false, DECI2_PORT); + EXPECT_TRUE(s.init_server()); + EXPECT_FALSE(s.wait_for_connection()); bool connected = l.connect_to_target(); EXPECT_TRUE(connected); // TODO - some sort of backoff and retry would be better - while (connected && !s.check_for_listener()) { + while (connected && !s.wait_for_connection()) { } } }