]> xenbits.xensource.com Git - libvirt.git/commitdiff
Add client side support for FD passing
authorDaniel P. Berrange <berrange@redhat.com>
Fri, 21 Oct 2011 10:48:03 +0000 (11:48 +0100)
committerDaniel P. Berrange <berrange@redhat.com>
Fri, 28 Oct 2011 09:42:54 +0000 (10:42 +0100)
Extend the RPC client code to allow file descriptors to be sent
to the server with calls, and received back with replies.

* src/remote/remote_driver.c: Stub extra args
* src/libvirt_private.syms, src/rpc/virnetclient.c,
  src/rpc/virnetclient.h, src/rpc/virnetclientprogram.c,
  src/rpc/virnetclientprogram.h: Extend APIs to allow
  FD passing

src/libvirt_private.syms
src/remote/remote_driver.c
src/rpc/virnetclient.c
src/rpc/virnetclient.h
src/rpc/virnetclientprogram.c
src/rpc/virnetclientprogram.h

index 8c74f56bce819e71423620e04203dd14b790f4f5..d5368877d67bd503def52b547b1c0f732b9485e1 100644 (file)
@@ -1184,6 +1184,10 @@ virFileFdopen;
 virFileRewrite;
 
 
+# virnetclient.h
+virNetClientHasPassFD;
+
+
 # virnetmessage.h
 virNetMessageClear;
 virNetMessageDecodeNumFDs;
index e98ebd737e235580029ce30eb0fc77080b60887c..382bb421af025f6e4ee3abe58d7ede6803d918b8 100644 (file)
@@ -4152,6 +4152,7 @@ call (virConnectPtr conn ATTRIBUTE_UNUSED,
                                  client,
                                  counter,
                                  proc_nr,
+                                 0, NULL, NULL, NULL,
                                  args_filter, args,
                                  ret_filter, ret);
     remoteDriverLock(priv);
index 085dc8d97e05bacb93fc8d47b2f102bc0e30ce86..2b5f67c4d3573ba1538e4dd8c80e660730c11950 100644 (file)
@@ -258,6 +258,16 @@ int virNetClientDupFD(virNetClientPtr client, bool cloexec)
 }
 
 
+bool virNetClientHasPassFD(virNetClientPtr client)
+{
+    bool hasPassFD;
+    virNetClientLock(client);
+    hasPassFD = virNetSocketHasPassFD(client->sock);
+    virNetClientUnlock(client);
+    return hasPassFD;
+}
+
+
 void virNetClientFree(virNetClientPtr client)
 {
     int i;
@@ -684,6 +694,7 @@ static int virNetClientCallDispatchStream(virNetClientPtr client)
 static int
 virNetClientCallDispatch(virNetClientPtr client)
 {
+    size_t i;
     if (virNetMessageDecodeHeader(&client->msg) < 0)
         return -1;
 
@@ -697,6 +708,15 @@ virNetClientCallDispatch(virNetClientPtr client)
     case VIR_NET_REPLY: /* Normal RPC replies */
         return virNetClientCallDispatchReply(client);
 
+    case VIR_NET_REPLY_WITH_FDS: /* Normal RPC replies with FDs */
+        if (virNetMessageDecodeNumFDs(&client->msg) < 0)
+            return -1;
+        for (i = 0 ; i < client->msg.nfds ; i++) {
+            if ((client->msg.fds[i] = virNetSocketRecvFD(client->sock)) < 0)
+                return -1;
+        }
+        return virNetClientCallDispatchReply(client);
+
     case VIR_NET_MESSAGE: /* Async notifications */
         return virNetClientCallDispatchMessage(client);
 
@@ -728,6 +748,11 @@ virNetClientIOWriteMessage(virNetClientPtr client,
     thecall->msg->bufferOffset += ret;
 
     if (thecall->msg->bufferOffset == thecall->msg->bufferLength) {
+        size_t i;
+        for (i = 0 ; i < thecall->msg->nfds ; i++) {
+            if (virNetSocketSendFD(client->sock, thecall->msg->fds[i]) < 0)
+                return -1;
+        }
         thecall->msg->bufferOffset = thecall->msg->bufferLength = 0;
         if (thecall->expectReply)
             thecall->mode = VIR_NET_CLIENT_MODE_WAIT_RX;
index 1fabcfde86b52eace924cd6648582583a1d241e5..fb679e897318c948b0ebb8af8a1e604ac52210f0 100644 (file)
@@ -56,6 +56,8 @@ void virNetClientRef(virNetClientPtr client);
 int virNetClientGetFD(virNetClientPtr client);
 int virNetClientDupFD(virNetClientPtr client, bool cloexec);
 
+bool virNetClientHasPassFD(virNetClientPtr client);
+
 int virNetClientAddProgram(virNetClientPtr client,
                            virNetClientProgramPtr prog);
 
index 33fa5078b78ba39941e72b57dcd2269d242aa691..36e23841e1021822179310afcf797be70f651231 100644 (file)
@@ -22,6 +22,8 @@
 
 #include <config.h>
 
+#include <unistd.h>
+
 #include "virnetclientprogram.h"
 #include "virnetclient.h"
 #include "virnetprotocol.h"
@@ -29,6 +31,8 @@
 #include "memory.h"
 #include "virterror_internal.h"
 #include "logging.h"
+#include "util.h"
+#include "virfile.h"
 
 #define VIR_FROM_THIS VIR_FROM_RPC
 #define virNetError(code, ...)                                    \
@@ -267,10 +271,20 @@ int virNetClientProgramCall(virNetClientProgramPtr prog,
                             virNetClientPtr client,
                             unsigned serial,
                             int proc,
+                            size_t noutfds,
+                            int *outfds,
+                            size_t *ninfds,
+                            int **infds,
                             xdrproc_t args_filter, void *args,
                             xdrproc_t ret_filter, void *ret)
 {
     virNetMessagePtr msg;
+    size_t i;
+
+    if (infds)
+        *infds = NULL;
+    if (ninfds)
+        *ninfds = 0;
 
     if (!(msg = virNetMessageNew(false)))
         return -1;
@@ -278,13 +292,38 @@ int virNetClientProgramCall(virNetClientProgramPtr prog,
     msg->header.prog = prog->program;
     msg->header.vers = prog->version;
     msg->header.status = VIR_NET_OK;
-    msg->header.type = VIR_NET_CALL;
+    msg->header.type = noutfds ? VIR_NET_CALL_WITH_FDS : VIR_NET_CALL;
     msg->header.serial = serial;
     msg->header.proc = proc;
+    msg->nfds = noutfds;
+    if (VIR_ALLOC_N(msg->fds, msg->nfds) < 0) {
+        virReportOOMError();
+        goto error;
+    }
+    for (i = 0 ; i < msg->nfds ; i++)
+        msg->fds[i] = -1;
+    for (i = 0 ; i < msg->nfds ; i++) {
+        if ((msg->fds[i] = dup(outfds[i])) < 0) {
+            virReportSystemError(errno,
+                                 _("Cannot duplicate FD %d"),
+                                 outfds[i]);
+            goto error;
+        }
+        if (virSetInherit(msg->fds[i], false) < 0) {
+            virReportSystemError(errno,
+                                 _("Cannot set close-on-exec %d"),
+                                 msg->fds[i]);
+            goto error;
+        }
+    }
 
     if (virNetMessageEncodeHeader(msg) < 0)
         goto error;
 
+    if (msg->nfds &&
+        virNetMessageEncodeNumFDs(msg) < 0)
+        goto error;
+
     if (virNetMessageEncodePayload(msg, args_filter, args) < 0)
         goto error;
 
@@ -295,7 +334,8 @@ int virNetClientProgramCall(virNetClientProgramPtr prog,
      * virNetClientSend should have validated the reply,
      * but it doesn't hurt to check again.
      */
-    if (msg->header.type != VIR_NET_REPLY) {
+    if (msg->header.type != VIR_NET_REPLY &&
+        msg->header.type != VIR_NET_REPLY_WITH_FDS) {
         virNetError(VIR_ERR_INTERNAL_ERROR,
                     _("Unexpected message type %d"), msg->header.type);
         goto error;
@@ -315,6 +355,30 @@ int virNetClientProgramCall(virNetClientProgramPtr prog,
 
     switch (msg->header.status) {
     case VIR_NET_OK:
+        if (infds && ninfds) {
+            *ninfds = msg->nfds;
+            if (VIR_ALLOC_N(*infds, *ninfds) < 0) {
+                virReportOOMError();
+                goto error;
+            }
+            for (i = 0 ; i < *ninfds ; i++)
+                *infds[i] = -1;
+            for (i = 0 ; i < *ninfds ; i++) {
+                if ((*infds[i] = dup(msg->fds[i])) < 0) {
+                    virReportSystemError(errno,
+                                         _("Cannot duplicate FD %d"),
+                                         msg->fds[i]);
+                    goto error;
+                }
+                if (virSetInherit(*infds[i], false) < 0) {
+                    virReportSystemError(errno,
+                                         _("Cannot set close-on-exec %d"),
+                                         *infds[i]);
+                    goto error;
+                }
+            }
+
+        }
         if (virNetMessageDecodePayload(msg, ret_filter, ret) < 0)
             goto error;
         break;
@@ -335,5 +399,9 @@ int virNetClientProgramCall(virNetClientProgramPtr prog,
 
 error:
     virNetMessageFree(msg);
+    if (infds && ninfds) {
+        for (i = 0 ; i < *ninfds ; i++)
+            VIR_FORCE_CLOSE(*infds[i]);
+    }
     return -1;
 }
index 82ae2c66fbb05a7b09da7db423d73e89814579a5..14a4c9650077711317638f12decad52d66913166 100644 (file)
@@ -77,6 +77,10 @@ int virNetClientProgramCall(virNetClientProgramPtr prog,
                             virNetClientPtr client,
                             unsigned serial,
                             int proc,
+                            size_t noutfds,
+                            int *outfds,
+                            size_t *ninfds,
+                            int **infds,
                             xdrproc_t args_filter, void *args,
                             xdrproc_t ret_filter, void *ret);