[Date Prev][Date Next]   [Thread Prev][Thread Next]   [Thread Index] [Date Index] [Author Index]

[libvirt] [PATCH 1/2] SASL: Introduce session mutex



Some of SASL interacting functions can be called within two or more
threads with the same pointer. Therefore we need to protect
virNetSASLSessionPtr with mutex to avoid non-consistent states.
---
 src/rpc/virnetsaslcontext.c |   67 +++++++++++++++++++++++++++++++++++++++++-
 1 files changed, 65 insertions(+), 2 deletions(-)

diff --git a/src/rpc/virnetsaslcontext.c b/src/rpc/virnetsaslcontext.c
index 6b2a883..ef91b9d 100644
--- a/src/rpc/virnetsaslcontext.c
+++ b/src/rpc/virnetsaslcontext.c
@@ -28,6 +28,7 @@
 #include "virterror_internal.h"
 #include "memory.h"
 #include "logging.h"
+#include "threads.h"
 
 #define VIR_FROM_THIS VIR_FROM_RPC
 #define virNetError(code, ...)                                    \
@@ -41,6 +42,7 @@ struct _virNetSASLContext {
 };
 
 struct _virNetSASLSession {
+    virMutex lock;
     sasl_conn_t *conn;
     int refs;
     size_t maxbufsize;
@@ -145,6 +147,16 @@ void virNetSASLContextFree(virNetSASLContextPtr ctxt)
     VIR_FREE(ctxt);
 }
 
+static void virNetSASLSessionLock(virNetSASLSessionPtr session)
+{
+    virMutexLock(&session->lock);
+}
+
+static void virNetSASLSessionUnlock(virNetSASLSessionPtr session)
+{
+    virMutexUnlock(&session->lock);
+}
+
 virNetSASLSessionPtr virNetSASLSessionNewClient(virNetSASLContextPtr ctxt ATTRIBUTE_UNUSED,
                                                 const char *service,
                                                 const char *hostname,
@@ -160,6 +172,9 @@ virNetSASLSessionPtr virNetSASLSessionNewClient(virNetSASLContextPtr ctxt ATTRIB
         goto cleanup;
     }
 
+    if (virMutexInit(&sasl->lock) < 0)
+        goto cleanup;
+
     sasl->refs = 1;
     /* Arbitrary size for amount of data we can encode in a single block */
     sasl->maxbufsize = 1 << 16;
@@ -198,6 +213,9 @@ virNetSASLSessionPtr virNetSASLSessionNewServer(virNetSASLContextPtr ctxt ATTRIB
         goto cleanup;
     }
 
+    if (virMutexInit(&sasl->lock) < 0)
+        goto cleanup;
+
     sasl->refs = 1;
     /* Arbitrary size for amount of data we can encode in a single block */
     sasl->maxbufsize = 1 << 16;
@@ -226,7 +244,9 @@ cleanup:
 
 void virNetSASLSessionRef(virNetSASLSessionPtr sasl)
 {
+    virNetSASLSessionLock(sasl);
     sasl->refs++;
+    virNetSASLSessionUnlock(sasl);
 }
 
 int virNetSASLSessionExtKeySize(virNetSASLSessionPtr sasl,
@@ -234,13 +254,16 @@ int virNetSASLSessionExtKeySize(virNetSASLSessionPtr sasl,
 {
     int err;
 
+    virNetSASLSessionLock(sasl);
     err = sasl_setprop(sasl->conn, SASL_SSF_EXTERNAL, &ssf);
     if (err != SASL_OK) {
         virNetError(VIR_ERR_INTERNAL_ERROR,
                     _("cannot set external SSF %d (%s)"),
                     err, sasl_errstring(err, NULL, NULL));
+        virNetSASLSessionUnlock(sasl);
         return -1;
     }
+    virNetSASLSessionUnlock(sasl);
     return 0;
 }
 
@@ -249,13 +272,16 @@ const char *virNetSASLSessionGetIdentity(virNetSASLSessionPtr sasl)
     const void *val;
     int err;
 
+    virNetSASLSessionLock(sasl);
     err = sasl_getprop(sasl->conn, SASL_USERNAME, &val);
     if (err != SASL_OK) {
         virNetError(VIR_ERR_AUTH_FAILED,
                     _("cannot query SASL username on connection %d (%s)"),
                     err, sasl_errstring(err, NULL, NULL));
+        virNetSASLSessionUnlock(sasl);
         return NULL;
     }
+    virNetSASLSessionUnlock(sasl);
     if (val == NULL) {
         virNetError(VIR_ERR_AUTH_FAILED,
                     _("no client username was found"));
@@ -272,13 +298,17 @@ int virNetSASLSessionGetKeySize(virNetSASLSessionPtr sasl)
     int err;
     int ssf;
     const void *val;
+
+    virNetSASLSessionLock(sasl);
     err = sasl_getprop(sasl->conn, SASL_SSF, &val);
     if (err != SASL_OK) {
         virNetError(VIR_ERR_AUTH_FAILED,
                     _("cannot query SASL ssf on connection %d (%s)"),
                     err, sasl_errstring(err, NULL, NULL));
+        virNetSASLSessionUnlock(sasl);
         return -1;
     }
+    virNetSASLSessionUnlock(sasl);
     ssf = *(const int *)val;
     return ssf;
 }
@@ -291,6 +321,7 @@ int virNetSASLSessionSecProps(virNetSASLSessionPtr sasl,
     sasl_security_properties_t secprops;
     int err;
 
+    virNetSASLSessionLock(sasl);
     VIR_DEBUG("minSSF=%d maxSSF=%d allowAnonymous=%d maxbufsize=%zu",
               minSSF, maxSSF, allowAnonymous, sasl->maxbufsize);
 
@@ -307,8 +338,10 @@ int virNetSASLSessionSecProps(virNetSASLSessionPtr sasl,
         virNetError(VIR_ERR_INTERNAL_ERROR,
                     _("cannot set security props %d (%s)"),
                     err, sasl_errstring(err, NULL, NULL));
+        virNetSASLSessionUnlock(sasl);
         return -1;
     }
+    virNetSASLSessionUnlock(sasl);
 
     return 0;
 }
@@ -319,17 +352,20 @@ static int virNetSASLSessionUpdateBufSize(virNetSASLSessionPtr sasl)
     unsigned *maxbufsize;
     int err;
 
+    virNetSASLSessionLock(sasl);
     err = sasl_getprop(sasl->conn, SASL_MAXOUTBUF, (const void **)&maxbufsize);
     if (err != SASL_OK) {
         virNetError(VIR_ERR_INTERNAL_ERROR,
                     _("cannot get security props %d (%s)"),
                     err, sasl_errstring(err, NULL, NULL));
+        virNetSASLSessionUnlock(sasl);
         return -1;
     }
 
     VIR_DEBUG("Negotiated bufsize is %u vs requested size %zu",
               *maxbufsize, sasl->maxbufsize);
     sasl->maxbufsize = *maxbufsize;
+    virNetSASLSessionUnlock(sasl);
     return 0;
 }
 
@@ -339,6 +375,7 @@ char *virNetSASLSessionListMechanisms(virNetSASLSessionPtr sasl)
     char *ret;
     int err;
 
+    virNetSASLSessionLock(sasl);
     err = sasl_listmech(sasl->conn,
                         NULL, /* Don't need to set user */
                         "", /* Prefix */
@@ -351,8 +388,10 @@ char *virNetSASLSessionListMechanisms(virNetSASLSessionPtr sasl)
         virNetError(VIR_ERR_INTERNAL_ERROR,
                     _("cannot list SASL mechanisms %d (%s)"),
                     err, sasl_errdetail(sasl->conn));
+        virNetSASLSessionUnlock(sasl);
         return NULL;
     }
+    virNetSASLSessionUnlock(sasl);
     if (!(ret = strdup(mechlist))) {
         virReportOOMError();
         return NULL;
@@ -373,6 +412,7 @@ int virNetSASLSessionClientStart(virNetSASLSessionPtr sasl,
     VIR_DEBUG("sasl=%p mechlist=%s prompt_need=%p clientout=%p clientoutlen=%p mech=%p",
               sasl, mechlist, prompt_need, clientout, clientoutlen, mech);
 
+    virNetSASLSessionLock(sasl);
     int err = sasl_client_start(sasl->conn,
                                 mechlist,
                                 prompt_need,
@@ -380,6 +420,7 @@ int virNetSASLSessionClientStart(virNetSASLSessionPtr sasl,
                                 &outlen,
                                 mech);
 
+    virNetSASLSessionUnlock(sasl);
     *clientoutlen = outlen;
 
     switch (err) {
@@ -414,12 +455,14 @@ int virNetSASLSessionClientStep(virNetSASLSessionPtr sasl,
     VIR_DEBUG("sasl=%p serverin=%s serverinlen=%zu prompt_need=%p clientout=%p clientoutlen=%p",
               sasl, serverin, serverinlen, prompt_need, clientout, clientoutlen);
 
+    virNetSASLSessionLock(sasl);
     int err = sasl_client_step(sasl->conn,
                                serverin,
                                inlen,
                                prompt_need,
                                clientout,
                                &outlen);
+    virNetSASLSessionUnlock(sasl);
     *clientoutlen = outlen;
 
     switch (err) {
@@ -449,6 +492,8 @@ int virNetSASLSessionServerStart(virNetSASLSessionPtr sasl,
 {
     unsigned inlen = clientinlen;
     unsigned outlen = 0;
+
+    virNetSASLSessionLock(sasl);
     int err = sasl_server_start(sasl->conn,
                                 mechname,
                                 clientin,
@@ -456,6 +501,7 @@ int virNetSASLSessionServerStart(virNetSASLSessionPtr sasl,
                                 serverout,
                                 &outlen);
 
+    virNetSASLSessionUnlock(sasl);
     *serveroutlen = outlen;
 
     switch (err) {
@@ -486,12 +532,14 @@ int virNetSASLSessionServerStep(virNetSASLSessionPtr sasl,
     unsigned inlen = clientinlen;
     unsigned outlen = 0;
 
+    virNetSASLSessionLock(sasl);
     int err = sasl_server_step(sasl->conn,
                                clientin,
                                inlen,
                                serverout,
                                &outlen);
 
+    virNetSASLSessionUnlock(sasl);
     *serveroutlen = outlen;
 
     switch (err) {
@@ -514,7 +562,11 @@ int virNetSASLSessionServerStep(virNetSASLSessionPtr sasl,
 
 size_t virNetSASLSessionGetMaxBufSize(virNetSASLSessionPtr sasl)
 {
-    return sasl->maxbufsize;
+    size_t ret;
+    virNetSASLSessionLock(sasl);
+    ret = sasl->maxbufsize;
+    virNetSASLSessionUnlock(sasl);
+    return ret;
 }
 
 ssize_t virNetSASLSessionEncode(virNetSASLSessionPtr sasl,
@@ -534,6 +586,7 @@ ssize_t virNetSASLSessionEncode(virNetSASLSessionPtr sasl,
         return -1;
     }
 
+    virNetSASLSessionLock(sasl);
     err = sasl_encode(sasl->conn,
                       input,
                       inlen,
@@ -545,8 +598,10 @@ ssize_t virNetSASLSessionEncode(virNetSASLSessionPtr sasl,
         virNetError(VIR_ERR_INTERNAL_ERROR,
                     _("failed to encode SASL data: %d (%s)"),
                     err, sasl_errstring(err, NULL, NULL));
+        virNetSASLSessionUnlock(sasl);
         return -1;
     }
+    virNetSASLSessionUnlock(sasl);
     return 0;
 }
 
@@ -567,6 +622,7 @@ ssize_t virNetSASLSessionDecode(virNetSASLSessionPtr sasl,
         return -1;
     }
 
+    virNetSASLSessionLock(sasl);
     err = sasl_decode(sasl->conn,
                       input,
                       inlen,
@@ -577,8 +633,10 @@ ssize_t virNetSASLSessionDecode(virNetSASLSessionPtr sasl,
         virNetError(VIR_ERR_INTERNAL_ERROR,
                     _("failed to decode SASL data: %d (%s)"),
                     err, sasl_errstring(err, NULL, NULL));
+        virNetSASLSessionUnlock(sasl);
         return -1;
     }
+    virNetSASLSessionUnlock(sasl);
     return 0;
 }
 
@@ -587,12 +645,17 @@ void virNetSASLSessionFree(virNetSASLSessionPtr sasl)
     if (!sasl)
         return;
 
+    virNetSASLSessionLock(sasl);
     sasl->refs--;
-    if (sasl->refs > 0)
+    if (sasl->refs > 0) {
+        virNetSASLSessionUnlock(sasl);
         return;
+    }
 
     if (sasl->conn)
         sasl_dispose(&sasl->conn);
 
+    virNetSASLSessionUnlock(sasl);
+    virMutexDestroy(&sasl->lock);
     VIR_FREE(sasl);
 }
-- 
1.7.5.rc3


[Date Prev][Date Next]   [Thread Prev][Thread Next]   [Thread Index] [Date Index] [Author Index]