virSocketAddr remoteAddr;
char *localAddrStr;
char *remoteAddrStr;
+
+ virNetTLSSessionPtr tlsSession;
+#if HAVE_SASL
+ virNetSASLSessionPtr saslSession;
+
+ const char *saslDecoded;
+ size_t saslDecodedLength;
+ size_t saslDecodedOffset;
+
+ const char *saslEncoded;
+ size_t saslEncodedLength;
+ size_t saslEncodedOffset;
+#endif
};
}
-#if HAVE_SYS_UN_H
+#ifdef HAVE_SYS_UN_H
int virNetSocketNewConnectUNIX(const char *path,
bool spawnDaemon,
const char *binary,
unlink(sock->localAddr.data.un.sun_path);
#endif
+ /* Make sure it can't send any more I/O during shutdown */
+ if (sock->tlsSession)
+ virNetTLSSessionSetIOCallbacks(sock->tlsSession, NULL, NULL, NULL);
+ virNetTLSSessionFree(sock->tlsSession);
+#if HAVE_SASL
+ virNetSASLSessionFree(sock->saslSession);
+#endif
+
VIR_FORCE_CLOSE(sock->fd);
VIR_FORCE_CLOSE(sock->errfd);
return sock->remoteAddrStr;
}
-ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len)
+
+static ssize_t virNetSocketTLSSessionWrite(const char *buf,
+ size_t len,
+ void *opaque)
+{
+ virNetSocketPtr sock = opaque;
+ return write(sock->fd, buf, len);
+}
+
+
+static ssize_t virNetSocketTLSSessionRead(char *buf,
+ size_t len,
+ void *opaque)
+{
+ virNetSocketPtr sock = opaque;
+ return read(sock->fd, buf, len);
+}
+
+
+void virNetSocketSetTLSSession(virNetSocketPtr sock,
+ virNetTLSSessionPtr sess)
+{
+ virNetTLSSessionFree(sock->tlsSession);
+ sock->tlsSession = sess;
+ virNetTLSSessionSetIOCallbacks(sess,
+ virNetSocketTLSSessionWrite,
+ virNetSocketTLSSessionRead,
+ sock);
+ virNetTLSSessionRef(sess);
+}
+
+
+#if HAVE_SASL
+void virNetSocketSetSASLSession(virNetSocketPtr sock,
+ virNetSASLSessionPtr sess)
+{
+ virNetSASLSessionFree(sock->saslSession);
+ sock->saslSession = sess;
+ virNetSASLSessionRef(sess);
+}
+#endif
+
+
+bool virNetSocketHasCachedData(virNetSocketPtr sock ATTRIBUTE_UNUSED)
+{
+#if HAVE_SASL
+ if (sock->saslDecoded)
+ return true;
+#endif
+ return false;
+}
+
+
+static ssize_t virNetSocketReadWire(virNetSocketPtr sock, char *buf, size_t len)
{
char *errout = NULL;
ssize_t ret;
reread:
- ret = read(sock->fd, buf, len);
+ if (sock->tlsSession &&
+ virNetTLSSessionGetHandshakeStatus(sock->tlsSession) ==
+ VIR_NET_TLS_HANDSHAKE_COMPLETE) {
+ ret = virNetTLSSessionRead(sock->tlsSession, buf, len);
+ } else {
+ ret = read(sock->fd, buf, len);
+ }
if ((ret < 0) && (errno == EINTR))
goto reread;
if ((ret < 0) && (errno == EAGAIN))
return 0;
+
if (ret <= 0 &&
sock->errfd != -1 &&
virFileReadLimFD(sock->errfd, 1024, &errout) >= 0 &&
return ret;
}
-ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len)
+static ssize_t virNetSocketWriteWire(virNetSocketPtr sock, const char *buf, size_t len)
{
ssize_t ret;
rewrite:
- ret = write(sock->fd, buf, len);
+ if (sock->tlsSession &&
+ virNetTLSSessionGetHandshakeStatus(sock->tlsSession) ==
+ VIR_NET_TLS_HANDSHAKE_COMPLETE) {
+ ret = virNetTLSSessionWrite(sock->tlsSession, buf, len);
+ } else {
+ ret = write(sock->fd, buf, len);
+ }
if (ret < 0) {
if (errno == EINTR)
}
+#if HAVE_SASL
+static ssize_t virNetSocketReadSASL(virNetSocketPtr sock, char *buf, size_t len)
+{
+ ssize_t got;
+
+ /* Need to read some more data off the wire */
+ if (sock->saslDecoded == NULL) {
+ ssize_t encodedLen = virNetSASLSessionGetMaxBufSize(sock->saslSession);
+ char *encoded;
+ if (VIR_ALLOC_N(encoded, encodedLen) < 0) {
+ virReportOOMError();
+ return -1;
+ }
+ encodedLen = virNetSocketReadWire(sock, encoded, encodedLen);
+
+ if (encodedLen <= 0) {
+ VIR_FREE(encoded);
+ return encodedLen;
+ }
+
+ if (virNetSASLSessionDecode(sock->saslSession,
+ encoded, encodedLen,
+ &sock->saslDecoded, &sock->saslDecodedLength) < 0) {
+ VIR_FREE(encoded);
+ return -1;
+ }
+ VIR_FREE(encoded);
+
+ sock->saslDecodedOffset = 0;
+ }
+
+ /* Some buffered decoded data to return now */
+ got = sock->saslDecodedLength - sock->saslDecodedOffset;
+
+ if (len > got)
+ len = got;
+
+ memcpy(buf, sock->saslDecoded + sock->saslDecodedOffset, len);
+ sock->saslDecodedOffset += len;
+
+ if (sock->saslDecodedOffset == sock->saslDecodedLength) {
+ sock->saslDecoded = NULL;
+ sock->saslDecodedOffset = sock->saslDecodedLength = 0;
+ }
+
+ return len;
+}
+
+
+static ssize_t virNetSocketWriteSASL(virNetSocketPtr sock, const char *buf, size_t len)
+{
+ int ret;
+ size_t tosend = virNetSASLSessionGetMaxBufSize(sock->saslSession);
+
+ /* SASL doesn't necessarily let us send the whole
+ buffer at once */
+ if (tosend > len)
+ tosend = len;
+
+ /* Not got any pending encoded data, so we need to encode raw stuff */
+ if (sock->saslEncoded == NULL) {
+ if (virNetSASLSessionEncode(sock->saslSession,
+ buf, tosend,
+ &sock->saslEncoded,
+ &sock->saslEncodedLength) < 0)
+ return -1;
+
+ sock->saslEncodedOffset = 0;
+ }
+
+ /* Send some of the encoded stuff out on the wire */
+ ret = virNetSocketWriteWire(sock,
+ sock->saslEncoded + sock->saslEncodedOffset,
+ sock->saslEncodedLength - sock->saslEncodedOffset);
+
+ if (ret <= 0)
+ return ret; /* -1 error, 0 == egain */
+
+ /* Note how much we sent */
+ sock->saslEncodedOffset += ret;
+
+ /* Sent all encoded, so update raw buffer to indicate completion */
+ if (sock->saslEncodedOffset == sock->saslEncodedLength) {
+ sock->saslEncoded = NULL;
+ sock->saslEncodedOffset = sock->saslEncodedLength = 0;
+
+ /* Mark as complete, so caller detects completion */
+ return tosend;
+ } else {
+ /* Still have stuff pending in saslEncoded buffer.
+ * Pretend to caller that we didn't send any yet.
+ * The caller will then retry with same buffer
+ * shortly, which lets us finish saslEncoded.
+ */
+ return 0;
+ }
+}
+#endif
+
+
+ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len)
+{
+#if HAVE_SASL
+ if (sock->saslSession)
+ return virNetSocketReadSASL(sock, buf, len);
+ else
+#endif
+ return virNetSocketReadWire(sock, buf, len);
+}
+
+ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len)
+{
+#if HAVE_SASL
+ if (sock->saslSession)
+ return virNetSocketWriteSASL(sock, buf, len);
+ else
+#endif
+ return virNetSocketWriteWire(sock, buf, len);
+}
+
+
int virNetSocketListen(virNetSocketPtr sock)
{
if (listen(sock->fd, 30) < 0) {