[Date Prev][Date Next]   [Thread Prev][Thread Next]   [Thread Index] [Date Index] [Author Index]

[libvirt] [PATCH 08/13] Turn virNetTLSContext and virNetTLSSession into virObject instances



From: "Daniel P. Berrange" <berrange redhat com>

Make virNetTLSContext and virNetTLSSession use the virObject
APIs for reference counting

Signed-off-by: Daniel P. Berrange <berrange redhat com>
---
 daemon/libvirtd.c             |    4 +-
 src/libvirt_private.syms      |    2 -
 src/libvirt_probes.d          |    8 +--
 src/remote/remote_driver.c    |    2 +-
 src/rpc/virnetclient.c        |    6 +--
 src/rpc/virnetserver.c        |    3 +-
 src/rpc/virnetserverclient.c  |   11 ++---
 src/rpc/virnetserverservice.c |   10 ++--
 src/rpc/virnetsocket.c        |    7 ++-
 src/rpc/virnettlscontext.c    |  110 +++++++++++++++--------------------------
 src/rpc/virnettlscontext.h    |   10 +---
 tests/virnettlscontexttest.c  |   10 ++--
 12 files changed, 66 insertions(+), 117 deletions(-)

diff --git a/daemon/libvirtd.c b/daemon/libvirtd.c
index 79f37ae..211a4bc 100644
--- a/daemon/libvirtd.c
+++ b/daemon/libvirtd.c
@@ -541,7 +541,7 @@ static int daemonSetupNetworking(virNetServerPtr srv,
                                             false,
                                             config->max_client_requests,
                                             ctxt))) {
-                virNetTLSContextFree(ctxt);
+                virObjectUnref(ctxt);
                 goto error;
             }
             if (virNetServerAddService(srv, svcTLS,
@@ -549,7 +549,7 @@ static int daemonSetupNetworking(virNetServerPtr srv,
                                        !config->listen_tcp ? "_libvirt._tcp" : NULL) < 0)
                 goto error;
 
-            virNetTLSContextFree(ctxt);
+            virObjectUnref(ctxt);
         }
     }
 
diff --git a/src/libvirt_private.syms b/src/libvirt_private.syms
index 3551fd0..035658e 100644
--- a/src/libvirt_private.syms
+++ b/src/libvirt_private.syms
@@ -1481,11 +1481,9 @@ virNetSocketWrite;
 
 # virnettlscontext.h
 virNetTLSContextCheckCertificate;
-virNetTLSContextFree;
 virNetTLSContextNewClient;
 virNetTLSContextNewServer;
 virNetTLSContextNewServerPath;
-virNetTLSSessionFree;
 virNetTLSSessionHandshake;
 virNetTLSSessionNew;
 virNetTLSSessionSetIOCallbacks;
diff --git a/src/libvirt_probes.d b/src/libvirt_probes.d
index ceb3caa..3b138a9 100644
--- a/src/libvirt_probes.d
+++ b/src/libvirt_probes.d
@@ -61,19 +61,15 @@ provider libvirt {
 
 	# file: src/rpc/virnettlscontext.c
 	# prefix: rpc
-	probe rpc_tls_context_new(void *ctxt, int refs, const char *cacert, const char *cacrl,
+	probe rpc_tls_context_new(void *ctxt, const char *cacert, const char *cacrl,
 				  const char *cert, const char *key, int sanityCheckCert, int requireValidCert, int isServer);
-	probe rpc_tls_context_ref(void *ctxt, int refs);
-	probe rpc_tls_context_free(void *ctxt, int refs);
 
 	probe rpc_tls_context_session_allow(void *ctxt, void *sess, const char *dname);
 	probe rpc_tls_context_session_deny(void *ctxt, void *sess, const char *dname);
 	probe rpc_tls_context_session_fail(void *ctxt, void *sess);
 
 
-	probe rpc_tls_session_new(void *sess, void *ctxt, int refs, const char *hostname, int isServer);
-	probe rpc_tls_session_ref(void *sess, int refs);
-	probe rpc_tls_session_free(void *sess, int refs);
+	probe rpc_tls_session_new(void *sess, void *ctxt, const char *hostname, int isServer);
 
 	probe rpc_tls_session_handshake_pass(void *sess);
 	probe rpc_tls_session_handshake_fail(void *sess);
diff --git a/src/remote/remote_driver.c b/src/remote/remote_driver.c
index eac50e6..28035de 100644
--- a/src/remote/remote_driver.c
+++ b/src/remote/remote_driver.c
@@ -908,7 +908,7 @@ doRemoteClose (virConnectPtr conn, struct private_data *priv)
               (xdrproc_t) xdr_void, (char *) NULL) == -1)
         ret = -1;
 
-    virNetTLSContextFree(priv->tls);
+    virObjectUnref(priv->tls);
     priv->tls = NULL;
     virNetClientClose(priv->client);
     virNetClientFree(priv->client);
diff --git a/src/rpc/virnetclient.c b/src/rpc/virnetclient.c
index 49d238e..2b51246 100644
--- a/src/rpc/virnetclient.c
+++ b/src/rpc/virnetclient.c
@@ -475,7 +475,7 @@ void virNetClientFree(virNetClientPtr client)
     if (client->sock)
         virNetSocketRemoveIOCallback(client->sock);
     virNetSocketFree(client->sock);
-    virNetTLSSessionFree(client->tls);
+    virObjectUnref(client->tls);
 #if HAVE_SASL
     virNetSASLSessionFree(client->sasl);
 #endif
@@ -499,7 +499,7 @@ virNetClientCloseLocked(virNetClientPtr client)
     virNetSocketRemoveIOCallback(client->sock);
     virNetSocketFree(client->sock);
     client->sock = NULL;
-    virNetTLSSessionFree(client->tls);
+    virObjectUnref(client->tls);
     client->tls = NULL;
 #if HAVE_SASL
     virNetSASLSessionFree(client->sasl);
@@ -661,7 +661,7 @@ int virNetClientSetTLSSession(virNetClientPtr client,
     return 0;
 
 error:
-    virNetTLSSessionFree(client->tls);
+    virObjectUnref(client->tls);
     client->tls = NULL;
     virNetClientUnlock(client);
     return -1;
diff --git a/src/rpc/virnetserver.c b/src/rpc/virnetserver.c
index 4a02aab..17da40c 100644
--- a/src/rpc/virnetserver.c
+++ b/src/rpc/virnetserver.c
@@ -655,8 +655,7 @@ no_memory:
 int virNetServerSetTLSContext(virNetServerPtr srv,
                               virNetTLSContextPtr tls)
 {
-    srv->tls = tls;
-    virNetTLSContextRef(tls);
+    srv->tls = virObjectRef(tls);
     return 0;
 }
 
diff --git a/src/rpc/virnetserverclient.c b/src/rpc/virnetserverclient.c
index a56031c..85a457e 100644
--- a/src/rpc/virnetserverclient.c
+++ b/src/rpc/virnetserverclient.c
@@ -348,7 +348,7 @@ virNetServerClientPtr virNetServerClientNew(virNetSocketPtr sock,
     client->sock = sock;
     client->auth = auth;
     client->readonly = readonly;
-    client->tlsCtxt = tls;
+    client->tlsCtxt = virObjectRef(tls);
     client->nrequests_max = nrequests_max;
 
     client->sockTimer = virEventAddTimeout(-1, virNetServerClientSockTimerFunc,
@@ -356,9 +356,6 @@ virNetServerClientPtr virNetServerClientNew(virNetSocketPtr sock,
     if (client->sockTimer < 0)
         goto error;
 
-    if (tls)
-        virNetTLSContextRef(tls);
-
     /* Prepare one for packet receive */
     if (!(client->rx = virNetMessageNew(true)))
         goto error;
@@ -600,8 +597,8 @@ void virNetServerClientFree(virNetServerClientPtr client)
 #endif
     if (client->sockTimer > 0)
         virEventRemoveTimeout(client->sockTimer);
-    virNetTLSSessionFree(client->tls);
-    virNetTLSContextFree(client->tlsCtxt);
+    virObjectUnref(client->tls);
+    virObjectUnref(client->tlsCtxt);
     virNetSocketFree(client->sock);
     virNetServerClientUnlock(client);
     virMutexDestroy(&client->lock);
@@ -656,7 +653,7 @@ void virNetServerClientClose(virNetServerClientPtr client)
         virNetSocketRemoveIOCallback(client->sock);
 
     if (client->tls) {
-        virNetTLSSessionFree(client->tls);
+        virObjectUnref(client->tls);
         client->tls = NULL;
     }
     client->wantClose = true;
diff --git a/src/rpc/virnetserverservice.c b/src/rpc/virnetserverservice.c
index 28202a4..b4689b4 100644
--- a/src/rpc/virnetserverservice.c
+++ b/src/rpc/virnetserverservice.c
@@ -116,9 +116,7 @@ virNetServerServicePtr virNetServerServiceNewTCP(const char *nodename,
     svc->auth = auth;
     svc->readonly = readonly;
     svc->nrequests_client_max = nrequests_client_max;
-    svc->tls = tls;
-    if (tls)
-        virNetTLSContextRef(tls);
+    svc->tls = virObjectRef(tls);
 
     if (virNetSocketNewListenTCP(nodename,
                                  service,
@@ -172,9 +170,7 @@ virNetServerServicePtr virNetServerServiceNewUNIX(const char *path,
     svc->auth = auth;
     svc->readonly = readonly;
     svc->nrequests_client_max = nrequests_client_max;
-    svc->tls = tls;
-    if (tls)
-        virNetTLSContextRef(tls);
+    svc->tls = virObjectRef(tls);
 
     svc->nsocks = 1;
     if (VIR_ALLOC_N(svc->socks, svc->nsocks) < 0)
@@ -265,7 +261,7 @@ void virNetServerServiceFree(virNetServerServicePtr svc)
         virNetSocketFree(svc->socks[i]);
     VIR_FREE(svc->socks);
 
-    virNetTLSContextFree(svc->tls);
+    virObjectUnref(svc->tls);
 
     VIR_FREE(svc);
 }
diff --git a/src/rpc/virnetsocket.c b/src/rpc/virnetsocket.c
index 0b32ffe..a851dad 100644
--- a/src/rpc/virnetsocket.c
+++ b/src/rpc/virnetsocket.c
@@ -748,7 +748,7 @@ void virNetSocketFree(virNetSocketPtr sock)
     /* Make sure it can't send any more I/O during shutdown */
     if (sock->tlsSession)
         virNetTLSSessionSetIOCallbacks(sock->tlsSession, NULL, NULL, NULL);
-    virNetTLSSessionFree(sock->tlsSession);
+    virObjectUnref(sock->tlsSession);
 #if HAVE_SASL
     virNetSASLSessionFree(sock->saslSession);
 #endif
@@ -909,13 +909,12 @@ void virNetSocketSetTLSSession(virNetSocketPtr sock,
                                virNetTLSSessionPtr sess)
 {
     virMutexLock(&sock->lock);
-    virNetTLSSessionFree(sock->tlsSession);
-    sock->tlsSession = sess;
+    virObjectUnref(sock->tlsSession);
+    sock->tlsSession = virObjectRef(sess);
     virNetTLSSessionSetIOCallbacks(sess,
                                    virNetSocketTLSSessionWrite,
                                    virNetSocketTLSSessionRead,
                                    sock);
-    virNetTLSSessionRef(sess);
     virMutexUnlock(&sock->lock);
 }
 
diff --git a/src/rpc/virnettlscontext.c b/src/rpc/virnettlscontext.c
index bf92088..74e13c7 100644
--- a/src/rpc/virnettlscontext.c
+++ b/src/rpc/virnettlscontext.c
@@ -53,8 +53,9 @@
                          __FUNCTION__, __LINE__, __VA_ARGS__)
 
 struct _virNetTLSContext {
+    virObject object;
+
     virMutex lock;
-    int refs;
 
     gnutls_certificate_credentials_t x509cred;
     gnutls_dh_params_t dhParams;
@@ -65,9 +66,9 @@ struct _virNetTLSContext {
 };
 
 struct _virNetTLSSession {
-    virMutex lock;
+    virObject object;
 
-    int refs;
+    virMutex lock;
 
     bool handshakeComplete;
 
@@ -79,6 +80,29 @@ struct _virNetTLSSession {
     void *opaque;
 };
 
+static virClassPtr virNetTLSContextClass;
+static virClassPtr virNetTLSSessionClass;
+static void virNetTLSContextDispose(void *obj);
+static void virNetTLSSessionDispose(void *obj);
+
+
+static int virNetTLSContextOnceInit(void)
+{
+    if (!(virNetTLSContextClass = virClassNew("virNetTLSContext",
+                                              sizeof(virNetTLSContext),
+                                              virNetTLSContextDispose)))
+        return -1;
+
+    if (!(virNetTLSSessionClass = virClassNew("virNetTLSSession",
+                                              sizeof(virNetTLSSession),
+                                              virNetTLSSessionDispose)))
+        return -1;
+
+    return 0;
+}
+
+VIR_ONCE_GLOBAL_INIT(virNetTLSContext)
+
 
 static int
 virNetTLSContextCheckCertFile(const char *type, const char *file, bool allowMissing)
@@ -650,10 +674,11 @@ static virNetTLSContextPtr virNetTLSContextNew(const char *cacert,
     char *gnutlsdebug;
     int err;
 
-    if (VIR_ALLOC(ctxt) < 0) {
-        virReportOOMError();
+    if (virNetTLSContextInitialize() < 0)
+        return NULL;
+
+    if (!(ctxt = virObjectNew(virNetTLSContextClass)))
         return NULL;
-    }
 
     if (virMutexInit(&ctxt->lock) < 0) {
         virNetError(VIR_ERR_INTERNAL_ERROR, "%s",
@@ -662,8 +687,6 @@ static virNetTLSContextPtr virNetTLSContextNew(const char *cacert,
         return NULL;
     }
 
-    ctxt->refs = 1;
-
     if ((gnutlsdebug = getenv("LIBVIRT_GNUTLS_DEBUG")) != NULL) {
         int val;
         if (virStrToLong_i(gnutlsdebug, NULL, 10, &val) < 0)
@@ -719,8 +742,8 @@ static virNetTLSContextPtr virNetTLSContextNew(const char *cacert,
     ctxt->isServer = isServer;
 
     PROBE(RPC_TLS_CONTEXT_NEW,
-          "ctxt=%p refs=%d cacert=%s cacrl=%s cert=%s key=%s sanityCheckCert=%d requireValidCert=%d isServer=%d",
-          ctxt, ctxt->refs, cacert, NULLSTR(cacrl), cert, key, sanityCheckCert, requireValidCert, isServer);
+          "ctxt=%p cacert=%s cacrl=%s cert=%s key=%s sanityCheckCert=%d requireValidCert=%d isServer=%d",
+          ctxt, cacert, NULLSTR(cacrl), cert, key, sanityCheckCert, requireValidCert, isServer);
 
     return ctxt;
 
@@ -930,17 +953,6 @@ virNetTLSContextPtr virNetTLSContextNewClient(const char *cacert,
 }
 
 
-void virNetTLSContextRef(virNetTLSContextPtr ctxt)
-{
-    virMutexLock(&ctxt->lock);
-    ctxt->refs++;
-    PROBE(RPC_TLS_CONTEXT_REF,
-          "ctxt=%p refs=%d",
-          ctxt, ctxt->refs);
-    virMutexUnlock(&ctxt->lock);
-}
-
-
 static int virNetTLSContextValidCertificate(virNetTLSContextPtr ctxt,
                                             virNetTLSSessionPtr sess)
 {
@@ -1109,30 +1121,16 @@ cleanup:
     return ret;
 }
 
-void virNetTLSContextFree(virNetTLSContextPtr ctxt)
+void virNetTLSContextDispose(void *obj)
 {
-    if (!ctxt)
-        return;
-
-    virMutexLock(&ctxt->lock);
-    PROBE(RPC_TLS_CONTEXT_FREE,
-          "ctxt=%p refs=%d",
-          ctxt, ctxt->refs);
-    ctxt->refs--;
-    if (ctxt->refs > 0) {
-        virMutexUnlock(&ctxt->lock);
-        return;
-    }
+    virNetTLSContextPtr ctxt = obj;
 
     gnutls_dh_params_deinit(ctxt->dhParams);
     gnutls_certificate_free_credentials(ctxt->x509cred);
-    virMutexUnlock(&ctxt->lock);
     virMutexDestroy(&ctxt->lock);
-    VIR_FREE(ctxt);
 }
 
 
-
 static ssize_t
 virNetTLSSessionPush(void *opaque, const void *buf, size_t len)
 {
@@ -1170,10 +1168,8 @@ virNetTLSSessionPtr virNetTLSSessionNew(virNetTLSContextPtr ctxt,
     VIR_DEBUG("ctxt=%p hostname=%s isServer=%d",
               ctxt, NULLSTR(hostname), ctxt->isServer);
 
-    if (VIR_ALLOC(sess) < 0) {
-        virReportOOMError();
+    if (!(sess = virObjectNew(virNetTLSSessionClass)))
         return NULL;
-    }
 
     if (virMutexInit(&sess->lock) < 0) {
         virNetError(VIR_ERR_INTERNAL_ERROR, "%s",
@@ -1182,7 +1178,6 @@ virNetTLSSessionPtr virNetTLSSessionNew(virNetTLSContextPtr ctxt,
         return NULL;
     }
 
-    sess->refs = 1;
     if (hostname &&
         !(sess->hostname = strdup(hostname))) {
         virReportOOMError();
@@ -1233,27 +1228,17 @@ virNetTLSSessionPtr virNetTLSSessionNew(virNetTLSContextPtr ctxt,
     sess->isServer = ctxt->isServer;
 
     PROBE(RPC_TLS_SESSION_NEW,
-          "sess=%p refs=%d ctxt=%p hostname=%s isServer=%d",
-          sess, sess->refs, ctxt, hostname, sess->isServer);
+          "sess=%p ctxt=%p hostname=%s isServer=%d",
+          sess, ctxt, hostname, sess->isServer);
 
     return sess;
 
 error:
-    virNetTLSSessionFree(sess);
+    virObjectUnref(sess);
     return NULL;
 }
 
 
-void virNetTLSSessionRef(virNetTLSSessionPtr sess)
-{
-    virMutexLock(&sess->lock);
-    sess->refs++;
-    PROBE(RPC_TLS_SESSION_REF,
-          "sess=%p refs=%d",
-          sess, sess->refs);
-    virMutexUnlock(&sess->lock);
-}
-
 void virNetTLSSessionSetIOCallbacks(virNetTLSSessionPtr sess,
                                     virNetTLSSessionWriteFunc writeFunc,
                                     virNetTLSSessionReadFunc readFunc,
@@ -1396,26 +1381,13 @@ cleanup:
 }
 
 
-void virNetTLSSessionFree(virNetTLSSessionPtr sess)
+void virNetTLSSessionDispose(void *obj)
 {
-    if (!sess)
-        return;
-
-    virMutexLock(&sess->lock);
-    PROBE(RPC_TLS_SESSION_FREE,
-          "sess=%p refs=%d",
-          sess, sess->refs);
-    sess->refs--;
-    if (sess->refs > 0) {
-        virMutexUnlock(&sess->lock);
-        return;
-    }
+    virNetTLSSessionPtr sess = obj;
 
     VIR_FREE(sess->hostname);
     gnutls_deinit(sess->session);
-    virMutexUnlock(&sess->lock);
     virMutexDestroy(&sess->lock);
-    VIR_FREE(sess);
 }
 
 /*
diff --git a/src/rpc/virnettlscontext.h b/src/rpc/virnettlscontext.h
index fdfce6d..4821016 100644
--- a/src/rpc/virnettlscontext.h
+++ b/src/rpc/virnettlscontext.h
@@ -22,6 +22,7 @@
 # define __VIR_NET_TLS_CONTEXT_H__
 
 # include "internal.h"
+# include "virobject.h"
 
 typedef struct _virNetTLSContext virNetTLSContext;
 typedef virNetTLSContext *virNetTLSContextPtr;
@@ -58,13 +59,9 @@ virNetTLSContextPtr virNetTLSContextNewClient(const char *cacert,
                                               bool sanityCheckCert,
                                               bool requireValidCert);
 
-void virNetTLSContextRef(virNetTLSContextPtr ctxt);
-
 int virNetTLSContextCheckCertificate(virNetTLSContextPtr ctxt,
                                      virNetTLSSessionPtr sess);
 
-void virNetTLSContextFree(virNetTLSContextPtr ctxt);
-
 
 typedef ssize_t (*virNetTLSSessionWriteFunc)(const char *buf, size_t len,
                                              void *opaque);
@@ -79,8 +76,6 @@ void virNetTLSSessionSetIOCallbacks(virNetTLSSessionPtr sess,
                                     virNetTLSSessionReadFunc readFunc,
                                     void *opaque);
 
-void virNetTLSSessionRef(virNetTLSSessionPtr sess);
-
 ssize_t virNetTLSSessionWrite(virNetTLSSessionPtr sess,
                               const char *buf, size_t len);
 ssize_t virNetTLSSessionRead(virNetTLSSessionPtr sess,
@@ -99,7 +94,4 @@ virNetTLSSessionGetHandshakeStatus(virNetTLSSessionPtr sess);
 
 int virNetTLSSessionGetKeySize(virNetTLSSessionPtr sess);
 
-void virNetTLSSessionFree(virNetTLSSessionPtr sess);
-
-
 #endif
diff --git a/tests/virnettlscontexttest.c b/tests/virnettlscontexttest.c
index e745487..32e1f77 100644
--- a/tests/virnettlscontexttest.c
+++ b/tests/virnettlscontexttest.c
@@ -496,7 +496,7 @@ static int testTLSContextInit(const void *opaque)
     ret = 0;
 
 cleanup:
-    virNetTLSContextFree(ctxt);
+    virObjectUnref(ctxt);
     gnutls_x509_crt_deinit(data->careq.crt);
     gnutls_x509_crt_deinit(data->certreq.crt);
     data->careq.crt = data->certreq.crt = NULL;
@@ -710,10 +710,10 @@ static int testTLSSessionInit(const void *opaque)
     ret = 0;
 
 cleanup:
-    virNetTLSContextFree(serverCtxt);
-    virNetTLSContextFree(clientCtxt);
-    virNetTLSSessionFree(serverSess);
-    virNetTLSSessionFree(clientSess);
+    virObjectUnref(serverCtxt);
+    virObjectUnref(clientCtxt);
+    virObjectUnref(serverSess);
+    virObjectUnref(clientSess);
     gnutls_x509_crt_deinit(data->careq.crt);
     if (data->othercareq.filename)
         gnutls_x509_crt_deinit(data->othercareq.crt);
-- 
1.7.10.2


[Date Prev][Date Next]   [Thread Prev][Thread Next]   [Thread Index] [Date Index] [Author Index]