Clean up localhost relay implementation to not rely on procfs parsing. (#13836)

* Clean up localhost relay implementation to not rely on procfs parsing.

* pr feedback

---------

Co-authored-by: Ben Hillis <benhill@ntdev.microsoft.com>
This commit is contained in:
Ben Hillis 2025-12-12 14:18:07 -08:00 committed by GitHub
parent 7f8422654e
commit f1e20b21c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 72 additions and 117 deletions

View File

@ -12,6 +12,8 @@
#include <netinet/ip.h>
#include <sys/syscall.h>
#include <linux/unistd.h>
#include <linux/sock_diag.h>
#include <linux/inet_diag.h>
#include <lxwil.h>
#include <linux/if_tun.h>
@ -21,6 +23,8 @@
#include "SecCompDispatcher.h"
#include "seccomp_defs.h"
#include "CommandLine.h"
#include "NetlinkChannel.h"
#include "NetlinkTransactionError.h"
#define TCP_LISTEN 10
@ -145,79 +149,59 @@ void ListenThread(sockaddr_vm hvSocketAddress, int listenSocket)
return;
}
std::vector<sockaddr_storage> ParseTcpFile(int family, FILE* file)
std::vector<sockaddr_storage> QueryListeningSockets(NetlinkChannel& channel)
{
char* line = nullptr;
auto freeLine = wil::scope_exit([&line]() { free(line); });
// Skip the first line which contains a header.
size_t lineLength = 0;
auto bytesRead = getline(&line, &lineLength, file);
THROW_LAST_ERROR_IF(bytesRead < 0);
// Each line contains information about TCP sockets on the system, the fields
// we are interested are for sockets that are or have been listening:
// 1: Socket address and port number
// 3: Socket status
std::vector<sockaddr_storage> sockets{};
while ((bytesRead = getline(&line, &lineLength, file)) != -1)
try
{
inet_diag_req_v2 message{};
message.sdiag_protocol = IPPROTO_TCP;
message.idiag_states = (1 << TCP_LISTEN);
auto onMessage = [&](const NetlinkResponse& response) {
for (const auto& e : response.Messages<inet_diag_msg>(SOCK_DIAG_BY_FAMILY))
{
const auto* payload = e.Payload();
sockaddr_storage sock{};
int index = 0;
int status = 0;
for (char *sp, *field = strtok_r(line, " \n", &sp); field != nullptr; field = strtok_r(NULL, " \n", &sp))
if (payload->idiag_family == AF_INET)
{
if (index == 1)
{
int port;
const char* portString = strchr(field, ':');
if (portString == nullptr)
{
break;
auto* ipv4 = reinterpret_cast<sockaddr_in*>(&sock);
ipv4->sin_family = AF_INET;
ipv4->sin_addr.s_addr = payload->id.idiag_src[0];
ipv4->sin_port = payload->id.idiag_sport;
}
portString += 1;
port = static_cast<int>(strtol(portString, nullptr, 16));
if (port == 0)
else if (payload->idiag_family == AF_INET6)
{
break;
auto* ipv6 = reinterpret_cast<sockaddr_in6*>(&sock);
ipv6->sin6_family = AF_INET6;
static_assert(sizeof(ipv6->sin6_addr.s6_addr32) == sizeof(payload->id.idiag_src));
memcpy(ipv6->sin6_addr.s6_addr32, payload->id.idiag_src, sizeof(ipv6->sin6_addr.s6_addr32));
ipv6->sin6_port = payload->id.idiag_sport;
}
if (family == AF_INET)
{
sockaddr_in ipv4Sock{};
ipv4Sock.sin_family = family;
ipv4Sock.sin_addr.s_addr = strtol(field, nullptr, 16);
ipv4Sock.sin_port = port;
memcpy(&sock, &ipv4Sock, sizeof(ipv4Sock));
}
else if (family == AF_INET6)
{
sockaddr_in6 ipv6Sock{};
ipv6Sock.sin6_family = family;
ipv6Sock.sin6_port = port;
for (int part = 0; part < 4; ++part)
{
char next[5];
next[4] = 0;
memcpy(next, field + part * 4, 4);
ipv6Sock.sin6_addr.__in6_union.__s6_addr32[part] = strtol(next, nullptr, 16);
}
memcpy(&sock, &ipv6Sock, sizeof(ipv6Sock));
}
}
else if (index == 3)
{
status = static_cast<int>(strtol(field, nullptr, 16));
break;
}
index += 1;
}
if ((status == TCP_LISTEN) && (sock.ss_family != 0))
{
sockets.emplace_back(sock);
}
};
// Query IPv4 listening sockets.
{
message.sdiag_family = AF_INET;
auto transaction = channel.CreateTransaction(message, SOCK_DIAG_BY_FAMILY, NLM_F_DUMP);
transaction.Execute(onMessage);
}
// Query IPv6 listening sockets.
{
message.sdiag_family = AF_INET6;
auto transaction = channel.CreateTransaction(message, SOCK_DIAG_BY_FAMILY, NLM_F_DUMP);
transaction.Execute(onMessage);
}
}
catch (const NetlinkTransactionError& e)
{
// Log but don't fail - network state might be temporarily unavailable
LOG_ERROR("Failed to query listening sockets via sock_diag: {}", e.what());
}
return sockets;
@ -246,12 +230,12 @@ LX_GNS_PORT_LISTENER_RELAY SockToRelayMessage(const sockaddr_storage& sock)
{
auto ipv4 = reinterpret_cast<const sockaddr_in*>(&sock);
message.Address[0] = ipv4->sin_addr.s_addr;
message.Port = ipv4->sin_port;
message.Port = ntohs(ipv4->sin_port);
}
else if (sock.ss_family == AF_INET6)
{
auto ipv6 = reinterpret_cast<const sockaddr_in6*>(&sock);
message.Port = ipv6->sin6_port;
message.Port = ntohs(ipv6->sin6_port);
memcpy(message.Address, ipv6->sin6_addr.__in6_union.__s6_addr, sizeof(message.Address));
}
return message;
@ -296,53 +280,23 @@ bool IsSameSockAddr(const sockaddr_storage& left, const sockaddr_storage& right)
{
auto leftIpv6 = reinterpret_cast<const sockaddr_in6*>(&left);
auto rightIpv6 = reinterpret_cast<const sockaddr_in6*>(&right);
if (leftIpv6->sin6_port != rightIpv6->sin6_port)
{
return false;
return (leftIpv6->sin6_port == rightIpv6->sin6_port && memcmp(&leftIpv6->sin6_addr, &rightIpv6->sin6_addr, sizeof(in6_addr)) == 0);
}
for (int part = 0; part < 4; ++part)
{
if (leftIpv6->sin6_addr.__in6_union.__s6_addr32[part] != rightIpv6->sin6_addr.__in6_union.__s6_addr32[part])
{
return false;
}
}
return true;
}
else
{
FATAL_ERROR("Unrecognized socket family {}", left.ss_family);
return false;
}
}
// Start looking for ports bound to localhost or wildcard.
int ScanProcNetTCP(wsl::shared::SocketChannel& channel)
// Monitor listening TCP sockets using sock_diag netlink interface.
int MonitorListeningSockets(wsl::shared::SocketChannel& channel)
{
// Periodically scan procfs for listening TCP sockets.
NetlinkChannel netlinkChannel(SOCK_RAW, NETLINK_SOCK_DIAG);
std::vector<sockaddr_storage> relays{};
int result = 0;
for (;;)
{
std::vector<sockaddr_storage> sockets;
wil::unique_file tcp4File{fopen("/proc/net/tcp", "r")};
if (tcp4File)
{
sockets = ParseTcpFile(AF_INET, tcp4File.get());
}
wil::unique_file tcp6File{fopen("/proc/net/tcp6", "r")};
if (tcp6File)
{
auto ipv6Sockets = ParseTcpFile(AF_INET6, tcp6File.get());
sockets.insert(sockets.end(), ipv6Sockets.begin(), ipv6Sockets.end());
}
if (!tcp4File && !tcp6File)
{
LOG_ERROR("Failed to open /proc/net/tcp and /proc/net/tcp6, closing port relay");
return 1;
}
auto sockets = QueryListeningSockets(netlinkChannel);
// Stop any relays that no longer match listening ports.
std::erase_if(relays, [&](const auto& entry) {
@ -386,9 +340,7 @@ int ScanProcNetTCP(wsl::shared::SocketChannel& channel)
}
// Sleep before scanning again.
//
// TODO: Investigate using EBPF notifications instead of a sleep.
sleep(1);
std::this_thread::sleep_for(std::chrono::seconds(1));
}
return result;
@ -432,7 +384,7 @@ try
if (ScanForPorts)
{
return ScanProcNetTCP(channel);
return MonitorListeningSockets(channel);
}
return 0;

View File

@ -2095,20 +2095,23 @@ class NetworkTests
VerifyNotBoundLoopback(port, false);
}
static void ValidateLocalhostRelayTraffic(bool ipv6)
static void ValidateLocalhostRelayTraffic(ADDRESS_FAMILY addressFamily)
{
THROW_HR_IF(E_INVALIDARG, addressFamily != AF_INET && addressFamily != AF_INET6);
// Bind a port in the guest.
auto [guestProcess, read] = BindGuestPort(ipv6 ? L"TCP6-LISTEN:1234,bind=::1" : L"TCP4-LISTEN:1234,bind=127.0.0.1", true);
auto [guestProcess, read] =
BindGuestPort(addressFamily == AF_INET6 ? L"TCP6-LISTEN:1234,bind=::1" : L"TCP4-LISTEN:1234,bind=127.0.0.1", true);
// Connect to the port via the localhost relay
wil::unique_socket hostSocket;
SOCKADDR_INET addr{};
addr.si_family = ipv6 ? AF_INET6 : AF_INET;
addr.si_family = addressFamily;
INETADDR_SETLOOPBACK((PSOCKADDR)&addr);
SS_PORT(&addr) = htons(1234);
auto pred = [&]() {
hostSocket.reset(socket(ipv6 ? AF_INET6 : AF_INET, SOCK_STREAM, IPPROTO_TCP));
hostSocket.reset(socket(addressFamily, SOCK_STREAM, IPPROTO_TCP));
THROW_HR_IF(E_ABORT, !hostSocket);
THROW_HR_IF(E_FAIL, connect(hostSocket.get(), reinterpret_cast<SOCKADDR*>(&addr), sizeof(addr)) == SOCKET_ERROR);
};
@ -2149,8 +2152,8 @@ class NetworkTests
WSL2_TEST_ONLY();
WslKeepAlive keepAlive;
ValidateLocalhostRelayTraffic(false);
ValidateLocalhostRelayTraffic(true);
ValidateLocalhostRelayTraffic(AF_INET);
ValidateLocalhostRelayTraffic(AF_INET6);
}
TEST_METHOD(NatLocalhostRelayNoIpv6)
@ -2161,7 +2164,7 @@ class NetworkTests
WslKeepAlive keepAlive;
VERIFY_ARE_EQUAL(LxsstuLaunchWsl(L"test -f /proc/net/tcp6"), 1L);
ValidateLocalhostRelayTraffic(false);
ValidateLocalhostRelayTraffic(AF_INET);
}
static void TestNonRootNamespaceEphemeralBind()