]> xenbits.xensource.com Git - libvirt.git/commitdiff
Add data encryption using SASL SSF layer
authorDaniel P. Berrange <berrange@redhat.com>
Wed, 5 Dec 2007 15:27:08 +0000 (15:27 +0000)
committerDaniel P. Berrange <berrange@redhat.com>
Wed, 5 Dec 2007 15:27:08 +0000 (15:27 +0000)
ChangeLog
qemud/internal.h
qemud/qemud.c
qemud/remote.c
src/remote_internal.c

index 455e615ca4b2860283a882dd42bf00de519fa492..c38b87d09ba920889f4584bea30ca6689b5790e8 100644 (file)
--- a/ChangeLog
+++ b/ChangeLog
@@ -1,3 +1,9 @@
+Wed Dec  5 10:25:00 EST 2007 Daniel P. Berrange <berrange@redhat.com>
+
+       * src/remote_internal.c, qemud/qemud.c, qemud/internal.h,
+       qemud/remote.c: Add support for SASL SSF layer providing
+       data encryption of the connection post-authentication.
+
 Wed Dec  5 10:20:00 EST 2007 Daniel P. Berrange <berrange@redhat.com>
 
        * configure.in: Add checks for SASL library
index 7bb83ef493ff37fee5e2f77f75b55e4416e51e9b..48fbcc352b836eafe0ae5e3fc29a144126dadc45 100644 (file)
@@ -73,10 +73,17 @@ enum qemud_mode {
     QEMUD_MODE_TLS_HANDSHAKE,
 };
 
-/* These have to remain compatible with gnutls_record_get_direction. */
-enum qemud_tls_direction {
-    QEMUD_TLS_DIRECTION_READ = 0,
-    QEMUD_TLS_DIRECTION_WRITE = 1,
+/* Whether we're passing reads & writes through a sasl SSF */
+enum qemud_sasl_ssf {
+    QEMUD_SASL_SSF_NONE = 0,
+    QEMUD_SASL_SSF_READ = 1,
+    QEMUD_SASL_SSF_WRITE = 2,
+};
+
+enum qemud_sock_type {
+    QEMUD_SOCK_TYPE_UNIX = 0,
+    QEMUD_SOCK_TYPE_TCP = 1,
+    QEMUD_SOCK_TYPE_TLS = 2,
 };
 
 /* Stores the per-client connection state */
@@ -90,13 +97,18 @@ struct qemud_client {
     struct sockaddr_storage addr;
     socklen_t addrlen;
 
-    /* If set, TLS is required on this socket. */
-    int tls;
-    gnutls_session_t session;
-    enum qemud_tls_direction direction;
+    int type; /* qemud_sock_type */
+    gnutls_session_t tlssession;
     int auth;
 #if HAVE_SASL
     sasl_conn_t *saslconn;
+    int saslSSF;
+    const char *saslDecoded;
+    unsigned int saslDecodedLength;
+    unsigned int saslDecodedOffset;
+    const char *saslEncoded;
+    unsigned int saslEncodedLength;
+    unsigned int saslEncodedOffset;
 #endif
 
     unsigned int incomingSerial;
@@ -121,8 +133,7 @@ struct qemud_client {
 struct qemud_socket {
     int fd;
     int readonly;
-    /* If set, TLS is required on this socket. */
-    int tls;
+    int type; /* qemud_sock_type */
     int auth;
     int port;
     struct qemud_socket *next;
index 273352faf68bf1be7c1535db9992f48936594f47..840b29761b7469d9d8740b46afb2ae8214a785d5 100644 (file)
@@ -461,6 +461,7 @@ static int qemudListenUnix(struct qemud_server *server,
 
     sock->readonly = readonly;
     sock->port = -1;
+    sock->type = QEMUD_SOCK_TYPE_UNIX;
 
     if ((sock->fd = socket(PF_UNIX, SOCK_STREAM, 0)) < 0) {
         qemudLog(QEMUD_ERR, "Failed to create socket: %s",
@@ -575,7 +576,7 @@ remoteMakeSockets (int *fds, int max_fds, int *nfds_r, const char *service)
 static int
 remoteListenTCP (struct qemud_server *server,
                  const char *port,
-                 int tls,
+                 int type,
                  int auth)
 {
     int fds[2];
@@ -604,7 +605,7 @@ remoteListenTCP (struct qemud_server *server,
         server->nsockets++;
 
         sock->fd = fds[i];
-        sock->tls = tls;
+        sock->type = type;
         sock->auth = auth;
 
         if (getsockname(sock->fd, (struct sockaddr *)(&sa), &salen) < 0)
@@ -743,10 +744,10 @@ static struct qemud_server *qemudInitialize(int sigread) {
 
     if (ipsock) {
 #if HAVE_SASL
-        if (listen_tcp && remoteListenTCP (server, tcp_port, 0, REMOTE_AUTH_SASL) < 0)
+        if (listen_tcp && remoteListenTCP (server, tcp_port, QEMUD_SOCK_TYPE_TCP, REMOTE_AUTH_SASL) < 0)
             goto cleanup;
 #else
-        if (listen_tcp && remoteListenTCP (server, tcp_port, 0, REMOTE_AUTH_NONE) < 0)
+        if (listen_tcp && remoteListenTCP (server, tcp_port, QEMUD_SOCK_TYPE_TCP, REMOTE_AUTH_NONE) < 0)
             goto cleanup;
 #endif
 
@@ -754,7 +755,7 @@ static struct qemud_server *qemudInitialize(int sigread) {
             if (remoteInitializeGnuTLS () < 0)
                 goto cleanup;
 
-            if (remoteListenTCP (server, tls_port, 1, REMOTE_AUTH_NONE) < 0)
+            if (remoteListenTCP (server, tls_port, QEMUD_SOCK_TYPE_TLS, REMOTE_AUTH_NONE) < 0)
                 goto cleanup;
         }
     }
@@ -787,7 +788,7 @@ static struct qemud_server *qemudInitialize(int sigread) {
          */
         sock = server->sockets;
         while (sock) {
-            if (sock->port != -1 && sock->tls) {
+            if (sock->port != -1 && sock->type == QEMUD_SOCK_TYPE_TLS) {
                 port = sock->port;
                 break;
             }
@@ -979,7 +980,7 @@ remoteCheckAccess (struct qemud_client *client)
     int found, err;
 
     /* Verify client certificate. */
-    if (remoteCheckCertificate (client->session) == -1) {
+    if (remoteCheckCertificate (client->tlssession) == -1) {
         qemudLog (QEMUD_ERR, "remoteCheckCertificate: failed to verify client's certificate");
         if (!tls_no_verify_certificate) return -1;
         else qemudLog (QEMUD_INFO, "remoteCheckCertificate: tls_no_verify_certificate is set so the bad certificate is ignored");
@@ -1031,7 +1032,6 @@ remoteCheckAccess (struct qemud_client *client)
     client->bufferOffset = 0;
     client->buffer[0] = '\1';
     client->mode = QEMUD_MODE_TX_PACKET;
-    client->direction = QEMUD_TLS_DIRECTION_WRITE;
     return 0;
 }
 
@@ -1065,12 +1065,12 @@ static int qemudDispatchServer(struct qemud_server *server, struct qemud_socket
     client->magic = QEMUD_CLIENT_MAGIC;
     client->fd = fd;
     client->readonly = sock->readonly;
-    client->tls = sock->tls;
+    client->type = sock->type;
     client->auth = sock->auth;
     memcpy (&client->addr, &addr, sizeof addr);
     client->addrlen = addrlen;
 
-    if (!client->tls) {
+    if (client->type != QEMUD_SOCK_TYPE_TLS) {
         client->mode = QEMUD_MODE_RX_HEADER;
         client->bufferLength = REMOTE_MESSAGE_HEADER_XDR_LEN;
 
@@ -1079,15 +1079,15 @@ static int qemudDispatchServer(struct qemud_server *server, struct qemud_socket
     } else {
         int ret;
 
-        client->session = remoteInitializeTLSSession ();
-        if (client->session == NULL)
+        client->tlssession = remoteInitializeTLSSession ();
+        if (client->tlssession == NULL)
             goto cleanup;
 
-        gnutls_transport_set_ptr (client->session,
+        gnutls_transport_set_ptr (client->tlssession,
                                   (gnutls_transport_ptr_t) (long) fd);
 
         /* Begin the TLS handshake. */
-        ret = gnutls_handshake (client->session);
+        ret = gnutls_handshake (client->tlssession);
         if (ret == 0) {
             /* Unlikely, but ...  Next step is to check the certificate. */
             if (remoteCheckAccess (client) == -1)
@@ -1099,7 +1099,6 @@ static int qemudDispatchServer(struct qemud_server *server, struct qemud_socket
             /* Most likely. */
             client->mode = QEMUD_MODE_TLS_HANDSHAKE;
             client->bufferLength = -1;
-            client->direction = gnutls_record_get_direction (client->session);
 
             if (qemudRegisterClientEvent (server, client, 0) < 0)
                 goto cleanup;
@@ -1117,7 +1116,7 @@ static int qemudDispatchServer(struct qemud_server *server, struct qemud_socket
     return 0;
 
  cleanup:
-    if (client->session) gnutls_deinit (client->session);
+    if (client->tlssession) gnutls_deinit (client->tlssession);
     close (fd);
     free (client);
     return -1;
@@ -1150,24 +1149,21 @@ static void qemudDispatchClientFailure(struct qemud_server *server, struct qemud
 #if HAVE_SASL
     if (client->saslconn) sasl_dispose(&client->saslconn);
 #endif
-    if (client->tls && client->session) gnutls_deinit (client->session);
+    if (client->tlssession) gnutls_deinit (client->tlssession);
     close(client->fd);
     free(client);
 }
 
 
 
-static int qemudClientRead(struct qemud_server *server,
-                           struct qemud_client *client) {
-    int ret, len;
-    char *data;
-
-    data = client->buffer + client->bufferOffset;
-    len = client->bufferLength - client->bufferOffset;
+static int qemudClientReadBuf(struct qemud_server *server,
+                              struct qemud_client *client,
+                              char *data, unsigned len) {
+    int ret;
 
     /*qemudDebug ("qemudClientRead: len = %d", len);*/
 
-    if (!client->tls) {
+    if (!client->tlssession) {
         if ((ret = read (client->fd, data, len)) <= 0) {
             if (ret == 0 || errno != EAGAIN) {
                 if (ret != 0)
@@ -1177,8 +1173,7 @@ static int qemudClientRead(struct qemud_server *server,
             return -1;
         }
     } else {
-        ret = gnutls_record_recv (client->session, data, len);
-        client->direction = gnutls_record_get_direction (client->session);
+        ret = gnutls_record_recv (client->tlssession, data, len);
         if (qemudRegisterClientEvent (server, client, 1) < 0)
             qemudDispatchClientFailure (server, client);
         else if (ret <= 0) {
@@ -1193,10 +1188,80 @@ static int qemudClientRead(struct qemud_server *server,
         }
     }
 
+    return ret;
+}
+
+static int qemudClientReadPlain(struct qemud_server *server,
+                                struct qemud_client *client) {
+    int ret;
+    ret = qemudClientReadBuf(server, client,
+                             client->buffer + client->bufferOffset,
+                             client->bufferLength - client->bufferOffset);
+    if (ret < 0)
+        return ret;
     client->bufferOffset += ret;
     return 0;
 }
 
+#if HAVE_SASL
+static int qemudClientReadSASL(struct qemud_server *server,
+                               struct qemud_client *client) {
+    int got, want;
+
+    /* We're doing a SSF data read, so now its times to ensure
+     * future writes are under SSF too.
+     *
+     * cf remoteSASLCheckSSF in remote.c
+     */
+    client->saslSSF |= QEMUD_SASL_SSF_WRITE;
+
+    /* Need to read some more data off the wire */
+    if (client->saslDecoded == NULL) {
+        char encoded[8192];
+        int encodedLen = sizeof(encoded);
+        encodedLen = qemudClientReadBuf(server, client, encoded, encodedLen);
+
+        if (encodedLen < 0)
+            return -1;
+
+        sasl_decode(client->saslconn, encoded, encodedLen,
+                    &client->saslDecoded, &client->saslDecodedLength);
+
+        client->saslDecodedOffset = 0;
+    }
+
+    /* Some buffered decoded data to return now */
+    got = client->saslDecodedLength - client->saslDecodedOffset;
+    want = client->bufferLength - client->bufferOffset;
+
+    if (want > got)
+        want = got;
+
+    memcpy(client->buffer + client->bufferOffset,
+           client->saslDecoded + client->saslDecodedOffset, want);
+    client->saslDecodedOffset += want;
+    client->bufferOffset += want;
+
+    if (client->saslDecodedOffset == client->saslDecodedLength) {
+        client->saslDecoded = NULL;
+        client->saslDecodedOffset = client->saslDecodedLength = 0;
+    }
+
+    return 0;
+}
+#endif
+
+static int qemudClientRead(struct qemud_server *server,
+                           struct qemud_client *client) {
+#if HAVE_SASL
+    if (client->saslSSF & QEMUD_SASL_SSF_READ)
+        return qemudClientReadSASL(server, client);
+    else
+#endif
+        return qemudClientReadPlain(server, client);
+}
+
+
 static void qemudDispatchClientRead(struct qemud_server *server, struct qemud_client *client) {
 
     /*qemudDebug ("qemudDispatchClientRead: mode = %d", client->mode);*/
@@ -1239,7 +1304,6 @@ static void qemudDispatchClientRead(struct qemud_server *server, struct qemud_cl
         client->mode = QEMUD_MODE_RX_PAYLOAD;
         client->bufferLength = len - REMOTE_MESSAGE_HEADER_XDR_LEN;
         client->bufferOffset = 0;
-        if (client->tls) client->direction = QEMUD_TLS_DIRECTION_READ;
 
         if (qemudRegisterClientEvent(server, client, 1) < 0) {
             qemudDispatchClientFailure(server, client);
@@ -1267,7 +1331,7 @@ static void qemudDispatchClientRead(struct qemud_server *server, struct qemud_cl
         int ret;
 
         /* Continue the handshake. */
-        ret = gnutls_handshake (client->session);
+        ret = gnutls_handshake (client->tlssession);
         if (ret == 0) {
             /* Finished.  Next step is to check the certificate. */
             if (remoteCheckAccess (client) == -1)
@@ -1279,7 +1343,6 @@ static void qemudDispatchClientRead(struct qemud_server *server, struct qemud_cl
                       gnutls_strerror (ret));
             qemudDispatchClientFailure (server, client);
         } else {
-            client->direction = gnutls_record_get_direction (client->session);
             if (qemudRegisterClientEvent (server ,client, 1) < 0)
                 qemudDispatchClientFailure (server, client);
         }
@@ -1294,15 +1357,11 @@ static void qemudDispatchClientRead(struct qemud_server *server, struct qemud_cl
 }
 
 
-static int qemudClientWrite(struct qemud_server *server,
-                            struct qemud_client *client) {
-    int ret, len;
-    char *data;
-
-    data = client->buffer + client->bufferOffset;
-    len = client->bufferLength - client->bufferOffset;
-
-    if (!client->tls) {
+static int qemudClientWriteBuf(struct qemud_server *server,
+                               struct qemud_client *client,
+                               const char *data, int len) {
+    int ret;
+    if (!client->tlssession) {
         if ((ret = write(client->fd, data, len)) == -1) {
             if (errno != EAGAIN) {
                 qemudLog (QEMUD_ERR, "write: %s", strerror (errno));
@@ -1311,8 +1370,7 @@ static int qemudClientWrite(struct qemud_server *server,
             return -1;
         }
     } else {
-        ret = gnutls_record_send (client->session, data, len);
-        client->direction = gnutls_record_get_direction (client->session);
+        ret = gnutls_record_send (client->tlssession, data, len);
         if (qemudRegisterClientEvent (server, client, 1) < 0)
             qemudDispatchClientFailure (server, client);
         else if (ret < 0) {
@@ -1324,12 +1382,72 @@ static int qemudClientWrite(struct qemud_server *server,
             return -1;
         }
     }
+    return ret;
+}
+
 
+static int qemudClientWritePlain(struct qemud_server *server,
+                                 struct qemud_client *client) {
+    int ret = qemudClientWriteBuf(server, client,
+                                  client->buffer + client->bufferOffset,
+                                  client->bufferLength - client->bufferOffset);
+    if (ret < 0)
+        return -1;
     client->bufferOffset += ret;
     return 0;
 }
 
 
+#if HAVE_SASL
+static int qemudClientWriteSASL(struct qemud_server *server,
+                                struct qemud_client *client) {
+    int ret;
+
+    /* Not got any pending encoded data, so we need to encode raw stuff */
+    if (client->saslEncoded == NULL) {
+        int err;
+        err = sasl_encode(client->saslconn,
+                          client->buffer + client->bufferOffset,
+                          client->bufferLength - client->bufferOffset,
+                          &client->saslEncoded,
+                          &client->saslEncodedLength);
+
+        client->saslEncodedOffset = 0;
+    }
+
+    /* Send some of the encoded stuff out on the wire */
+    ret = qemudClientWriteBuf(server, client,
+                              client->saslEncoded + client->saslEncodedOffset,
+                              client->saslEncodedLength - client->saslEncodedOffset);
+
+    if (ret < 0)
+        return -1;
+
+    /* Note how much we sent */
+    client->saslEncodedOffset += ret;
+
+    /* Sent all encoded, so update raw buffer to indicate completion */
+    if (client->saslEncodedOffset == client->saslEncodedLength) {
+        client->saslEncoded = NULL;
+        client->saslEncodedOffset = client->saslEncodedLength = 0;
+        client->bufferOffset = client->bufferLength;
+    }
+
+    return 0;
+}
+#endif
+
+static int qemudClientWrite(struct qemud_server *server,
+                            struct qemud_client *client) {
+#if HAVE_SASL
+    if (client->saslSSF & QEMUD_SASL_SSF_WRITE)
+        return qemudClientWriteSASL(server, client);
+    else
+#endif
+        return qemudClientWritePlain(server, client);
+}
+
+
 static void qemudDispatchClientWrite(struct qemud_server *server, struct qemud_client *client) {
     switch (client->mode) {
     case QEMUD_MODE_TX_PACKET: {
@@ -1341,7 +1459,6 @@ static void qemudDispatchClientWrite(struct qemud_server *server, struct qemud_c
             client->mode = QEMUD_MODE_RX_HEADER;
             client->bufferLength = REMOTE_MESSAGE_HEADER_XDR_LEN;
             client->bufferOffset = 0;
-            if (client->tls) client->direction = QEMUD_TLS_DIRECTION_READ;
 
             if (qemudRegisterClientEvent (server, client, 1) < 0)
                 qemudDispatchClientFailure (server, client);
@@ -1354,7 +1471,7 @@ static void qemudDispatchClientWrite(struct qemud_server *server, struct qemud_c
         int ret;
 
         /* Continue the handshake. */
-        ret = gnutls_handshake (client->session);
+        ret = gnutls_handshake (client->tlssession);
         if (ret == 0) {
             /* Finished.  Next step is to check the certificate. */
             if (remoteCheckAccess (client) == -1)
@@ -1366,7 +1483,6 @@ static void qemudDispatchClientWrite(struct qemud_server *server, struct qemud_c
                       gnutls_strerror (ret));
             qemudDispatchClientFailure (server, client);
         } else {
-            client->direction = gnutls_record_get_direction (client->session);
             if (qemudRegisterClientEvent (server, client, 1))
                 qemudDispatchClientFailure (server, client);
         }
@@ -1406,25 +1522,37 @@ static void qemudDispatchClientEvent(int fd, int events, void *opaque) {
 static int qemudRegisterClientEvent(struct qemud_server *server,
                                     struct qemud_client *client,
                                     int removeFirst) {
+    int mode;
+    switch (client->mode) {
+    case QEMUD_MODE_TLS_HANDSHAKE:
+        if (gnutls_record_get_direction (client->tlssession) == 0)
+            mode = POLLIN;
+        else
+            mode = POLLOUT;
+        break;
+
+    case QEMUD_MODE_RX_HEADER:
+    case QEMUD_MODE_RX_PAYLOAD:
+        mode = POLLIN;
+        break;
+
+    case QEMUD_MODE_TX_PACKET:
+        mode = POLLOUT;
+        break;
+
+    default:
+        return -1;
+    }
+
     if (removeFirst)
         if (virEventRemoveHandleImpl(client->fd) < 0)
             return -1;
 
-    if (client->tls) {
-        if (virEventAddHandleImpl(client->fd,
-                                  (client->direction ?
-                                   POLLOUT : POLLIN) | POLLERR | POLLHUP,
-                                  qemudDispatchClientEvent,
-                                  server) < 0)
+    if (virEventAddHandleImpl(client->fd,
+                              mode | POLLERR | POLLHUP,
+                              qemudDispatchClientEvent,
+                              server) < 0)
             return -1;
-    } else {
-        if (virEventAddHandleImpl(client->fd,
-                                  (client->mode == QEMUD_MODE_TX_PACKET ?
-                                   POLLOUT : POLLIN) | POLLERR | POLLHUP,
-                                  qemudDispatchClientEvent,
-                                  server) < 0)
-            return -1;
-    }
 
     return 0;
 }
index be256425bb03db212d29b09eaa05dca72ad40d70..ea0bb9509b58fcc85d36ca23802220e44be42cd2 100644 (file)
@@ -284,7 +284,6 @@ remoteDispatchClientRequest (struct qemud_server *server ATTRIBUTE_UNUSED,
     client->mode = QEMUD_MODE_TX_PACKET;
     client->bufferLength = len;
     client->bufferOffset = 0;
-    if (client->tls) client->direction = QEMUD_TLS_DIRECTION_WRITE;
 }
 
 /* An error occurred during the dispatching process itself (ie. not
@@ -369,7 +368,6 @@ remoteDispatchSendError (struct qemud_client *client,
     client->mode = QEMUD_MODE_TX_PACKET;
     client->bufferLength = len;
     client->bufferOffset = 0;
-    if (client->tls) client->direction = QEMUD_TLS_DIRECTION_WRITE;
 }
 
 static void
@@ -2042,6 +2040,7 @@ remoteDispatchAuthSaslInit (struct qemud_client *client,
                             remote_auth_sasl_init_ret *ret)
 {
     const char *mechlist = NULL;
+    sasl_security_properties_t secprops;
     int err;
     struct sockaddr_storage sa;
     socklen_t salen;
@@ -2097,6 +2096,60 @@ remoteDispatchAuthSaslInit (struct qemud_client *client,
         return -2;
     }
 
+    /* Inform SASL that we've got an external SSF layer from TLS */
+    if (client->type == QEMUD_SOCK_TYPE_TLS) {
+        gnutls_cipher_algorithm_t cipher;
+        sasl_ssf_t ssf;
+
+        cipher = gnutls_cipher_get(client->tlssession);
+        if (!(ssf = (sasl_ssf_t)gnutls_cipher_get_key_size(cipher))) {
+            qemudLog(QEMUD_ERR, "cannot TLS get cipher size");
+            remoteDispatchFailAuth(client, req);
+            sasl_dispose(&client->saslconn);
+            client->saslconn = NULL;
+            return -2;
+        }
+        ssf *= 8; /* tls key size is bytes, sasl wants bits */
+
+        err = sasl_setprop(client->saslconn, SASL_SSF_EXTERNAL, &ssf);
+        if (err != SASL_OK) {
+            qemudLog(QEMUD_ERR, "cannot set SASL external SSF %d (%s)",
+                     err, sasl_errstring(err, NULL, NULL));
+            remoteDispatchFailAuth(client, req);
+            sasl_dispose(&client->saslconn);
+            client->saslconn = NULL;
+            return -2;
+        }
+    }
+
+    memset (&secprops, 0, sizeof secprops);
+    if (client->type == QEMUD_SOCK_TYPE_TLS ||
+        client->type == QEMUD_SOCK_TYPE_UNIX) {
+        /* If we've got TLS or UNIX domain sock, we don't care about SSF */
+        secprops.min_ssf = 0;
+        secprops.max_ssf = 0;
+        secprops.maxbufsize = 8192;
+        secprops.security_flags = 0;
+    } else {
+        /* Plain TCP, better get an SSF layer */
+        secprops.min_ssf = 56; /* Good enough to require kerberos */
+        secprops.max_ssf = 100000; /* Arbitrary big number */
+        secprops.maxbufsize = 8192;
+        /* Forbid any anonymous or trivially crackable auth */
+        secprops.security_flags =
+            SASL_SEC_NOANONYMOUS | SASL_SEC_NOPLAINTEXT;
+    }
+
+    err = sasl_setprop(client->saslconn, SASL_SEC_PROPS, &secprops);
+    if (err != SASL_OK) {
+        qemudLog(QEMUD_ERR, "cannot set SASL security props %d (%s)",
+                 err, sasl_errstring(err, NULL, NULL));
+        remoteDispatchFailAuth(client, req);
+        sasl_dispose(&client->saslconn);
+        client->saslconn = NULL;
+        return -2;
+    }
+
     err = sasl_listmech(client->saslconn,
                         NULL, /* Don't need to set user */
                         "", /* Prefix */
@@ -2127,6 +2180,49 @@ remoteDispatchAuthSaslInit (struct qemud_client *client,
 }
 
 
+/* We asked for an SSF layer, so sanity check that we actaully
+ * got what we asked for */
+static int
+remoteSASLCheckSSF (struct qemud_client *client,
+                    remote_message_header *req) {
+    const void *val;
+    int err, ssf;
+
+    if (client->type == QEMUD_SOCK_TYPE_TLS ||
+        client->type == QEMUD_SOCK_TYPE_UNIX)
+        return 0; /* TLS or UNIX domain sockets trivially OK */
+
+    err = sasl_getprop(client->saslconn, SASL_SSF, &val);
+    if (err != SASL_OK) {
+        qemudLog(QEMUD_ERR, "cannot query SASL ssf on connection %d (%s)",
+                 err, sasl_errstring(err, NULL, NULL));
+        remoteDispatchFailAuth(client, req);
+        sasl_dispose(&client->saslconn);
+        client->saslconn = NULL;
+        return -1;
+    }
+    ssf = *(const int *)val;
+    REMOTE_DEBUG("negotiated an SSF of %d", ssf);
+    if (ssf < 56) { /* 56 is good for Kerberos */
+        qemudLog(QEMUD_ERR, "negotiated SSF %d was not strong enough", ssf);
+        remoteDispatchFailAuth(client, req);
+        sasl_dispose(&client->saslconn);
+        client->saslconn = NULL;
+        return -1;
+    }
+
+    /* Only setup for read initially, because we're about to send an RPC
+     * reply which must be in plain text. When the next incoming RPC
+     * arrives, we'll switch on writes too
+     *
+     * cf qemudClientReadSASL  in qemud.c
+     */
+    client->saslSSF = QEMUD_SASL_SSF_READ;
+
+    /* We have a SSF !*/
+    return 0;
+}
+
 /*
  * This starts the SASL authentication negotiation.
  */
@@ -2192,6 +2288,9 @@ remoteDispatchAuthSaslStart (struct qemud_client *client,
     if (err == SASL_CONTINUE) {
         ret->complete = 0;
     } else {
+        if (remoteSASLCheckSSF(client, req) < 0)
+            return -2;
+
         REMOTE_DEBUG("Authentication successful %d", client->fd);
         ret->complete = 1;
         client->auth = REMOTE_AUTH_NONE;
@@ -2263,6 +2362,9 @@ remoteDispatchAuthSaslStep (struct qemud_client *client,
     if (err == SASL_CONTINUE) {
         ret->complete = 0;
     } else {
+        if (remoteSASLCheckSSF(client, req) < 0)
+            return -2;
+
         REMOTE_DEBUG("Authentication successful %d", client->fd);
         ret->complete = 1;
         client->auth = REMOTE_AUTH_NONE;
index f05fd86743dccddd72aa58735c2cd1c008f9d35a..92e153fade5a23dab8a0bdad973f01e356125e66 100644 (file)
@@ -79,6 +79,9 @@ struct private_data {
     FILE *debugLog;             /* Debug remote protocol */
 #if HAVE_SASL
     sasl_conn_t *saslconn;      /* SASL context */
+    const char *saslDecoded;
+    unsigned int saslDecodedLength;
+    unsigned int saslDecodedOffset;
 #endif
 };
 
@@ -2907,15 +2910,14 @@ static char *addrToString(struct sockaddr_storage *sa, socklen_t salen)
 
 /* Perform the SASL authentication process
  *
- * XXX negotiate a session encryption layer for non-TLS sockets
  * XXX fetch credentials from a libvirt client app callback
- * XXX max packet size spec
  * XXX better mechanism negotiation ? Ask client app ?
  */
 static int
 remoteAuthSASL (virConnectPtr conn, struct private_data *priv, int in_open)
 {
     sasl_conn_t *saslconn = NULL;
+    sasl_security_properties_t secprops;
     remote_auth_sasl_init_ret iret;
     remote_auth_sasl_start_args sargs;
     remote_auth_sasl_start_ret sret;
@@ -2929,6 +2931,8 @@ remoteAuthSASL (virConnectPtr conn, struct private_data *priv, int in_open)
     struct sockaddr_storage sa;
     socklen_t salen;
     char *localAddr, *remoteAddr;
+    const void *val;
+    sasl_ssf_t ssf;
 
     remoteDebug(priv, "Client initialize SASL authentication");
     /* Sets up the SASL library as a whole */
@@ -2987,6 +2991,51 @@ remoteAuthSASL (virConnectPtr conn, struct private_data *priv, int in_open)
         return -1;
     }
 
+    /* Initialize some connection props we care about */
+    if (priv->uses_tls) {
+        gnutls_cipher_algorithm_t cipher;
+
+        cipher = gnutls_cipher_get(priv->session);
+        if (!(ssf = (sasl_ssf_t)gnutls_cipher_get_key_size(cipher))) {
+            __virRaiseError (in_open ? NULL : conn, NULL, NULL, VIR_FROM_REMOTE,
+                             VIR_ERR_INTERNAL_ERROR, VIR_ERR_ERROR, NULL, NULL, NULL, 0, 0,
+                             "invalid cipher size for TLS session");
+            sasl_dispose(&saslconn);
+            return -1;
+        }
+        ssf *= 8; /* key size is bytes, sasl wants bits */
+
+        remoteDebug(priv, "Setting external SSF %d", ssf);
+        err = sasl_setprop(saslconn, SASL_SSF_EXTERNAL, &ssf);
+        if (err != SASL_OK) {
+            __virRaiseError (in_open ? NULL : conn, NULL, NULL, VIR_FROM_REMOTE,
+                             VIR_ERR_INTERNAL_ERROR, VIR_ERR_ERROR, NULL, NULL, NULL, 0, 0,
+                             "cannot set external SSF %d (%s)",
+                             err, sasl_errstring(err, NULL, NULL));
+            sasl_dispose(&saslconn);
+            return -1;
+        }
+    }
+
+    memset (&secprops, 0, sizeof secprops);
+    /* If we've got TLS, we don't care about SSF */
+    secprops.min_ssf = priv->uses_tls ? 0 : 56; /* Equiv to DES supported by all Kerberos */
+    secprops.max_ssf = priv->uses_tls ? 0 : 100000; /* Very strong ! AES == 256 */
+    secprops.maxbufsize = 100000;
+    /* If we're not TLS, then forbid any anonymous or trivially crackable auth */
+    secprops.security_flags = priv->uses_tls ? 0 :
+        SASL_SEC_NOANONYMOUS | SASL_SEC_NOPLAINTEXT;
+
+    err = sasl_setprop(saslconn, SASL_SEC_PROPS, &secprops);
+    if (err != SASL_OK) {
+        __virRaiseError (in_open ? NULL : conn, NULL, NULL, VIR_FROM_REMOTE,
+                         VIR_ERR_INTERNAL_ERROR, VIR_ERR_ERROR, NULL, NULL, NULL, 0, 0,
+                         "cannot set security props %d (%s)",
+                         err, sasl_errstring(err, NULL, NULL));
+        sasl_dispose(&saslconn);
+        return -1;
+    }
+
     /* First call is to inquire about supported mechanisms in the server */
     memset (&iret, 0, sizeof iret);
     if (call (conn, priv, in_open, REMOTE_PROC_AUTH_SASL_INIT,
@@ -3103,9 +3152,30 @@ remoteAuthSASL (virConnectPtr conn, struct private_data *priv, int in_open)
         }
     }
 
+    /* Check for suitable SSF if non-TLS */
+    if (!priv->uses_tls) {
+        err = sasl_getprop(saslconn, SASL_SSF, &val);
+        if (err != SASL_OK) {
+            __virRaiseError (in_open ? NULL : conn, NULL, NULL, VIR_FROM_REMOTE,
+                             VIR_ERR_AUTH_FAILED, VIR_ERR_ERROR, NULL, NULL, NULL, 0, 0,
+                             "cannot query SASL ssf on connection %d (%s)",
+                             err, sasl_errstring(err, NULL, NULL));
+            sasl_dispose(&saslconn);
+            return -1;
+        }
+        ssf = *(const int *)val;
+        remoteDebug(priv, "SASL SSF value %d", ssf);
+        if (ssf < 56) { /* 56 == DES level, good for Kerberos */
+            __virRaiseError (in_open ? NULL : conn, NULL, NULL, VIR_FROM_REMOTE,
+                             VIR_ERR_AUTH_FAILED, VIR_ERR_ERROR, NULL, NULL, NULL, 0, 0,
+                             "negotiation SSF %d was not strong enough", ssf);
+            sasl_dispose(&saslconn);
+            return -1;
+        }
+    }
+
     remoteDebug(priv, "SASL authentication complete");
-    /* XXX keep this around for wire encoding */
-    sasl_dispose(&saslconn);
+    priv->saslconn = saslconn;
     return 0;
 }
 #endif /* HAVE_SASL */
@@ -3306,11 +3376,11 @@ call (virConnectPtr conn, struct private_data *priv,
 }
 
 static int
-really_write (virConnectPtr conn, struct private_data *priv,
-              int in_open /* if we are in virConnectOpen */,
-              char *bytes, int len)
+really_write_buf (virConnectPtr conn, struct private_data *priv,
+                  int in_open /* if we are in virConnectOpen */,
+                  const char *bytes, int len)
 {
-    char *p;
+    const char *p;
     int err;
 
     p = bytes;
@@ -3348,57 +3418,158 @@ really_write (virConnectPtr conn, struct private_data *priv,
 }
 
 static int
-really_read (virConnectPtr conn, struct private_data *priv,
-             int in_open /* if we are in virConnectOpen */,
-             char *bytes, int len)
+really_write_plain (virConnectPtr conn, struct private_data *priv,
+                    int in_open /* if we are in virConnectOpen */,
+                    char *bytes, int len)
+{
+    return really_write_buf(conn, priv, in_open, bytes, len);
+}
+
+#if HAVE_SASL
+static int
+really_write_sasl (virConnectPtr conn, struct private_data *priv,
+              int in_open /* if we are in virConnectOpen */,
+              char *bytes, int len)
+{
+    const char *output;
+    unsigned int outputlen;
+    int err;
+
+    err = sasl_encode(priv->saslconn, bytes, len, &output, &outputlen);
+    if (err != SASL_OK) {
+        return -1;
+    }
+
+    return really_write_buf(conn, priv, in_open, output, outputlen);
+}
+#endif
+
+static int
+really_write (virConnectPtr conn, struct private_data *priv,
+              int in_open /* if we are in virConnectOpen */,
+              char *bytes, int len)
+{
+#if HAVE_SASL
+    if (priv->saslconn)
+        return really_write_sasl(conn, priv, in_open, bytes, len);
+    else
+#endif
+        return really_write_plain(conn, priv, in_open, bytes, len);
+}
+
+static int
+really_read_buf (virConnectPtr conn, struct private_data *priv,
+                 int in_open /* if we are in virConnectOpen */,
+                 char *bytes, int len)
 {
-    char *p;
     int err;
 
-    p = bytes;
     if (priv->uses_tls) {
-        do {
-            err = gnutls_record_recv (priv->session, p, len);
-            if (err < 0) {
-                if (err == GNUTLS_E_INTERRUPTED || err == GNUTLS_E_AGAIN)
-                    continue;
-                error (in_open ? NULL : conn,
-                       VIR_ERR_GNUTLS_ERROR, gnutls_strerror (err));
-                return -1;
-            }
-            if (err == 0) {
-                error (in_open ? NULL : conn,
-                       VIR_ERR_RPC, "socket closed unexpectedly");
-                return -1;
-            }
-            len -= err;
-            p += err;
+    tlsreread:
+        err = gnutls_record_recv (priv->session, bytes, len);
+        if (err < 0) {
+            if (err == GNUTLS_E_INTERRUPTED)
+                goto tlsreread;
+            error (in_open ? NULL : conn,
+                   VIR_ERR_GNUTLS_ERROR, gnutls_strerror (err));
+            return -1;
         }
-        while (len > 0);
+        if (err == 0) {
+            error (in_open ? NULL : conn,
+                   VIR_ERR_RPC, "socket closed unexpectedly");
+            return -1;
+        }
+        return err;
     } else {
-        do {
-            err = read (priv->sock, p, len);
-            if (err == -1) {
-                if (errno == EINTR || errno == EAGAIN)
-                    continue;
-                error (in_open ? NULL : conn,
-                       VIR_ERR_SYSTEM_ERROR, strerror (errno));
-                return -1;
-            }
-            if (err == 0) {
-                error (in_open ? NULL : conn,
-                       VIR_ERR_RPC, "socket closed unexpectedly");
-                return -1;
-            }
-            len -= err;
-            p += err;
+    reread:
+        err = read (priv->sock, bytes, len);
+        if (err == -1) {
+            if (errno == EINTR)
+                goto reread;
+            error (in_open ? NULL : conn,
+                   VIR_ERR_SYSTEM_ERROR, strerror (errno));
+            return -1;
         }
-        while (len > 0);
+        if (err == 0) {
+            error (in_open ? NULL : conn,
+                   VIR_ERR_RPC, "socket closed unexpectedly");
+            return -1;
+        }
+        return err;
     }
 
     return 0;
 }
 
+static int
+really_read_plain (virConnectPtr conn, struct private_data *priv,
+                   int in_open /* if we are in virConnectOpen */,
+                   char *bytes, int len)
+{
+    do {
+        int ret = really_read_buf (conn, priv, in_open, bytes, len);
+        if (ret < 0)
+            return -1;
+
+        len -= ret;
+        bytes += ret;
+    } while (len > 0);
+
+    return 0;
+}
+
+#if HAVE_SASL
+static int
+really_read_sasl (virConnectPtr conn, struct private_data *priv,
+                  int in_open /* if we are in virConnectOpen */,
+                  char *bytes, int len)
+{
+    do {
+        int want, got;
+        if (priv->saslDecoded == NULL) {
+            char encoded[8192];
+            int encodedLen = sizeof(encoded);
+            int err, ret;
+            ret = really_read_buf (conn, priv, in_open, encoded, encodedLen);
+            if (ret < 0)
+                return -1;
+
+            err = sasl_decode(priv->saslconn, encoded, ret,
+                              &priv->saslDecoded, &priv->saslDecodedLength);
+        }
+
+        got = priv->saslDecodedLength - priv->saslDecodedOffset;
+        want = len;
+        if (want > got)
+            want = got;
+
+        memcpy(bytes, priv->saslDecoded + priv->saslDecodedOffset, want);
+        priv->saslDecodedOffset += want;
+        if (priv->saslDecodedOffset == priv->saslDecodedLength) {
+            priv->saslDecoded = NULL;
+            priv->saslDecodedOffset = priv->saslDecodedLength = 0;
+        }
+        bytes += want;
+        len -= want;
+    } while (len > 0);
+
+    return 0;
+}
+#endif
+
+static int
+really_read (virConnectPtr conn, struct private_data *priv,
+             int in_open /* if we are in virConnectOpen */,
+             char *bytes, int len)
+{
+#if HAVE_SASL
+    if (priv->saslconn)
+        return really_read_sasl (conn, priv, in_open, bytes, len);
+    else
+#endif
+        return really_read_plain (conn, priv, in_open, bytes, len);
+}
+
 /* For errors internal to this library. */
 static void
 error (virConnectPtr conn, virErrorNumber code, const char *info)