[Date Prev][Date Next]   [Thread Prev][Thread Next]   [Thread Index] [Date Index] [Author Index]

Re: [libvirt] [PATCH 4/4] Add mutex protection to SASL and TLS modules



On Mon, Jul 25, 2011 at 06:03:25PM +0100, Daniel P. Berrange wrote:
> From: "Daniel P. Berrange" <berrange redhat com>
> 
> The virNetSASLContext, virNetSASLSession, virNetTLSContext and
> virNetTLSSession classes previously relied in their owners
> (virNetClient / virNetServer / virNetServerClient) to provide
> locking protection for concurrent usage. When virNetSocket
> gained its own locking code, this invalidated the implicit
> safety the SASL/TLS modules relied on. Thus we need to give
> them all explicit locking of their own via new mutexes.
> 
> * src/rpc/virnetsaslcontext.c, src/rpc/virnettlscontext.c: Add
>   a mutex per object
> ---
>  src/rpc/virnetsaslcontext.c |  284 ++++++++++++++++++++++++++++++++-----------
>  src/rpc/virnettlscontext.c  |  105 +++++++++++++---
>  2 files changed, 297 insertions(+), 92 deletions(-)
> 
> diff --git a/src/rpc/virnetsaslcontext.c b/src/rpc/virnetsaslcontext.c
> index 6b2a883..71796b9 100644
> --- a/src/rpc/virnetsaslcontext.c
> +++ b/src/rpc/virnetsaslcontext.c
> @@ -27,6 +27,7 @@
>  
>  #include "virterror_internal.h"
>  #include "memory.h"
> +#include "threads.h"
>  #include "logging.h"
>  
>  #define VIR_FROM_THIS VIR_FROM_RPC
> @@ -36,11 +37,13 @@
>  
>  
>  struct _virNetSASLContext {
> +    virMutex lock;
>      const char *const*usernameWhitelist;
>      int refs;
>  };
>  
>  struct _virNetSASLSession {
> +    virMutex lock;
>      sasl_conn_t *conn;
>      int refs;
>      size_t maxbufsize;
> @@ -65,6 +68,13 @@ virNetSASLContextPtr virNetSASLContextNewClient(void)
>          return NULL;
>      }
>  
> +    if (virMutexInit(&ctxt->lock) < 0) {
> +        virNetError(VIR_ERR_INTERNAL_ERROR, "%s",
> +                    _("Failed to initialized mutex"));
> +        VIR_FREE(ctxt);
> +        return NULL;
> +    }
> +
>      ctxt->refs = 1;
>  
>      return ctxt;
> @@ -88,6 +98,13 @@ virNetSASLContextPtr virNetSASLContextNewServer(const char *const*usernameWhitel
>          return NULL;
>      }
>  
> +    if (virMutexInit(&ctxt->lock) < 0) {
> +        virNetError(VIR_ERR_INTERNAL_ERROR, "%s",
> +                    _("Failed to initialized mutex"));
> +        VIR_FREE(ctxt);
> +        return NULL;
> +    }
> +
>      ctxt->usernameWhitelist = usernameWhitelist;
>      ctxt->refs = 1;
>  
> @@ -98,21 +115,28 @@ int virNetSASLContextCheckIdentity(virNetSASLContextPtr ctxt,
>                                     const char *identity)
>  {
>      const char *const*wildcards;
> +    int ret = -1;
> +
> +    virMutexLock(&ctxt->lock);
>  
>      /* If the list is not set, allow any DN. */
>      wildcards = ctxt->usernameWhitelist;
> -    if (!wildcards)
> -        return 1; /* No ACL, allow all */
> +    if (!wildcards) {
> +        ret = 1; /* No ACL, allow all */
> +        goto cleanup;
> +    }
>  
>      while (*wildcards) {
> -        int ret = fnmatch (*wildcards, identity, 0);
> -        if (ret == 0) /* Succesful match */
> -            return 1;
> +        int rv = fnmatch (*wildcards, identity, 0);
> +        if (rv == 0) {
> +            ret = 1;
> +            goto cleanup; /* Succesful match */
> +        }
>          if (ret != FNM_NOMATCH) {
>              virNetError(VIR_ERR_INTERNAL_ERROR,
>                          _("Malformed TLS whitelist regular expression '%s'"),
>                          *wildcards);
> -            return -1;
> +            goto cleanup;
>          }
>  
>          wildcards++;
> @@ -124,13 +148,19 @@ int virNetSASLContextCheckIdentity(virNetSASLContextPtr ctxt,
>      /* This is the most common error: make it informative. */
>      virNetError(VIR_ERR_SYSTEM_ERROR, "%s",
>                  _("Client's username is not on the list of allowed clients"));
> -    return 0;
> +    ret = 0;
> +
> +cleanup:
> +    virMutexUnlock(&ctxt->lock);
> +    return ret;
>  }
>  
>  
>  void virNetSASLContextRef(virNetSASLContextPtr ctxt)
>  {
> +    virMutexLock(&ctxt->lock);
>      ctxt->refs++;
> +    virMutexUnlock(&ctxt->lock);
>  }
>  
>  void virNetSASLContextFree(virNetSASLContextPtr ctxt)
> @@ -138,10 +168,15 @@ void virNetSASLContextFree(virNetSASLContextPtr ctxt)
>      if (!ctxt)
>          return;
>  
> +    virMutexLock(&ctxt->lock);
>      ctxt->refs--;
> -    if (ctxt->refs > 0)
> +    if (ctxt->refs > 0) {
> +        virMutexUnlock(&ctxt->lock);
>          return;
> +    }
>  
> +    virMutexUnlock(&ctxt->lock);
> +    virMutexDestroy(&ctxt->lock);
>      VIR_FREE(ctxt);
>  }
>  
> @@ -160,6 +195,13 @@ virNetSASLSessionPtr virNetSASLSessionNewClient(virNetSASLContextPtr ctxt ATTRIB
>          goto cleanup;
>      }
>  
> +    if (virMutexInit(&sasl->lock) < 0) {
> +        virNetError(VIR_ERR_INTERNAL_ERROR, "%s",
> +                    _("Failed to initialized mutex"));
> +        VIR_FREE(sasl);
> +        return NULL;
> +    }
> +
>      sasl->refs = 1;
>      /* Arbitrary size for amount of data we can encode in a single block */
>      sasl->maxbufsize = 1 << 16;
> @@ -198,6 +240,13 @@ virNetSASLSessionPtr virNetSASLSessionNewServer(virNetSASLContextPtr ctxt ATTRIB
>          goto cleanup;
>      }
>  
> +    if (virMutexInit(&sasl->lock) < 0) {
> +        virNetError(VIR_ERR_INTERNAL_ERROR, "%s",
> +                    _("Failed to initialized mutex"));
> +        VIR_FREE(sasl);
> +        return NULL;
> +    }
> +
>      sasl->refs = 1;
>      /* Arbitrary size for amount of data we can encode in a single block */
>      sasl->maxbufsize = 1 << 16;
> @@ -226,43 +275,56 @@ cleanup:
>  
>  void virNetSASLSessionRef(virNetSASLSessionPtr sasl)
>  {
> +    virMutexLock(&sasl->lock);
>      sasl->refs++;
> +    virMutexUnlock(&sasl->lock);
>  }
>  
>  int virNetSASLSessionExtKeySize(virNetSASLSessionPtr sasl,
>                                  int ssf)
>  {
>      int err;
> +    int ret = -1;
> +    virMutexLock(&sasl->lock);
>  
>      err = sasl_setprop(sasl->conn, SASL_SSF_EXTERNAL, &ssf);
>      if (err != SASL_OK) {
>          virNetError(VIR_ERR_INTERNAL_ERROR,
>                      _("cannot set external SSF %d (%s)"),
>                      err, sasl_errstring(err, NULL, NULL));
> -        return -1;
> +        goto cleanup;
>      }
> -    return 0;
> +
> +    ret = 0;
> +
> +cleanup:
> +    virMutexLock(&sasl->lock);
> +    return ret;
>  }
>  
>  const char *virNetSASLSessionGetIdentity(virNetSASLSessionPtr sasl)
>  {
> -    const void *val;
> +    const void *val = NULL;
>      int err;
> +    virMutexLock(&sasl->lock);
>  
>      err = sasl_getprop(sasl->conn, SASL_USERNAME, &val);
>      if (err != SASL_OK) {
>          virNetError(VIR_ERR_AUTH_FAILED,
>                      _("cannot query SASL username on connection %d (%s)"),
>                      err, sasl_errstring(err, NULL, NULL));
> -        return NULL;
> +        val = NULL;
> +        goto cleanup;
>      }
>      if (val == NULL) {
>          virNetError(VIR_ERR_AUTH_FAILED,
>                      _("no client username was found"));
> -        return NULL;
> +        goto cleanup;
>      }
>      VIR_DEBUG("SASL client username %s", (const char *)val);
>  
> +cleanup:
> +    virMutexUnlock(&sasl->lock);
>      return (const char*)val;
>  }
>  
> @@ -272,14 +334,20 @@ int virNetSASLSessionGetKeySize(virNetSASLSessionPtr sasl)
>      int err;
>      int ssf;
>      const void *val;
> +
> +    virMutexLock(&sasl->lock);
>      err = sasl_getprop(sasl->conn, SASL_SSF, &val);
>      if (err != SASL_OK) {
>          virNetError(VIR_ERR_AUTH_FAILED,
>                      _("cannot query SASL ssf on connection %d (%s)"),
>                      err, sasl_errstring(err, NULL, NULL));
> -        return -1;
> +        ssf = -1;
> +        goto cleanup;
>      }
>      ssf = *(const int *)val;
> +
> +cleanup:
> +    virMutexUnlock(&sasl->lock);
>      return ssf;
>  }
>  
> @@ -290,10 +358,12 @@ int virNetSASLSessionSecProps(virNetSASLSessionPtr sasl,
>  {
>      sasl_security_properties_t secprops;
>      int err;
> +    int ret = -1;
>  
>      VIR_DEBUG("minSSF=%d maxSSF=%d allowAnonymous=%d maxbufsize=%zu",
>                minSSF, maxSSF, allowAnonymous, sasl->maxbufsize);
>  
> +    virMutexLock(&sasl->lock);
>      memset(&secprops, 0, sizeof secprops);
>  
>      secprops.min_ssf = minSSF;
> @@ -307,10 +377,14 @@ int virNetSASLSessionSecProps(virNetSASLSessionPtr sasl,
>          virNetError(VIR_ERR_INTERNAL_ERROR,
>                      _("cannot set security props %d (%s)"),
>                      err, sasl_errstring(err, NULL, NULL));
> -        return -1;
> +        goto cleanup;
>      }
>  
> -    return 0;
> +    ret = 0;
> +
> +cleanup:
> +    virMutexUnlock(&sasl->lock);
> +    return ret;
>  }
>  
>  
> @@ -336,9 +410,10 @@ static int virNetSASLSessionUpdateBufSize(virNetSASLSessionPtr sasl)
>  char *virNetSASLSessionListMechanisms(virNetSASLSessionPtr sasl)
>  {
>      const char *mechlist;
> -    char *ret;
> +    char *ret = NULL;
>      int err;
>  
> +    virMutexLock(&sasl->lock);
>      err = sasl_listmech(sasl->conn,
>                          NULL, /* Don't need to set user */
>                          "", /* Prefix */
> @@ -351,12 +426,15 @@ char *virNetSASLSessionListMechanisms(virNetSASLSessionPtr sasl)
>          virNetError(VIR_ERR_INTERNAL_ERROR,
>                      _("cannot list SASL mechanisms %d (%s)"),
>                      err, sasl_errdetail(sasl->conn));
> -        return NULL;
> +        goto cleanup;
>      }
>      if (!(ret = strdup(mechlist))) {
>          virReportOOMError();
> -        return NULL;
> +        goto cleanup;
>      }
> +
> +cleanup:
> +    virMutexUnlock(&sasl->lock);
>      return ret;
>  }
>  
> @@ -369,35 +447,44 @@ int virNetSASLSessionClientStart(virNetSASLSessionPtr sasl,
>                                   const char **mech)
>  {
>      unsigned outlen = 0;
> +    int err;
> +    int ret = -1;
>  
>      VIR_DEBUG("sasl=%p mechlist=%s prompt_need=%p clientout=%p clientoutlen=%p mech=%p",
>                sasl, mechlist, prompt_need, clientout, clientoutlen, mech);
>  
> -    int err = sasl_client_start(sasl->conn,
> -                                mechlist,
> -                                prompt_need,
> -                                clientout,
> -                                &outlen,
> -                                mech);
> +    virMutexLock(&sasl->lock);
> +    err = sasl_client_start(sasl->conn,
> +                            mechlist,
> +                            prompt_need,
> +                            clientout,
> +                            &outlen,
> +                            mech);
>  
>      *clientoutlen = outlen;
>  
>      switch (err) {
>      case SASL_OK:
>          if (virNetSASLSessionUpdateBufSize(sasl) < 0)
> -            return -1;
> -        return VIR_NET_SASL_COMPLETE;
> +            goto cleanup;
> +        ret = VIR_NET_SASL_COMPLETE;
> +        break;
>      case SASL_CONTINUE:
> -        return VIR_NET_SASL_CONTINUE;
> +        ret = VIR_NET_SASL_CONTINUE;
> +        break;
>      case SASL_INTERACT:
> -        return VIR_NET_SASL_INTERACT;
> -
> +        ret = VIR_NET_SASL_INTERACT;
> +        break;
>      default:
>          virNetError(VIR_ERR_AUTH_FAILED,
>                      _("Failed to start SASL negotiation: %d (%s)"),
>                      err, sasl_errdetail(sasl->conn));
> -        return -1;
> +        break;
>      }
> +
> +cleanup:
> +    virMutexUnlock(&sasl->lock);
> +    return ret;
>  }
>  
>  
> @@ -410,34 +497,43 @@ int virNetSASLSessionClientStep(virNetSASLSessionPtr sasl,
>  {
>      unsigned inlen = serverinlen;
>      unsigned outlen = 0;
> +    int err;
> +    int ret = -1;
>  
>      VIR_DEBUG("sasl=%p serverin=%s serverinlen=%zu prompt_need=%p clientout=%p clientoutlen=%p",
>                sasl, serverin, serverinlen, prompt_need, clientout, clientoutlen);
>  
> -    int err = sasl_client_step(sasl->conn,
> -                               serverin,
> -                               inlen,
> -                               prompt_need,
> -                               clientout,
> -                               &outlen);
> +    virMutexLock(&sasl->lock);
> +    err = sasl_client_step(sasl->conn,
> +                           serverin,
> +                           inlen,
> +                           prompt_need,
> +                           clientout,
> +                           &outlen);
>      *clientoutlen = outlen;
>  
>      switch (err) {
>      case SASL_OK:
>          if (virNetSASLSessionUpdateBufSize(sasl) < 0)
> -            return -1;
> -        return VIR_NET_SASL_COMPLETE;
> +            goto cleanup;
> +        ret = VIR_NET_SASL_COMPLETE;
> +        break;
>      case SASL_CONTINUE:
> -        return VIR_NET_SASL_CONTINUE;
> +        ret = VIR_NET_SASL_CONTINUE;
> +        break;
>      case SASL_INTERACT:
> -        return VIR_NET_SASL_INTERACT;
> -
> +        ret = VIR_NET_SASL_INTERACT;
> +        break;
>      default:
>          virNetError(VIR_ERR_AUTH_FAILED,
>                      _("Failed to step SASL negotiation: %d (%s)"),
>                      err, sasl_errdetail(sasl->conn));
> -        return -1;
> +        break;
>      }
> +
> +cleanup:
> +    virMutexUnlock(&sasl->lock);
> +    return ret;
>  }
>  
>  int virNetSASLSessionServerStart(virNetSASLSessionPtr sasl,
> @@ -449,31 +545,41 @@ int virNetSASLSessionServerStart(virNetSASLSessionPtr sasl,
>  {
>      unsigned inlen = clientinlen;
>      unsigned outlen = 0;
> -    int err = sasl_server_start(sasl->conn,
> -                                mechname,
> -                                clientin,
> -                                inlen,
> -                                serverout,
> -                                &outlen);
> +    int err;
> +    int ret = -1;
> +
> +    virMutexLock(&sasl->lock);
> +    err = sasl_server_start(sasl->conn,
> +                            mechname,
> +                            clientin,
> +                            inlen,
> +                            serverout,
> +                            &outlen);
>  
>      *serveroutlen = outlen;
>  
>      switch (err) {
>      case SASL_OK:
>          if (virNetSASLSessionUpdateBufSize(sasl) < 0)
> -            return -1;
> -        return VIR_NET_SASL_COMPLETE;
> +            goto cleanup;
> +        ret = VIR_NET_SASL_COMPLETE;
> +        break;
>      case SASL_CONTINUE:
> -        return VIR_NET_SASL_CONTINUE;
> +        ret = VIR_NET_SASL_CONTINUE;
> +        break;
>      case SASL_INTERACT:
> -        return VIR_NET_SASL_INTERACT;
> -
> +        ret = VIR_NET_SASL_INTERACT;
> +        break;
>      default:
>          virNetError(VIR_ERR_AUTH_FAILED,
>                      _("Failed to start SASL negotiation: %d (%s)"),
>                      err, sasl_errdetail(sasl->conn));
> -        return -1;
> +        break;
>      }
> +
> +cleanup:
> +    virMutexUnlock(&sasl->lock);
> +    return ret;
>  }
>  
>  
> @@ -485,36 +591,49 @@ int virNetSASLSessionServerStep(virNetSASLSessionPtr sasl,
>  {
>      unsigned inlen = clientinlen;
>      unsigned outlen = 0;
> +    int err;
> +    int ret = -1;
>  
> -    int err = sasl_server_step(sasl->conn,
> -                               clientin,
> -                               inlen,
> -                               serverout,
> -                               &outlen);
> +    virMutexLock(&sasl->lock);
> +    err = sasl_server_step(sasl->conn,
> +                           clientin,
> +                           inlen,
> +                           serverout,
> +                           &outlen);
>  
>      *serveroutlen = outlen;
>  
>      switch (err) {
>      case SASL_OK:
>          if (virNetSASLSessionUpdateBufSize(sasl) < 0)
> -            return -1;
> -        return VIR_NET_SASL_COMPLETE;
> +            goto cleanup;
> +        ret = VIR_NET_SASL_COMPLETE;
> +        break;
>      case SASL_CONTINUE:
> -        return VIR_NET_SASL_CONTINUE;
> +        ret = VIR_NET_SASL_CONTINUE;
> +        break;
>      case SASL_INTERACT:
> -        return VIR_NET_SASL_INTERACT;
> -
> +        ret = VIR_NET_SASL_INTERACT;
> +        break;
>      default:
>          virNetError(VIR_ERR_AUTH_FAILED,
>                      _("Failed to start SASL negotiation: %d (%s)"),
>                      err, sasl_errdetail(sasl->conn));
> -        return -1;
> +        break;
>      }
> +
> +cleanup:
> +    virMutexUnlock(&sasl->lock);
> +    return ret;
>  }
>  
>  size_t virNetSASLSessionGetMaxBufSize(virNetSASLSessionPtr sasl)
>  {
> -    return sasl->maxbufsize;
> +    size_t ret;
> +    virMutexLock(&sasl->lock);
> +    ret = sasl->maxbufsize;
> +    virMutexUnlock(&sasl->lock);
> +    return ret;
>  }
>  
>  ssize_t virNetSASLSessionEncode(virNetSASLSessionPtr sasl,
> @@ -526,12 +645,14 @@ ssize_t virNetSASLSessionEncode(virNetSASLSessionPtr sasl,
>      unsigned inlen = inputLen;
>      unsigned outlen = 0;
>      int err;
> +    ssize_t ret = -1;
>  
> +    virMutexLock(&sasl->lock);
>      if (inputLen > sasl->maxbufsize) {
>          virReportSystemError(EINVAL,
>                               _("SASL data length %zu too long, max %zu"),
>                               inputLen, sasl->maxbufsize);
> -        return -1;
> +        goto cleanup;
>      }
>  
>      err = sasl_encode(sasl->conn,
> @@ -545,9 +666,13 @@ ssize_t virNetSASLSessionEncode(virNetSASLSessionPtr sasl,
>          virNetError(VIR_ERR_INTERNAL_ERROR,
>                      _("failed to encode SASL data: %d (%s)"),
>                      err, sasl_errstring(err, NULL, NULL));
> -        return -1;
> +        goto cleanup;
>      }
> -    return 0;
> +    ret = 0;
> +
> +cleanup:
> +    virMutexUnlock(&sasl->lock);
> +    return ret;
>  }
>  
>  ssize_t virNetSASLSessionDecode(virNetSASLSessionPtr sasl,
> @@ -559,12 +684,14 @@ ssize_t virNetSASLSessionDecode(virNetSASLSessionPtr sasl,
>      unsigned inlen = inputLen;
>      unsigned outlen = 0;
>      int err;
> +    ssize_t ret = -1;
>  
> +    virMutexLock(&sasl->lock);
>      if (inputLen > sasl->maxbufsize) {
>          virReportSystemError(EINVAL,
>                               _("SASL data length %zu too long, max %zu"),
>                               inputLen, sasl->maxbufsize);
> -        return -1;
> +        goto cleanup;
>      }
>  
>      err = sasl_decode(sasl->conn,
> @@ -577,9 +704,13 @@ ssize_t virNetSASLSessionDecode(virNetSASLSessionPtr sasl,
>          virNetError(VIR_ERR_INTERNAL_ERROR,
>                      _("failed to decode SASL data: %d (%s)"),
>                      err, sasl_errstring(err, NULL, NULL));
> -        return -1;
> +        goto cleanup;
>      }
> -    return 0;
> +    ret = 0;
> +
> +cleanup:
> +    virMutexUnlock(&sasl->lock);
> +    return ret;
>  }
>  
>  void virNetSASLSessionFree(virNetSASLSessionPtr sasl)
> @@ -587,12 +718,17 @@ void virNetSASLSessionFree(virNetSASLSessionPtr sasl)
>      if (!sasl)
>          return;
>  
> +    virMutexLock(&sasl->lock);
>      sasl->refs--;
> -    if (sasl->refs > 0)
> +    if (sasl->refs > 0) {
> +        virMutexUnlock(&sasl->lock);
>          return;
> +    }
>  
>      if (sasl->conn)
>          sasl_dispose(&sasl->conn);
>  
> +    virMutexUnlock(&sasl->lock);
> +    virMutexDestroy(&sasl->lock);
>      VIR_FREE(sasl);
>  }
> diff --git a/src/rpc/virnettlscontext.c b/src/rpc/virnettlscontext.c
> index bde4e7a..db03669 100644
> --- a/src/rpc/virnettlscontext.c
> +++ b/src/rpc/virnettlscontext.c
> @@ -34,6 +34,7 @@
>  #include "virterror_internal.h"
>  #include "util.h"
>  #include "logging.h"
> +#include "threads.h"
>  #include "configmake.h"
>  
>  #define DH_BITS 1024
> @@ -52,6 +53,7 @@
>                           __FUNCTION__, __LINE__, __VA_ARGS__)
>  
>  struct _virNetTLSContext {
> +    virMutex lock;
>      int refs;
>  
>      gnutls_certificate_credentials_t x509cred;
> @@ -63,6 +65,8 @@ struct _virNetTLSContext {
>  };
>  
>  struct _virNetTLSSession {
> +    virMutex lock;
> +
>      int refs;
>  
>      bool handshakeComplete;
> @@ -653,6 +657,13 @@ static virNetTLSContextPtr virNetTLSContextNew(const char *cacert,
>          return NULL;
>      }
>  
> +    if (virMutexInit(&ctxt->lock) < 0) {
> +        virNetError(VIR_ERR_INTERNAL_ERROR, "%s",
> +                    _("Failed to initialized mutex"));
> +        VIR_FREE(ctxt);
> +        return NULL;
> +    }
> +
>      ctxt->refs = 1;
>  
>      /* Initialise GnuTLS. */
> @@ -1053,18 +1064,29 @@ authfail:
>  int virNetTLSContextCheckCertificate(virNetTLSContextPtr ctxt,
>                                       virNetTLSSessionPtr sess)
>  {
> +    int ret = -1;
> +
> +    virMutexLock(&ctxt->lock);
> +    virMutexLock(&sess->lock);
>      if (virNetTLSContextValidCertificate(ctxt, sess) < 0) {
>          virErrorPtr err = virGetLastError();
>          VIR_WARN("Certificate check failed %s", err && err->message ? err->message : "<unknown>");
>          if (ctxt->requireValidCert) {
>              virNetError(VIR_ERR_AUTH_FAILED, "%s",
>                          _("Failed to verify peer's certificate"));
> -            return -1;
> +            goto cleanup;
>          }
>          virResetLastError();
>          VIR_INFO("Ignoring bad certificate at user request");
>      }
> -    return 0;
> +
> +    ret = 0;
> +
> +cleanup:
> +    virMutexUnlock(&ctxt->lock);
> +    virMutexUnlock(&sess->lock);
> +
> +    return ret;
>  }
>  
>  void virNetTLSContextFree(virNetTLSContextPtr ctxt)
> @@ -1072,12 +1094,17 @@ void virNetTLSContextFree(virNetTLSContextPtr ctxt)
>      if (!ctxt)
>          return;
>  
> +    virMutexLock(&ctxt->lock);
>      ctxt->refs--;
> -    if (ctxt->refs > 0)
> +    if (ctxt->refs > 0) {
> +        virMutexUnlock(&ctxt->lock);
>          return;
> +    }
>  
>      gnutls_dh_params_deinit(ctxt->dhParams);
>      gnutls_certificate_free_credentials(ctxt->x509cred);
> +    virMutexUnlock(&ctxt->lock);
> +    virMutexDestroy(&ctxt->lock);
>      VIR_FREE(ctxt);
>  }
>  
> @@ -1124,6 +1151,13 @@ virNetTLSSessionPtr virNetTLSSessionNew(virNetTLSContextPtr ctxt,
>          return NULL;
>      }
>  
> +    if (virMutexInit(&sess->lock) < 0) {
> +        virNetError(VIR_ERR_INTERNAL_ERROR, "%s",
> +                    _("Failed to initialized mutex"));
> +        VIR_FREE(ctxt);
> +        return NULL;
> +    }
> +
>      sess->refs = 1;
>      if (hostname &&
>          !(sess->hostname = strdup(hostname))) {
> @@ -1184,7 +1218,9 @@ error:
>  
>  void virNetTLSSessionRef(virNetTLSSessionPtr sess)
>  {
> +    virMutexLock(&sess->lock);
>      sess->refs++;
> +    virMutexUnlock(&sess->lock);
>  }
>  
>  void virNetTLSSessionSetIOCallbacks(virNetTLSSessionPtr sess,
> @@ -1192,9 +1228,11 @@ void virNetTLSSessionSetIOCallbacks(virNetTLSSessionPtr sess,
>                                      virNetTLSSessionReadFunc readFunc,
>                                      void *opaque)
>  {
> +    virMutexLock(&sess->lock);
>      sess->writeFunc = writeFunc;
>      sess->readFunc = readFunc;
>      sess->opaque = opaque;
> +    virMutexUnlock(&sess->lock);
>  }
>  
>  
> @@ -1202,10 +1240,12 @@ ssize_t virNetTLSSessionWrite(virNetTLSSessionPtr sess,
>                                const char *buf, size_t len)
>  {
>      ssize_t ret;
> +
> +    virMutexLock(&sess->lock);
>      ret = gnutls_record_send(sess->session, buf, len);
>  
>      if (ret >= 0)
> -        return ret;
> +        goto cleanup;
>  
>      switch (ret) {
>      case GNUTLS_E_AGAIN:
> @@ -1222,7 +1262,11 @@ ssize_t virNetTLSSessionWrite(virNetTLSSessionPtr sess,
>          break;
>      }
>  
> -    return -1;
> +    ret = -1;
> +
> +cleanup:
> +    virMutexUnlock(&sess->lock);
> +    return ret;
>  }
>  
>  ssize_t virNetTLSSessionRead(virNetTLSSessionPtr sess,
> @@ -1230,10 +1274,11 @@ ssize_t virNetTLSSessionRead(virNetTLSSessionPtr sess,
>  {
>      ssize_t ret;
>  
> +    virMutexLock(&sess->lock);
>      ret = gnutls_record_recv(sess->session, buf, len);
>  
>      if (ret >= 0)
> -        return ret;
> +        goto cleanup;
>  
>      switch (ret) {
>      case GNUTLS_E_AGAIN:
> @@ -1247,21 +1292,29 @@ ssize_t virNetTLSSessionRead(virNetTLSSessionPtr sess,
>          break;
>      }
>  
> -    return -1;
> +    ret = -1;
> +
> +cleanup:
> +    virMutexUnlock(&sess->lock);
> +    return ret;
>  }
>  
>  int virNetTLSSessionHandshake(virNetTLSSessionPtr sess)
>  {
> +    int ret;
>      VIR_DEBUG("sess=%p", sess);
> -    int ret = gnutls_handshake(sess->session);
> +    virMutexLock(&sess->lock);
> +    ret = gnutls_handshake(sess->session);
>      VIR_DEBUG("Ret=%d", ret);
>      if (ret == 0) {
>          sess->handshakeComplete = true;
>          VIR_DEBUG("Handshake is complete");
> -        return 0;
> +        goto cleanup;
> +    }
> +    if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) {
> +        ret = 1;
> +        goto cleanup;
>      }
> -    if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN)
> -        return 1;
>  
>  #if 0
>      PROBE(CLIENT_TLS_FAIL, "fd=%d",
> @@ -1271,32 +1324,43 @@ int virNetTLSSessionHandshake(virNetTLSSessionPtr sess)
>      virNetError(VIR_ERR_AUTH_FAILED,
>                  _("TLS handshake failed %s"),
>                  gnutls_strerror(ret));
> -    return -1;
> +    ret = -1;
> +
> +cleanup:
> +    virMutexUnlock(&sess->lock);
> +    return ret;
>  }
>  
>  virNetTLSSessionHandshakeStatus
>  virNetTLSSessionGetHandshakeStatus(virNetTLSSessionPtr sess)
>  {
> +    virNetTLSSessionHandshakeStatus ret;
> +    virMutexLock(&sess->lock);
>      if (sess->handshakeComplete)
> -        return VIR_NET_TLS_HANDSHAKE_COMPLETE;
> +        ret = VIR_NET_TLS_HANDSHAKE_COMPLETE;
>      else if (gnutls_record_get_direction(sess->session) == 0)
> -        return VIR_NET_TLS_HANDSHAKE_RECVING;
> +        ret = VIR_NET_TLS_HANDSHAKE_RECVING;
>      else
> -        return VIR_NET_TLS_HANDSHAKE_SENDING;
> +        ret = VIR_NET_TLS_HANDSHAKE_SENDING;
> +    virMutexUnlock(&sess->lock);
> +    return ret;
>  }
>  
>  int virNetTLSSessionGetKeySize(virNetTLSSessionPtr sess)
>  {
>      gnutls_cipher_algorithm_t cipher;
>      int ssf;
> -
> +    virMutexLock(&sess->lock);
>      cipher = gnutls_cipher_get(sess->session);
>      if (!(ssf = gnutls_cipher_get_key_size(cipher))) {
>          virNetError(VIR_ERR_INTERNAL_ERROR, "%s",
>                      _("invalid cipher size for TLS session"));
> -        return -1;
> +        ssf = -1;
> +        goto cleanup;
>      }
>  
> +cleanup:
> +    virMutexUnlock(&sess->lock);
>      return ssf;
>  }
>  
> @@ -1306,11 +1370,16 @@ void virNetTLSSessionFree(virNetTLSSessionPtr sess)
>      if (!sess)
>          return;
>  
> +    virMutexLock(&sess->lock);
>      sess->refs--;
> -    if (sess->refs > 0)
> +    if (sess->refs > 0) {
> +        virMutexUnlock(&sess->lock);
>          return;
> +    }
>  
>      VIR_FREE(sess->hostname);
>      gnutls_deinit(sess->session);
> +    virMutexUnlock(&sess->lock);
> +    virMutexDestroy(&sess->lock);
>      VIR_FREE(sess);
>  }

  ACK

Daniel

-- 
Daniel Veillard      | libxml Gnome XML XSLT toolkit  http://xmlsoft.org/
daniel veillard com  | Rpmfind RPM search engine http://rpmfind.net/
http://veillard.com/ | virtualization library  http://libvirt.org/


[Date Prev][Date Next]   [Thread Prev][Thread Next]   [Thread Index] [Date Index] [Author Index]