diff options
Diffstat (limited to 'libnetd_client/NetdClient.cpp')
-rw-r--r-- | libnetd_client/NetdClient.cpp | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/libnetd_client/NetdClient.cpp b/libnetd_client/NetdClient.cpp index 1d8501a..f7e8cc2 100644 --- a/libnetd_client/NetdClient.cpp +++ b/libnetd_client/NetdClient.cpp @@ -18,12 +18,22 @@ #include "netd_client/FwmarkCommands.h" #include <sys/socket.h> +#include <unistd.h> + +#define CLOSE_FD_AND_RESTORE_ERRNO(fd) \ + do { \ + int error = errno; \ + close(fd); \ + errno = error; \ + } while (0) namespace { typedef int (*ConnectFunctionType)(int, const sockaddr*, socklen_t); +typedef int (*AcceptFunctionType)(int, sockaddr*, socklen_t*); ConnectFunctionType libcConnect = 0; +AcceptFunctionType libcAccept = 0; int netdClientConnect(int sockfd, const sockaddr* addr, socklen_t addrlen) { if (FwmarkClient::shouldSetFwmark(sockfd, addr)) { @@ -35,6 +45,30 @@ int netdClientConnect(int sockfd, const sockaddr* addr, socklen_t addrlen) { return libcConnect(sockfd, addr, addrlen); } +int netdClientAccept(int sockfd, sockaddr* addr, socklen_t* addrlen) { + int acceptedSocket = libcAccept(sockfd, addr, addrlen); + if (acceptedSocket == -1) { + return -1; + } + sockaddr socketAddress; + if (!addr) { + socklen_t socketAddressLen = sizeof(socketAddress); + if (getsockname(acceptedSocket, &socketAddress, &socketAddressLen) == -1) { + CLOSE_FD_AND_RESTORE_ERRNO(acceptedSocket); + return -1; + } + addr = &socketAddress; + } + if (FwmarkClient::shouldSetFwmark(acceptedSocket, addr)) { + char data[] = {FWMARK_COMMAND_ON_ACCEPT}; + if (!FwmarkClient().send(data, sizeof(data), acceptedSocket)) { + CLOSE_FD_AND_RESTORE_ERRNO(acceptedSocket); + return -1; + } + } + return acceptedSocket; +} + } // namespace extern "C" void netdClientInitConnect(ConnectFunctionType* function) { @@ -43,3 +77,10 @@ extern "C" void netdClientInitConnect(ConnectFunctionType* function) { *function = netdClientConnect; } } + +extern "C" void netdClientInitAccept(AcceptFunctionType* function) { + if (function && *function) { + libcAccept = *function; + *function = netdClientAccept; + } +} |