]> xenbits.xensource.com Git - libvirt.git/commitdiff
rpc: avoid libvirtd crash on unexpected client close
authorEric Blake <eblake@redhat.com>
Mon, 1 Aug 2011 19:41:38 +0000 (13:41 -0600)
committerEric Blake <eblake@redhat.com>
Tue, 2 Aug 2011 13:46:37 +0000 (07:46 -0600)
Steps to reproduce this problem (vm1 is not running):
for i in `seq 50`; do virsh managedsave vm1& done; killall virsh

Pre-patch, virNetServerClientClose could end up setting client->sock
to NULL prior to other cleanup functions trying to use client->sock.
This fixes things by checking for NULL in more places, and by deferring
the cleanup until after all queued messages have been served.

* src/rpc/virnetserverclient.c (virNetServerClientRegisterEvent)
(virNetServerClientGetFD, virNetServerClientIsSecure)
(virNetServerClientLocalAddrString)
(virNetServerClientRemoteAddrString): Check for closed socket.
(virNetServerClientClose): Rearrange close sequence.
Analysis from Wen Congyang.

src/rpc/virnetserverclient.c

index 3c0dba8d65f9f20e7a454a852ef082c1fbf8fea9..2f6c04076261ed920181a7520a60b85be6aa362b 100644 (file)
@@ -177,7 +177,8 @@ static int virNetServerClientRegisterEvent(virNetServerClientPtr client)
 
     client->refs++;
     VIR_DEBUG("Registering client event callback %d", mode);
-    if (virNetSocketAddIOCallback(client->sock,
+    if (!client->sock ||
+        virNetSocketAddIOCallback(client->sock,
                                   mode,
                                   virNetServerClientDispatchEvent,
                                   client,
@@ -386,9 +387,10 @@ int virNetServerClientGetTLSKeySize(virNetServerClientPtr client)
 
 int virNetServerClientGetFD(virNetServerClientPtr client)
 {
-    int fd = 0;
+    int fd = -1;
     virNetServerClientLock(client);
-    fd = virNetSocketGetFD(client->sock);
+    if (client->sock)
+        fd = virNetSocketGetFD(client->sock);
     virNetServerClientUnlock(client);
     return fd;
 }
@@ -396,9 +398,10 @@ int virNetServerClientGetFD(virNetServerClientPtr client)
 int virNetServerClientGetLocalIdentity(virNetServerClientPtr client,
                                        uid_t *uid, pid_t *pid)
 {
-    int ret;
+    int ret = -1;
     virNetServerClientLock(client);
-    ret = virNetSocketGetLocalIdentity(client->sock, uid, pid);
+    if (client->sock)
+        ret = virNetSocketGetLocalIdentity(client->sock, uid, pid);
     virNetServerClientUnlock(client);
     return ret;
 }
@@ -413,7 +416,7 @@ bool virNetServerClientIsSecure(virNetServerClientPtr client)
     if (client->sasl)
         secure = true;
 #endif
-    if (virNetSocketIsLocal(client->sock))
+    if (client->sock && virNetSocketIsLocal(client->sock))
         secure = true;
     virNetServerClientUnlock(client);
     return secure;
@@ -502,12 +505,16 @@ void virNetServerClientSetDispatcher(virNetServerClientPtr client,
 
 const char *virNetServerClientLocalAddrString(virNetServerClientPtr client)
 {
+    if (!client->sock)
+        return NULL;
     return virNetSocketLocalAddrString(client->sock);
 }
 
 
 const char *virNetServerClientRemoteAddrString(virNetServerClientPtr client)
 {
+    if (!client->sock)
+        return NULL;
     return virNetSocketRemoteAddrString(client->sock);
 }
 
@@ -570,10 +577,7 @@ void virNetServerClientClose(virNetServerClientPtr client)
         virNetTLSSessionFree(client->tls);
         client->tls = NULL;
     }
-    if (client->sock) {
-        virNetSocketFree(client->sock);
-        client->sock = NULL;
-    }
+    client->wantClose = true;
 
     while (client->rx) {
         virNetMessagePtr msg
@@ -586,6 +590,11 @@ void virNetServerClientClose(virNetServerClientPtr client)
         virNetMessageFree(msg);
     }
 
+    if (client->sock) {
+        virNetSocketFree(client->sock);
+        client->sock = NULL;
+    }
+
     virNetServerClientUnlock(client);
 }