#include "virterror_internal.h"
#include "logging.h"
#include "event.h"
+#include "threads.h"
#define VIR_FROM_THIS VIR_FROM_RPC
#define virNetError(code, ...) \
__FUNCTION__, __LINE__, __VA_ARGS__)
struct _virNetClientStream {
+ virMutex lock;
+
virNetClientProgramPtr prog;
int proc;
unsigned serial;
size_t incomingOffset;
size_t incomingLength;
-
virNetClientStreamEventCallback cb;
void *cbOpaque;
virFreeCallback cbFree;
virNetClientStreamPtr st = opaque;
int events = 0;
- /* XXX we need a mutex on 'st' to protect this callback */
+
+ virMutexLock(&st->lock);
if (st->cb &&
(st->cbEvents & VIR_STREAM_EVENT_READABLE) &&
virFreeCallback cbFree = st->cbFree;
st->cbDispatch = 1;
+ virMutexUnlock(&st->lock);
(cb)(st, events, cbOpaque);
+ virMutexLock(&st->lock);
st->cbDispatch = 0;
if (!st->cb && cbFree)
(cbFree)(cbOpaque);
}
+ virMutexUnlock(&st->lock);
}
return NULL;
}
- virNetClientProgramRef(prog);
-
st->refs = 1;
st->prog = prog;
st->proc = proc;
st->serial = serial;
+ if (virMutexInit(&st->lock) < 0) {
+ virNetError(VIR_ERR_INTERNAL_ERROR, "%s",
+ _("cannot initialize mutex"));
+ VIR_FREE(st);
+ return NULL;
+ }
+
+ virNetClientProgramRef(prog);
+
return st;
}
void virNetClientStreamRef(virNetClientStreamPtr st)
{
+ virMutexLock(&st->lock);
st->refs++;
+ virMutexUnlock(&st->lock);
}
void virNetClientStreamFree(virNetClientStreamPtr st)
{
+ virMutexLock(&st->lock);
st->refs--;
- if (st->refs > 0)
+ if (st->refs > 0) {
+ virMutexUnlock(&st->lock);
return;
+ }
+
+ virMutexUnlock(&st->lock);
virResetError(&st->err);
VIR_FREE(st->incoming);
+ virMutexDestroy(&st->lock);
virNetClientProgramFree(st->prog);
VIR_FREE(st);
}
bool virNetClientStreamMatches(virNetClientStreamPtr st,
virNetMessagePtr msg)
{
+ bool match = false;
+ virMutexLock(&st->lock);
if (virNetClientProgramMatches(st->prog, msg) &&
st->proc == msg->header.proc &&
st->serial == msg->header.serial)
- return 1;
- return 0;
+ match = true;
+ virMutexUnlock(&st->lock);
+ return match;
}
bool virNetClientStreamRaiseError(virNetClientStreamPtr st)
{
- if (st->err.code == VIR_ERR_OK)
+ virMutexLock(&st->lock);
+ if (st->err.code == VIR_ERR_OK) {
+ virMutexUnlock(&st->lock);
return false;
+ }
virRaiseErrorFull(__FILE__, __FUNCTION__, __LINE__,
st->err.domain,
st->err.int1,
st->err.int2,
"%s", st->err.message ? st->err.message : _("Unknown error"));
-
+ virMutexUnlock(&st->lock);
return true;
}
virNetMessageError err;
int ret = -1;
+ virMutexLock(&st->lock);
+
if (st->err.code != VIR_ERR_OK)
VIR_DEBUG("Overwriting existing stream error %s", NULLSTR(st->err.message));
cleanup:
xdr_free((xdrproc_t)xdr_virNetMessageError, (void*)&err);
+ virMutexUnlock(&st->lock);
return ret;
}
int virNetClientStreamQueuePacket(virNetClientStreamPtr st,
virNetMessagePtr msg)
{
- size_t avail = st->incomingLength - st->incomingOffset;
- size_t need = msg->bufferLength - msg->bufferOffset;
+ int ret = -1;
+ size_t need;
+ virMutexLock(&st->lock);
+ need = msg->bufferLength - msg->bufferOffset;
+ size_t avail = st->incomingLength - st->incomingOffset;
if (need > avail) {
size_t extra = need - avail;
if (VIR_REALLOC_N(st->incoming,
st->incomingLength + extra) < 0) {
VIR_DEBUG("Out of memory handling stream data");
- return -1;
+ goto cleanup;
}
st->incomingLength += extra;
}
VIR_DEBUG("Stream incoming data offset %zu length %zu",
st->incomingOffset, st->incomingLength);
- return 0;
+
+ ret = 0;
+
+cleanup:
+ virMutexUnlock(&st->lock);
+ return ret;
}
if (!(msg = virNetMessageNew()))
return -1;
+ virMutexLock(&st->lock);
+
msg->header.prog = virNetClientProgramGetProgram(st->prog);
msg->header.vers = virNetClientProgramGetVersion(st->prog);
msg->header.status = status;
msg->header.serial = st->serial;
msg->header.proc = st->proc;
+ virMutexUnlock(&st->lock);
+
if (virNetMessageEncodeHeader(msg) < 0)
goto error;
int rv = -1;
VIR_DEBUG("st=%p client=%p data=%p nbytes=%zu nonblock=%d",
st, client, data, nbytes, nonblock);
+ virMutexLock(&st->lock);
if (!st->incomingOffset) {
virNetMessagePtr msg;
int ret;
msg->header.proc = st->proc;
VIR_DEBUG("Dummy packet to wait for stream data");
+ virMutexUnlock(&st->lock);
ret = virNetClientSend(client, msg, true);
-
+ virMutexLock(&st->lock);
virNetMessageFree(msg);
if (ret < 0)
virNetClientStreamEventTimerUpdate(st);
cleanup:
+ virMutexUnlock(&st->lock);
return rv;
}
void *opaque,
virFreeCallback ff)
{
+ int ret = -1;
+
+ virMutexLock(&st->lock);
if (st->cb) {
virNetError(VIR_ERR_INTERNAL_ERROR,
"%s", _("multiple stream callbacks not supported"));
- return 1;
+ goto cleanup;
}
- virNetClientStreamRef(st);
+ st->refs++;
if ((st->cbTimer =
virEventAddTimeout(-1,
virNetClientStreamEventTimer,
st,
virNetClientStreamEventTimerFree)) < 0) {
- virNetClientStreamFree(st);
- return -1;
+ st->refs--;
+ goto cleanup;
}
st->cb = cb;
virNetClientStreamEventTimerUpdate(st);
- return 0;
+ ret = 0;
+
+cleanup:
+ virMutexUnlock(&st->lock);
+ return ret;
}
int virNetClientStreamEventUpdateCallback(virNetClientStreamPtr st,
int events)
{
+ int ret = -1;
+
+ virMutexLock(&st->lock);
if (!st->cb) {
virNetError(VIR_ERR_INTERNAL_ERROR,
"%s", _("no stream callback registered"));
- return -1;
+ goto cleanup;
}
st->cbEvents = events;
virNetClientStreamEventTimerUpdate(st);
- return 0;
+ ret = 0;
+
+cleanup:
+ virMutexUnlock(&st->lock);
+ return ret;
}
int virNetClientStreamEventRemoveCallback(virNetClientStreamPtr st)
{
+ int ret = -1;
+
+ virMutexUnlock(&st->lock);
if (!st->cb) {
virNetError(VIR_ERR_INTERNAL_ERROR,
"%s", _("no stream callback registered"));
- return -1;
+ goto cleanup;
}
if (!st->cbDispatch &&
st->cbEvents = 0;
virEventRemoveTimeout(st->cbTimer);
- return 0;
+ ret = 0;
+
+cleanup:
+ virMutexUnlock(&st->lock);
+ return ret;
}