--- /dev/null
+/*
+ * virnetsocket.c: generic network socket handling
+ *
+ * Copyright (C) 2006-2011 Red Hat, Inc.
+ * Copyright (C) 2006 Daniel P. Berrange
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
+ *
+ * Author: Daniel P. Berrange <berrange@redhat.com>
+ */
+
+#include <config.h>
+
+#include <sys/stat.h>
+#include <sys/socket.h>
+#include <unistd.h>
+#include <sys/wait.h>
+
+#ifdef HAVE_NETINET_TCP_H
+# include <netinet/tcp.h>
+#endif
+
+#include "virnetsocket.h"
+#include "util.h"
+#include "memory.h"
+#include "virterror_internal.h"
+#include "logging.h"
+#include "files.h"
+#include "event.h"
+
+#define VIR_FROM_THIS VIR_FROM_RPC
+
+#define virNetError(code, ...) \
+ virReportErrorHelper(VIR_FROM_THIS, code, __FILE__, \
+ __FUNCTION__, __LINE__, __VA_ARGS__)
+
+
+struct _virNetSocket {
+ int fd;
+ int watch;
+ pid_t pid;
+ int errfd;
+ bool client;
+ virNetSocketIOFunc func;
+ void *opaque;
+ virSocketAddr localAddr;
+ virSocketAddr remoteAddr;
+ char *localAddrStr;
+ char *remoteAddrStr;
+};
+
+
+#ifndef WIN32
+static int virNetSocketForkDaemon(const char *binary)
+{
+ int ret;
+ virCommandPtr cmd = virCommandNewArgList(binary,
+ "--timeout=30",
+ NULL);
+
+ virCommandAddEnvPassCommon(cmd);
+ virCommandClearCaps(cmd);
+ virCommandDaemonize(cmd);
+ ret = virCommandRun(cmd, NULL);
+ virCommandFree(cmd);
+ return ret;
+}
+#endif
+
+
+static virNetSocketPtr virNetSocketNew(virSocketAddrPtr localAddr,
+ virSocketAddrPtr remoteAddr,
+ bool isClient,
+ int fd, int errfd, pid_t pid)
+{
+ virNetSocketPtr sock;
+ int no_slow_start = 1;
+
+ VIR_DEBUG("localAddr=%p remoteAddr=%p fd=%d errfd=%d pid=%d",
+ localAddr, remoteAddr,
+ fd, errfd, pid);
+
+ if (virSetCloseExec(fd) < 0) {
+ virReportSystemError(errno, "%s",
+ _("Unable to set close-on-exec flag"));
+ return NULL;
+ }
+ if (virSetNonBlock(fd) < 0) {
+ virReportSystemError(errno, "%s",
+ _("Unable to enable non-blocking flag"));
+ return NULL;
+ }
+
+ if (VIR_ALLOC(sock) < 0) {
+ virReportOOMError();
+ return NULL;
+ }
+
+ if (localAddr)
+ sock->localAddr = *localAddr;
+ if (remoteAddr)
+ sock->remoteAddr = *remoteAddr;
+ sock->fd = fd;
+ sock->errfd = errfd;
+ sock->pid = pid;
+
+ /* Disable nagle for TCP sockets */
+ if (sock->localAddr.data.sa.sa_family == AF_INET ||
+ sock->localAddr.data.sa.sa_family == AF_INET6) {
+ if (setsockopt(fd, IPPROTO_TCP, TCP_NODELAY,
+ &no_slow_start,
+ sizeof(no_slow_start)) < 0) {
+ virReportSystemError(errno, "%s",
+ _("Unable to disable nagle algorithm"));
+ goto error;
+ }
+ }
+
+
+ if (localAddr &&
+ !(sock->localAddrStr = virSocketFormatAddrFull(localAddr, true, ";")))
+ goto error;
+
+ if (remoteAddr &&
+ !(sock->remoteAddrStr = virSocketFormatAddrFull(remoteAddr, true, ";")))
+ goto error;
+
+ sock->client = isClient;
+
+ VIR_DEBUG("sock=%p localAddrStr=%s remoteAddrStr=%s",
+ sock, NULLSTR(sock->localAddrStr), NULLSTR(sock->remoteAddrStr));
+
+ return sock;
+
+error:
+ sock->fd = sock->errfd = -1; /* Caller owns fd/errfd on failure */
+ virNetSocketFree(sock);
+ return NULL;
+}
+
+
+int virNetSocketNewListenTCP(const char *nodename,
+ const char *service,
+ virNetSocketPtr **retsocks,
+ size_t *nretsocks)
+{
+ virNetSocketPtr *socks = NULL;
+ size_t nsocks = 0;
+ struct addrinfo *ai = NULL;
+ struct addrinfo hints;
+ int fd = -1;
+ int i;
+
+ *retsocks = NULL;
+ *nretsocks = 0;
+
+ memset(&hints, 0, sizeof hints);
+ hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG;
+ hints.ai_socktype = SOCK_STREAM;
+
+ int e = getaddrinfo(nodename, service, &hints, &ai);
+ if (e != 0) {
+ virNetError(VIR_ERR_SYSTEM_ERROR,
+ _("Unable to resolve address '%s' service '%s': %s"),
+ nodename, service, gai_strerror(e));
+ return -1;
+ }
+
+ struct addrinfo *runp = ai;
+ while (runp) {
+ virSocketAddr addr;
+
+ memset(&addr, 0, sizeof(addr));
+
+ if ((fd = socket(runp->ai_family, runp->ai_socktype,
+ runp->ai_protocol)) < 0) {
+ virReportSystemError(errno, "%s", _("Unable to create socket"));
+ goto error;
+ }
+
+ int opt = 1;
+ if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof opt) < 0) {
+ virReportSystemError(errno, "%s", _("Unable to enable port reuse"));
+ goto error;
+ }
+
+#ifdef IPV6_V6ONLY
+ if (runp->ai_family == PF_INET6) {
+ int on = 1;
+ /*
+ * Normally on Linux an INET6 socket will bind to the INET4
+ * address too. If getaddrinfo returns results with INET4
+ * first though, this will result in INET6 binding failing.
+ * We can trivially cope with multiple server sockets, so
+ * we force it to only listen on IPv6
+ */
+ if (setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY,
+ (void*)&on, sizeof on) < 0) {
+ virReportSystemError(errno, "%s",
+ _("Unable to force bind to IPv6 only"));
+ goto error;
+ }
+ }
+#endif
+
+ if (bind(fd, runp->ai_addr, runp->ai_addrlen) < 0) {
+ if (errno != EADDRINUSE) {
+ virReportSystemError(errno, "%s", _("Unable to bind to port"));
+ goto error;
+ }
+ VIR_FORCE_CLOSE(fd);
+ continue;
+ }
+
+ addr.len = sizeof(addr.data);
+ if (getsockname(fd, &addr.data.sa, &addr.len) < 0) {
+ virReportSystemError(errno, "%s", _("Unable to get local socket name"));
+ goto error;
+ }
+
+ VIR_DEBUG("%p f=%d f=%d", &addr, runp->ai_family, addr.data.sa.sa_family);
+
+ if (VIR_EXPAND_N(socks, nsocks, 1) < 0) {
+ virReportOOMError();
+ goto error;
+ }
+
+ if (!(socks[nsocks-1] = virNetSocketNew(&addr, NULL, false, fd, -1, 0)))
+ goto error;
+ runp = runp->ai_next;
+ fd = -1;
+ }
+
+ freeaddrinfo(ai);
+
+ *retsocks = socks;
+ *nretsocks = nsocks;
+ return 0;
+
+error:
+ for (i = 0 ; i < nsocks ; i++)
+ virNetSocketFree(socks[i]);
+ VIR_FREE(socks);
+ freeaddrinfo(ai);
+ VIR_FORCE_CLOSE(fd);
+ return -1;
+}
+
+
+#if HAVE_SYS_UN_H
+int virNetSocketNewListenUNIX(const char *path,
+ mode_t mask,
+ gid_t grp,
+ virNetSocketPtr *retsock)
+{
+ virSocketAddr addr;
+ mode_t oldmask;
+ int fd;
+
+ *retsock = NULL;
+
+ memset(&addr, 0, sizeof(addr));
+
+ addr.len = sizeof(addr.data.un);
+
+ if ((fd = socket(PF_UNIX, SOCK_STREAM, 0)) < 0) {
+ virReportSystemError(errno, "%s", _("Failed to create socket"));
+ goto error;
+ }
+
+ addr.data.un.sun_family = AF_UNIX;
+ if (virStrcpyStatic(addr.data.un.sun_path, path) == NULL) {
+ virReportSystemError(ENOMEM, _("Path %s too long for unix socket"), path);
+ goto error;
+ }
+ if (addr.data.un.sun_path[0] == '@')
+ addr.data.un.sun_path[0] = '\0';
+ else
+ unlink(addr.data.un.sun_path);
+
+ oldmask = umask(~mask);
+
+ if (bind(fd, &addr.data.sa, addr.len) < 0) {
+ umask(oldmask);
+ virReportSystemError(errno,
+ _("Failed to bind socket to '%s'"),
+ path);
+ goto error;
+ }
+ umask(oldmask);
+
+ /* chown() doesn't work for abstract sockets but we use them only
+ * if libvirtd runs unprivileged
+ */
+ if (grp != 0 && chown(path, -1, grp)) {
+ virReportSystemError(errno,
+ _("Failed to change group ID of '%s' to %d"),
+ path, grp);
+ goto error;
+ }
+
+ if (!(*retsock = virNetSocketNew(&addr, NULL, false, fd, -1, 0)))
+ goto error;
+
+ return 0;
+
+error:
+ if (path[0] != '@')
+ unlink(path);
+ VIR_FORCE_CLOSE(fd);
+ return -1;
+}
+#else
+int virNetSocketNewListenUNIX(const char *path ATTRIBUTE_UNUSED,
+ mode_t mask ATTRIBUTE_UNUSED,
+ gid_t grp ATTRIBUTE_UNUSED,
+ virNetSocketPtr *retsock ATTRIBUTE_UNUSED)
+{
+ virReportSystemError(ENOSYS, "%s",
+ _("UNIX sockets are not supported on this platform"));
+ return -1;
+}
+#endif
+
+
+int virNetSocketNewConnectTCP(const char *nodename,
+ const char *service,
+ virNetSocketPtr *retsock)
+{
+ struct addrinfo *ai = NULL;
+ struct addrinfo hints;
+ int fd = -1;
+ virSocketAddr localAddr;
+ virSocketAddr remoteAddr;
+ struct addrinfo *runp;
+ int savedErrno = ENOENT;
+
+ *retsock = NULL;
+
+ memset(&localAddr, 0, sizeof(localAddr));
+ memset(&remoteAddr, 0, sizeof(remoteAddr));
+
+ memset(&hints, 0, sizeof hints);
+ hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG;
+ hints.ai_socktype = SOCK_STREAM;
+
+ int e = getaddrinfo(nodename, service, &hints, &ai);
+ if (e != 0) {
+ virNetError(VIR_ERR_SYSTEM_ERROR,
+ _("Unable to resolve address '%s' service '%s': %s"),
+ nodename, service, gai_strerror (e));
+ return -1;
+ }
+
+ runp = ai;
+ while (runp) {
+ int opt = 1;
+
+ if ((fd = socket(runp->ai_family, runp->ai_socktype,
+ runp->ai_protocol)) < 0) {
+ virReportSystemError(errno, "%s", _("Unable to create socket"));
+ goto error;
+ }
+
+ setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof opt);
+
+ if (connect(fd, runp->ai_addr, runp->ai_addrlen) >= 0)
+ break;
+
+ savedErrno = errno;
+ VIR_FORCE_CLOSE(fd);
+ runp = runp->ai_next;
+ }
+
+ if (fd == -1) {
+ virReportSystemError(savedErrno,
+ _("unable to connect to server at '%s:%s'"),
+ nodename, service);
+ goto error;
+ }
+
+ localAddr.len = sizeof(localAddr.data);
+ if (getsockname(fd, &localAddr.data.sa, &localAddr.len) < 0) {
+ virReportSystemError(errno, "%s", _("Unable to get local socket name"));
+ goto error;
+ }
+
+ remoteAddr.len = sizeof(remoteAddr.data);
+ if (getpeername(fd, &remoteAddr.data.sa, &remoteAddr.len) < 0) {
+ virReportSystemError(errno, "%s", _("Unable to get remote socket name"));
+ goto error;
+ }
+
+ if (!(*retsock = virNetSocketNew(&localAddr, &remoteAddr, true, fd, -1, 0)))
+ goto error;
+
+ freeaddrinfo(ai);
+
+ return 0;
+
+error:
+ freeaddrinfo(ai);
+ VIR_FORCE_CLOSE(fd);
+ return -1;
+}
+
+
+#if HAVE_SYS_UN_H
+int virNetSocketNewConnectUNIX(const char *path,
+ bool spawnDaemon,
+ const char *binary,
+ virNetSocketPtr *retsock)
+{
+ virSocketAddr localAddr;
+ virSocketAddr remoteAddr;
+ int fd;
+ int retries = 0;
+
+ memset(&localAddr, 0, sizeof(localAddr));
+ memset(&remoteAddr, 0, sizeof(remoteAddr));
+
+ remoteAddr.len = sizeof(remoteAddr.data.un);
+
+ if ((fd = socket(PF_UNIX, SOCK_STREAM, 0)) < 0) {
+ virReportSystemError(errno, "%s", _("Failed to create socket"));
+ goto error;
+ }
+
+ remoteAddr.data.un.sun_family = AF_UNIX;
+ if (virStrcpyStatic(remoteAddr.data.un.sun_path, path) == NULL) {
+ virReportSystemError(ENOMEM, _("Path %s too long for unix socket"), path);
+ goto error;
+ }
+ if (remoteAddr.data.un.sun_path[0] == '@')
+ remoteAddr.data.un.sun_path[0] = '\0';
+
+retry:
+ if (connect(fd, &remoteAddr.data.sa, remoteAddr.len) < 0) {
+ if (errno == ECONNREFUSED && spawnDaemon && retries < 20) {
+ if (retries == 0 &&
+ virNetSocketForkDaemon(binary) < 0)
+ goto error;
+
+ retries++;
+ usleep(1000 * 100 * retries);
+ goto retry;
+ }
+
+ virReportSystemError(errno,
+ _("Failed to connect socket to '%s'"),
+ path);
+ goto error;
+ }
+
+ localAddr.len = sizeof(localAddr.data);
+ if (getsockname(fd, &localAddr.data.sa, &localAddr.len) < 0) {
+ virReportSystemError(errno, "%s", _("Unable to get local socket name"));
+ goto error;
+ }
+
+ if (!(*retsock = virNetSocketNew(&localAddr, &remoteAddr, true, fd, -1, 0)))
+ goto error;
+
+ return 0;
+
+error:
+ VIR_FORCE_CLOSE(fd);
+ return -1;
+}
+#else
+int virNetSocketNewConnectUNIX(const char *path ATTRIBUTE_UNUSED,
+ bool spawnDaemon ATTRIBUTE_UNUSED,
+ const char *binary ATTRIBUTE_UNUSED,
+ virNetSocketPtr *retsock ATTRIBUTE_UNUSED)
+{
+ virReportSystemError(ENOSYS, "%s",
+ _("UNIX sockets are not supported on this platform"));
+ return -1;
+}
+#endif
+
+
+#ifndef WIN32
+int virNetSocketNewConnectCommand(virCommandPtr cmd,
+ virNetSocketPtr *retsock)
+{
+ pid_t pid = 0;
+ int sv[2];
+ int errfd[2];
+
+ *retsock = NULL;
+
+ /* Fork off the external process. Use socketpair to create a private
+ * (unnamed) Unix domain socket to the child process so we don't have
+ * to faff around with two file descriptors (a la 'pipe(2)').
+ */
+ if (socketpair(PF_UNIX, SOCK_STREAM, 0, sv) < 0) {
+ virReportSystemError(errno, "%s",
+ _("unable to create socket pair"));
+ goto error;
+ }
+
+ if (pipe(errfd) < 0) {
+ virReportSystemError(errno, "%s",
+ _("unable to create socket pair"));
+ goto error;
+ }
+
+ virCommandSetInputFD(cmd, sv[1]);
+ virCommandSetOutputFD(cmd, &sv[1]);
+ virCommandSetErrorFD(cmd, &errfd[1]);
+
+ if (virCommandRunAsync(cmd, &pid) < 0)
+ goto error;
+
+ /* Parent continues here. */
+ VIR_FORCE_CLOSE(sv[1]);
+ VIR_FORCE_CLOSE(errfd[1]);
+
+ if (!(*retsock = virNetSocketNew(NULL, NULL, true, sv[0], errfd[0], pid)))
+ goto error;
+
+ virCommandFree(cmd);
+
+ return 0;
+
+error:
+ VIR_FORCE_CLOSE(sv[0]);
+ VIR_FORCE_CLOSE(sv[1]);
+ VIR_FORCE_CLOSE(errfd[0]);
+ VIR_FORCE_CLOSE(errfd[1]);
+
+ virCommandAbort(cmd);
+ virCommandFree(cmd);
+
+ return -1;
+}
+#else
+int virNetSocketNewConnectCommand(virCommandPtr cmd ATTRIBUTE_UNUSED,
+ virNetSocketPtr *retsock ATTRIBUTE_UNUSED)
+{
+ virReportSystemError(errno, "%s",
+ _("Tunnelling sockets not supported on this platform"));
+ return -1;
+}
+#endif
+
+int virNetSocketNewConnectSSH(const char *nodename,
+ const char *service,
+ const char *binary,
+ const char *username,
+ bool noTTY,
+ const char *netcat,
+ const char *path,
+ virNetSocketPtr *retsock)
+{
+ virCommandPtr cmd;
+ *retsock = NULL;
+
+ cmd = virCommandNew(binary ? binary : "ssh");
+ virCommandAddEnvPassCommon(cmd);
+ virCommandAddEnvPass(cmd, "SSH_AUTH_SOCK");
+ virCommandAddEnvPass(cmd, "SSH_ASKPASS");
+ virCommandClearCaps(cmd);
+
+ if (service)
+ virCommandAddArgList(cmd, "-p", service, NULL);
+ if (username)
+ virCommandAddArgList(cmd, "-l", username, NULL);
+ if (noTTY)
+ virCommandAddArgList(cmd, "-T", "-o", "BatchMode=yes",
+ "-e", "none", NULL);
+ virCommandAddArgList(cmd, nodename,
+ netcat ? netcat : "nc",
+ "-U", path, NULL);
+
+ return virNetSocketNewConnectCommand(cmd, retsock);
+}
+
+
+int virNetSocketNewConnectExternal(const char **cmdargv,
+ virNetSocketPtr *retsock)
+{
+ virCommandPtr cmd;
+
+ *retsock = NULL;
+
+ cmd = virCommandNewArgs(cmdargv);
+ virCommandAddEnvPassCommon(cmd);
+ virCommandClearCaps(cmd);
+
+ return virNetSocketNewConnectCommand(cmd, retsock);
+}
+
+
+void virNetSocketFree(virNetSocketPtr sock)
+{
+ if (!sock)
+ return;
+
+ VIR_DEBUG("sock=%p fd=%d", sock, sock->fd);
+ if (sock->watch > 0) {
+ virEventRemoveHandle(sock->watch);
+ sock->watch = -1;
+ }
+
+#ifdef HAVE_SYS_UN_H
+ /* If a server socket, then unlink UNIX path */
+ if (!sock->client &&
+ sock->localAddr.data.sa.sa_family == AF_UNIX &&
+ sock->localAddr.data.un.sun_path[0] != '\0')
+ unlink(sock->localAddr.data.un.sun_path);
+#endif
+
+ VIR_FORCE_CLOSE(sock->fd);
+ VIR_FORCE_CLOSE(sock->errfd);
+
+#ifndef WIN32
+ if (sock->pid > 0) {
+ pid_t reap;
+ kill(sock->pid, SIGTERM);
+ do {
+retry:
+ reap = waitpid(sock->pid, NULL, 0);
+ if (reap == -1 && errno == EINTR)
+ goto retry;
+ } while (reap != -1 && reap != sock->pid);
+ }
+#endif
+
+ VIR_FREE(sock->localAddrStr);
+ VIR_FREE(sock->remoteAddrStr);
+
+ VIR_FREE(sock);
+}
+
+
+int virNetSocketGetFD(virNetSocketPtr sock)
+{
+ return sock->fd;
+}
+
+
+bool virNetSocketIsLocal(virNetSocketPtr sock)
+{
+ if (sock->localAddr.data.sa.sa_family == AF_UNIX)
+ return true;
+ return false;
+}
+
+
+#ifdef SO_PEERCRED
+int virNetSocketGetLocalIdentity(virNetSocketPtr sock,
+ uid_t *uid,
+ pid_t *pid)
+{
+ struct ucred cr;
+ unsigned int cr_len = sizeof (cr);
+
+ if (getsockopt(sock->fd, SOL_SOCKET, SO_PEERCRED, &cr, &cr_len) < 0) {
+ virReportSystemError(errno, "%s",
+ _("Failed to get client socket identity"));
+ return -1;
+ }
+
+ *pid = cr.pid;
+ *uid = cr.uid;
+ return 0;
+}
+#else
+int virNetSocketGetLocalIdentity(virNetSocketPtr sock ATTRIBUTE_UNUSED,
+ uid_t *uid ATTRIBUTE_UNUSED,
+ pid_t *pid ATTRIBUTE_UNUSED)
+{
+ /* XXX Many more OS support UNIX socket credentials we could port to. See dbus ....*/
+ virReportSystemError(ENOSYS, "%s",
+ _("Client socket identity not available"));
+ return -1;
+}
+#endif
+
+
+int virNetSocketSetBlocking(virNetSocketPtr sock,
+ bool blocking)
+{
+ return virSetBlocking(sock->fd, blocking);
+}
+
+
+const char *virNetSocketLocalAddrString(virNetSocketPtr sock)
+{
+ return sock->localAddrStr;
+}
+
+const char *virNetSocketRemoteAddrString(virNetSocketPtr sock)
+{
+ return sock->remoteAddrStr;
+}
+
+ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len)
+{
+ char *errout = NULL;
+ ssize_t ret;
+reread:
+ ret = read(sock->fd, buf, len);
+
+ if ((ret < 0) && (errno == EINTR))
+ goto reread;
+ if ((ret < 0) && (errno == EAGAIN))
+ return 0;
+ if (ret <= 0 &&
+ sock->errfd != -1 &&
+ virFileReadLimFD(sock->errfd, 1024, &errout) >= 0 &&
+ errout != NULL) {
+ size_t elen = strlen(errout);
+ if (elen && errout[elen-1] == '\n')
+ errout[elen-1] = '\0';
+ }
+
+ if (ret < 0) {
+ if (errout)
+ virReportSystemError(errno,
+ _("Cannot recv data: %s"), errout);
+ else
+ virReportSystemError(errno, "%s",
+ _("Cannot recv data"));
+ ret = -1;
+ } else if (ret == 0) {
+ if (errout)
+ virReportSystemError(EIO,
+ _("End of file while reading data: %s"), errout);
+ else
+ virReportSystemError(EIO, "%s",
+ _("End of file while reading data"));
+ ret = -1;
+ }
+
+ VIR_FREE(errout);
+ return ret;
+}
+
+ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len)
+{
+ ssize_t ret;
+rewrite:
+ ret = write(sock->fd, buf, len);
+
+ if (ret < 0) {
+ if (errno == EINTR)
+ goto rewrite;
+ if (errno == EAGAIN)
+ return 0;
+
+ virReportSystemError(errno, "%s",
+ _("Cannot write data"));
+ return -1;
+ }
+ if (ret == 0) {
+ virReportSystemError(EIO, "%s",
+ _("End of file while writing data"));
+ return -1;
+ }
+
+ return ret;
+}
+
+
+int virNetSocketListen(virNetSocketPtr sock)
+{
+ if (listen(sock->fd, 30) < 0) {
+ virReportSystemError(errno, "%s", _("Unable to listen on socket"));
+ return -1;
+ }
+ return 0;
+}
+
+int virNetSocketAccept(virNetSocketPtr sock, virNetSocketPtr *clientsock)
+{
+ int fd;
+ virSocketAddr localAddr;
+ virSocketAddr remoteAddr;
+
+ *clientsock = NULL;
+
+ memset(&localAddr, 0, sizeof(localAddr));
+ memset(&remoteAddr, 0, sizeof(remoteAddr));
+
+ remoteAddr.len = sizeof(remoteAddr.data.stor);
+ if ((fd = accept(sock->fd, &remoteAddr.data.sa, &remoteAddr.len)) < 0) {
+ if (errno == ECONNABORTED ||
+ errno == EAGAIN)
+ return 0;
+
+ virReportSystemError(errno, "%s",
+ _("Unable to accept client"));
+ return -1;
+ }
+
+ localAddr.len = sizeof(localAddr.data);
+ if (getsockname(fd, &localAddr.data.sa, &localAddr.len) < 0) {
+ virReportSystemError(errno, "%s", _("Unable to get local socket name"));
+ VIR_FORCE_CLOSE(fd);
+ return -1;
+ }
+
+ if (!(*clientsock = virNetSocketNew(&localAddr,
+ &remoteAddr,
+ true,
+ fd, -1, 0))) {
+ VIR_FORCE_CLOSE(fd);
+ return -1;
+ }
+
+ return 0;
+}
+
+
+static void virNetSocketEventHandle(int fd ATTRIBUTE_UNUSED,
+ int watch ATTRIBUTE_UNUSED,
+ int events,
+ void *opaque)
+{
+ virNetSocketPtr sock = opaque;
+
+ sock->func(sock, events, sock->opaque);
+}
+
+int virNetSocketAddIOCallback(virNetSocketPtr sock,
+ int events,
+ virNetSocketIOFunc func,
+ void *opaque)
+{
+ if (sock->watch > 0) {
+ VIR_DEBUG("Watch already registered on socket %p", sock);
+ return -1;
+ }
+
+ if ((sock->watch = virEventAddHandle(sock->fd,
+ events,
+ virNetSocketEventHandle,
+ sock,
+ NULL)) < 0) {
+ VIR_WARN("Failed to register watch on socket %p", sock);
+ return -1;
+ }
+ sock->func = func;
+ sock->opaque = opaque;
+
+ return 0;
+}
+
+void virNetSocketUpdateIOCallback(virNetSocketPtr sock,
+ int events)
+{
+ if (sock->watch <= 0) {
+ VIR_DEBUG("Watch not registered on socket %p", sock);
+ return;
+ }
+
+ virEventUpdateHandle(sock->watch, events);
+}
+
+void virNetSocketRemoveIOCallback(virNetSocketPtr sock)
+{
+ if (sock->watch <= 0) {
+ VIR_DEBUG("Watch not registered on socket %p", sock);
+ return;
+ }
+
+ virEventRemoveHandle(sock->watch);
+ sock->watch = 0;
+}
--- /dev/null
+/*
+ * Copyright (C) 2011 Red Hat, Inc.
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
+ *
+ * Author: Daniel P. Berrange <berrange@redhat.com>
+ */
+
+#include <config.h>
+
+#include <stdlib.h>
+#include <signal.h>
+#ifdef HAVE_IFADDRS_H
+# include <ifaddrs.h>
+#endif
+
+#include "testutils.h"
+#include "util.h"
+#include "virterror_internal.h"
+#include "memory.h"
+#include "logging.h"
+#include "files.h"
+
+#include "rpc/virnetsocket.h"
+
+#define VIR_FROM_THIS VIR_FROM_RPC
+
+#if HAVE_IFADDRS_H
+# define BASE_PORT 5672
+
+static int
+checkProtocols(bool *hasIPv4, bool *hasIPv6,
+ int *freePort)
+{
+ struct ifaddrs *ifaddr = NULL, *ifa;
+ struct sockaddr_in in4;
+ struct sockaddr_in6 in6;
+ int s4 = -1, s6 = -1;
+ int i;
+ int ret = -1;
+
+ *hasIPv4 = *hasIPv6 = false;
+ *freePort = 0;
+
+ if (getifaddrs(&ifaddr) < 0)
+ goto cleanup;
+
+ for (ifa = ifaddr; ifa != NULL; ifa = ifa->ifa_next) {
+ if (!ifa->ifa_addr)
+ continue;
+
+ if (ifa->ifa_addr->sa_family == AF_INET)
+ *hasIPv4 = true;
+ if (ifa->ifa_addr->sa_family == AF_INET6)
+ *hasIPv6 = true;
+ }
+
+ VIR_DEBUG("Protocols: v4 %d v6 %d\n", *hasIPv4, *hasIPv6);
+
+ freeifaddrs(ifaddr);
+
+ for (i = 0 ; i < 50 ; i++) {
+ int only = 1;
+ if ((s4 = socket(AF_INET, SOCK_STREAM, 0)) < 0)
+ goto cleanup;
+
+ if ((s6 = socket(AF_INET6, SOCK_STREAM, 0)) < 0)
+ goto cleanup;
+
+ if (setsockopt(s6, IPPROTO_IPV6, IPV6_V6ONLY, &only, sizeof(only)) < 0)
+ goto cleanup;
+
+ memset(&in4, 0, sizeof(in4));
+ memset(&in6, 0, sizeof(in6));
+
+ in4.sin_family = AF_INET;
+ in4.sin_port = htons(BASE_PORT + i);
+ in4.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+ in6.sin6_family = AF_INET6;
+ in6.sin6_port = htons(BASE_PORT + i);
+ in6.sin6_addr = in6addr_loopback;
+
+ if (bind(s4, (struct sockaddr *)&in4, sizeof(in4)) < 0) {
+ if (errno == EADDRINUSE) {
+ VIR_FORCE_CLOSE(s4);
+ VIR_FORCE_CLOSE(s6);
+ continue;
+ }
+ goto cleanup;
+ }
+ if (bind(s6, (struct sockaddr *)&in6, sizeof(in6)) < 0) {
+ if (errno == EADDRINUSE) {
+ VIR_FORCE_CLOSE(s4);
+ VIR_FORCE_CLOSE(s6);
+ continue;
+ }
+ goto cleanup;
+ }
+
+ *freePort = BASE_PORT + i;
+ break;
+ }
+
+ VIR_DEBUG("Choose port %d\n", *freePort);
+
+ ret = 0;
+
+cleanup:
+ VIR_FORCE_CLOSE(s4);
+ VIR_FORCE_CLOSE(s6);
+ return ret;
+}
+
+
+struct testTCPData {
+ const char *lnode;
+ int port;
+ const char *cnode;
+};
+
+static int testSocketTCPAccept(const void *opaque)
+{
+ virNetSocketPtr *lsock = NULL; /* Listen socket */
+ size_t nlsock = 0, i;
+ virNetSocketPtr ssock = NULL; /* Server socket */
+ virNetSocketPtr csock = NULL; /* Client socket */
+ const struct testTCPData *data = opaque;
+ int ret = -1;
+ char portstr[100];
+
+ snprintf(portstr, sizeof(portstr), "%d", data->port);
+
+ if (virNetSocketNewListenTCP(data->lnode, portstr, &lsock, &nlsock) < 0)
+ goto cleanup;
+
+ for (i = 0 ; i < nlsock ; i++) {
+ if (virNetSocketListen(lsock[i]) < 0)
+ goto cleanup;
+ }
+
+ if (virNetSocketNewConnectTCP(data->cnode, portstr, &csock) < 0)
+ goto cleanup;
+
+ virNetSocketFree(csock);
+
+ for (i = 0 ; i < nlsock ; i++) {
+ if (virNetSocketAccept(lsock[i], &ssock) != -1 && ssock) {
+ char c = 'a';
+ if (virNetSocketWrite(ssock, &c, 1) != -1 &&
+ virNetSocketRead(ssock, &c, 1) != -1) {
+ VIR_DEBUG("Unexpected client socket present");
+ goto cleanup;
+ }
+ }
+ virNetSocketFree(ssock);
+ ssock = NULL;
+ }
+
+ ret = 0;
+
+cleanup:
+ virNetSocketFree(ssock);
+ for (i = 0 ; i < nlsock ; i++)
+ virNetSocketFree(lsock[i]);
+ VIR_FREE(lsock);
+ return ret;
+}
+#endif
+
+
+#ifndef WIN32
+static int testSocketUNIXAccept(const void *data ATTRIBUTE_UNUSED)
+{
+ virNetSocketPtr lsock = NULL; /* Listen socket */
+ virNetSocketPtr ssock = NULL; /* Server socket */
+ virNetSocketPtr csock = NULL; /* Client socket */
+ int ret = -1;
+
+ char *path;
+ if (progname[0] == '/') {
+ if (virAsprintf(&path, "%s-test.sock", progname) < 0) {
+ virReportOOMError();
+ goto cleanup;
+ }
+ } else {
+ if (virAsprintf(&path, "%s/%s-test.sock", abs_builddir, progname) < 0) {
+ virReportOOMError();
+ goto cleanup;
+ }
+ }
+
+ if (virNetSocketNewListenUNIX(path, 0700, getgid(), &lsock) < 0)
+ goto cleanup;
+
+ if (virNetSocketListen(lsock) < 0)
+ goto cleanup;
+
+ if (virNetSocketNewConnectUNIX(path, false, NULL, &csock) < 0)
+ goto cleanup;
+
+ virNetSocketFree(csock);
+
+ if (virNetSocketAccept(lsock, &ssock) != -1) {
+ char c = 'a';
+ if (virNetSocketWrite(ssock, &c, 1) != -1) {
+ VIR_DEBUG("Unexpected client socket present");
+ goto cleanup;
+ }
+ }
+
+ ret = 0;
+
+cleanup:
+ VIR_FREE(path);
+ virNetSocketFree(lsock);
+ virNetSocketFree(ssock);
+ return ret;
+}
+
+
+static int testSocketUNIXAddrs(const void *data ATTRIBUTE_UNUSED)
+{
+ virNetSocketPtr lsock = NULL; /* Listen socket */
+ virNetSocketPtr ssock = NULL; /* Server socket */
+ virNetSocketPtr csock = NULL; /* Client socket */
+ int ret = -1;
+
+ char *path;
+ if (progname[0] == '/') {
+ if (virAsprintf(&path, "%s-test.sock", progname) < 0) {
+ virReportOOMError();
+ goto cleanup;
+ }
+ } else {
+ if (virAsprintf(&path, "%s/%s-test.sock", abs_builddir, progname) < 0) {
+ virReportOOMError();
+ goto cleanup;
+ }
+ }
+
+ if (virNetSocketNewListenUNIX(path, 0700, getgid(), &lsock) < 0)
+ goto cleanup;
+
+ if (STRNEQ(virNetSocketLocalAddrString(lsock), "127.0.0.1;0")) {
+ VIR_DEBUG("Unexpected local address");
+ goto cleanup;
+ }
+
+ if (virNetSocketRemoteAddrString(lsock) != NULL) {
+ VIR_DEBUG("Unexpected remote address");
+ goto cleanup;
+ }
+
+ if (virNetSocketListen(lsock) < 0)
+ goto cleanup;
+
+ if (virNetSocketNewConnectUNIX(path, false, NULL, &csock) < 0)
+ goto cleanup;
+
+ if (STRNEQ(virNetSocketLocalAddrString(csock), "127.0.0.1;0")) {
+ VIR_DEBUG("Unexpected local address");
+ goto cleanup;
+ }
+
+ if (STRNEQ(virNetSocketRemoteAddrString(csock), "127.0.0.1;0")) {
+ VIR_DEBUG("Unexpected local address");
+ goto cleanup;
+ }
+
+
+ if (virNetSocketAccept(lsock, &ssock) < 0) {
+ VIR_DEBUG("Unexpected client socket missing");
+ goto cleanup;
+ }
+
+
+ if (STRNEQ(virNetSocketLocalAddrString(ssock), "127.0.0.1;0")) {
+ VIR_DEBUG("Unexpected local address");
+ goto cleanup;
+ }
+
+ if (STRNEQ(virNetSocketRemoteAddrString(ssock), "127.0.0.1;0")) {
+ VIR_DEBUG("Unexpected local address");
+ goto cleanup;
+ }
+
+
+ ret = 0;
+
+cleanup:
+ VIR_FREE(path);
+ virNetSocketFree(lsock);
+ virNetSocketFree(ssock);
+ virNetSocketFree(csock);
+ return ret;
+}
+
+static int testSocketCommandNormal(const void *data ATTRIBUTE_UNUSED)
+{
+ virNetSocketPtr csock = NULL; /* Client socket */
+ char buf[100];
+ size_t i;
+ int ret = -1;
+ virCommandPtr cmd = virCommandNewArgList("/bin/cat", "/dev/zero", NULL);
+ virCommandAddEnvPassCommon(cmd);
+
+ if (virNetSocketNewConnectCommand(cmd, &csock) < 0)
+ goto cleanup;
+
+ virNetSocketSetBlocking(csock, true);
+
+ if (virNetSocketRead(csock, buf, sizeof(buf)) < 0)
+ goto cleanup;
+
+ for (i = 0 ; i < sizeof(buf) ; i++)
+ if (buf[i] != '\0')
+ goto cleanup;
+
+ ret = 0;
+
+cleanup:
+ virNetSocketFree(csock);
+ return ret;
+}
+
+static int testSocketCommandFail(const void *data ATTRIBUTE_UNUSED)
+{
+ virNetSocketPtr csock = NULL; /* Client socket */
+ char buf[100];
+ int ret = -1;
+ virCommandPtr cmd = virCommandNewArgList("/bin/cat", "/dev/does-not-exist", NULL);
+ virCommandAddEnvPassCommon(cmd);
+
+ if (virNetSocketNewConnectCommand(cmd, &csock) < 0)
+ goto cleanup;
+
+ virNetSocketSetBlocking(csock, true);
+
+ if (virNetSocketRead(csock, buf, sizeof(buf)) == 0)
+ goto cleanup;
+
+ ret = 0;
+
+cleanup:
+ virNetSocketFree(csock);
+ return ret;
+}
+
+struct testSSHData {
+ const char *nodename;
+ const char *service;
+ const char *binary;
+ const char *username;
+ bool noTTY;
+ const char *netcat;
+ const char *path;
+
+ const char *expectOut;
+ bool failConnect;
+ bool dieEarly;
+};
+
+static int testSocketSSH(const void *opaque)
+{
+ const struct testSSHData *data = opaque;
+ virNetSocketPtr csock = NULL; /* Client socket */
+ int ret = -1;
+ char buf[1024];
+
+ if (virNetSocketNewConnectSSH(data->nodename,
+ data->service,
+ data->binary,
+ data->username,
+ data->noTTY,
+ data->netcat,
+ data->path,
+ &csock) < 0)
+ goto cleanup;
+
+ virNetSocketSetBlocking(csock, true);
+
+ if (data->failConnect) {
+ if (virNetSocketRead(csock, buf, sizeof(buf)-1) >= 0) {
+ VIR_DEBUG("Expected connect failure, but got some socket data");
+ goto cleanup;
+ }
+ } else {
+ ssize_t rv;
+ if ((rv = virNetSocketRead(csock, buf, sizeof(buf)-1)) < 0) {
+ VIR_DEBUG("Didn't get any socket data");
+ goto cleanup;
+ }
+ buf[rv] = '\0';
+
+ if (!STREQ(buf, data->expectOut)) {
+ virtTestDifference(stderr, data->expectOut, buf);
+ goto cleanup;
+ }
+
+ if (data->dieEarly &&
+ virNetSocketRead(csock, buf, sizeof(buf)-1) >= 0) {
+ VIR_DEBUG("Got too much socket data");
+ goto cleanup;
+ }
+ }
+
+ ret = 0;
+
+cleanup:
+ virNetSocketFree(csock);
+ return ret;
+}
+
+#endif
+
+
+static int
+mymain(void)
+{
+ int ret = 0;
+#ifdef HAVE_IFADDRS_H
+ bool hasIPv4, hasIPv6;
+ int freePort;
+#endif
+
+ signal(SIGPIPE, SIG_IGN);
+
+#ifdef HAVE_IFADDRS_H
+ if (checkProtocols(&hasIPv4, &hasIPv6, &freePort) < 0) {
+ fprintf(stderr, "Cannot identify IPv4/6 availability\n");
+ return (EXIT_FAILURE);
+ }
+
+ if (hasIPv4) {
+ struct testTCPData tcpData = { "127.0.0.1", freePort, "127.0.0.1" };
+ if (virtTestRun("Socket TCP/IPv4 Accept", 1, testSocketTCPAccept, &tcpData) < 0)
+ ret = -1;
+ }
+ if (hasIPv6) {
+ struct testTCPData tcpData = { "::1", freePort, "::1" };
+ if (virtTestRun("Socket TCP/IPv6 Accept", 1, testSocketTCPAccept, &tcpData) < 0)
+ ret = -1;
+ }
+ if (hasIPv6 && hasIPv4) {
+ struct testTCPData tcpData = { NULL, freePort, "127.0.0.1" };
+ if (virtTestRun("Socket TCP/IPv4+IPv6 Accept", 1, testSocketTCPAccept, &tcpData) < 0)
+ ret = -1;
+
+ tcpData.cnode = "::1";
+ if (virtTestRun("Socket TCP/IPv4+IPv6 Accept", 1, testSocketTCPAccept, &tcpData) < 0)
+ ret = -1;
+ }
+#endif
+
+#ifndef WIN32
+ if (virtTestRun("Socket UNIX Accept", 1, testSocketUNIXAccept, NULL) < 0)
+ ret = -1;
+
+ if (virtTestRun("Socket UNIX Addrs", 1, testSocketUNIXAddrs, NULL) < 0)
+ ret = -1;
+
+ if (virtTestRun("Socket External Command /dev/zero", 1, testSocketCommandNormal, NULL) < 0)
+ ret = -1;
+ if (virtTestRun("Socket External Command /dev/does-not-exist", 1, testSocketCommandFail, NULL) < 0)
+ ret = -1;
+
+ struct testSSHData sshData1 = {
+ .nodename = "somehost",
+ .path = "/tmp/socket",
+ .expectOut = "somehost nc -U /tmp/socket\n",
+ };
+ if (virtTestRun("SSH test 1", 1, testSocketSSH, &sshData1) < 0)
+ ret = -1;
+
+ struct testSSHData sshData2 = {
+ .nodename = "somehost",
+ .service = "9000",
+ .username = "fred",
+ .netcat = "netcat",
+ .noTTY = true,
+ .path = "/tmp/socket",
+ .expectOut = "-p 9000 -l fred -T -o BatchMode=yes -e none somehost netcat -U /tmp/socket\n",
+ };
+ if (virtTestRun("SSH test 2", 1, testSocketSSH, &sshData2) < 0)
+ ret = -1;
+
+ struct testSSHData sshData3 = {
+ .nodename = "nosuchhost",
+ .path = "/tmp/socket",
+ .failConnect = true,
+ };
+ if (virtTestRun("SSH test 3", 1, testSocketSSH, &sshData3) < 0)
+ ret = -1;
+
+ struct testSSHData sshData4 = {
+ .nodename = "crashyhost",
+ .path = "/tmp/socket",
+ .expectOut = "crashyhost nc -U /tmp/socket\n",
+ .dieEarly = true,
+ };
+ if (virtTestRun("SSH test 4", 1, testSocketSSH, &sshData4) < 0)
+ ret = -1;
+
+#endif
+
+ return (ret==0 ? EXIT_SUCCESS : EXIT_FAILURE);
+}
+
+VIRT_TEST_MAIN(mymain)