[libvirt] [PATCH v2 1/2] rpc: Switch to dynamically allocated message buffer

Michal Privoznik mprivozn at redhat.com
Tue May 15 15:04:33 UTC 2012


Currently, we are allocating buffer for RPC messages statically.
This is not such pain when RPC limits are small. However, if we want
ever to increase those limits, we need to allocate buffer dynamically,
based on RPC message len (= the first 4 bytes). Therefore we will
decrease our mem usage in most cases and still be flexible enough in
corner cases.
---
 src/rpc/virnetclient.c       |   16 ++-
 src/rpc/virnetmessage.c      |   12 ++-
 src/rpc/virnetmessage.h      |    5 +-
 src/rpc/virnetserverclient.c |   24 +++-
 tests/virnetmessagetest.c    |  393 +++++++++++++++++++++++-------------------
 5 files changed, 266 insertions(+), 184 deletions(-)

diff --git a/src/rpc/virnetclient.c b/src/rpc/virnetclient.c
index 3a60db6..69d72f7 100644
--- a/src/rpc/virnetclient.c
+++ b/src/rpc/virnetclient.c
@@ -801,7 +801,12 @@ virNetClientCallDispatchReply(virNetClientPtr client)
         return -1;
     }
 
-    memcpy(thecall->msg->buffer, client->msg.buffer, sizeof(client->msg.buffer));
+    if (VIR_REALLOC_N(thecall->msg->buffer, client->msg.bufferLength) < 0) {
+        virReportOOMError();
+        return -1;
+    }
+
+    memcpy(thecall->msg->buffer, client->msg.buffer, client->msg.bufferLength);
     memcpy(&thecall->msg->header, &client->msg.header, sizeof(client->msg.header));
     thecall->msg->bufferLength = client->msg.bufferLength;
     thecall->msg->bufferOffset = client->msg.bufferOffset;
@@ -987,6 +992,7 @@ virNetClientIOWriteMessage(virNetClientPtr client,
         }
         thecall->msg->donefds = 0;
         thecall->msg->bufferOffset = thecall->msg->bufferLength = 0;
+        VIR_FREE(thecall->msg->buffer);
         if (thecall->expectReply)
             thecall->mode = VIR_NET_CLIENT_MODE_WAIT_RX;
         else
@@ -1030,8 +1036,13 @@ virNetClientIOReadMessage(virNetClientPtr client)
     ssize_t ret;
 
     /* Start by reading length word */
-    if (client->msg.bufferLength == 0)
+    if (client->msg.bufferLength == 0) {
         client->msg.bufferLength = 4;
+        if (VIR_ALLOC_N(client->msg.buffer, client->msg.bufferLength) < 0) {
+            virReportOOMError();
+            return -ENOMEM;
+        }
+    }
 
     wantData = client->msg.bufferLength - client->msg.bufferOffset;
 
@@ -1108,6 +1119,7 @@ virNetClientIOHandleInput(virNetClientPtr client)
 
                 ret = virNetClientCallDispatch(client);
                 client->msg.bufferOffset = client->msg.bufferLength = 0;
+                VIR_FREE(client->msg.buffer);
                 /*
                  * We've completed one call, but we don't want to
                  * spin around the loop forever if there are many
diff --git a/src/rpc/virnetmessage.c b/src/rpc/virnetmessage.c
index 17ecc90..dc4c212 100644
--- a/src/rpc/virnetmessage.c
+++ b/src/rpc/virnetmessage.c
@@ -61,6 +61,7 @@ void virNetMessageClear(virNetMessagePtr msg)
     for (i = 0 ; i < msg->nfds ; i++)
         VIR_FORCE_CLOSE(msg->fds[i]);
     VIR_FREE(msg->fds);
+    VIR_FREE(msg->buffer);
     memset(msg, 0, sizeof(*msg));
     msg->tracked = tracked;
 }
@@ -79,6 +80,7 @@ void virNetMessageFree(virNetMessagePtr msg)
 
     for (i = 0 ; i < msg->nfds ; i++)
         VIR_FORCE_CLOSE(msg->fds[i]);
+    VIR_FREE(msg->buffer);
     VIR_FREE(msg->fds);
     VIR_FREE(msg);
 }
@@ -144,6 +146,10 @@ int virNetMessageDecodeLength(virNetMessagePtr msg)
     /* Extend our declared buffer length and carry
        on reading the header + payload */
     msg->bufferLength += len;
+    if (VIR_REALLOC_N(msg->buffer, msg->bufferLength) < 0) {
+        virReportOOMError();
+        goto cleanup;
+    }
 
     VIR_DEBUG("Got length, now need %zu total (%u more)",
               msg->bufferLength, len);
@@ -212,7 +218,11 @@ int virNetMessageEncodeHeader(virNetMessagePtr msg)
     int ret = -1;
     unsigned int len = 0;
 
-    msg->bufferLength = sizeof(msg->buffer);
+    msg->bufferLength = VIR_NET_MESSAGE_MAX + VIR_NET_MESSAGE_LEN_MAX;
+    if (VIR_ALLOC_N(msg->buffer, msg->bufferLength) < 0) {
+        virReportOOMError();
+        goto cleanup;
+    }
     msg->bufferOffset = 0;
 
     /* Format the header. */
diff --git a/src/rpc/virnetmessage.h b/src/rpc/virnetmessage.h
index c54e7c6..8f36a70 100644
--- a/src/rpc/virnetmessage.h
+++ b/src/rpc/virnetmessage.h
@@ -31,13 +31,10 @@ typedef virNetMessage *virNetMessagePtr;
 
 typedef void (*virNetMessageFreeCallback)(virNetMessagePtr msg, void *opaque);
 
-/* Never allocate this (huge) buffer on the stack. Always
- * use virNetMessageNew() to allocate on the heap
- */
 struct _virNetMessage {
     bool tracked;
 
-    char buffer[VIR_NET_MESSAGE_MAX + VIR_NET_MESSAGE_LEN_MAX];
+    char *buffer; /* Typically VIR_NET_MESSAGE_MAX + VIR_NET_MESSAGE_LEN_MAX */
     size_t bufferLength;
     size_t bufferOffset;
 
diff --git a/src/rpc/virnetserverclient.c b/src/rpc/virnetserverclient.c
index 67600fd..6ae4e25 100644
--- a/src/rpc/virnetserverclient.c
+++ b/src/rpc/virnetserverclient.c
@@ -313,6 +313,11 @@ virNetServerClientCheckAccess(virNetServerClientPtr client)
      * (NB. The '\1' byte is sent in an encrypted record).
      */
     confirm->bufferLength = 1;
+    if (VIR_ALLOC_N(confirm->buffer, confirm->bufferLength) < 0) {
+        virReportOOMError();
+        virNetMessageFree(confirm);
+        return -1;
+    }
     confirm->bufferOffset = 0;
     confirm->buffer[0] = '\1';
 
@@ -373,6 +378,10 @@ virNetServerClientPtr virNetServerClientNew(virNetSocketPtr sock,
     if (!(client->rx = virNetMessageNew(true)))
         goto error;
     client->rx->bufferLength = VIR_NET_MESSAGE_LEN_MAX;
+    if (VIR_ALLOC_N(client->rx->buffer, client->rx->bufferLength) < 0) {
+        virReportOOMError();
+        goto error;
+    }
     client->nrequests = 1;
 
     PROBE(RPC_SERVER_CLIENT_NEW,
@@ -922,7 +931,13 @@ readmore:
                 client->wantClose = true;
             } else {
                 client->rx->bufferLength = VIR_NET_MESSAGE_LEN_MAX;
-                client->nrequests++;
+                if (VIR_ALLOC_N(client->rx->buffer,
+                                client->rx->bufferLength) < 0) {
+                    virReportOOMError();
+                    client->wantClose = true;
+                } else {
+                    client->nrequests++;
+                }
             }
         }
         virNetServerClientUpdateEvent(client);
@@ -1019,8 +1034,13 @@ virNetServerClientDispatchWrite(virNetServerClientPtr client)
                     client->nrequests < client->nrequests_max) {
                     /* Ready to recv more messages */
                     virNetMessageClear(msg);
+                    msg->bufferLength = VIR_NET_MESSAGE_LEN_MAX;
+                    if (VIR_ALLOC_N(msg->buffer, msg->bufferLength) < 0) {
+                        virReportOOMError();
+                        virNetMessageFree(msg);
+                        return;
+                    }
                     client->rx = msg;
-                    client->rx->bufferLength = VIR_NET_MESSAGE_LEN_MAX;
                     msg = NULL;
                     client->nrequests++;
                 }
diff --git a/tests/virnetmessagetest.c b/tests/virnetmessagetest.c
index 28dc09f..6c294ca 100644
--- a/tests/virnetmessagetest.c
+++ b/tests/virnetmessagetest.c
@@ -35,7 +35,7 @@
 
 static int testMessageHeaderEncode(const void *args ATTRIBUTE_UNUSED)
 {
-    static virNetMessage msg;
+    virNetMessagePtr msg = virNetMessageNew(true);
     static const char expect[] = {
         0x00, 0x00, 0x00, 0x1c,  /* Length */
         0x11, 0x22, 0x33, 0x44,  /* Program */
@@ -45,128 +45,153 @@ static int testMessageHeaderEncode(const void *args ATTRIBUTE_UNUSED)
         0x00, 0x00, 0x00, 0x99,  /* Serial */
         0x00, 0x00, 0x00, 0x00,  /* Status */
     };
-    memset(&msg, 0, sizeof(msg));
-
-    msg.header.prog = 0x11223344;
-    msg.header.vers = 0x01;
-    msg.header.proc = 0x666;
-    msg.header.type = VIR_NET_CALL;
-    msg.header.serial = 0x99;
-    msg.header.status = VIR_NET_OK;
+    /* According to doc to virNetMessageEncodeHeader(&msg):
+     * msg->buffer will be this long */
+    unsigned long msg_buf_size = VIR_NET_MESSAGE_MAX + VIR_NET_MESSAGE_LEN_MAX;
+    int ret = -1;
 
-    if (virNetMessageEncodeHeader(&msg) < 0)
+    if (!msg) {
+        virReportOOMError();
         return -1;
+    }
+
+    msg->header.prog = 0x11223344;
+    msg->header.vers = 0x01;
+    msg->header.proc = 0x666;
+    msg->header.type = VIR_NET_CALL;
+    msg->header.serial = 0x99;
+    msg->header.status = VIR_NET_OK;
+
+    if (virNetMessageEncodeHeader(msg) < 0)
+        goto cleanup;
 
-    if (ARRAY_CARDINALITY(expect) != msg.bufferOffset) {
+    if (ARRAY_CARDINALITY(expect) != msg->bufferOffset) {
         VIR_DEBUG("Expect message offset %zu got %zu",
-                  sizeof(expect), msg.bufferOffset);
-        return -1;
+                  sizeof(expect), msg->bufferOffset);
+        goto cleanup;
     }
 
-    if (msg.bufferLength != sizeof(msg.buffer)) {
+    if (msg->bufferLength != msg_buf_size) {
         VIR_DEBUG("Expect message offset %zu got %zu",
-                  sizeof(msg.buffer), msg.bufferLength);
-        return -1;
+                  msg_buf_size, msg->bufferLength);
+        goto cleanup;
     }
 
-    if (memcmp(expect, msg.buffer, sizeof(expect)) != 0) {
-        virtTestDifferenceBin(stderr, expect, msg.buffer, sizeof(expect));
-        return -1;
+    if (memcmp(expect, msg->buffer, sizeof(expect)) != 0) {
+        virtTestDifferenceBin(stderr, expect, msg->buffer, sizeof(expect));
+        goto cleanup;
     }
 
-    return 0;
+    ret = 0;
+cleanup:
+    virNetMessageFree(msg);
+    return ret;
 }
 
 static int testMessageHeaderDecode(const void *args ATTRIBUTE_UNUSED)
 {
-    static virNetMessage msg = {
-        .bufferOffset = 0,
-        .bufferLength = 0x4,
-        .buffer = {
-            0x00, 0x00, 0x00, 0x1c,  /* Length */
-            0x11, 0x22, 0x33, 0x44,  /* Program */
-            0x00, 0x00, 0x00, 0x01,  /* Version */
-            0x00, 0x00, 0x06, 0x66,  /* Procedure */
-            0x00, 0x00, 0x00, 0x01,  /* Type */
-            0x00, 0x00, 0x00, 0x99,  /* Serial */
-            0x00, 0x00, 0x00, 0x01,  /* Status */
-        },
-        .header = { 0, 0, 0, 0, 0, 0 },
+    virNetMessagePtr msg = virNetMessageNew(true);
+    static char input_buf [] =  {
+        0x00, 0x00, 0x00, 0x1c,  /* Length */
+        0x11, 0x22, 0x33, 0x44,  /* Program */
+        0x00, 0x00, 0x00, 0x01,  /* Version */
+        0x00, 0x00, 0x06, 0x66,  /* Procedure */
+        0x00, 0x00, 0x00, 0x01,  /* Type */
+        0x00, 0x00, 0x00, 0x99,  /* Serial */
+        0x00, 0x00, 0x00, 0x01,  /* Status */
     };
+    int ret = -1;
+
+    if (!msg) {
+        virReportOOMError();
+        return -1;
+    }
+
+    msg->bufferLength = 4;
+    if (VIR_ALLOC_N(msg->buffer, msg->bufferLength) < 0) {
+        virReportOOMError();
+        goto cleanup;
+    }
+    memcpy(msg->buffer, input_buf, msg->bufferLength);
 
-    msg.header.prog = 0x11223344;
-    msg.header.vers = 0x01;
-    msg.header.proc = 0x666;
-    msg.header.type = VIR_NET_CALL;
-    msg.header.serial = 0x99;
-    msg.header.status = VIR_NET_OK;
+    msg->header.prog = 0x11223344;
+    msg->header.vers = 0x01;
+    msg->header.proc = 0x666;
+    msg->header.type = VIR_NET_CALL;
+    msg->header.serial = 0x99;
+    msg->header.status = VIR_NET_OK;
 
-    if (virNetMessageDecodeLength(&msg) < 0) {
+    if (virNetMessageDecodeLength(msg) < 0) {
         VIR_DEBUG("Failed to decode message header");
-        return -1;
+        goto cleanup;
     }
 
-    if (msg.bufferOffset != 0x4) {
+    if (msg->bufferOffset != 0x4) {
         VIR_DEBUG("Expecting offset %zu got %zu",
-                  (size_t)4, msg.bufferOffset);
-        return -1;
+                  (size_t)4, msg->bufferOffset);
+        goto cleanup;
     }
 
-    if (msg.bufferLength != 0x1c) {
+    if (msg->bufferLength != 0x1c) {
         VIR_DEBUG("Expecting length %zu got %zu",
-                  (size_t)0x1c, msg.bufferLength);
-        return -1;
+                  (size_t)0x1c, msg->bufferLength);
+        goto cleanup;
     }
 
-    if (virNetMessageDecodeHeader(&msg) < 0) {
+    memcpy(msg->buffer, input_buf, msg->bufferLength);
+
+    if (virNetMessageDecodeHeader(msg) < 0) {
         VIR_DEBUG("Failed to decode message header");
-        return -1;
+        goto cleanup;
     }
 
-    if (msg.bufferOffset != msg.bufferLength) {
+    if (msg->bufferOffset != msg->bufferLength) {
         VIR_DEBUG("Expect message offset %zu got %zu",
-                  msg.bufferOffset, msg.bufferLength);
-        return -1;
+                  msg->bufferOffset, msg->bufferLength);
+        goto cleanup;
     }
 
-    if (msg.header.prog != 0x11223344) {
+    if (msg->header.prog != 0x11223344) {
         VIR_DEBUG("Expect prog %d got %d",
-                  0x11223344, msg.header.prog);
-        return -1;
+                  0x11223344, msg->header.prog);
+        goto cleanup;
     }
-    if (msg.header.vers != 0x1) {
+    if (msg->header.vers != 0x1) {
         VIR_DEBUG("Expect vers %d got %d",
-                  0x11223344, msg.header.vers);
-        return -1;
+                  0x11223344, msg->header.vers);
+        goto cleanup;
     }
-    if (msg.header.proc != 0x666) {
+    if (msg->header.proc != 0x666) {
         VIR_DEBUG("Expect proc %d got %d",
-                  0x666, msg.header.proc);
-        return -1;
+                  0x666, msg->header.proc);
+        goto cleanup;
     }
-    if (msg.header.type != VIR_NET_REPLY) {
+    if (msg->header.type != VIR_NET_REPLY) {
         VIR_DEBUG("Expect type %d got %d",
-                  VIR_NET_REPLY, msg.header.type);
-        return -1;
+                  VIR_NET_REPLY, msg->header.type);
+        goto cleanup;
     }
-    if (msg.header.serial != 0x99) {
+    if (msg->header.serial != 0x99) {
         VIR_DEBUG("Expect serial %d got %d",
-                  0x99, msg.header.serial);
-        return -1;
+                  0x99, msg->header.serial);
+        goto cleanup;
     }
-    if (msg.header.status != VIR_NET_ERROR) {
+    if (msg->header.status != VIR_NET_ERROR) {
         VIR_DEBUG("Expect status %d got %d",
-                  VIR_NET_ERROR, msg.header.status);
-        return -1;
+                  VIR_NET_ERROR, msg->header.status);
+        goto cleanup;
     }
 
-    return 0;
+    ret = 0;
+cleanup:
+    virNetMessageFree(msg);
+    return ret;
 }
 
 static int testMessagePayloadEncode(const void *args ATTRIBUTE_UNUSED)
 {
     virNetMessageError err;
-    static virNetMessage msg;
+    virNetMessagePtr msg = virNetMessageNew(true);
     int ret = -1;
     static const char expect[] = {
         0x00, 0x00, 0x00, 0x74,  /* Length */
@@ -200,7 +225,12 @@ static int testMessagePayloadEncode(const void *args ATTRIBUTE_UNUSED)
         0x00, 0x00, 0x00, 0x02,  /* Error int2 */
         0x00, 0x00, 0x00, 0x00,  /* Error network pointer */
     };
-    memset(&msg, 0, sizeof(msg));
+
+    if (!msg) {
+        virReportOOMError();
+        return -1;
+    }
+
     memset(&err, 0, sizeof(err));
 
     err.code = VIR_ERR_INTERNAL_ERROR;
@@ -223,33 +253,33 @@ static int testMessagePayloadEncode(const void *args ATTRIBUTE_UNUSED)
     err.int1 = 1;
     err.int2 = 2;
 
-    msg.header.prog = 0x11223344;
-    msg.header.vers = 0x01;
-    msg.header.proc = 0x666;
-    msg.header.type = VIR_NET_MESSAGE;
-    msg.header.serial = 0x99;
-    msg.header.status = VIR_NET_ERROR;
+    msg->header.prog = 0x11223344;
+    msg->header.vers = 0x01;
+    msg->header.proc = 0x666;
+    msg->header.type = VIR_NET_MESSAGE;
+    msg->header.serial = 0x99;
+    msg->header.status = VIR_NET_ERROR;
 
-    if (virNetMessageEncodeHeader(&msg) < 0)
+    if (virNetMessageEncodeHeader(msg) < 0)
         goto cleanup;
 
-    if (virNetMessageEncodePayload(&msg, (xdrproc_t)xdr_virNetMessageError, &err) < 0)
+    if (virNetMessageEncodePayload(msg, (xdrproc_t)xdr_virNetMessageError, &err) < 0)
         goto cleanup;
 
-    if (ARRAY_CARDINALITY(expect) != msg.bufferLength) {
+    if (ARRAY_CARDINALITY(expect) != msg->bufferLength) {
         VIR_DEBUG("Expect message length %zu got %zu",
-                  sizeof(expect), msg.bufferLength);
+                  sizeof(expect), msg->bufferLength);
         goto cleanup;
     }
 
-    if (msg.bufferOffset != 0) {
+    if (msg->bufferOffset != 0) {
         VIR_DEBUG("Expect message offset 0 got %zu",
-                  msg.bufferOffset);
+                  msg->bufferOffset);
         goto cleanup;
     }
 
-    if (memcmp(expect, msg.buffer, sizeof(expect)) != 0) {
-        virtTestDifferenceBin(stderr, expect, msg.buffer, sizeof(expect));
+    if (memcmp(expect, msg->buffer, sizeof(expect)) != 0) {
+        virtTestDifferenceBin(stderr, expect, msg->buffer, sizeof(expect));
         goto cleanup;
     }
 
@@ -267,166 +297,176 @@ cleanup:
     VIR_FREE(err.str1);
     VIR_FREE(err.str2);
     VIR_FREE(err.str3);
+    virNetMessageFree(msg);
     return ret;
 }
 
 static int testMessagePayloadDecode(const void *args ATTRIBUTE_UNUSED)
 {
     virNetMessageError err;
-    static virNetMessage msg = {
-        .bufferOffset = 0,
-        .bufferLength = 0x4,
-        .buffer = {
-            0x00, 0x00, 0x00, 0x74,  /* Length */
-            0x11, 0x22, 0x33, 0x44,  /* Program */
-            0x00, 0x00, 0x00, 0x01,  /* Version */
-            0x00, 0x00, 0x06, 0x66,  /* Procedure */
-            0x00, 0x00, 0x00, 0x02,  /* Type */
-            0x00, 0x00, 0x00, 0x99,  /* Serial */
-            0x00, 0x00, 0x00, 0x01,  /* Status */
-
-            0x00, 0x00, 0x00, 0x01,  /* Error code */
-            0x00, 0x00, 0x00, 0x07,  /* Error domain */
-            0x00, 0x00, 0x00, 0x01,  /* Error message pointer */
-            0x00, 0x00, 0x00, 0x0b,  /* Error message length */
-            'H', 'e', 'l', 'l',  /* Error message string */
-            'o', ' ', 'W', 'o',
-            'r', 'l', 'd', '\0',
-            0x00, 0x00, 0x00, 0x02,  /* Error level */
-            0x00, 0x00, 0x00, 0x00,  /* Error domain pointer */
-            0x00, 0x00, 0x00, 0x01,  /* Error str1 pointer */
-            0x00, 0x00, 0x00, 0x03,  /* Error str1 length */
-            'O', 'n', 'e', '\0',  /* Error str1 message */
-            0x00, 0x00, 0x00, 0x01,  /* Error str2 pointer */
-            0x00, 0x00, 0x00, 0x03,  /* Error str2 length */
-            'T', 'w', 'o', '\0',  /* Error str2 message */
-            0x00, 0x00, 0x00, 0x01,  /* Error str3 pointer */
-            0x00, 0x00, 0x00, 0x05,  /* Error str3 length */
-            'T', 'h', 'r', 'e',  /* Error str3 message */
-            'e', '\0', '\0', '\0',
-            0x00, 0x00, 0x00, 0x01,  /* Error int1 */
-            0x00, 0x00, 0x00, 0x02,  /* Error int2 */
-            0x00, 0x00, 0x00, 0x00,  /* Error network pointer */
-        },
-        .header = { 0, 0, 0, 0, 0, 0 },
+    virNetMessagePtr msg = virNetMessageNew(true);
+    static char input_buffer[] = {
+        0x00, 0x00, 0x00, 0x74,  /* Length */
+        0x11, 0x22, 0x33, 0x44,  /* Program */
+        0x00, 0x00, 0x00, 0x01,  /* Version */
+        0x00, 0x00, 0x06, 0x66,  /* Procedure */
+        0x00, 0x00, 0x00, 0x02,  /* Type */
+        0x00, 0x00, 0x00, 0x99,  /* Serial */
+        0x00, 0x00, 0x00, 0x01,  /* Status */
+
+        0x00, 0x00, 0x00, 0x01,  /* Error code */
+        0x00, 0x00, 0x00, 0x07,  /* Error domain */
+        0x00, 0x00, 0x00, 0x01,  /* Error message pointer */
+        0x00, 0x00, 0x00, 0x0b,  /* Error message length */
+        'H', 'e', 'l', 'l',  /* Error message string */
+        'o', ' ', 'W', 'o',
+        'r', 'l', 'd', '\0',
+        0x00, 0x00, 0x00, 0x02,  /* Error level */
+        0x00, 0x00, 0x00, 0x00,  /* Error domain pointer */
+        0x00, 0x00, 0x00, 0x01,  /* Error str1 pointer */
+        0x00, 0x00, 0x00, 0x03,  /* Error str1 length */
+        'O', 'n', 'e', '\0',  /* Error str1 message */
+        0x00, 0x00, 0x00, 0x01,  /* Error str2 pointer */
+        0x00, 0x00, 0x00, 0x03,  /* Error str2 length */
+        'T', 'w', 'o', '\0',  /* Error str2 message */
+        0x00, 0x00, 0x00, 0x01,  /* Error str3 pointer */
+        0x00, 0x00, 0x00, 0x05,  /* Error str3 length */
+        'T', 'h', 'r', 'e',  /* Error str3 message */
+        'e', '\0', '\0', '\0',
+        0x00, 0x00, 0x00, 0x01,  /* Error int1 */
+        0x00, 0x00, 0x00, 0x02,  /* Error int2 */
+        0x00, 0x00, 0x00, 0x00,  /* Error network pointer */
     };
+    int ret = -1;
+
+    msg->bufferLength = 4;
+    if (VIR_ALLOC_N(msg->buffer, msg->bufferLength) < 0) {
+        virReportOOMError();
+        goto cleanup;
+    }
+    memcpy(msg->buffer, input_buffer, msg->bufferLength);
     memset(&err, 0, sizeof(err));
 
-    if (virNetMessageDecodeLength(&msg) < 0) {
+    if (virNetMessageDecodeLength(msg) < 0) {
         VIR_DEBUG("Failed to decode message header");
-        return -1;
+        goto cleanup;
     }
 
-    if (msg.bufferOffset != 0x4) {
+    if (msg->bufferOffset != 0x4) {
         VIR_DEBUG("Expecting offset %zu got %zu",
-                  (size_t)4, msg.bufferOffset);
-        return -1;
+                  (size_t)4, msg->bufferOffset);
+        goto cleanup;
     }
 
-    if (msg.bufferLength != 0x74) {
+    if (msg->bufferLength != 0x74) {
         VIR_DEBUG("Expecting length %zu got %zu",
-                  (size_t)0x74, msg.bufferLength);
-        return -1;
+                  (size_t)0x74, msg->bufferLength);
+        goto cleanup;
     }
 
-    if (virNetMessageDecodeHeader(&msg) < 0) {
+    memcpy(msg->buffer, input_buffer, msg->bufferLength);
+
+    if (virNetMessageDecodeHeader(msg) < 0) {
         VIR_DEBUG("Failed to decode message header");
-        return -1;
+        goto cleanup;
     }
 
-    if (msg.bufferOffset != 28) {
+    if (msg->bufferOffset != 28) {
         VIR_DEBUG("Expect message offset %zu got %zu",
-                  msg.bufferOffset, (size_t)28);
-        return -1;
+                  msg->bufferOffset, (size_t)28);
+        goto cleanup;
     }
 
-    if (msg.bufferLength != 0x74) {
+    if (msg->bufferLength != 0x74) {
         VIR_DEBUG("Expecting length %zu got %zu",
-                  (size_t)0x1c, msg.bufferLength);
-        return -1;
+                  (size_t)0x1c, msg->bufferLength);
+        goto cleanup;
     }
 
-    if (virNetMessageDecodePayload(&msg, (xdrproc_t)xdr_virNetMessageError, &err) < 0) {
+    if (virNetMessageDecodePayload(msg, (xdrproc_t)xdr_virNetMessageError, &err) < 0) {
         VIR_DEBUG("Failed to decode message payload");
-        return -1;
+        goto cleanup;
     }
 
     if (err.code != VIR_ERR_INTERNAL_ERROR) {
         VIR_DEBUG("Expect code %d got %d",
                   VIR_ERR_INTERNAL_ERROR, err.code);
-        return -1;
+        goto cleanup;
     }
 
     if (err.domain != VIR_FROM_RPC) {
         VIR_DEBUG("Expect domain %d got %d",
                   VIR_ERR_RPC, err.domain);
-        return -1;
+        goto cleanup;
     }
 
     if (err.message == NULL ||
         STRNEQ(*err.message, "Hello World")) {
         VIR_DEBUG("Expect str1 'Hello World' got %s",
                   err.message ? *err.message : "(null)");
-        return -1;
+        goto cleanup;
     }
 
     if (err.dom != NULL) {
         VIR_DEBUG("Expect NULL dom");
-        return -1;
+        goto cleanup;
     }
 
     if (err.level != VIR_ERR_ERROR) {
         VIR_DEBUG("Expect leve %d got %d",
                   VIR_ERR_ERROR, err.level);
-        return -1;
+        goto cleanup;
     }
 
     if (err.str1 == NULL ||
         STRNEQ(*err.str1, "One")) {
         VIR_DEBUG("Expect str1 'One' got %s",
                   err.str1 ? *err.str1 : "(null)");
-        return -1;
+        goto cleanup;
     }
 
     if (err.str2 == NULL ||
         STRNEQ(*err.str2, "Two")) {
         VIR_DEBUG("Expect str3 'Two' got %s",
                   err.str2 ? *err.str2 : "(null)");
-        return -1;
+        goto cleanup;
     }
 
     if (err.str3 == NULL ||
         STRNEQ(*err.str3, "Three")) {
         VIR_DEBUG("Expect str3 'Three' got %s",
                   err.str3 ? *err.str3 : "(null)");
-        return -1;
+        goto cleanup;
     }
 
     if (err.int1 != 1) {
         VIR_DEBUG("Expect int1 1 got %d",
                   err.int1);
-        return -1;
+        goto cleanup;
     }
 
     if (err.int2 != 2) {
         VIR_DEBUG("Expect int2 2 got %d",
                   err.int2);
-        return -1;
+        goto cleanup;
     }
 
     if (err.net != NULL) {
         VIR_DEBUG("Expect NULL network");
-        return -1;
+        goto cleanup;
     }
 
+    ret = 0;
+cleanup:
     xdr_free((xdrproc_t)xdr_virNetMessageError, (void*)&err);
-    return 0;
+    virNetMessageFree(msg);
+    return ret;
 }
 
 static int testMessagePayloadStreamEncode(const void *args ATTRIBUTE_UNUSED)
 {
     char stream[] = "The quick brown fox jumps over the lazy dog";
-    static virNetMessage msg;
+    virNetMessagePtr msg = virNetMessageNew(true);
     static const char expect[] = {
         0x00, 0x00, 0x00, 0x47,  /* Length */
         0x11, 0x22, 0x33, 0x44,  /* Program */
@@ -448,39 +488,42 @@ static int testMessagePayloadStreamEncode(const void *args ATTRIBUTE_UNUSED)
         'a', 'z', 'y', ' ',
         'd', 'o', 'g',
     };
-    memset(&msg, 0, sizeof(msg));
+    int ret = -1;
 
-    msg.header.prog = 0x11223344;
-    msg.header.vers = 0x01;
-    msg.header.proc = 0x666;
-    msg.header.type = VIR_NET_STREAM;
-    msg.header.serial = 0x99;
-    msg.header.status = VIR_NET_CONTINUE;
+    msg->header.prog = 0x11223344;
+    msg->header.vers = 0x01;
+    msg->header.proc = 0x666;
+    msg->header.type = VIR_NET_STREAM;
+    msg->header.serial = 0x99;
+    msg->header.status = VIR_NET_CONTINUE;
 
-    if (virNetMessageEncodeHeader(&msg) < 0)
-        return -1;
+    if (virNetMessageEncodeHeader(msg) < 0)
+        goto cleanup;
 
-    if (virNetMessageEncodePayloadRaw(&msg, stream, strlen(stream)) < 0)
-        return -1;
+    if (virNetMessageEncodePayloadRaw(msg, stream, strlen(stream)) < 0)
+        goto cleanup;
 
-    if (ARRAY_CARDINALITY(expect) != msg.bufferLength) {
+    if (ARRAY_CARDINALITY(expect) != msg->bufferLength) {
         VIR_DEBUG("Expect message length %zu got %zu",
-                  sizeof(expect), msg.bufferLength);
-        return -1;
+                  sizeof(expect), msg->bufferLength);
+        goto cleanup;
     }
 
-    if (msg.bufferOffset != 0) {
+    if (msg->bufferOffset != 0) {
         VIR_DEBUG("Expect message offset 0 got %zu",
-                  msg.bufferOffset);
-        return -1;
+                  msg->bufferOffset);
+        goto cleanup;
     }
 
-    if (memcmp(expect, msg.buffer, sizeof(expect)) != 0) {
-        virtTestDifferenceBin(stderr, expect, msg.buffer, sizeof(expect));
-        return -1;
+    if (memcmp(expect, msg->buffer, sizeof(expect)) != 0) {
+        virtTestDifferenceBin(stderr, expect, msg->buffer, sizeof(expect));
+        goto cleanup;
     }
 
-    return 0;
+    ret = 0;
+cleanup:
+    virNetMessageFree(msg);
+    return ret;
 }
 
 
-- 
1.7.8.5




More information about the libvir-list mailing list