[Date Prev][Date Next]   [Thread Prev][Thread Next]   [Thread Index] [Date Index] [Author Index]

[libvirt] [PATCH 1/4] Add mutex locking and reference counting to virNetSocket



From: "Daniel P. Berrange" <berrange redhat com>

Remove the need for a virNetSocket object to be protected by
locks from the object using it, by introducing its own native
locking and reference counting

* src/rpc/virnetsocket.c: Add locking & reference counting
---
 src/rpc/virnetsocket.c |  147 +++++++++++++++++++++++++++++++++++++++---------
 1 files changed, 120 insertions(+), 27 deletions(-)

diff --git a/src/rpc/virnetsocket.c b/src/rpc/virnetsocket.c
index 7ea1ab7..8dd4d3a 100644
--- a/src/rpc/virnetsocket.c
+++ b/src/rpc/virnetsocket.c
@@ -40,6 +40,7 @@
 #include "logging.h"
 #include "files.h"
 #include "event.h"
+#include "threads.h"
 
 #define VIR_FROM_THIS VIR_FROM_RPC
 
@@ -49,6 +50,9 @@
 
 
 struct _virNetSocket {
+    virMutex lock;
+    int refs;
+
     int fd;
     int watch;
     pid_t pid;
@@ -122,6 +126,14 @@ static virNetSocketPtr virNetSocketNew(virSocketAddrPtr localAddr,
         return NULL;
     }
 
+    if (virMutexInit(&sock->lock) < 0) {
+        virReportSystemError(errno, "%s",
+                             _("Unable to initialize mutex"));
+        VIR_FREE(sock);
+        return NULL;
+    }
+    sock->refs = 1;
+
     if (localAddr)
         sock->localAddr = *localAddr;
     if (remoteAddr)
@@ -627,6 +639,13 @@ void virNetSocketFree(virNetSocketPtr sock)
     if (!sock)
         return;
 
+    virMutexLock(&sock->lock);
+    sock->refs--;
+    if (sock->refs > 0) {
+        virMutexUnlock(&sock->lock);
+        return;
+    }
+
     VIR_DEBUG("sock=%p fd=%d", sock, sock->fd);
     if (sock->watch > 0) {
         virEventRemoveHandle(sock->watch);
@@ -657,27 +676,41 @@ void virNetSocketFree(virNetSocketPtr sock)
     VIR_FREE(sock->localAddrStr);
     VIR_FREE(sock->remoteAddrStr);
 
+    virMutexUnlock(&sock->lock);
+    virMutexDestroy(&sock->lock);
+
     VIR_FREE(sock);
 }
 
 
 int virNetSocketGetFD(virNetSocketPtr sock)
 {
-    return sock->fd;
+    int fd;
+    virMutexLock(&sock->lock);
+    fd = sock->fd;
+    virMutexUnlock(&sock->lock);
+    return fd;
 }
 
 
 bool virNetSocketIsLocal(virNetSocketPtr sock)
 {
+    bool isLocal = false;
+    virMutexLock(&sock->lock);
     if (sock->localAddr.data.sa.sa_family == AF_UNIX)
-        return true;
-    return false;
+        isLocal = true;
+    virMutexUnlock(&sock->lock);
+    return isLocal;
 }
 
 
 int virNetSocketGetPort(virNetSocketPtr sock)
 {
-    return virSocketGetPort(&sock->localAddr);
+    int port;
+    virMutexLock(&sock->lock);
+    port = virSocketGetPort(&sock->localAddr);
+    virMutexUnlock(&sock->lock);
+    return port;
 }
 
 
@@ -688,15 +721,19 @@ int virNetSocketGetLocalIdentity(virNetSocketPtr sock,
 {
     struct ucred cr;
     unsigned int cr_len = sizeof (cr);
+    virMutexLock(&sock->lock);
 
     if (getsockopt(sock->fd, SOL_SOCKET, SO_PEERCRED, &cr, &cr_len) < 0) {
         virReportSystemError(errno, "%s",
                              _("Failed to get client socket identity"));
+        virMutexUnlock(&sock->lock);
         return -1;
     }
 
     *pid = cr.pid;
     *uid = cr.uid;
+
+    virMutexUnlock(&sock->lock);
     return 0;
 }
 #else
@@ -715,7 +752,11 @@ int virNetSocketGetLocalIdentity(virNetSocketPtr sock ATTRIBUTE_UNUSED,
 int virNetSocketSetBlocking(virNetSocketPtr sock,
                             bool blocking)
 {
-    return virSetBlocking(sock->fd, blocking);
+    int ret;
+    virMutexLock(&sock->lock);
+    ret = virSetBlocking(sock->fd, blocking);
+    virMutexUnlock(&sock->lock);
+    return ret;
 }
 
 
@@ -751,6 +792,7 @@ static ssize_t virNetSocketTLSSessionRead(char *buf,
 void virNetSocketSetTLSSession(virNetSocketPtr sock,
                                virNetTLSSessionPtr sess)
 {
+    virMutexLock(&sock->lock);
     virNetTLSSessionFree(sock->tlsSession);
     sock->tlsSession = sess;
     virNetTLSSessionSetIOCallbacks(sess,
@@ -758,6 +800,7 @@ void virNetSocketSetTLSSession(virNetSocketPtr sock,
                                    virNetSocketTLSSessionRead,
                                    sock);
     virNetTLSSessionRef(sess);
+    virMutexUnlock(&sock->lock);
 }
 
 
@@ -765,20 +808,25 @@ void virNetSocketSetTLSSession(virNetSocketPtr sock,
 void virNetSocketSetSASLSession(virNetSocketPtr sock,
                                 virNetSASLSessionPtr sess)
 {
+    virMutexLock(&sock->lock);
     virNetSASLSessionFree(sock->saslSession);
     sock->saslSession = sess;
     virNetSASLSessionRef(sess);
+    virMutexUnlock(&sock->lock);
 }
 #endif
 
 
 bool virNetSocketHasCachedData(virNetSocketPtr sock ATTRIBUTE_UNUSED)
 {
+    bool hasCached = false;
+    virMutexLock(&sock->lock);
 #if HAVE_SASL
     if (sock->saslDecoded)
-        return true;
+        hasCached = true;
 #endif
-    return false;
+    virMutexUnlock(&sock->lock);
+    return hasCached;
 }
 
 
@@ -965,39 +1013,54 @@ static ssize_t virNetSocketWriteSASL(virNetSocketPtr sock, const char *buf, size
 
 ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len)
 {
+    ssize_t ret;
+    virMutexLock(&sock->lock);
 #if HAVE_SASL
     if (sock->saslSession)
-        return virNetSocketReadSASL(sock, buf, len);
+        ret = virNetSocketReadSASL(sock, buf, len);
     else
 #endif
-        return virNetSocketReadWire(sock, buf, len);
+        ret = virNetSocketReadWire(sock, buf, len);
+    virMutexUnlock(&sock->lock);
+    return ret;
 }
 
 ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len)
 {
+    ssize_t ret;
+
+    virMutexLock(&sock->lock);
 #if HAVE_SASL
     if (sock->saslSession)
-        return virNetSocketWriteSASL(sock, buf, len);
+        ret = virNetSocketWriteSASL(sock, buf, len);
     else
 #endif
-        return virNetSocketWriteWire(sock, buf, len);
+        ret = virNetSocketWriteWire(sock, buf, len);
+    virMutexUnlock(&sock->lock);
+    return ret;
 }
 
 
 int virNetSocketListen(virNetSocketPtr sock)
 {
+    virMutexLock(&sock->lock);
     if (listen(sock->fd, 30) < 0) {
         virReportSystemError(errno, "%s", _("Unable to listen on socket"));
+        virMutexUnlock(&sock->lock);
         return -1;
     }
+    virMutexUnlock(&sock->lock);
     return 0;
 }
 
 int virNetSocketAccept(virNetSocketPtr sock, virNetSocketPtr *clientsock)
 {
-    int fd;
+    int fd = -1;
     virSocketAddr localAddr;
     virSocketAddr remoteAddr;
+    int ret = -1;
+
+    virMutexLock(&sock->lock);
 
     *clientsock = NULL;
 
@@ -1007,30 +1070,35 @@ int virNetSocketAccept(virNetSocketPtr sock, virNetSocketPtr *clientsock)
     remoteAddr.len = sizeof(remoteAddr.data.stor);
     if ((fd = accept(sock->fd, &remoteAddr.data.sa, &remoteAddr.len)) < 0) {
         if (errno == ECONNABORTED ||
-            errno == EAGAIN)
-            return 0;
+            errno == EAGAIN) {
+            ret = 0;
+            goto cleanup;
+        }
 
         virReportSystemError(errno, "%s",
                              _("Unable to accept client"));
-        return -1;
+        goto cleanup;
     }
 
     localAddr.len = sizeof(localAddr.data);
     if (getsockname(fd, &localAddr.data.sa, &localAddr.len) < 0) {
         virReportSystemError(errno, "%s", _("Unable to get local socket name"));
-        VIR_FORCE_CLOSE(fd);
-        return -1;
+        goto cleanup;
     }
 
     if (!(*clientsock = virNetSocketNew(&localAddr,
                                         &remoteAddr,
                                         true,
-                                        fd, -1, 0))) {
-        VIR_FORCE_CLOSE(fd);
-        return -1;
-    }
+                                        fd, -1, 0)))
+        goto cleanup;
 
-    return 0;
+    fd = -1;
+    ret = 0;
+
+cleanup:
+    VIR_FORCE_CLOSE(fd);
+    virMutexUnlock(&sock->lock);
+    return ret;
 }
 
 
@@ -1040,52 +1108,77 @@ static void virNetSocketEventHandle(int watch ATTRIBUTE_UNUSED,
                                     void *opaque)
 {
     virNetSocketPtr sock = opaque;
+    virNetSocketIOFunc func;
+    void *eopaque;
 
-    sock->func(sock, events, sock->opaque);
+    virMutexLock(&sock->lock);
+    func = sock->func;
+    eopaque = sock->opaque;
+    virMutexUnlock(&sock->lock);
+
+    if (func)
+        func(sock, events, eopaque);
 }
 
+
 int virNetSocketAddIOCallback(virNetSocketPtr sock,
                               int events,
                               virNetSocketIOFunc func,
                               void *opaque)
 {
+    int ret = -1;
+
+    virMutexLock(&sock->lock);
     if (sock->watch > 0) {
         VIR_DEBUG("Watch already registered on socket %p", sock);
-        return -1;
+        goto cleanup;
     }
 
+    sock->refs++;
     if ((sock->watch = virEventAddHandle(sock->fd,
                                          events,
                                          virNetSocketEventHandle,
                                          sock,
                                          NULL)) < 0) {
         VIR_DEBUG("Failed to register watch on socket %p", sock);
-        return -1;
+        goto cleanup;
     }
     sock->func = func;
     sock->opaque = opaque;
 
-    return 0;
+    ret = 0;
+
+cleanup:
+    virMutexUnlock(&sock->lock);
+    return ret;
 }
 
 void virNetSocketUpdateIOCallback(virNetSocketPtr sock,
                                   int events)
 {
+    virMutexLock(&sock->lock);
     if (sock->watch <= 0) {
         VIR_DEBUG("Watch not registered on socket %p", sock);
+        virMutexUnlock(&sock->lock);
         return;
     }
 
     virEventUpdateHandle(sock->watch, events);
+
+    virMutexUnlock(&sock->lock);
 }
 
 void virNetSocketRemoveIOCallback(virNetSocketPtr sock)
 {
+    virMutexLock(&sock->lock);
+
     if (sock->watch <= 0) {
         VIR_DEBUG("Watch not registered on socket %p", sock);
+        virMutexUnlock(&sock->lock);
         return;
     }
 
     virEventRemoveHandle(sock->watch);
-    sock->watch = 0;
+
+    virMutexUnlock(&sock->lock);
 }
-- 
1.7.6


[Date Prev][Date Next]   [Thread Prev][Thread Next]   [Thread Index] [Date Index] [Author Index]