]> xenbits.xensource.com Git - libvirt.git/commitdiff
Turn virNetTLSContext and virNetTLSSession into virObject instances
authorDaniel P. Berrange <berrange@redhat.com>
Wed, 11 Jul 2012 13:35:48 +0000 (14:35 +0100)
committerDaniel P. Berrange <berrange@redhat.com>
Tue, 7 Aug 2012 10:47:41 +0000 (11:47 +0100)
Make virNetTLSContext and virNetTLSSession use the virObject
APIs for reference counting

Signed-off-by: Daniel P. Berrange <berrange@redhat.com>
13 files changed:
cfg.mk
daemon/libvirtd.c
src/libvirt_private.syms
src/libvirt_probes.d
src/remote/remote_driver.c
src/rpc/virnetclient.c
src/rpc/virnetserver.c
src/rpc/virnetserverclient.c
src/rpc/virnetserverservice.c
src/rpc/virnetsocket.c
src/rpc/virnettlscontext.c
src/rpc/virnettlscontext.h
tests/virnettlscontexttest.c

diff --git a/cfg.mk b/cfg.mk
index ccff146a3f52b795c278e751ddd1b13962a6b042..9ba0d665752ea6b4c243a9beeccc82f35dde9805 100644 (file)
--- a/cfg.mk
+++ b/cfg.mk
@@ -158,7 +158,6 @@ useless_free_options =                              \
   --name=virNetSocketFree                       \
   --name=virNetSASLContextFree                  \
   --name=virNetSASLSessionFree                  \
-  --name=virNetTLSSessionFree                   \
   --name=virNWFilterDefFree                    \
   --name=virNWFilterEntryFree                  \
   --name=virNWFilterHashTableFree              \
index 49b69ef8ce8f48827bedc3f92b68927f3a9728aa..7dd7d5cdc6dccb81592ad0392b08da16f968f77e 100644 (file)
@@ -541,7 +541,7 @@ static int daemonSetupNetworking(virNetServerPtr srv,
                                             false,
                                             config->max_client_requests,
                                             ctxt))) {
-                virNetTLSContextFree(ctxt);
+                virObjectUnref(ctxt);
                 goto error;
             }
             if (virNetServerAddService(srv, svcTLS,
@@ -549,7 +549,7 @@ static int daemonSetupNetworking(virNetServerPtr srv,
                                        !config->listen_tcp ? "_libvirt._tcp" : NULL) < 0)
                 goto error;
 
-            virNetTLSContextFree(ctxt);
+            virObjectUnref(ctxt);
         }
     }
 
index 2ab94a6155949f08613c6d0bea1841d9ffc7b6f1..acaa6f3c336c8b16d8750cd5fdf5844480ad340f 100644 (file)
@@ -1625,20 +1625,16 @@ virNetSocketWrite;
 
 # virnettlscontext.h
 virNetTLSContextCheckCertificate;
-virNetTLSContextFree;
 virNetTLSContextNewClient;
 virNetTLSContextNewClientPath;
 virNetTLSContextNewServer;
 virNetTLSContextNewServerPath;
-virNetTLSContextRef;
 virNetTLSInit;
-virNetTLSSessionFree;
 virNetTLSSessionGetHandshakeStatus;
 virNetTLSSessionGetKeySize;
 virNetTLSSessionHandshake;
 virNetTLSSessionNew;
 virNetTLSSessionRead;
-virNetTLSSessionRef;
 virNetTLSSessionSetIOCallbacks;
 virNetTLSSessionWrite;
 
index ceb3caa189f2c3c0b79fb16cc9d290349820474e..3b138a9adfc9d58a1f1efb0600e183eb0f9634f3 100644 (file)
@@ -61,19 +61,15 @@ provider libvirt {
 
        # file: src/rpc/virnettlscontext.c
        # prefix: rpc
-       probe rpc_tls_context_new(void *ctxt, int refs, const char *cacert, const char *cacrl,
+       probe rpc_tls_context_new(void *ctxt, const char *cacert, const char *cacrl,
                                  const char *cert, const char *key, int sanityCheckCert, int requireValidCert, int isServer);
-       probe rpc_tls_context_ref(void *ctxt, int refs);
-       probe rpc_tls_context_free(void *ctxt, int refs);
 
        probe rpc_tls_context_session_allow(void *ctxt, void *sess, const char *dname);
        probe rpc_tls_context_session_deny(void *ctxt, void *sess, const char *dname);
        probe rpc_tls_context_session_fail(void *ctxt, void *sess);
 
 
-       probe rpc_tls_session_new(void *sess, void *ctxt, int refs, const char *hostname, int isServer);
-       probe rpc_tls_session_ref(void *sess, int refs);
-       probe rpc_tls_session_free(void *sess, int refs);
+       probe rpc_tls_session_new(void *sess, void *ctxt, const char *hostname, int isServer);
 
        probe rpc_tls_session_handshake_pass(void *sess);
        probe rpc_tls_session_handshake_fail(void *sess);
index afd367bea65ecea0be5acd002be9075ff285bb43..511608063a51e90db9859afc5b49202c0f3edf17 100644 (file)
@@ -943,7 +943,7 @@ doRemoteClose (virConnectPtr conn, struct private_data *priv)
               (xdrproc_t) xdr_void, (char *) NULL) == -1)
         ret = -1;
 
-    virNetTLSContextFree(priv->tls);
+    virObjectUnref(priv->tls);
     priv->tls = NULL;
     virNetClientClose(priv->client);
     virNetClientFree(priv->client);
index cb373b622a2382ede858469e137eefa369cd24bf..72f55a1a0f8ae7f912a0b33c1243b49b783ae496 100644 (file)
@@ -495,7 +495,7 @@ void virNetClientFree(virNetClientPtr client)
     if (client->sock)
         virNetSocketRemoveIOCallback(client->sock);
     virNetSocketFree(client->sock);
-    virNetTLSSessionFree(client->tls);
+    virObjectUnref(client->tls);
 #if HAVE_SASL
     virNetSASLSessionFree(client->sasl);
 #endif
@@ -532,7 +532,7 @@ virNetClientCloseLocked(virNetClientPtr client)
 
     virNetSocketFree(client->sock);
     client->sock = NULL;
-    virNetTLSSessionFree(client->tls);
+    virObjectUnref(client->tls);
     client->tls = NULL;
 #if HAVE_SASL
     virNetSASLSessionFree(client->sasl);
@@ -712,7 +712,7 @@ int virNetClientSetTLSSession(virNetClientPtr client,
     return 0;
 
 error:
-    virNetTLSSessionFree(client->tls);
+    virObjectUnref(client->tls);
     client->tls = NULL;
     virNetClientUnlock(client);
     return -1;
index 295e8fd8aabdc3dff955202f2f2ef14a019e59a6..afe7640dda4a01ac4a39cf741072dc8ecdd29fe5 100644 (file)
@@ -642,8 +642,7 @@ no_memory:
 int virNetServerSetTLSContext(virNetServerPtr srv,
                               virNetTLSContextPtr tls)
 {
-    srv->tls = tls;
-    virNetTLSContextRef(tls);
+    srv->tls = virObjectRef(tls);
     return 0;
 }
 
index d0a144c25569b4bee4deff5efa12ae9f4c576a71..c419e74e7b401129e6a34123ef58160b24296078 100644 (file)
@@ -346,7 +346,7 @@ virNetServerClientPtr virNetServerClientNew(virNetSocketPtr sock,
     client->sock = sock;
     client->auth = auth;
     client->readonly = readonly;
-    client->tlsCtxt = tls;
+    client->tlsCtxt = virObjectRef(tls);
     client->nrequests_max = nrequests_max;
 
     client->sockTimer = virEventAddTimeout(-1, virNetServerClientSockTimerFunc,
@@ -354,9 +354,6 @@ virNetServerClientPtr virNetServerClientNew(virNetSocketPtr sock,
     if (client->sockTimer < 0)
         goto error;
 
-    if (tls)
-        virNetTLSContextRef(tls);
-
     /* Prepare one for packet receive */
     if (!(client->rx = virNetMessageNew(true)))
         goto error;
@@ -598,8 +595,8 @@ void virNetServerClientFree(virNetServerClientPtr client)
 #endif
     if (client->sockTimer > 0)
         virEventRemoveTimeout(client->sockTimer);
-    virNetTLSSessionFree(client->tls);
-    virNetTLSContextFree(client->tlsCtxt);
+    virObjectUnref(client->tls);
+    virObjectUnref(client->tlsCtxt);
     virNetSocketFree(client->sock);
     virNetServerClientUnlock(client);
     virMutexDestroy(&client->lock);
@@ -654,7 +651,7 @@ void virNetServerClientClose(virNetServerClientPtr client)
         virNetSocketRemoveIOCallback(client->sock);
 
     if (client->tls) {
-        virNetTLSSessionFree(client->tls);
+        virObjectUnref(client->tls);
         client->tls = NULL;
     }
     client->wantClose = true;
index 2880df3b7550189c85bbc9aecc8e43c842d7763c..60fe89f9de7d7d43b0aef9f68dcba961d5f9f2cf 100644 (file)
@@ -116,9 +116,7 @@ virNetServerServicePtr virNetServerServiceNewTCP(const char *nodename,
     svc->auth = auth;
     svc->readonly = readonly;
     svc->nrequests_client_max = nrequests_client_max;
-    svc->tls = tls;
-    if (tls)
-        virNetTLSContextRef(tls);
+    svc->tls = virObjectRef(tls);
 
     if (virNetSocketNewListenTCP(nodename,
                                  service,
@@ -172,9 +170,7 @@ virNetServerServicePtr virNetServerServiceNewUNIX(const char *path,
     svc->auth = auth;
     svc->readonly = readonly;
     svc->nrequests_client_max = nrequests_client_max;
-    svc->tls = tls;
-    if (tls)
-        virNetTLSContextRef(tls);
+    svc->tls = virObjectRef(tls);
 
     svc->nsocks = 1;
     if (VIR_ALLOC_N(svc->socks, svc->nsocks) < 0)
@@ -265,7 +261,7 @@ void virNetServerServiceFree(virNetServerServicePtr svc)
         virNetSocketFree(svc->socks[i]);
     VIR_FREE(svc->socks);
 
-    virNetTLSContextFree(svc->tls);
+    virObjectUnref(svc->tls);
 
     VIR_FREE(svc);
 }
index 88e55250600a903dacabd54c25030fa95fa33566..bca78b5c6d6895307b50ecb4b17d756b1151ee9f 100644 (file)
@@ -748,7 +748,7 @@ void virNetSocketFree(virNetSocketPtr sock)
     /* Make sure it can't send any more I/O during shutdown */
     if (sock->tlsSession)
         virNetTLSSessionSetIOCallbacks(sock->tlsSession, NULL, NULL, NULL);
-    virNetTLSSessionFree(sock->tlsSession);
+    virObjectUnref(sock->tlsSession);
 #if HAVE_SASL
     virNetSASLSessionFree(sock->saslSession);
 #endif
@@ -909,13 +909,12 @@ void virNetSocketSetTLSSession(virNetSocketPtr sock,
                                virNetTLSSessionPtr sess)
 {
     virMutexLock(&sock->lock);
-    virNetTLSSessionFree(sock->tlsSession);
-    sock->tlsSession = sess;
+    virObjectUnref(sock->tlsSession);
+    sock->tlsSession = virObjectRef(sess);
     virNetTLSSessionSetIOCallbacks(sess,
                                    virNetSocketTLSSessionWrite,
                                    virNetSocketTLSSessionRead,
                                    sock);
-    virNetTLSSessionRef(sess);
     virMutexUnlock(&sock->lock);
 }
 
index 5ae22f25dd88555808b05db331cf052d03858b83..9fe6eb1c1d2e2b44d09ab09dbf93a9e63746ad48 100644 (file)
@@ -50,8 +50,9 @@
 #define VIR_FROM_THIS VIR_FROM_RPC
 
 struct _virNetTLSContext {
+    virObject object;
+
     virMutex lock;
-    int refs;
 
     gnutls_certificate_credentials_t x509cred;
     gnutls_dh_params_t dhParams;
@@ -62,9 +63,9 @@ struct _virNetTLSContext {
 };
 
 struct _virNetTLSSession {
-    virMutex lock;
+    virObject object;
 
-    int refs;
+    virMutex lock;
 
     bool handshakeComplete;
 
@@ -76,6 +77,29 @@ struct _virNetTLSSession {
     void *opaque;
 };
 
+static virClassPtr virNetTLSContextClass;
+static virClassPtr virNetTLSSessionClass;
+static void virNetTLSContextDispose(void *obj);
+static void virNetTLSSessionDispose(void *obj);
+
+
+static int virNetTLSContextOnceInit(void)
+{
+    if (!(virNetTLSContextClass = virClassNew("virNetTLSContext",
+                                              sizeof(virNetTLSContext),
+                                              virNetTLSContextDispose)))
+        return -1;
+
+    if (!(virNetTLSSessionClass = virClassNew("virNetTLSSession",
+                                              sizeof(virNetTLSSession),
+                                              virNetTLSSessionDispose)))
+        return -1;
+
+    return 0;
+}
+
+VIR_ONCE_GLOBAL_INIT(virNetTLSContext)
+
 
 static int
 virNetTLSContextCheckCertFile(const char *type, const char *file, bool allowMissing)
@@ -647,10 +671,11 @@ static virNetTLSContextPtr virNetTLSContextNew(const char *cacert,
     char *gnutlsdebug;
     int err;
 
-    if (VIR_ALLOC(ctxt) < 0) {
-        virReportOOMError();
+    if (virNetTLSContextInitialize() < 0)
+        return NULL;
+
+    if (!(ctxt = virObjectNew(virNetTLSContextClass)))
         return NULL;
-    }
 
     if (virMutexInit(&ctxt->lock) < 0) {
         virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
@@ -659,8 +684,6 @@ static virNetTLSContextPtr virNetTLSContextNew(const char *cacert,
         return NULL;
     }
 
-    ctxt->refs = 1;
-
     if ((gnutlsdebug = getenv("LIBVIRT_GNUTLS_DEBUG")) != NULL) {
         int val;
         if (virStrToLong_i(gnutlsdebug, NULL, 10, &val) < 0)
@@ -716,8 +739,8 @@ static virNetTLSContextPtr virNetTLSContextNew(const char *cacert,
     ctxt->isServer = isServer;
 
     PROBE(RPC_TLS_CONTEXT_NEW,
-          "ctxt=%p refs=%d cacert=%s cacrl=%s cert=%s key=%s sanityCheckCert=%d requireValidCert=%d isServer=%d",
-          ctxt, ctxt->refs, cacert, NULLSTR(cacrl), cert, key, sanityCheckCert, requireValidCert, isServer);
+          "ctxt=%p cacert=%s cacrl=%s cert=%s key=%s sanityCheckCert=%d requireValidCert=%d isServer=%d",
+          ctxt, cacert, NULLSTR(cacrl), cert, key, sanityCheckCert, requireValidCert, isServer);
 
     return ctxt;
 
@@ -927,17 +950,6 @@ virNetTLSContextPtr virNetTLSContextNewClient(const char *cacert,
 }
 
 
-void virNetTLSContextRef(virNetTLSContextPtr ctxt)
-{
-    virMutexLock(&ctxt->lock);
-    ctxt->refs++;
-    PROBE(RPC_TLS_CONTEXT_REF,
-          "ctxt=%p refs=%d",
-          ctxt, ctxt->refs);
-    virMutexUnlock(&ctxt->lock);
-}
-
-
 static int virNetTLSContextValidCertificate(virNetTLSContextPtr ctxt,
                                             virNetTLSSessionPtr sess)
 {
@@ -1106,30 +1118,16 @@ cleanup:
     return ret;
 }
 
-void virNetTLSContextFree(virNetTLSContextPtr ctxt)
+void virNetTLSContextDispose(void *obj)
 {
-    if (!ctxt)
-        return;
-
-    virMutexLock(&ctxt->lock);
-    PROBE(RPC_TLS_CONTEXT_FREE,
-          "ctxt=%p refs=%d",
-          ctxt, ctxt->refs);
-    ctxt->refs--;
-    if (ctxt->refs > 0) {
-        virMutexUnlock(&ctxt->lock);
-        return;
-    }
+    virNetTLSContextPtr ctxt = obj;
 
     gnutls_dh_params_deinit(ctxt->dhParams);
     gnutls_certificate_free_credentials(ctxt->x509cred);
-    virMutexUnlock(&ctxt->lock);
     virMutexDestroy(&ctxt->lock);
-    VIR_FREE(ctxt);
 }
 
 
-
 static ssize_t
 virNetTLSSessionPush(void *opaque, const void *buf, size_t len)
 {
@@ -1167,10 +1165,8 @@ virNetTLSSessionPtr virNetTLSSessionNew(virNetTLSContextPtr ctxt,
     VIR_DEBUG("ctxt=%p hostname=%s isServer=%d",
               ctxt, NULLSTR(hostname), ctxt->isServer);
 
-    if (VIR_ALLOC(sess) < 0) {
-        virReportOOMError();
+    if (!(sess = virObjectNew(virNetTLSSessionClass)))
         return NULL;
-    }
 
     if (virMutexInit(&sess->lock) < 0) {
         virReportError(VIR_ERR_INTERNAL_ERROR, "%s",
@@ -1179,7 +1175,6 @@ virNetTLSSessionPtr virNetTLSSessionNew(virNetTLSContextPtr ctxt,
         return NULL;
     }
 
-    sess->refs = 1;
     if (hostname &&
         !(sess->hostname = strdup(hostname))) {
         virReportOOMError();
@@ -1230,27 +1225,17 @@ virNetTLSSessionPtr virNetTLSSessionNew(virNetTLSContextPtr ctxt,
     sess->isServer = ctxt->isServer;
 
     PROBE(RPC_TLS_SESSION_NEW,
-          "sess=%p refs=%d ctxt=%p hostname=%s isServer=%d",
-          sess, sess->refs, ctxt, hostname, sess->isServer);
+          "sess=%p ctxt=%p hostname=%s isServer=%d",
+          sess, ctxt, hostname, sess->isServer);
 
     return sess;
 
 error:
-    virNetTLSSessionFree(sess);
+    virObjectUnref(sess);
     return NULL;
 }
 
 
-void virNetTLSSessionRef(virNetTLSSessionPtr sess)
-{
-    virMutexLock(&sess->lock);
-    sess->refs++;
-    PROBE(RPC_TLS_SESSION_REF,
-          "sess=%p refs=%d",
-          sess, sess->refs);
-    virMutexUnlock(&sess->lock);
-}
-
 void virNetTLSSessionSetIOCallbacks(virNetTLSSessionPtr sess,
                                     virNetTLSSessionWriteFunc writeFunc,
                                     virNetTLSSessionReadFunc readFunc,
@@ -1393,26 +1378,13 @@ cleanup:
 }
 
 
-void virNetTLSSessionFree(virNetTLSSessionPtr sess)
+void virNetTLSSessionDispose(void *obj)
 {
-    if (!sess)
-        return;
-
-    virMutexLock(&sess->lock);
-    PROBE(RPC_TLS_SESSION_FREE,
-          "sess=%p refs=%d",
-          sess, sess->refs);
-    sess->refs--;
-    if (sess->refs > 0) {
-        virMutexUnlock(&sess->lock);
-        return;
-    }
+    virNetTLSSessionPtr sess = obj;
 
     VIR_FREE(sess->hostname);
     gnutls_deinit(sess->session);
-    virMutexUnlock(&sess->lock);
     virMutexDestroy(&sess->lock);
-    VIR_FREE(sess);
 }
 
 /*
index 8893da9e8b98dc2365375e80984c125e35cecd5c..e47c3c043ed228294f269645fd0d9497d125c9ee 100644 (file)
@@ -22,6 +22,7 @@
 # define __VIR_NET_TLS_CONTEXT_H__
 
 # include "internal.h"
+# include "virobject.h"
 
 typedef struct _virNetTLSContext virNetTLSContext;
 typedef virNetTLSContext *virNetTLSContextPtr;
@@ -58,13 +59,9 @@ virNetTLSContextPtr virNetTLSContextNewClient(const char *cacert,
                                               bool sanityCheckCert,
                                               bool requireValidCert);
 
-void virNetTLSContextRef(virNetTLSContextPtr ctxt);
-
 int virNetTLSContextCheckCertificate(virNetTLSContextPtr ctxt,
                                      virNetTLSSessionPtr sess);
 
-void virNetTLSContextFree(virNetTLSContextPtr ctxt);
-
 
 typedef ssize_t (*virNetTLSSessionWriteFunc)(const char *buf, size_t len,
                                              void *opaque);
@@ -79,8 +76,6 @@ void virNetTLSSessionSetIOCallbacks(virNetTLSSessionPtr sess,
                                     virNetTLSSessionReadFunc readFunc,
                                     void *opaque);
 
-void virNetTLSSessionRef(virNetTLSSessionPtr sess);
-
 ssize_t virNetTLSSessionWrite(virNetTLSSessionPtr sess,
                               const char *buf, size_t len);
 ssize_t virNetTLSSessionRead(virNetTLSSessionPtr sess,
@@ -99,7 +94,4 @@ virNetTLSSessionGetHandshakeStatus(virNetTLSSessionPtr sess);
 
 int virNetTLSSessionGetKeySize(virNetTLSSessionPtr sess);
 
-void virNetTLSSessionFree(virNetTLSSessionPtr sess);
-
-
 #endif
index 0dfaa23bfe225e291e0e2616700f2300180969db..397c9670dd68397012e4a5ab836c995948291323 100644 (file)
@@ -496,7 +496,7 @@ static int testTLSContextInit(const void *opaque)
     ret = 0;
 
 cleanup:
-    virNetTLSContextFree(ctxt);
+    virObjectUnref(ctxt);
     gnutls_x509_crt_deinit(data->careq.crt);
     gnutls_x509_crt_deinit(data->certreq.crt);
     data->careq.crt = data->certreq.crt = NULL;
@@ -710,10 +710,10 @@ static int testTLSSessionInit(const void *opaque)
     ret = 0;
 
 cleanup:
-    virNetTLSContextFree(serverCtxt);
-    virNetTLSContextFree(clientCtxt);
-    virNetTLSSessionFree(serverSess);
-    virNetTLSSessionFree(clientSess);
+    virObjectUnref(serverCtxt);
+    virObjectUnref(clientCtxt);
+    virObjectUnref(serverSess);
+    virObjectUnref(clientSess);
     gnutls_x509_crt_deinit(data->careq.crt);
     if (data->othercareq.filename)
         gnutls_x509_crt_deinit(data->othercareq.crt);