]> xenbits.xensource.com Git - libvirt.git/commitdiff
Integrate TLS/SASL directly into the socket APIs
authorDaniel P. Berrange <berrange@redhat.com>
Fri, 10 Dec 2010 12:22:03 +0000 (12:22 +0000)
committerDaniel P. Berrange <berrange@redhat.com>
Fri, 24 Jun 2011 10:48:30 +0000 (11:48 +0100)
This extends the basic virNetSocket APIs to allow them to have
a handle to the TLS/SASL session objects, once established.
This ensures that any data reads/writes are automagically
passed through the TLS/SASL encryption layers if required.

* src/rpc/virnetsocket.c, src/rpc/virnetsocket.h: Wire up
  SASL/TLS encryption

src/rpc/virnetsocket.c
src/rpc/virnetsocket.h

index a8c029631723c515fea4cf5d12f3e2f3c07e9c13..055fca268b646d97a007a612a54eedca36e22b26 100644 (file)
@@ -59,6 +59,19 @@ struct _virNetSocket {
     virSocketAddr remoteAddr;
     char *localAddrStr;
     char *remoteAddrStr;
+
+    virNetTLSSessionPtr tlsSession;
+#if HAVE_SASL
+    virNetSASLSessionPtr saslSession;
+
+    const char *saslDecoded;
+    size_t saslDecodedLength;
+    size_t saslDecodedOffset;
+
+    const char *saslEncoded;
+    size_t saslEncodedLength;
+    size_t saslEncodedOffset;
+#endif
 };
 
 
@@ -417,7 +430,7 @@ error:
 }
 
 
-#if HAVE_SYS_UN_H
+#ifdef HAVE_SYS_UN_H
 int virNetSocketNewConnectUNIX(const char *path,
                                bool spawnDaemon,
                                const char *binary,
@@ -624,6 +637,14 @@ void virNetSocketFree(virNetSocketPtr sock)
         unlink(sock->localAddr.data.un.sun_path);
 #endif
 
+    /* Make sure it can't send any more I/O during shutdown */
+    if (sock->tlsSession)
+        virNetTLSSessionSetIOCallbacks(sock->tlsSession, NULL, NULL, NULL);
+    virNetTLSSessionFree(sock->tlsSession);
+#if HAVE_SASL
+    virNetSASLSessionFree(sock->saslSession);
+#endif
+
     VIR_FORCE_CLOSE(sock->fd);
     VIR_FORCE_CLOSE(sock->errfd);
 
@@ -709,17 +730,77 @@ const char *virNetSocketRemoteAddrString(virNetSocketPtr sock)
     return sock->remoteAddrStr;
 }
 
-ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len)
+
+static ssize_t virNetSocketTLSSessionWrite(const char *buf,
+                                           size_t len,
+                                           void *opaque)
+{
+    virNetSocketPtr sock = opaque;
+    return write(sock->fd, buf, len);
+}
+
+
+static ssize_t virNetSocketTLSSessionRead(char *buf,
+                                          size_t len,
+                                          void *opaque)
+{
+    virNetSocketPtr sock = opaque;
+    return read(sock->fd, buf, len);
+}
+
+
+void virNetSocketSetTLSSession(virNetSocketPtr sock,
+                               virNetTLSSessionPtr sess)
+{
+    virNetTLSSessionFree(sock->tlsSession);
+    sock->tlsSession = sess;
+    virNetTLSSessionSetIOCallbacks(sess,
+                                   virNetSocketTLSSessionWrite,
+                                   virNetSocketTLSSessionRead,
+                                   sock);
+    virNetTLSSessionRef(sess);
+}
+
+
+#if HAVE_SASL
+void virNetSocketSetSASLSession(virNetSocketPtr sock,
+                                virNetSASLSessionPtr sess)
+{
+    virNetSASLSessionFree(sock->saslSession);
+    sock->saslSession = sess;
+    virNetSASLSessionRef(sess);
+}
+#endif
+
+
+bool virNetSocketHasCachedData(virNetSocketPtr sock ATTRIBUTE_UNUSED)
+{
+#if HAVE_SASL
+    if (sock->saslDecoded)
+        return true;
+#endif
+    return false;
+}
+
+
+static ssize_t virNetSocketReadWire(virNetSocketPtr sock, char *buf, size_t len)
 {
     char *errout = NULL;
     ssize_t ret;
 reread:
-    ret = read(sock->fd, buf, len);
+    if (sock->tlsSession &&
+        virNetTLSSessionGetHandshakeStatus(sock->tlsSession) ==
+        VIR_NET_TLS_HANDSHAKE_COMPLETE) {
+        ret = virNetTLSSessionRead(sock->tlsSession, buf, len);
+    } else {
+        ret = read(sock->fd, buf, len);
+    }
 
     if ((ret < 0) && (errno == EINTR))
         goto reread;
     if ((ret < 0) && (errno == EAGAIN))
         return 0;
+
     if (ret <= 0 &&
         sock->errfd != -1 &&
         virFileReadLimFD(sock->errfd, 1024, &errout) >= 0 &&
@@ -751,11 +832,17 @@ reread:
     return ret;
 }
 
-ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len)
+static ssize_t virNetSocketWriteWire(virNetSocketPtr sock, const char *buf, size_t len)
 {
     ssize_t ret;
 rewrite:
-    ret = write(sock->fd, buf, len);
+    if (sock->tlsSession &&
+        virNetTLSSessionGetHandshakeStatus(sock->tlsSession) ==
+        VIR_NET_TLS_HANDSHAKE_COMPLETE) {
+        ret = virNetTLSSessionWrite(sock->tlsSession, buf, len);
+    } else {
+        ret = write(sock->fd, buf, len);
+    }
 
     if (ret < 0) {
         if (errno == EINTR)
@@ -777,6 +864,127 @@ rewrite:
 }
 
 
+#if HAVE_SASL
+static ssize_t virNetSocketReadSASL(virNetSocketPtr sock, char *buf, size_t len)
+{
+    ssize_t got;
+
+    /* Need to read some more data off the wire */
+    if (sock->saslDecoded == NULL) {
+        ssize_t encodedLen = virNetSASLSessionGetMaxBufSize(sock->saslSession);
+        char *encoded;
+        if (VIR_ALLOC_N(encoded, encodedLen) < 0) {
+            virReportOOMError();
+            return -1;
+        }
+        encodedLen = virNetSocketReadWire(sock, encoded, encodedLen);
+
+        if (encodedLen <= 0) {
+            VIR_FREE(encoded);
+            return encodedLen;
+        }
+
+        if (virNetSASLSessionDecode(sock->saslSession,
+                                    encoded, encodedLen,
+                                    &sock->saslDecoded, &sock->saslDecodedLength) < 0) {
+            VIR_FREE(encoded);
+            return -1;
+        }
+        VIR_FREE(encoded);
+
+        sock->saslDecodedOffset = 0;
+    }
+
+    /* Some buffered decoded data to return now */
+    got = sock->saslDecodedLength - sock->saslDecodedOffset;
+
+    if (len > got)
+        len = got;
+
+    memcpy(buf, sock->saslDecoded + sock->saslDecodedOffset, len);
+    sock->saslDecodedOffset += len;
+
+    if (sock->saslDecodedOffset == sock->saslDecodedLength) {
+        sock->saslDecoded = NULL;
+        sock->saslDecodedOffset = sock->saslDecodedLength = 0;
+    }
+
+    return len;
+}
+
+
+static ssize_t virNetSocketWriteSASL(virNetSocketPtr sock, const char *buf, size_t len)
+{
+    int ret;
+    size_t tosend = virNetSASLSessionGetMaxBufSize(sock->saslSession);
+
+    /* SASL doesn't necessarily let us send the whole
+       buffer at once */
+    if (tosend > len)
+        tosend = len;
+
+    /* Not got any pending encoded data, so we need to encode raw stuff */
+    if (sock->saslEncoded == NULL) {
+        if (virNetSASLSessionEncode(sock->saslSession,
+                                    buf, tosend,
+                                    &sock->saslEncoded,
+                                    &sock->saslEncodedLength) < 0)
+            return -1;
+
+        sock->saslEncodedOffset = 0;
+    }
+
+    /* Send some of the encoded stuff out on the wire */
+    ret = virNetSocketWriteWire(sock,
+                                sock->saslEncoded + sock->saslEncodedOffset,
+                                sock->saslEncodedLength - sock->saslEncodedOffset);
+
+    if (ret <= 0)
+        return ret; /* -1 error, 0 == egain */
+
+    /* Note how much we sent */
+    sock->saslEncodedOffset += ret;
+
+    /* Sent all encoded, so update raw buffer to indicate completion */
+    if (sock->saslEncodedOffset == sock->saslEncodedLength) {
+        sock->saslEncoded = NULL;
+        sock->saslEncodedOffset = sock->saslEncodedLength = 0;
+
+        /* Mark as complete, so caller detects completion */
+        return tosend;
+    } else {
+        /* Still have stuff pending in saslEncoded buffer.
+         * Pretend to caller that we didn't send any yet.
+         * The caller will then retry with same buffer
+         * shortly, which lets us finish saslEncoded.
+         */
+        return 0;
+    }
+}
+#endif
+
+
+ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len)
+{
+#if HAVE_SASL
+    if (sock->saslSession)
+        return virNetSocketReadSASL(sock, buf, len);
+    else
+#endif
+        return virNetSocketReadWire(sock, buf, len);
+}
+
+ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len)
+{
+#if HAVE_SASL
+    if (sock->saslSession)
+        return virNetSocketWriteSASL(sock, buf, len);
+    else
+#endif
+        return virNetSocketWriteWire(sock, buf, len);
+}
+
+
 int virNetSocketListen(virNetSocketPtr sock)
 {
     if (listen(sock->fd, 30) < 0) {
index 218fe8f16feda3dd415960b957609234007316ec..59ff28824ff6c369b825b7dd62c4e53212fec1f3 100644 (file)
 
 # include "network.h"
 # include "command.h"
+# include "virnettlscontext.h"
+# ifdef HAVE_SASL
+#  include "virnetsaslcontext.h"
+# endif
 
 typedef struct _virNetSocket virNetSocket;
 typedef virNetSocket *virNetSocketPtr;
@@ -83,6 +87,13 @@ int virNetSocketSetBlocking(virNetSocketPtr sock,
 ssize_t virNetSocketRead(virNetSocketPtr sock, char *buf, size_t len);
 ssize_t virNetSocketWrite(virNetSocketPtr sock, const char *buf, size_t len);
 
+void virNetSocketSetTLSSession(virNetSocketPtr sock,
+                               virNetTLSSessionPtr sess);
+# ifdef HAVE_SASL
+void virNetSocketSetSASLSession(virNetSocketPtr sock,
+                                virNetSASLSessionPtr sess);
+# endif
+bool virNetSocketHasCachedData(virNetSocketPtr sock);
 void virNetSocketFree(virNetSocketPtr sock);
 
 const char *virNetSocketLocalAddrString(virNetSocketPtr sock);