[libvirt] [go PATCH 22/37] stream: fix error reporting thread safety

Daniel P. Berrangé berrange at redhat.com
Mon Jul 16 13:24:08 UTC 2018


Create wrapper functions for each stream C API that accepts a
virErrorPtr parameter. This avoids accessing a thread local from a
goroutine which may race with other goroutines doing native API calls in
the same OS thread.

Signed-off-by: Daniel P. Berrangé <berrange at redhat.com>
---
 stream.go         |  80 +++++++++-------
 stream_wrapper.go | 231 +++++++++++++++++++++++++++++++++++++++-------
 stream_wrapper.h  |  84 +++++++++++++----
 3 files changed, 312 insertions(+), 83 deletions(-)

diff --git a/stream.go b/stream.go
index c5c1ef7..515ae08 100644
--- a/stream.go
+++ b/stream.go
@@ -64,9 +64,10 @@ type Stream struct {
 
 // See also https://libvirt.org/html/libvirt-libvirt-stream.html#virStreamAbort
 func (v *Stream) Abort() error {
-	result := C.virStreamAbort(v.ptr)
+	var err C.virError
+	result := C.virStreamAbortWrapper(v.ptr, &err)
 	if result == -1 {
-		return GetLastError()
+		return makeError(&err)
 	}
 
 	return nil
@@ -74,9 +75,10 @@ func (v *Stream) Abort() error {
 
 // See also https://libvirt.org/html/libvirt-libvirt-stream.html#virStreamFinish
 func (v *Stream) Finish() error {
-	result := C.virStreamFinish(v.ptr)
+	var err C.virError
+	result := C.virStreamFinishWrapper(v.ptr, &err)
 	if result == -1 {
-		return GetLastError()
+		return makeError(&err)
 	}
 
 	return nil
@@ -84,27 +86,30 @@ func (v *Stream) Finish() error {
 
 // See also https://libvirt.org/html/libvirt-libvirt-stream.html#virStreamFree
 func (v *Stream) Free() error {
-	ret := C.virStreamFree(v.ptr)
+	var err C.virError
+	ret := C.virStreamFreeWrapper(v.ptr, &err)
 	if ret == -1 {
-		return GetLastError()
+		return makeError(&err)
 	}
 	return nil
 }
 
 // See also https://libvirt.org/html/libvirt-libvirt-stream.html#virStreamRef
 func (c *Stream) Ref() error {
-	ret := C.virStreamRef(c.ptr)
+	var err C.virError
+	ret := C.virStreamRefWrapper(c.ptr, &err)
 	if ret == -1 {
-		return GetLastError()
+		return makeError(&err)
 	}
 	return nil
 }
 
 // See also https://libvirt.org/html/libvirt-libvirt-stream.html#virStreamRecv
 func (v *Stream) Recv(p []byte) (int, error) {
-	n := C.virStreamRecv(v.ptr, (*C.char)(unsafe.Pointer(&p[0])), C.size_t(len(p)))
+	var err C.virError
+	n := C.virStreamRecvWrapper(v.ptr, (*C.char)(unsafe.Pointer(&p[0])), C.size_t(len(p)), &err)
 	if n < 0 {
-		return 0, GetLastError()
+		return 0, makeError(&err)
 	}
 	if n == 0 {
 		return 0, io.EOF
@@ -119,9 +124,10 @@ func (v *Stream) RecvFlags(p []byte, flags StreamRecvFlagsValues) (int, error) {
 		return 0, GetNotImplementedError("virStreamRecvFlags")
 	}
 
-	n := C.virStreamRecvFlagsWrapper(v.ptr, (*C.char)(unsafe.Pointer(&p[0])), C.size_t(len(p)), C.uint(flags))
+	var err C.virError
+	n := C.virStreamRecvFlagsWrapper(v.ptr, (*C.char)(unsafe.Pointer(&p[0])), C.size_t(len(p)), C.uint(flags), &err)
 	if n < 0 {
-		return 0, GetLastError()
+		return 0, makeError(&err)
 	}
 	if n == 0 {
 		return 0, io.EOF
@@ -137,9 +143,10 @@ func (v *Stream) RecvHole(flags uint) (int64, error) {
 	}
 
 	var len C.longlong
-	ret := C.virStreamRecvHoleWrapper(v.ptr, &len, C.uint(flags))
+	var err C.virError
+	ret := C.virStreamRecvHoleWrapper(v.ptr, &len, C.uint(flags), &err)
 	if ret < 0 {
-		return 0, GetLastError()
+		return 0, makeError(&err)
 	}
 
 	return int64(len), nil
@@ -147,9 +154,10 @@ func (v *Stream) RecvHole(flags uint) (int64, error) {
 
 // See also https://libvirt.org/html/libvirt-libvirt-stream.html#virStreamSend
 func (v *Stream) Send(p []byte) (int, error) {
-	n := C.virStreamSend(v.ptr, (*C.char)(unsafe.Pointer(&p[0])), C.size_t(len(p)))
+	var err C.virError
+	n := C.virStreamSendWrapper(v.ptr, (*C.char)(unsafe.Pointer(&p[0])), C.size_t(len(p)), &err)
 	if n < 0 {
-		return 0, GetLastError()
+		return 0, makeError(&err)
 	}
 	if n == 0 {
 		return 0, io.EOF
@@ -164,9 +172,10 @@ func (v *Stream) SendHole(len int64, flags uint32) error {
 		return GetNotImplementedError("virStreamSendHole")
 	}
 
-	ret := C.virStreamSendHoleWrapper(v.ptr, C.longlong(len), C.uint(flags))
+	var err C.virError
+	ret := C.virStreamSendHoleWrapper(v.ptr, C.longlong(len), C.uint(flags), &err)
 	if ret < 0 {
-		return GetLastError()
+		return makeError(&err)
 	}
 
 	return nil
@@ -220,10 +229,11 @@ func (v *Stream) RecvAll(handler StreamSinkFunc) error {
 
 	callbackID := registerCallbackId(handler)
 
-	ret := C.virStreamRecvAllWrapper(v.ptr, (C.int)(callbackID))
+	var err C.virError
+	ret := C.virStreamRecvAllWrapper(v.ptr, (C.int)(callbackID), &err)
 	freeCallbackId(callbackID)
 	if ret == -1 {
-		return GetLastError()
+		return makeError(&err)
 	}
 
 	return nil
@@ -238,11 +248,12 @@ func (v *Stream) SparseRecvAll(handler StreamSinkFunc, holeHandler StreamSinkHol
 	callbackID := registerCallbackId(handler)
 	holeCallbackID := registerCallbackId(holeHandler)
 
-	ret := C.virStreamSparseRecvAllWrapper(v.ptr, (C.int)(callbackID), (C.int)(holeCallbackID))
+	var err C.virError
+	ret := C.virStreamSparseRecvAllWrapper(v.ptr, (C.int)(callbackID), (C.int)(holeCallbackID), &err)
 	freeCallbackId(callbackID)
 	freeCallbackId(holeCallbackID)
 	if ret == -1 {
-		return GetLastError()
+		return makeError(&err)
 	}
 
 	return nil
@@ -325,10 +336,11 @@ func (v *Stream) SendAll(handler StreamSourceFunc) error {
 
 	callbackID := registerCallbackId(handler)
 
-	ret := C.virStreamSendAllWrapper(v.ptr, (C.int)(callbackID))
+	var err C.virError
+	ret := C.virStreamSendAllWrapper(v.ptr, (C.int)(callbackID), &err)
 	freeCallbackId(callbackID)
 	if ret == -1 {
-		return GetLastError()
+		return makeError(&err)
 	}
 
 	return nil
@@ -344,12 +356,13 @@ func (v *Stream) SparseSendAll(handler StreamSourceFunc, holeHandler StreamSourc
 	holeCallbackID := registerCallbackId(holeHandler)
 	skipCallbackID := registerCallbackId(skipHandler)
 
-	ret := C.virStreamSparseSendAllWrapper(v.ptr, (C.int)(callbackID), (C.int)(holeCallbackID), (C.int)(skipCallbackID))
+	var err C.virError
+	ret := C.virStreamSparseSendAllWrapper(v.ptr, (C.int)(callbackID), (C.int)(holeCallbackID), (C.int)(skipCallbackID), &err)
 	freeCallbackId(callbackID)
 	freeCallbackId(holeCallbackID)
 	freeCallbackId(skipCallbackID)
 	if ret == -1 {
-		return GetLastError()
+		return makeError(&err)
 	}
 
 	return nil
@@ -361,9 +374,10 @@ type StreamEventCallback func(*Stream, StreamEventType)
 func (v *Stream) EventAddCallback(events StreamEventType, callback StreamEventCallback) error {
 	callbackID := registerCallbackId(callback)
 
-	ret := C.virStreamEventAddCallbackWrapper(v.ptr, (C.int)(events), (C.int)(callbackID))
+	var err C.virError
+	ret := C.virStreamEventAddCallbackWrapper(v.ptr, (C.int)(events), (C.int)(callbackID), &err)
 	if ret == -1 {
-		return GetLastError()
+		return makeError(&err)
 	}
 
 	return nil
@@ -383,9 +397,10 @@ func streamEventCallback(st C.virStreamPtr, events int, callbackID int) {
 
 // See also https://libvirt.org/html/libvirt-libvirt-stream.html#virStreamEventUpdateCallback
 func (v *Stream) EventUpdateCallback(events StreamEventType) error {
-	ret := C.virStreamEventUpdateCallback(v.ptr, (C.int)(events))
+	var err C.virError
+	ret := C.virStreamEventUpdateCallbackWrapper(v.ptr, (C.int)(events), &err)
 	if ret == -1 {
-		return GetLastError()
+		return makeError(&err)
 	}
 
 	return nil
@@ -393,9 +408,10 @@ func (v *Stream) EventUpdateCallback(events StreamEventType) error {
 
 // See also https://libvirt.org/html/libvirt-libvirt-stream.html#virStreamEventRemoveCallback
 func (v *Stream) EventRemoveCallback() error {
-	ret := C.virStreamEventRemoveCallback(v.ptr)
+	var err C.virError
+	ret := C.virStreamEventRemoveCallbackWrapper(v.ptr, &err)
 	if ret == -1 {
-		return GetLastError()
+		return makeError(&err)
 	}
 
 	return nil
diff --git a/stream_wrapper.go b/stream_wrapper.go
index 90cc110..e563a74 100644
--- a/stream_wrapper.go
+++ b/stream_wrapper.go
@@ -81,84 +81,251 @@ static int streamSinkHoleCallbackHelper(virStreamPtr st, long long length, void
     return streamSinkHoleCallback(st, length, cbdata->holeCallbackID);
 }
 
-int virStreamSendAllWrapper(virStreamPtr st, int callbackID)
+int
+virStreamAbortWrapper(virStreamPtr stream,
+                      virErrorPtr err)
+{
+    int ret = virStreamAbort(stream);
+    if (ret < 0) {
+        virCopyLastError(err);
+    }
+    return ret;
+}
+
+
+void
+streamEventCallback(virStreamPtr st, int events, int callbackID);
+
+static void
+streamEventCallbackHelper(virStreamPtr st, int events, void *opaque)
+{
+    streamEventCallback(st, events, (int)(intptr_t)opaque);
+}
+
+int
+virStreamEventAddCallbackWrapper(virStreamPtr stream,
+                                 int events,
+                                 int callbackID,
+                                 virErrorPtr err)
+{
+    int ret = virStreamEventAddCallback(stream, events, streamEventCallbackHelper, (void *)(intptr_t)callbackID, NULL);
+    if (ret < 0) {
+        virCopyLastError(err);
+    }
+    return ret;
+}
+
+
+int
+virStreamEventRemoveCallbackWrapper(virStreamPtr stream,
+                                    virErrorPtr err)
+{
+    int ret = virStreamEventRemoveCallback(stream);
+    if (ret < 0) {
+        virCopyLastError(err);
+    }
+    return ret;
+}
+
+
+int
+virStreamEventUpdateCallbackWrapper(virStreamPtr stream,
+                                    int events,
+                                    virErrorPtr err)
+{
+    int ret = virStreamEventUpdateCallback(stream, events);
+    if (ret < 0) {
+        virCopyLastError(err);
+    }
+    return ret;
+}
+
+
+int
+virStreamFinishWrapper(virStreamPtr stream,
+                       virErrorPtr err)
+{
+    int ret = virStreamFinish(stream);
+    if (ret < 0) {
+        virCopyLastError(err);
+    }
+    return ret;
+}
+
+
+int
+virStreamFreeWrapper(virStreamPtr stream,
+                     virErrorPtr err)
+{
+    int ret = virStreamFree(stream);
+    if (ret < 0) {
+        virCopyLastError(err);
+    }
+    return ret;
+}
+
+
+int
+virStreamRecvWrapper(virStreamPtr stream,
+                     char *data,
+                     size_t nbytes,
+                     virErrorPtr err)
+{
+    int ret = virStreamRecv(stream, data, nbytes);
+    if (ret < 0) {
+        virCopyLastError(err);
+    }
+    return ret;
+}
+
+
+int
+virStreamRecvAllWrapper(virStreamPtr stream,
+                        int callbackID,
+                        virErrorPtr err)
 {
     struct CallbackData cbdata = { .callbackID = callbackID };
-    return virStreamSendAll(st, streamSourceCallbackHelper, &cbdata);
+    int ret = virStreamRecvAll(stream, streamSinkCallbackHelper, &cbdata);
+    if (ret < 0) {
+        virCopyLastError(err);
+    }
+    return ret;
 }
 
-int virStreamSparseSendAllWrapper(virStreamPtr st, int callbackID, int holeCallbackID, int skipCallbackID)
+
+int
+virStreamRecvFlagsWrapper(virStreamPtr stream,
+                          char *data,
+                          size_t nbytes,
+                          unsigned int flags,
+                          virErrorPtr err)
 {
-    struct CallbackData cbdata = { .callbackID = callbackID, .holeCallbackID = holeCallbackID, .skipCallbackID = skipCallbackID };
 #if LIBVIR_VERSION_NUMBER < 3004000
     assert(0); // Caller should have checked version
 #else
-    return virStreamSparseSendAll(st, streamSourceCallbackHelper, streamSourceHoleCallbackHelper, streamSourceSkipCallbackHelper, &cbdata);
+    int ret = virStreamRecvFlags(stream, data, nbytes, flags);
+    if (ret < 0) {
+        virCopyLastError(err);
+    }
+    return ret;
 #endif
 }
 
 
-int virStreamRecvAllWrapper(virStreamPtr st, int callbackID)
-{
-    struct CallbackData cbdata = { .callbackID = callbackID };
-    return virStreamRecvAll(st, streamSinkCallbackHelper, &cbdata);
-}
-
-int virStreamSparseRecvAllWrapper(virStreamPtr st, int callbackID, int holeCallbackID)
+int
+virStreamRecvHoleWrapper(virStreamPtr stream,
+                         long long *length,
+                         unsigned int flags,
+                         virErrorPtr err)
 {
-    struct CallbackData cbdata = { .callbackID = callbackID, .holeCallbackID = holeCallbackID };
 #if LIBVIR_VERSION_NUMBER < 3004000
     assert(0); // Caller should have checked version
 #else
-    return virStreamSparseRecvAll(st, streamSinkCallbackHelper, streamSinkHoleCallbackHelper, &cbdata);
+    int ret = virStreamRecvHole(stream, length, flags);
+    if (ret < 0) {
+        virCopyLastError(err);
+    }
+    return ret;
 #endif
 }
 
-void streamEventCallback(virStreamPtr st, int events, int callbackID);
 
-static void streamEventCallbackHelper(virStreamPtr st, int events, void *opaque)
+int
+virStreamRefWrapper(virStreamPtr stream,
+                    virErrorPtr err)
 {
-    streamEventCallback(st, events, (int)(intptr_t)opaque);
+    int ret = virStreamRef(stream);
+    if (ret < 0) {
+        virCopyLastError(err);
+    }
+    return ret;
 }
 
-int virStreamEventAddCallbackWrapper(virStreamPtr st, int events, int callbackID)
+
+int
+virStreamSendWrapper(virStreamPtr stream,
+                     const char *data,
+                     size_t nbytes,
+                     virErrorPtr err)
 {
-    return virStreamEventAddCallback(st, events, streamEventCallbackHelper, (void *)(intptr_t)callbackID, NULL);
+    int ret = virStreamSend(stream, data, nbytes);
+    if (ret < 0) {
+        virCopyLastError(err);
+    }
+    return ret;
 }
 
-int virStreamRecvFlagsWrapper(virStreamPtr st,
-			     char *data,
-			     size_t nbytes,
-			     unsigned int flags)
+
+int
+virStreamSendAllWrapper(virStreamPtr stream,
+                        int callbackID,
+                        virErrorPtr err)
+{
+    struct CallbackData cbdata = { .callbackID = callbackID };
+    int ret = virStreamSendAll(stream, streamSourceCallbackHelper, &cbdata);
+    if (ret < 0) {
+        virCopyLastError(err);
+    }
+    return ret;
+}
+
+
+int
+virStreamSendHoleWrapper(virStreamPtr stream,
+                         long long length,
+                         unsigned int flags,
+                         virErrorPtr err)
 {
 #if LIBVIR_VERSION_NUMBER < 3004000
     assert(0); // Caller should have checked version
 #else
-    return virStreamRecvFlags(st, data, nbytes, flags);
+    int ret = virStreamSendHole(stream, length, flags);
+    if (ret < 0) {
+        virCopyLastError(err);
+    }
+    return ret;
 #endif
 }
 
-int virStreamSendHoleWrapper(virStreamPtr st,
-			    long long length,
-			    unsigned int flags)
+
+int
+virStreamSparseRecvAllWrapper(virStreamPtr stream,
+                              int callbackID,
+                              int holeCallbackID,
+                              virErrorPtr err)
 {
+    struct CallbackData cbdata = { .callbackID = callbackID, .holeCallbackID = holeCallbackID };
 #if LIBVIR_VERSION_NUMBER < 3004000
     assert(0); // Caller should have checked version
 #else
-    return virStreamSendHole(st, length, flags);
+    int ret = virStreamSparseRecvAll(stream, streamSinkCallbackHelper, streamSinkHoleCallbackHelper, &cbdata);
+    if (ret < 0) {
+        virCopyLastError(err);
+    }
+    return ret;
 #endif
 }
 
-int virStreamRecvHoleWrapper(virStreamPtr st,
-			    long long *length,
-			    unsigned int flags)
+
+int
+virStreamSparseSendAllWrapper(virStreamPtr stream,
+                              int callbackID,
+                              int holeCallbackID,
+                              int skipCallbackID,
+                              virErrorPtr err)
 {
+    struct CallbackData cbdata = { .callbackID = callbackID, .holeCallbackID = holeCallbackID, .skipCallbackID = skipCallbackID };
 #if LIBVIR_VERSION_NUMBER < 3004000
     assert(0); // Caller should have checked version
 #else
-    return virStreamRecvHole(st, length, flags);
+    int ret = virStreamSparseSendAll(stream, streamSourceCallbackHelper, streamSourceHoleCallbackHelper, streamSourceSkipCallbackHelper, &cbdata);
+    if (ret < 0) {
+        virCopyLastError(err);
+    }
+    return ret;
 #endif
 }
 
+
 */
 import "C"
diff --git a/stream_wrapper.h b/stream_wrapper.h
index 5484441..b9c6e57 100644
--- a/stream_wrapper.h
+++ b/stream_wrapper.h
@@ -32,43 +32,89 @@
 #include "stream_compat.h"
 
 int
-virStreamSendAllWrapper(virStreamPtr st,
-                        int callbackID);
+virStreamAbortWrapper(virStreamPtr stream,
+                      virErrorPtr err);
 
 int
-virStreamRecvAllWrapper(virStreamPtr st,
-                        int callbackID);
+virStreamEventAddCallbackWrapper(virStreamPtr st,
+                                 int events,
+                                 int callbackID,
+                                 virErrorPtr err);
 
 int
-virStreamSparseSendAllWrapper(virStreamPtr st,
-                              int callbackID,
-                              int holeCallbackID,
-                              int skipCallbackID);
+virStreamEventRemoveCallbackWrapper(virStreamPtr stream,
+                                    virErrorPtr err);
 
 int
-virStreamSparseRecvAllWrapper(virStreamPtr st,
-                              int callbackID,
-                              int holeCallbackID);
+virStreamEventUpdateCallbackWrapper(virStreamPtr stream,
+                                    int events,
+                                    virErrorPtr err);
 
 int
-virStreamEventAddCallbackWrapper(virStreamPtr st,
-                                 int events,
-                                 int callbackID);
+virStreamFinishWrapper(virStreamPtr stream,
+                       virErrorPtr err);
+
+int
+virStreamFreeWrapper(virStreamPtr stream,
+                     virErrorPtr err);
+
+int
+virStreamRecvWrapper(virStreamPtr stream,
+                     char *data,
+                     size_t nbytes,
+                     virErrorPtr err);
+
+int
+virStreamRecvAllWrapper(virStreamPtr st,
+                        int callbackID,
+                        virErrorPtr err);
 
 int
 virStreamRecvFlagsWrapper(virStreamPtr st,
                           char *data,
                           size_t nbytes,
-                          unsigned int flags);
+                          unsigned int flags,
+                          virErrorPtr err);
+
+int
+virStreamRecvHoleWrapper(virStreamPtr,
+                         long long *length,
+                         unsigned int flags,
+                         virErrorPtr err);
+
+int
+virStreamRefWrapper(virStreamPtr stream,
+                    virErrorPtr err);
+
+int
+virStreamSendWrapper(virStreamPtr stream,
+                     const char *data,
+                     size_t nbytes,
+                     virErrorPtr err);
+
+int
+virStreamSendAllWrapper(virStreamPtr st,
+                        int callbackID,
+                        virErrorPtr err);
 
 int
 virStreamSendHoleWrapper(virStreamPtr st,
                          long long length,
-                         unsigned int flags);
+                         unsigned int flags,
+                         virErrorPtr err);
 
 int
-virStreamRecvHoleWrapper(virStreamPtr,
-                         long long *length,
-                         unsigned int flags);
+virStreamSparseRecvAllWrapper(virStreamPtr st,
+                              int callbackID,
+                              int holeCallbackID,
+                              virErrorPtr err);
+
+int
+virStreamSparseSendAllWrapper(virStreamPtr st,
+                              int callbackID,
+                              int holeCallbackID,
+                              int skipCallbackID,
+                              virErrorPtr err);
+
 
 #endif /* LIBVIRT_GO_STREAM_WRAPPER_H__ */
-- 
2.17.1




More information about the libvir-list mailing list