mirror of https://github.com/microsoft/WSL
Clean up localhost relay implementation to not rely on procfs parsing. (#13836)
* Clean up localhost relay implementation to not rely on procfs parsing. * pr feedback --------- Co-authored-by: Ben Hillis <benhill@ntdev.microsoft.com>
This commit is contained in:
parent
7f8422654e
commit
f1e20b21c9
|
|
@ -12,6 +12,8 @@
|
|||
#include <netinet/ip.h>
|
||||
#include <sys/syscall.h>
|
||||
#include <linux/unistd.h>
|
||||
#include <linux/sock_diag.h>
|
||||
#include <linux/inet_diag.h>
|
||||
#include <lxwil.h>
|
||||
#include <linux/if_tun.h>
|
||||
|
||||
|
|
@ -21,6 +23,8 @@
|
|||
#include "SecCompDispatcher.h"
|
||||
#include "seccomp_defs.h"
|
||||
#include "CommandLine.h"
|
||||
#include "NetlinkChannel.h"
|
||||
#include "NetlinkTransactionError.h"
|
||||
|
||||
#define TCP_LISTEN 10
|
||||
|
||||
|
|
@ -145,80 +149,60 @@ void ListenThread(sockaddr_vm hvSocketAddress, int listenSocket)
|
|||
return;
|
||||
}
|
||||
|
||||
std::vector<sockaddr_storage> ParseTcpFile(int family, FILE* file)
|
||||
std::vector<sockaddr_storage> QueryListeningSockets(NetlinkChannel& channel)
|
||||
{
|
||||
char* line = nullptr;
|
||||
auto freeLine = wil::scope_exit([&line]() { free(line); });
|
||||
|
||||
// Skip the first line which contains a header.
|
||||
size_t lineLength = 0;
|
||||
auto bytesRead = getline(&line, &lineLength, file);
|
||||
THROW_LAST_ERROR_IF(bytesRead < 0);
|
||||
|
||||
// Each line contains information about TCP sockets on the system, the fields
|
||||
// we are interested are for sockets that are or have been listening:
|
||||
// 1: Socket address and port number
|
||||
// 3: Socket status
|
||||
std::vector<sockaddr_storage> sockets{};
|
||||
while ((bytesRead = getline(&line, &lineLength, file)) != -1)
|
||||
try
|
||||
{
|
||||
sockaddr_storage sock{};
|
||||
int index = 0;
|
||||
int status = 0;
|
||||
for (char *sp, *field = strtok_r(line, " \n", &sp); field != nullptr; field = strtok_r(NULL, " \n", &sp))
|
||||
inet_diag_req_v2 message{};
|
||||
message.sdiag_protocol = IPPROTO_TCP;
|
||||
message.idiag_states = (1 << TCP_LISTEN);
|
||||
|
||||
auto onMessage = [&](const NetlinkResponse& response) {
|
||||
for (const auto& e : response.Messages<inet_diag_msg>(SOCK_DIAG_BY_FAMILY))
|
||||
{
|
||||
const auto* payload = e.Payload();
|
||||
sockaddr_storage sock{};
|
||||
|
||||
if (payload->idiag_family == AF_INET)
|
||||
{
|
||||
auto* ipv4 = reinterpret_cast<sockaddr_in*>(&sock);
|
||||
ipv4->sin_family = AF_INET;
|
||||
ipv4->sin_addr.s_addr = payload->id.idiag_src[0];
|
||||
ipv4->sin_port = payload->id.idiag_sport;
|
||||
}
|
||||
else if (payload->idiag_family == AF_INET6)
|
||||
{
|
||||
auto* ipv6 = reinterpret_cast<sockaddr_in6*>(&sock);
|
||||
ipv6->sin6_family = AF_INET6;
|
||||
static_assert(sizeof(ipv6->sin6_addr.s6_addr32) == sizeof(payload->id.idiag_src));
|
||||
memcpy(ipv6->sin6_addr.s6_addr32, payload->id.idiag_src, sizeof(ipv6->sin6_addr.s6_addr32));
|
||||
ipv6->sin6_port = payload->id.idiag_sport;
|
||||
}
|
||||
|
||||
sockets.emplace_back(sock);
|
||||
}
|
||||
};
|
||||
|
||||
// Query IPv4 listening sockets.
|
||||
{
|
||||
if (index == 1)
|
||||
{
|
||||
int port;
|
||||
const char* portString = strchr(field, ':');
|
||||
if (portString == nullptr)
|
||||
{
|
||||
break;
|
||||
}
|
||||
portString += 1;
|
||||
port = static_cast<int>(strtol(portString, nullptr, 16));
|
||||
if (port == 0)
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
if (family == AF_INET)
|
||||
{
|
||||
sockaddr_in ipv4Sock{};
|
||||
ipv4Sock.sin_family = family;
|
||||
ipv4Sock.sin_addr.s_addr = strtol(field, nullptr, 16);
|
||||
ipv4Sock.sin_port = port;
|
||||
memcpy(&sock, &ipv4Sock, sizeof(ipv4Sock));
|
||||
}
|
||||
else if (family == AF_INET6)
|
||||
{
|
||||
sockaddr_in6 ipv6Sock{};
|
||||
ipv6Sock.sin6_family = family;
|
||||
ipv6Sock.sin6_port = port;
|
||||
for (int part = 0; part < 4; ++part)
|
||||
{
|
||||
char next[5];
|
||||
next[4] = 0;
|
||||
memcpy(next, field + part * 4, 4);
|
||||
ipv6Sock.sin6_addr.__in6_union.__s6_addr32[part] = strtol(next, nullptr, 16);
|
||||
}
|
||||
memcpy(&sock, &ipv6Sock, sizeof(ipv6Sock));
|
||||
}
|
||||
}
|
||||
else if (index == 3)
|
||||
{
|
||||
status = static_cast<int>(strtol(field, nullptr, 16));
|
||||
break;
|
||||
}
|
||||
|
||||
index += 1;
|
||||
message.sdiag_family = AF_INET;
|
||||
auto transaction = channel.CreateTransaction(message, SOCK_DIAG_BY_FAMILY, NLM_F_DUMP);
|
||||
transaction.Execute(onMessage);
|
||||
}
|
||||
|
||||
if ((status == TCP_LISTEN) && (sock.ss_family != 0))
|
||||
// Query IPv6 listening sockets.
|
||||
{
|
||||
sockets.emplace_back(sock);
|
||||
message.sdiag_family = AF_INET6;
|
||||
auto transaction = channel.CreateTransaction(message, SOCK_DIAG_BY_FAMILY, NLM_F_DUMP);
|
||||
transaction.Execute(onMessage);
|
||||
}
|
||||
}
|
||||
catch (const NetlinkTransactionError& e)
|
||||
{
|
||||
// Log but don't fail - network state might be temporarily unavailable
|
||||
LOG_ERROR("Failed to query listening sockets via sock_diag: {}", e.what());
|
||||
}
|
||||
|
||||
return sockets;
|
||||
}
|
||||
|
|
@ -246,12 +230,12 @@ LX_GNS_PORT_LISTENER_RELAY SockToRelayMessage(const sockaddr_storage& sock)
|
|||
{
|
||||
auto ipv4 = reinterpret_cast<const sockaddr_in*>(&sock);
|
||||
message.Address[0] = ipv4->sin_addr.s_addr;
|
||||
message.Port = ipv4->sin_port;
|
||||
message.Port = ntohs(ipv4->sin_port);
|
||||
}
|
||||
else if (sock.ss_family == AF_INET6)
|
||||
{
|
||||
auto ipv6 = reinterpret_cast<const sockaddr_in6*>(&sock);
|
||||
message.Port = ipv6->sin6_port;
|
||||
message.Port = ntohs(ipv6->sin6_port);
|
||||
memcpy(message.Address, ipv6->sin6_addr.__in6_union.__s6_addr, sizeof(message.Address));
|
||||
}
|
||||
return message;
|
||||
|
|
@ -296,53 +280,23 @@ bool IsSameSockAddr(const sockaddr_storage& left, const sockaddr_storage& right)
|
|||
{
|
||||
auto leftIpv6 = reinterpret_cast<const sockaddr_in6*>(&left);
|
||||
auto rightIpv6 = reinterpret_cast<const sockaddr_in6*>(&right);
|
||||
if (leftIpv6->sin6_port != rightIpv6->sin6_port)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
for (int part = 0; part < 4; ++part)
|
||||
{
|
||||
if (leftIpv6->sin6_addr.__in6_union.__s6_addr32[part] != rightIpv6->sin6_addr.__in6_union.__s6_addr32[part])
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
FATAL_ERROR("Unrecognized socket family {}", left.ss_family);
|
||||
return false;
|
||||
return (leftIpv6->sin6_port == rightIpv6->sin6_port && memcmp(&leftIpv6->sin6_addr, &rightIpv6->sin6_addr, sizeof(in6_addr)) == 0);
|
||||
}
|
||||
|
||||
FATAL_ERROR("Unrecognized socket family {}", left.ss_family);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Start looking for ports bound to localhost or wildcard.
|
||||
int ScanProcNetTCP(wsl::shared::SocketChannel& channel)
|
||||
// Monitor listening TCP sockets using sock_diag netlink interface.
|
||||
int MonitorListeningSockets(wsl::shared::SocketChannel& channel)
|
||||
{
|
||||
// Periodically scan procfs for listening TCP sockets.
|
||||
NetlinkChannel netlinkChannel(SOCK_RAW, NETLINK_SOCK_DIAG);
|
||||
std::vector<sockaddr_storage> relays{};
|
||||
int result = 0;
|
||||
|
||||
for (;;)
|
||||
{
|
||||
std::vector<sockaddr_storage> sockets;
|
||||
wil::unique_file tcp4File{fopen("/proc/net/tcp", "r")};
|
||||
if (tcp4File)
|
||||
{
|
||||
sockets = ParseTcpFile(AF_INET, tcp4File.get());
|
||||
}
|
||||
|
||||
wil::unique_file tcp6File{fopen("/proc/net/tcp6", "r")};
|
||||
if (tcp6File)
|
||||
{
|
||||
auto ipv6Sockets = ParseTcpFile(AF_INET6, tcp6File.get());
|
||||
sockets.insert(sockets.end(), ipv6Sockets.begin(), ipv6Sockets.end());
|
||||
}
|
||||
|
||||
if (!tcp4File && !tcp6File)
|
||||
{
|
||||
LOG_ERROR("Failed to open /proc/net/tcp and /proc/net/tcp6, closing port relay");
|
||||
return 1;
|
||||
}
|
||||
auto sockets = QueryListeningSockets(netlinkChannel);
|
||||
|
||||
// Stop any relays that no longer match listening ports.
|
||||
std::erase_if(relays, [&](const auto& entry) {
|
||||
|
|
@ -386,9 +340,7 @@ int ScanProcNetTCP(wsl::shared::SocketChannel& channel)
|
|||
}
|
||||
|
||||
// Sleep before scanning again.
|
||||
//
|
||||
// TODO: Investigate using EBPF notifications instead of a sleep.
|
||||
sleep(1);
|
||||
std::this_thread::sleep_for(std::chrono::seconds(1));
|
||||
}
|
||||
|
||||
return result;
|
||||
|
|
@ -432,7 +384,7 @@ try
|
|||
|
||||
if (ScanForPorts)
|
||||
{
|
||||
return ScanProcNetTCP(channel);
|
||||
return MonitorListeningSockets(channel);
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
|
|
|||
|
|
@ -2095,20 +2095,23 @@ class NetworkTests
|
|||
VerifyNotBoundLoopback(port, false);
|
||||
}
|
||||
|
||||
static void ValidateLocalhostRelayTraffic(bool ipv6)
|
||||
static void ValidateLocalhostRelayTraffic(ADDRESS_FAMILY addressFamily)
|
||||
{
|
||||
THROW_HR_IF(E_INVALIDARG, addressFamily != AF_INET && addressFamily != AF_INET6);
|
||||
|
||||
// Bind a port in the guest.
|
||||
auto [guestProcess, read] = BindGuestPort(ipv6 ? L"TCP6-LISTEN:1234,bind=::1" : L"TCP4-LISTEN:1234,bind=127.0.0.1", true);
|
||||
auto [guestProcess, read] =
|
||||
BindGuestPort(addressFamily == AF_INET6 ? L"TCP6-LISTEN:1234,bind=::1" : L"TCP4-LISTEN:1234,bind=127.0.0.1", true);
|
||||
|
||||
// Connect to the port via the localhost relay
|
||||
wil::unique_socket hostSocket;
|
||||
SOCKADDR_INET addr{};
|
||||
addr.si_family = ipv6 ? AF_INET6 : AF_INET;
|
||||
addr.si_family = addressFamily;
|
||||
INETADDR_SETLOOPBACK((PSOCKADDR)&addr);
|
||||
SS_PORT(&addr) = htons(1234);
|
||||
|
||||
auto pred = [&]() {
|
||||
hostSocket.reset(socket(ipv6 ? AF_INET6 : AF_INET, SOCK_STREAM, IPPROTO_TCP));
|
||||
hostSocket.reset(socket(addressFamily, SOCK_STREAM, IPPROTO_TCP));
|
||||
THROW_HR_IF(E_ABORT, !hostSocket);
|
||||
THROW_HR_IF(E_FAIL, connect(hostSocket.get(), reinterpret_cast<SOCKADDR*>(&addr), sizeof(addr)) == SOCKET_ERROR);
|
||||
};
|
||||
|
|
@ -2149,8 +2152,8 @@ class NetworkTests
|
|||
WSL2_TEST_ONLY();
|
||||
WslKeepAlive keepAlive;
|
||||
|
||||
ValidateLocalhostRelayTraffic(false);
|
||||
ValidateLocalhostRelayTraffic(true);
|
||||
ValidateLocalhostRelayTraffic(AF_INET);
|
||||
ValidateLocalhostRelayTraffic(AF_INET6);
|
||||
}
|
||||
|
||||
TEST_METHOD(NatLocalhostRelayNoIpv6)
|
||||
|
|
@ -2161,7 +2164,7 @@ class NetworkTests
|
|||
WslKeepAlive keepAlive;
|
||||
|
||||
VERIFY_ARE_EQUAL(LxsstuLaunchWsl(L"test -f /proc/net/tcp6"), 1L);
|
||||
ValidateLocalhostRelayTraffic(false);
|
||||
ValidateLocalhostRelayTraffic(AF_INET);
|
||||
}
|
||||
|
||||
static void TestNonRootNamespaceEphemeralBind()
|
||||
|
|
|
|||
Loading…
Reference in New Issue