Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 163 additions & 18 deletions src/linux/init/localhost.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <vector>
#include <iostream>

#include <fcntl.h>
#include <libgen.h>
#include <sys/epoll.h>
#include <sys/socket.h>
Expand All @@ -30,6 +31,45 @@

namespace {

// Per-direction relay buffer for the non-blocking socket relay. The read limit
// for each direction is reduced by the amount of data pending from an incomplete
// write, establishing back-pressure through SOCK_STREAM flow control to throttle
// an abusive peer and bound memory usage.
struct RelayDirection
{
int srcFd;
int dstFd;
std::vector<gsl::byte> buf;
size_t head;
size_t tail;
bool srcEof;
bool done;

size_t Pending() const
{
return tail - head;
}
size_t Available() const
{
return buf.size() - tail;
}

void Compact()
{
if (head > 0)
{
auto pending = Pending();
if (pending > 0)
{
memmove(buf.data(), buf.data() + head, pending);
}

head = 0;
tail = pending;
}
}
};

void ListenThread(sockaddr_vm hvSocketAddress, int listenSocket)
{
pollfd pollDescriptors[] = {{listenSocket, POLLIN}};
Expand Down Expand Up @@ -103,40 +143,145 @@ void ListenThread(sockaddr_vm hvSocketAddress, int listenSocket)
return;
}

// Resize the buffer to be the requested size.
buffer.resize(message->BufferSize);
// Switch both sockets to non-blocking for the relay loop.
for (int fd : {tcpSocket.get(), relaySocket.get()})
{
int flags = fcntl(fd, F_GETFL, 0);
THROW_LAST_ERROR_IF(flags < 0);
THROW_LAST_ERROR_IF(fcntl(fd, F_SETFL, flags | O_NONBLOCK) < 0);
}

const auto bufferSize = message->BufferSize;
buffer.resize(bufferSize);
RelayDirection dirs[2] = {
{relaySocket.get(), tcpSocket.get(), std::move(buffer), 0, 0, false, false},
{tcpSocket.get(), relaySocket.get(), std::vector<gsl::byte>(bufferSize), 0, 0, false, false},
};

// Begin relaying data.
int outFd[2] = {tcpSocket.get(), relaySocket.get()};
pollfd pollDescriptors[] = {{relaySocket.get(), POLLIN}, {tcpSocket.get(), POLLIN}};
pollfd pfds[4] = {};
int pollDirIndex[4] = {};

for (;;)
{
if ((pollDescriptors[0].fd == -1) || (pollDescriptors[1].fd == -1))
// Complete directions where the source hit EOF and all
// pending data has been flushed to the destination.
for (auto& d : dirs)
{
if (!d.done && d.srcEof && d.Pending() == 0)
{
shutdown(d.dstFd, SHUT_WR);
d.done = true;
}
}

if (dirs[0].done && dirs[1].done)
{
return;
}

// Build the poll set based on current state.
int nfds = 0;

for (int i = 0; i < 2; i++)
{
auto& d = dirs[i];
if (d.done)
{
continue;
}

// Poll for read when the source is open and the buffer has space.
if (!d.srcEof && d.Available() > 0)
{
pfds[nfds] = {d.srcFd, POLLIN, 0};
pollDirIndex[nfds] = i;
nfds++;
}

// Poll for write when there is data waiting to go out.
if (d.Pending() > 0)
{
pfds[nfds] = {d.dstFd, POLLOUT, 0};
pollDirIndex[nfds] = i;
nfds++;
}
}

if (nfds == 0)
{
return;
}

THROW_LAST_ERROR_IF(poll(pollDescriptors, COUNT_OF(pollDescriptors), -1) < 0);
THROW_LAST_ERROR_IF(poll(pfds, nfds, -1) < 0);

bytesRead = 0;
for (int Index = 0; Index < COUNT_OF(pollDescriptors); Index += 1)
for (int j = 0; j < nfds; j++)
{
if (pollDescriptors[Index].revents & POLLIN)
auto& d = dirs[pollDirIndex[j]];

if (pfds[j].events & POLLOUT)
{
bytesRead = UtilReadBuffer(pollDescriptors[Index].fd, buffer);
if (bytesRead == 0)
// can't write to dstFd any more
if (pfds[j].revents & (POLLERR | POLLHUP))
{
pollDescriptors[Index].fd = -1;
shutdown(outFd[Index], SHUT_WR);
d.done = true;
continue;
}
else if (bytesRead < 0)

if (!(pfds[j].revents & POLLOUT))
{
return;
continue;
}

auto written = TEMP_FAILURE_RETRY(write(d.dstFd, d.buf.data() + d.head, d.Pending()));
if (written < 0)
{
if (errno == EAGAIN || errno == EWOULDBLOCK)
{
continue;
}

d.done = true;
continue;
}

d.head += written;
if (d.Pending() == 0)
{
d.head = d.tail = 0;
}
}
else
{
if (!(pfds[j].revents & POLLIN))
{
// No data to read; if the source is gone, mark EOF.
if (pfds[j].revents & (POLLERR | POLLHUP))
{
d.srcEof = true;
}

continue;
}

d.Compact();
auto nread = TEMP_FAILURE_RETRY(read(d.srcFd, d.buf.data() + d.tail, d.Available()));
if (nread == 0)
{
d.srcEof = true;
}
else if (nread < 0)
{
if (errno == EAGAIN || errno == EWOULDBLOCK)
{
continue;
}

d.srcEof = true;
continue;
}
else if (UtilWriteBuffer(outFd[Index], buffer.data(), bytesRead) < 0)
else
{
return;
d.tail += nread;
}
}
}
Expand Down
Loading
Loading