aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ssh-agent-filter.C50
1 files changed, 41 insertions, 9 deletions
diff --git a/ssh-agent-filter.C b/ssh-agent-filter.C
index ed0d7a7..5aedb35 100644
--- a/ssh-agent-filter.C
+++ b/ssh-agent-filter.C
@@ -62,6 +62,9 @@ using std::move;
using std::count;
#include <thread>
+#include <mutex>
+using std::mutex;
+using std::lock_guard;
#include <cerrno>
#include <csignal>
@@ -70,6 +73,7 @@ using std::count;
#include <unistd.h>
#include <sys/types.h>
#include <sys/socket.h>
+#include <sys/select.h>
#include <sys/un.h>
#include <sys/wait.h>
#include <sysexits.h>
@@ -97,6 +101,7 @@ bool debug{false};
bool all_confirmed{false};
string saf_name;
fs::path path;
+mutex fd_fork_mutex;
string md5_hex (string const & s) {
@@ -136,9 +141,12 @@ int make_upstream_agent_conn () {
if (!(path = getenv("SSH_AUTH_SOCK")))
throw invalid_argument("no $SSH_AUTH_SOCK");
- if ((sock = socket(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0)) == -1)
- throw system_error(errno, system_category(), "socket");
- cloexec(sock);
+ {
+ lock_guard<mutex> lock{fd_fork_mutex};
+ if ((sock = socket(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0)) == -1)
+ throw system_error(errno, system_category(), "socket");
+ cloexec(sock);
+ }
addr.sun_family = AF_UNIX;
@@ -157,9 +165,15 @@ int make_listen_sock () {
int sock;
struct sockaddr_un addr;
- if ((sock = socket(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0)) == -1)
- throw system_error(errno, system_category(), "socket");
- cloexec(sock);
+ {
+ lock_guard<mutex> lock{fd_fork_mutex};
+ if ((sock = socket(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0)) == -1)
+ throw system_error(errno, system_category(), "socket");
+ cloexec(sock);
+ }
+
+ if (fcntl(sock, F_SETFL, fcntl(sock, F_GETFL) | O_NONBLOCK))
+ throw system_error(errno, system_category(), "fcntl");
addr.sun_family = AF_UNIX;
@@ -289,7 +303,11 @@ bool confirm (string const & question) {
char const * sap;
if (!(sap = getenv("SSH_ASKPASS")))
sap = "ssh-askpass";
- pid_t pid = fork();
+ pid_t pid;
+ {
+ lock_guard<mutex> lock{fd_fork_mutex};
+ pid = fork();
+ }
if (pid < 0)
throw runtime_error("fork()");
if (pid == 0) {
@@ -484,8 +502,22 @@ int main (int const argc, char const * const * const argv) {
signal(SIGHUP, sighandler);
signal(SIGTERM, sighandler);
- int client_sock;
- while ((client_sock = accept(listen_sock, nullptr, nullptr)) != -1) {
+ for (;;) {
+ fd_set fds;
+ FD_ZERO(&fds);
+ FD_SET(listen_sock, &fds);
+ select(listen_sock + 1, &fds, nullptr, nullptr, nullptr);
+ int client_sock;
+ {
+ lock_guard<mutex> lock{fd_fork_mutex};
+ if ((client_sock = accept(listen_sock, nullptr, nullptr)) == -1) {
+ if (errno == EAGAIN || errno == EWOULDBLOCK)
+ continue;
+ else
+ break;
+ }
+ cloexec(client_sock);
+ }
std::thread t{handle_client, client_sock};
t.detach();
}