[Date Prev][Date Next]   [Thread Prev][Thread Next]   [Thread Index] [Date Index] [Author Index]

[libvirt] [PATCH 3/4] Use a virFreeCallback on virNetSocket to ensure safe release



From: "Daniel P. Berrange" <berrange redhat com>

When unregistering an I/O callback from a virNetSocket object,
there is still a chance that an event may come in on the callback.
In this case it is possible that the virNetSocket might have been
freed already. Make use of a virFreeCallback when registering
the I/O callbacks and hold a reference for the entire time the
callback is set.

* src/rpc/virnetsocket.c: Register a free function for the
  file handle watch
* src/rpc/virnetsocket.h, src/rpc/virnetserverservice.c,
  src/rpc/virnetserverclient.c, src/rpc/virnetclient.c: Add
  a free function for the socket I/O watches
---
 src/rpc/virnetclient.c        |   13 ++++++++++++-
 src/rpc/virnetserverclient.c  |   13 ++++++++++++-
 src/rpc/virnetserverservice.c |   20 ++++++++++++++++++--
 src/rpc/virnetsocket.c        |   30 ++++++++++++++++++++++++++++--
 src/rpc/virnetsocket.h        |    3 ++-
 5 files changed, 72 insertions(+), 7 deletions(-)

diff --git a/src/rpc/virnetclient.c b/src/rpc/virnetclient.c
index 4a9eabc..27542a5 100644
--- a/src/rpc/virnetclient.c
+++ b/src/rpc/virnetclient.c
@@ -110,6 +110,13 @@ static void virNetClientIncomingEvent(virNetSocketPtr sock,
                                       int events,
                                       void *opaque);
 
+static void virNetClientEventFree(void *opaque)
+{
+    virNetClientPtr client = opaque;
+
+    virNetClientFree(client);
+}
+
 static virNetClientPtr virNetClientNew(virNetSocketPtr sock,
                                        const char *hostname)
 {
@@ -140,11 +147,15 @@ static virNetClientPtr virNetClientNew(virNetSocketPtr sock,
         goto no_memory;
 
     /* Set up a callback to listen on the socket data */
+    client->refs++;
     if (virNetSocketAddIOCallback(client->sock,
                                   VIR_EVENT_HANDLE_READABLE,
                                   virNetClientIncomingEvent,
-                                  client) < 0)
+                                  client,
+                                  virNetClientEventFree) < 0) {
+        client->refs--;
         VIR_DEBUG("Failed to add event watch, disabling events");
+    }
 
     VIR_DEBUG("client=%p refs=%d", client, client->refs);
     return client;
diff --git a/src/rpc/virnetserverclient.c b/src/rpc/virnetserverclient.c
index 341981f..317d59c 100644
--- a/src/rpc/virnetserverclient.c
+++ b/src/rpc/virnetserverclient.c
@@ -160,6 +160,13 @@ virNetServerClientCalculateHandleMode(virNetServerClientPtr client) {
     return mode;
 }
 
+static void virNetServerClientEventFree(void *opaque)
+{
+    virNetServerClientPtr client = opaque;
+
+    virNetServerClientFree(client);
+}
+
 /*
  * @server: a locked or unlocked server object
  * @client: a locked client object
@@ -168,12 +175,16 @@ static int virNetServerClientRegisterEvent(virNetServerClientPtr client)
 {
     int mode = virNetServerClientCalculateHandleMode(client);
 
+    client->refs++;
     VIR_DEBUG("Registering client event callback %d", mode);
     if (virNetSocketAddIOCallback(client->sock,
                                   mode,
                                   virNetServerClientDispatchEvent,
-                                  client) < 0)
+                                  client,
+                                  virNetServerClientEventFree) < 0) {
+        client->refs--;
         return -1;
+    }
 
     return 0;
 }
diff --git a/src/rpc/virnetserverservice.c b/src/rpc/virnetserverservice.c
index 8c250e2..d5648dc 100644
--- a/src/rpc/virnetserverservice.c
+++ b/src/rpc/virnetserverservice.c
@@ -91,6 +91,14 @@ error:
 }
 
 
+static void virNetServerServiceEventFree(void *opaque)
+{
+    virNetServerServicePtr svc = opaque;
+
+    virNetServerServiceFree(svc);
+}
+
+
 virNetServerServicePtr virNetServerServiceNewTCP(const char *nodename,
                                                  const char *service,
                                                  int auth,
@@ -124,11 +132,15 @@ virNetServerServicePtr virNetServerServiceNewTCP(const char *nodename,
 
         /* IO callback is initially disabled, until we're ready
          * to deal with incoming clients */
+        virNetServerServiceRef(svc);
         if (virNetSocketAddIOCallback(svc->socks[i],
                                       0,
                                       virNetServerServiceAccept,
-                                      svc) < 0)
+                                      svc,
+                                      virNetServerServiceEventFree) < 0) {
+            virNetServerServiceFree(svc);
             goto error;
+        }
     }
 
 
@@ -180,11 +192,15 @@ virNetServerServicePtr virNetServerServiceNewUNIX(const char *path,
 
         /* IO callback is initially disabled, until we're ready
          * to deal with incoming clients */
+        virNetServerServiceRef(svc);
         if (virNetSocketAddIOCallback(svc->socks[i],
                                       0,
                                       virNetServerServiceAccept,
-                                      svc) < 0)
+                                      svc,
+                                      virNetServerServiceEventFree) < 0) {
+            virNetServerServiceFree(svc);
             goto error;
+        }
     }
 
 
diff --git a/src/rpc/virnetsocket.c b/src/rpc/virnetsocket.c
index 8dd4d3a..43460d9 100644
--- a/src/rpc/virnetsocket.c
+++ b/src/rpc/virnetsocket.c
@@ -58,8 +58,12 @@ struct _virNetSocket {
     pid_t pid;
     int errfd;
     bool client;
+
+    /* Event callback fields */
     virNetSocketIOFunc func;
     void *opaque;
+    virFreeCallback ff;
+
     virSocketAddr localAddr;
     virSocketAddr remoteAddr;
     char *localAddrStr;
@@ -1121,10 +1125,31 @@ static void virNetSocketEventHandle(int watch ATTRIBUTE_UNUSED,
 }
 
 
+static void virNetSocketEventFree(void *opaque)
+{
+    virNetSocketPtr sock = opaque;
+    virFreeCallback ff;
+    void *eopaque;
+
+    virMutexLock(&sock->lock);
+    ff = sock->ff;
+    eopaque = sock->opaque;
+    sock->func = NULL;
+    sock->ff = NULL;
+    sock->opaque = NULL;
+    virMutexUnlock(&sock->lock);
+
+    if (ff)
+        ff(eopaque);
+
+    virNetSocketFree(sock);
+}
+
 int virNetSocketAddIOCallback(virNetSocketPtr sock,
                               int events,
                               virNetSocketIOFunc func,
-                              void *opaque)
+                              void *opaque,
+                              virFreeCallback ff)
 {
     int ret = -1;
 
@@ -1139,12 +1164,13 @@ int virNetSocketAddIOCallback(virNetSocketPtr sock,
                                          events,
                                          virNetSocketEventHandle,
                                          sock,
-                                         NULL)) < 0) {
+                                         virNetSocketEventFree)) < 0) {
         VIR_DEBUG("Failed to register watch on socket %p", sock);
         goto cleanup;
     }
     sock->func = func;
     sock->opaque = opaque;
+    sock->ff = ff;
 
     ret = 0;
 
diff --git a/src/rpc/virnetsocket.h b/src/rpc/virnetsocket.h
index 5f882ac..e13ab8f 100644
--- a/src/rpc/virnetsocket.h
+++ b/src/rpc/virnetsocket.h
@@ -109,7 +109,8 @@ int virNetSocketAccept(virNetSocketPtr sock,
 int virNetSocketAddIOCallback(virNetSocketPtr sock,
                               int events,
                               virNetSocketIOFunc func,
-                              void *opaque);
+                              void *opaque,
+                              virFreeCallback ff);
 
 void virNetSocketUpdateIOCallback(virNetSocketPtr sock,
                                   int events);
-- 
1.7.6


[Date Prev][Date Next]   [Thread Prev][Thread Next]   [Thread Index] [Date Index] [Author Index]