From b87b17a895383146cdec184a257a406fb93ab572 Mon Sep 17 00:00:00 2001 From: Sreeram Ramachandran Date: Tue, 13 May 2014 15:42:23 -0700 Subject: Mark sockets on accept(). (cherry picked from commit 1321cfecc6fb0e075796a34ed3c6a30d7ff3bea3) Change-Id: I5d09be413cf720fbed905f96313b007997ada76c --- libnetd_client/NetdClient.cpp | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) (limited to 'libnetd_client') diff --git a/libnetd_client/NetdClient.cpp b/libnetd_client/NetdClient.cpp index 1d8501a..f7e8cc2 100644 --- a/libnetd_client/NetdClient.cpp +++ b/libnetd_client/NetdClient.cpp @@ -18,12 +18,22 @@ #include "netd_client/FwmarkCommands.h" #include +#include + +#define CLOSE_FD_AND_RESTORE_ERRNO(fd) \ + do { \ + int error = errno; \ + close(fd); \ + errno = error; \ + } while (0) namespace { typedef int (*ConnectFunctionType)(int, const sockaddr*, socklen_t); +typedef int (*AcceptFunctionType)(int, sockaddr*, socklen_t*); ConnectFunctionType libcConnect = 0; +AcceptFunctionType libcAccept = 0; int netdClientConnect(int sockfd, const sockaddr* addr, socklen_t addrlen) { if (FwmarkClient::shouldSetFwmark(sockfd, addr)) { @@ -35,6 +45,30 @@ int netdClientConnect(int sockfd, const sockaddr* addr, socklen_t addrlen) { return libcConnect(sockfd, addr, addrlen); } +int netdClientAccept(int sockfd, sockaddr* addr, socklen_t* addrlen) { + int acceptedSocket = libcAccept(sockfd, addr, addrlen); + if (acceptedSocket == -1) { + return -1; + } + sockaddr socketAddress; + if (!addr) { + socklen_t socketAddressLen = sizeof(socketAddress); + if (getsockname(acceptedSocket, &socketAddress, &socketAddressLen) == -1) { + CLOSE_FD_AND_RESTORE_ERRNO(acceptedSocket); + return -1; + } + addr = &socketAddress; + } + if (FwmarkClient::shouldSetFwmark(acceptedSocket, addr)) { + char data[] = {FWMARK_COMMAND_ON_ACCEPT}; + if (!FwmarkClient().send(data, sizeof(data), acceptedSocket)) { + CLOSE_FD_AND_RESTORE_ERRNO(acceptedSocket); + return -1; + } + } + return acceptedSocket; +} + } // namespace extern "C" void netdClientInitConnect(ConnectFunctionType* function) { @@ -43,3 +77,10 @@ extern "C" void netdClientInitConnect(ConnectFunctionType* function) { *function = netdClientConnect; } } + +extern "C" void netdClientInitAccept(AcceptFunctionType* function) { + if (function && *function) { + libcAccept = *function; + *function = netdClientAccept; + } +} -- cgit v1.1