[Date Prev][Date Next]   [Thread Prev][Thread Next]   [Thread Index] [Date Index] [Author Index]

[libvirt] [PATCH 8/9] Fix locking wrt virNetClientStreamPtr object



The client stream object can be used independantly of the
virNetClientPtr object, so must have full locking of its
own and not rely on any caller.

* src/remote/remote_driver.c: Remove locking around stream
  callback
* src/rpc/virnetclientstream.c: Add locking to all APIs
  and callbacks
---
 src/remote/remote_driver.c   |    3 -
 src/rpc/virnetclientstream.c |  112 +++++++++++++++++++++++++++++++++---------
 2 files changed, 89 insertions(+), 26 deletions(-)

diff --git a/src/remote/remote_driver.c b/src/remote/remote_driver.c
index 2ac87c8..bb686c8 100644
--- a/src/remote/remote_driver.c
+++ b/src/remote/remote_driver.c
@@ -3254,11 +3254,8 @@ static void remoteStreamEventCallback(virNetClientStreamPtr stream ATTRIBUTE_UNU
                                       void *opaque)
 {
     struct remoteStreamCallbackData *cbdata = opaque;
-    struct private_data *priv = cbdata->st->conn->privateData;
 
-    remoteDriverUnlock(priv);
     (cbdata->cb)(cbdata->st, events, cbdata->opaque);
-    remoteDriverLock(priv);
 }
 
 
diff --git a/src/rpc/virnetclientstream.c b/src/rpc/virnetclientstream.c
index 99c7b41..9da5aee 100644
--- a/src/rpc/virnetclientstream.c
+++ b/src/rpc/virnetclientstream.c
@@ -28,6 +28,7 @@
 #include "virterror_internal.h"
 #include "logging.h"
 #include "event.h"
+#include "threads.h"
 
 #define VIR_FROM_THIS VIR_FROM_RPC
 #define virNetError(code, ...)                                    \
@@ -35,6 +36,8 @@
                          __FUNCTION__, __LINE__, __VA_ARGS__)
 
 struct _virNetClientStream {
+    virMutex lock;
+
     virNetClientProgramPtr prog;
     int proc;
     unsigned serial;
@@ -53,7 +56,6 @@ struct _virNetClientStream {
     size_t incomingOffset;
     size_t incomingLength;
 
-
     virNetClientStreamEventCallback cb;
     void *cbOpaque;
     virFreeCallback cbFree;
@@ -89,7 +91,8 @@ virNetClientStreamEventTimer(int timer ATTRIBUTE_UNUSED, void *opaque)
     virNetClientStreamPtr st = opaque;
     int events = 0;
 
-    /* XXX we need a mutex on 'st' to protect this callback */
+
+    virMutexLock(&st->lock);
 
     if (st->cb &&
         (st->cbEvents & VIR_STREAM_EVENT_READABLE) &&
@@ -106,12 +109,15 @@ virNetClientStreamEventTimer(int timer ATTRIBUTE_UNUSED, void *opaque)
         virFreeCallback cbFree = st->cbFree;
 
         st->cbDispatch = 1;
+        virMutexUnlock(&st->lock);
         (cb)(st, events, cbOpaque);
+        virMutexLock(&st->lock);
         st->cbDispatch = 0;
 
         if (!st->cb && cbFree)
             (cbFree)(cbOpaque);
     }
+    virMutexUnlock(&st->lock);
 }
 
 
@@ -134,30 +140,45 @@ virNetClientStreamPtr virNetClientStreamNew(virNetClientProgramPtr prog,
         return NULL;
     }
 
-    virNetClientProgramRef(prog);
-
     st->refs = 1;
     st->prog = prog;
     st->proc = proc;
     st->serial = serial;
 
+    if (virMutexInit(&st->lock) < 0) {
+        virNetError(VIR_ERR_INTERNAL_ERROR, "%s",
+                    _("cannot initialize mutex"));
+        VIR_FREE(st);
+        return NULL;
+    }
+
+    virNetClientProgramRef(prog);
+
     return st;
 }
 
 
 void virNetClientStreamRef(virNetClientStreamPtr st)
 {
+    virMutexLock(&st->lock);
     st->refs++;
+    virMutexUnlock(&st->lock);
 }
 
 void virNetClientStreamFree(virNetClientStreamPtr st)
 {
+    virMutexLock(&st->lock);
     st->refs--;
-    if (st->refs > 0)
+    if (st->refs > 0) {
+        virMutexUnlock(&st->lock);
         return;
+    }
+
+    virMutexUnlock(&st->lock);
 
     virResetError(&st->err);
     VIR_FREE(st->incoming);
+    virMutexDestroy(&st->lock);
     virNetClientProgramFree(st->prog);
     VIR_FREE(st);
 }
@@ -165,18 +186,24 @@ void virNetClientStreamFree(virNetClientStreamPtr st)
 bool virNetClientStreamMatches(virNetClientStreamPtr st,
                                virNetMessagePtr msg)
 {
+    bool match = false;
+    virMutexLock(&st->lock);
     if (virNetClientProgramMatches(st->prog, msg) &&
         st->proc == msg->header.proc &&
         st->serial == msg->header.serial)
-        return 1;
-    return 0;
+        match = true;
+    virMutexUnlock(&st->lock);
+    return match;
 }
 
 
 bool virNetClientStreamRaiseError(virNetClientStreamPtr st)
 {
-    if (st->err.code == VIR_ERR_OK)
+    virMutexLock(&st->lock);
+    if (st->err.code == VIR_ERR_OK) {
+        virMutexUnlock(&st->lock);
         return false;
+    }
 
     virRaiseErrorFull(__FILE__, __FUNCTION__, __LINE__,
                       st->err.domain,
@@ -188,7 +215,7 @@ bool virNetClientStreamRaiseError(virNetClientStreamPtr st)
                       st->err.int1,
                       st->err.int2,
                       "%s", st->err.message ? st->err.message : _("Unknown error"));
-
+    virMutexUnlock(&st->lock);
     return true;
 }
 
@@ -199,6 +226,8 @@ int virNetClientStreamSetError(virNetClientStreamPtr st,
     virNetMessageError err;
     int ret = -1;
 
+    virMutexLock(&st->lock);
+
     if (st->err.code != VIR_ERR_OK)
         VIR_DEBUG("Overwriting existing stream error %s", NULLSTR(st->err.message));
 
@@ -242,6 +271,7 @@ int virNetClientStreamSetError(virNetClientStreamPtr st,
 
 cleanup:
     xdr_free((xdrproc_t)xdr_virNetMessageError, (void*)&err);
+    virMutexUnlock(&st->lock);
     return ret;
 }
 
@@ -249,15 +279,18 @@ cleanup:
 int virNetClientStreamQueuePacket(virNetClientStreamPtr st,
                                   virNetMessagePtr msg)
 {
-    size_t avail = st->incomingLength - st->incomingOffset;
-    size_t need = msg->bufferLength - msg->bufferOffset;
+    int ret = -1;
+    size_t need;
 
+    virMutexLock(&st->lock);
+    need = msg->bufferLength - msg->bufferOffset;
+    size_t avail = st->incomingLength - st->incomingOffset;
     if (need > avail) {
         size_t extra = need - avail;
         if (VIR_REALLOC_N(st->incoming,
                           st->incomingLength + extra) < 0) {
             VIR_DEBUG("Out of memory handling stream data");
-            return -1;
+            goto cleanup;
         }
         st->incomingLength += extra;
     }
@@ -269,7 +302,12 @@ int virNetClientStreamQueuePacket(virNetClientStreamPtr st,
 
     VIR_DEBUG("Stream incoming data offset %zu length %zu",
               st->incomingOffset, st->incomingLength);
-    return 0;
+
+    ret = 0;
+
+cleanup:
+    virMutexUnlock(&st->lock);
+    return ret;
 }
 
 
@@ -286,6 +324,8 @@ int virNetClientStreamSendPacket(virNetClientStreamPtr st,
     if (!(msg = virNetMessageNew()))
         return -1;
 
+    virMutexLock(&st->lock);
+
     msg->header.prog = virNetClientProgramGetProgram(st->prog);
     msg->header.vers = virNetClientProgramGetVersion(st->prog);
     msg->header.status = status;
@@ -293,6 +333,8 @@ int virNetClientStreamSendPacket(virNetClientStreamPtr st,
     msg->header.serial = st->serial;
     msg->header.proc = st->proc;
 
+    virMutexUnlock(&st->lock);
+
     if (virNetMessageEncodeHeader(msg) < 0)
         goto error;
 
@@ -329,6 +371,7 @@ int virNetClientStreamRecvPacket(virNetClientStreamPtr st,
     int rv = -1;
     VIR_DEBUG("st=%p client=%p data=%p nbytes=%zu nonblock=%d",
               st, client, data, nbytes, nonblock);
+    virMutexLock(&st->lock);
     if (!st->incomingOffset) {
         virNetMessagePtr msg;
         int ret;
@@ -351,8 +394,9 @@ int virNetClientStreamRecvPacket(virNetClientStreamPtr st,
         msg->header.proc = st->proc;
 
         VIR_DEBUG("Dummy packet to wait for stream data");
+        virMutexUnlock(&st->lock);
         ret = virNetClientSend(client, msg, true);
-
+        virMutexLock(&st->lock);
         virNetMessageFree(msg);
 
         if (ret < 0)
@@ -380,6 +424,7 @@ int virNetClientStreamRecvPacket(virNetClientStreamPtr st,
     virNetClientStreamEventTimerUpdate(st);
 
 cleanup:
+    virMutexUnlock(&st->lock);
     return rv;
 }
 
@@ -390,20 +435,23 @@ int virNetClientStreamEventAddCallback(virNetClientStreamPtr st,
                                        void *opaque,
                                        virFreeCallback ff)
 {
+    int ret = -1;
+
+    virMutexLock(&st->lock);
     if (st->cb) {
         virNetError(VIR_ERR_INTERNAL_ERROR,
                     "%s", _("multiple stream callbacks not supported"));
-        return 1;
+        goto cleanup;
     }
 
-    virNetClientStreamRef(st);
+    st->refs++;
     if ((st->cbTimer =
          virEventAddTimeout(-1,
                             virNetClientStreamEventTimer,
                             st,
                             virNetClientStreamEventTimerFree)) < 0) {
-        virNetClientStreamFree(st);
-        return -1;
+        st->refs--;
+        goto cleanup;
     }
 
     st->cb = cb;
@@ -413,31 +461,45 @@ int virNetClientStreamEventAddCallback(virNetClientStreamPtr st,
 
     virNetClientStreamEventTimerUpdate(st);
 
-    return 0;
+    ret = 0;
+
+cleanup:
+    virMutexUnlock(&st->lock);
+    return ret;
 }
 
 int virNetClientStreamEventUpdateCallback(virNetClientStreamPtr st,
                                           int events)
 {
+    int ret = -1;
+
+    virMutexLock(&st->lock);
     if (!st->cb) {
         virNetError(VIR_ERR_INTERNAL_ERROR,
                     "%s", _("no stream callback registered"));
-        return -1;
+        goto cleanup;
     }
 
     st->cbEvents = events;
 
     virNetClientStreamEventTimerUpdate(st);
 
-    return 0;
+    ret = 0;
+
+cleanup:
+    virMutexUnlock(&st->lock);
+    return ret;
 }
 
 int virNetClientStreamEventRemoveCallback(virNetClientStreamPtr st)
 {
+    int ret = -1;
+
+    virMutexUnlock(&st->lock);
     if (!st->cb) {
         virNetError(VIR_ERR_INTERNAL_ERROR,
                     "%s", _("no stream callback registered"));
-        return -1;
+        goto cleanup;
     }
 
     if (!st->cbDispatch &&
@@ -449,5 +511,9 @@ int virNetClientStreamEventRemoveCallback(virNetClientStreamPtr st)
     st->cbEvents = 0;
     virEventRemoveTimeout(st->cbTimer);
 
-    return 0;
+    ret = 0;
+
+cleanup:
+    virMutexUnlock(&st->lock);
+    return ret;
 }
-- 
1.7.4.4


[Date Prev][Date Next]   [Thread Prev][Thread Next]   [Thread Index] [Date Index] [Author Index]