[libvirt] [PATCH 8/9] Fix locking wrt virNetClientStreamPtr object

Daniel P. Berrange berrange at redhat.com
Tue Jun 28 17:01:58 UTC 2011


The client stream object can be used independantly of the
virNetClientPtr object, so must have full locking of its
own and not rely on any caller.

* src/remote/remote_driver.c: Remove locking around stream
  callback
* src/rpc/virnetclientstream.c: Add locking to all APIs
  and callbacks
---
 src/remote/remote_driver.c   |    3 -
 src/rpc/virnetclientstream.c |  112 +++++++++++++++++++++++++++++++++---------
 2 files changed, 89 insertions(+), 26 deletions(-)

diff --git a/src/remote/remote_driver.c b/src/remote/remote_driver.c
index 2ac87c8..bb686c8 100644
--- a/src/remote/remote_driver.c
+++ b/src/remote/remote_driver.c
@@ -3254,11 +3254,8 @@ static void remoteStreamEventCallback(virNetClientStreamPtr stream ATTRIBUTE_UNU
                                       void *opaque)
 {
     struct remoteStreamCallbackData *cbdata = opaque;
-    struct private_data *priv = cbdata->st->conn->privateData;
 
-    remoteDriverUnlock(priv);
     (cbdata->cb)(cbdata->st, events, cbdata->opaque);
-    remoteDriverLock(priv);
 }
 
 
diff --git a/src/rpc/virnetclientstream.c b/src/rpc/virnetclientstream.c
index 99c7b41..9da5aee 100644
--- a/src/rpc/virnetclientstream.c
+++ b/src/rpc/virnetclientstream.c
@@ -28,6 +28,7 @@
 #include "virterror_internal.h"
 #include "logging.h"
 #include "event.h"
+#include "threads.h"
 
 #define VIR_FROM_THIS VIR_FROM_RPC
 #define virNetError(code, ...)                                    \
@@ -35,6 +36,8 @@
                          __FUNCTION__, __LINE__, __VA_ARGS__)
 
 struct _virNetClientStream {
+    virMutex lock;
+
     virNetClientProgramPtr prog;
     int proc;
     unsigned serial;
@@ -53,7 +56,6 @@ struct _virNetClientStream {
     size_t incomingOffset;
     size_t incomingLength;
 
-
     virNetClientStreamEventCallback cb;
     void *cbOpaque;
     virFreeCallback cbFree;
@@ -89,7 +91,8 @@ virNetClientStreamEventTimer(int timer ATTRIBUTE_UNUSED, void *opaque)
     virNetClientStreamPtr st = opaque;
     int events = 0;
 
-    /* XXX we need a mutex on 'st' to protect this callback */
+
+    virMutexLock(&st->lock);
 
     if (st->cb &&
         (st->cbEvents & VIR_STREAM_EVENT_READABLE) &&
@@ -106,12 +109,15 @@ virNetClientStreamEventTimer(int timer ATTRIBUTE_UNUSED, void *opaque)
         virFreeCallback cbFree = st->cbFree;
 
         st->cbDispatch = 1;
+        virMutexUnlock(&st->lock);
         (cb)(st, events, cbOpaque);
+        virMutexLock(&st->lock);
         st->cbDispatch = 0;
 
         if (!st->cb && cbFree)
             (cbFree)(cbOpaque);
     }
+    virMutexUnlock(&st->lock);
 }
 
 
@@ -134,30 +140,45 @@ virNetClientStreamPtr virNetClientStreamNew(virNetClientProgramPtr prog,
         return NULL;
     }
 
-    virNetClientProgramRef(prog);
-
     st->refs = 1;
     st->prog = prog;
     st->proc = proc;
     st->serial = serial;
 
+    if (virMutexInit(&st->lock) < 0) {
+        virNetError(VIR_ERR_INTERNAL_ERROR, "%s",
+                    _("cannot initialize mutex"));
+        VIR_FREE(st);
+        return NULL;
+    }
+
+    virNetClientProgramRef(prog);
+
     return st;
 }
 
 
 void virNetClientStreamRef(virNetClientStreamPtr st)
 {
+    virMutexLock(&st->lock);
     st->refs++;
+    virMutexUnlock(&st->lock);
 }
 
 void virNetClientStreamFree(virNetClientStreamPtr st)
 {
+    virMutexLock(&st->lock);
     st->refs--;
-    if (st->refs > 0)
+    if (st->refs > 0) {
+        virMutexUnlock(&st->lock);
         return;
+    }
+
+    virMutexUnlock(&st->lock);
 
     virResetError(&st->err);
     VIR_FREE(st->incoming);
+    virMutexDestroy(&st->lock);
     virNetClientProgramFree(st->prog);
     VIR_FREE(st);
 }
@@ -165,18 +186,24 @@ void virNetClientStreamFree(virNetClientStreamPtr st)
 bool virNetClientStreamMatches(virNetClientStreamPtr st,
                                virNetMessagePtr msg)
 {
+    bool match = false;
+    virMutexLock(&st->lock);
     if (virNetClientProgramMatches(st->prog, msg) &&
         st->proc == msg->header.proc &&
         st->serial == msg->header.serial)
-        return 1;
-    return 0;
+        match = true;
+    virMutexUnlock(&st->lock);
+    return match;
 }
 
 
 bool virNetClientStreamRaiseError(virNetClientStreamPtr st)
 {
-    if (st->err.code == VIR_ERR_OK)
+    virMutexLock(&st->lock);
+    if (st->err.code == VIR_ERR_OK) {
+        virMutexUnlock(&st->lock);
         return false;
+    }
 
     virRaiseErrorFull(__FILE__, __FUNCTION__, __LINE__,
                       st->err.domain,
@@ -188,7 +215,7 @@ bool virNetClientStreamRaiseError(virNetClientStreamPtr st)
                       st->err.int1,
                       st->err.int2,
                       "%s", st->err.message ? st->err.message : _("Unknown error"));
-
+    virMutexUnlock(&st->lock);
     return true;
 }
 
@@ -199,6 +226,8 @@ int virNetClientStreamSetError(virNetClientStreamPtr st,
     virNetMessageError err;
     int ret = -1;
 
+    virMutexLock(&st->lock);
+
     if (st->err.code != VIR_ERR_OK)
         VIR_DEBUG("Overwriting existing stream error %s", NULLSTR(st->err.message));
 
@@ -242,6 +271,7 @@ int virNetClientStreamSetError(virNetClientStreamPtr st,
 
 cleanup:
     xdr_free((xdrproc_t)xdr_virNetMessageError, (void*)&err);
+    virMutexUnlock(&st->lock);
     return ret;
 }
 
@@ -249,15 +279,18 @@ cleanup:
 int virNetClientStreamQueuePacket(virNetClientStreamPtr st,
                                   virNetMessagePtr msg)
 {
-    size_t avail = st->incomingLength - st->incomingOffset;
-    size_t need = msg->bufferLength - msg->bufferOffset;
+    int ret = -1;
+    size_t need;
 
+    virMutexLock(&st->lock);
+    need = msg->bufferLength - msg->bufferOffset;
+    size_t avail = st->incomingLength - st->incomingOffset;
     if (need > avail) {
         size_t extra = need - avail;
         if (VIR_REALLOC_N(st->incoming,
                           st->incomingLength + extra) < 0) {
             VIR_DEBUG("Out of memory handling stream data");
-            return -1;
+            goto cleanup;
         }
         st->incomingLength += extra;
     }
@@ -269,7 +302,12 @@ int virNetClientStreamQueuePacket(virNetClientStreamPtr st,
 
     VIR_DEBUG("Stream incoming data offset %zu length %zu",
               st->incomingOffset, st->incomingLength);
-    return 0;
+
+    ret = 0;
+
+cleanup:
+    virMutexUnlock(&st->lock);
+    return ret;
 }
 
 
@@ -286,6 +324,8 @@ int virNetClientStreamSendPacket(virNetClientStreamPtr st,
     if (!(msg = virNetMessageNew()))
         return -1;
 
+    virMutexLock(&st->lock);
+
     msg->header.prog = virNetClientProgramGetProgram(st->prog);
     msg->header.vers = virNetClientProgramGetVersion(st->prog);
     msg->header.status = status;
@@ -293,6 +333,8 @@ int virNetClientStreamSendPacket(virNetClientStreamPtr st,
     msg->header.serial = st->serial;
     msg->header.proc = st->proc;
 
+    virMutexUnlock(&st->lock);
+
     if (virNetMessageEncodeHeader(msg) < 0)
         goto error;
 
@@ -329,6 +371,7 @@ int virNetClientStreamRecvPacket(virNetClientStreamPtr st,
     int rv = -1;
     VIR_DEBUG("st=%p client=%p data=%p nbytes=%zu nonblock=%d",
               st, client, data, nbytes, nonblock);
+    virMutexLock(&st->lock);
     if (!st->incomingOffset) {
         virNetMessagePtr msg;
         int ret;
@@ -351,8 +394,9 @@ int virNetClientStreamRecvPacket(virNetClientStreamPtr st,
         msg->header.proc = st->proc;
 
         VIR_DEBUG("Dummy packet to wait for stream data");
+        virMutexUnlock(&st->lock);
         ret = virNetClientSend(client, msg, true);
-
+        virMutexLock(&st->lock);
         virNetMessageFree(msg);
 
         if (ret < 0)
@@ -380,6 +424,7 @@ int virNetClientStreamRecvPacket(virNetClientStreamPtr st,
     virNetClientStreamEventTimerUpdate(st);
 
 cleanup:
+    virMutexUnlock(&st->lock);
     return rv;
 }
 
@@ -390,20 +435,23 @@ int virNetClientStreamEventAddCallback(virNetClientStreamPtr st,
                                        void *opaque,
                                        virFreeCallback ff)
 {
+    int ret = -1;
+
+    virMutexLock(&st->lock);
     if (st->cb) {
         virNetError(VIR_ERR_INTERNAL_ERROR,
                     "%s", _("multiple stream callbacks not supported"));
-        return 1;
+        goto cleanup;
     }
 
-    virNetClientStreamRef(st);
+    st->refs++;
     if ((st->cbTimer =
          virEventAddTimeout(-1,
                             virNetClientStreamEventTimer,
                             st,
                             virNetClientStreamEventTimerFree)) < 0) {
-        virNetClientStreamFree(st);
-        return -1;
+        st->refs--;
+        goto cleanup;
     }
 
     st->cb = cb;
@@ -413,31 +461,45 @@ int virNetClientStreamEventAddCallback(virNetClientStreamPtr st,
 
     virNetClientStreamEventTimerUpdate(st);
 
-    return 0;
+    ret = 0;
+
+cleanup:
+    virMutexUnlock(&st->lock);
+    return ret;
 }
 
 int virNetClientStreamEventUpdateCallback(virNetClientStreamPtr st,
                                           int events)
 {
+    int ret = -1;
+
+    virMutexLock(&st->lock);
     if (!st->cb) {
         virNetError(VIR_ERR_INTERNAL_ERROR,
                     "%s", _("no stream callback registered"));
-        return -1;
+        goto cleanup;
     }
 
     st->cbEvents = events;
 
     virNetClientStreamEventTimerUpdate(st);
 
-    return 0;
+    ret = 0;
+
+cleanup:
+    virMutexUnlock(&st->lock);
+    return ret;
 }
 
 int virNetClientStreamEventRemoveCallback(virNetClientStreamPtr st)
 {
+    int ret = -1;
+
+    virMutexUnlock(&st->lock);
     if (!st->cb) {
         virNetError(VIR_ERR_INTERNAL_ERROR,
                     "%s", _("no stream callback registered"));
-        return -1;
+        goto cleanup;
     }
 
     if (!st->cbDispatch &&
@@ -449,5 +511,9 @@ int virNetClientStreamEventRemoveCallback(virNetClientStreamPtr st)
     st->cbEvents = 0;
     virEventRemoveTimeout(st->cbTimer);
 
-    return 0;
+    ret = 0;
+
+cleanup:
+    virMutexUnlock(&st->lock);
+    return ret;
 }
-- 
1.7.4.4




More information about the libvir-list mailing list