mbedtls: Fix setting allowed cipher suites
[project/ustream-ssl.git] / ustream-ssl.c
index f403cb5..dd0faf9 100644 (file)
  */
 
 #include <errno.h>
-
-#include <openssl/ssl.h>
-#include <openssl/err.h>
-
+#include <stdlib.h>
+#include <string.h>
 #include <libubox/ustream.h>
-#include "ustream-io.h"
-#include "ustream-ssl.h"
-
-static void ssl_init(void)
-{
-       static bool _init = false;
-
-       if (_init)
-               return;
-
-       SSL_load_error_strings();
-       SSL_library_init();
 
-       _init = true;
-}
+#include "ustream-ssl.h"
+#include "ustream-internal.h"
 
 static void ustream_ssl_error_cb(struct uloop_timeout *t)
 {
@@ -45,79 +31,61 @@ static void ustream_ssl_error_cb(struct uloop_timeout *t)
        int error = us->error;
 
        if (us->notify_error)
-               us->notify_error(us, error, ERR_error_string(us->error, buffer));
-}
-
-static void ustream_ssl_error(struct ustream_ssl *us, int error)
-{
-       us->error = error;
-       uloop_timeout_set(&us->error_timer, 0);
+               us->notify_error(us, error, __ustream_ssl_strerror(us->error, buffer, sizeof(buffer)));
 }
 
 static void ustream_ssl_check_conn(struct ustream_ssl *us)
 {
-       int ret;
-
        if (us->connected || us->error)
                return;
 
-       if (us->server)
-               ret = SSL_accept(us->ssl);
-       else
-               ret = SSL_connect(us->ssl);
-
-       if (ret == 1) {
+       if (__ustream_ssl_connect(us) == U_SSL_OK) {
                us->connected = true;
                if (us->notify_connected)
                        us->notify_connected(us);
-               return;
+               ustream_write_pending(&us->stream);
        }
-
-       ret = SSL_get_error(us->ssl, ret);
-       if (ret == SSL_ERROR_WANT_READ || ret == SSL_ERROR_WANT_WRITE)
-               return;
-
-       ustream_ssl_error(us, ret);
 }
 
-static void ustream_ssl_notify_read(struct ustream *s, int bytes)
+static bool __ustream_ssl_poll(struct ustream *s)
 {
        struct ustream_ssl *us = container_of(s->next, struct ustream_ssl, stream);
        char *buf;
-       int wr = 0, len, ret;
+       int len, ret;
+       bool more = false;
 
        ustream_ssl_check_conn(us);
        if (!us->connected || us->error)
-               return;
-
-       buf = ustream_reserve(&us->stream, 1, &len);
-       if (!len)
-               return;
+               return false;
 
        do {
-               ret = SSL_read(us->ssl, buf, len);
-               if (ret < 0) {
-                       ret = SSL_get_error(us->ssl, ret);
-
-                       if (ret == SSL_ERROR_WANT_READ)
-                               break;
-
-                       ustream_ssl_error(us, ret);
+               buf = ustream_reserve(&us->stream, 1, &len);
+               if (!len)
                        break;
-               }
-               if (ret == 0) {
+
+               ret = __ustream_ssl_read(us, buf, len);
+               switch (ret) {
+               case U_SSL_PENDING:
+                       return more;
+               case U_SSL_ERROR:
+                       return false;
+               case 0:
                        us->stream.eof = true;
                        ustream_state_change(&us->stream);
-                       break;
+                       return false;
+               default:
+                       ustream_fill_read(&us->stream, ret);
+                       more = true;
+                       continue;
                }
-
-               buf += ret;
-               len -= ret;
-               wr += ret;
        } while (1);
 
-       if (wr)
-               ustream_fill_read(&us->stream, wr);
+       return more;
+}
+
+static void ustream_ssl_notify_read(struct ustream *s, int bytes)
+{
+       __ustream_ssl_poll(s);
 }
 
 static void ustream_ssl_notify_write(struct ustream *s, int bytes)
@@ -137,7 +105,6 @@ static void ustream_ssl_notify_state(struct ustream *s)
 static int ustream_ssl_write(struct ustream *s, const char *buf, int len, bool more)
 {
        struct ustream_ssl *us = container_of(s, struct ustream_ssl, stream);
-       int ret;
 
        if (!us->connected || us->error)
                return 0;
@@ -145,14 +112,7 @@ static int ustream_ssl_write(struct ustream *s, const char *buf, int len, bool m
        if (us->conn->w.data_bytes)
                return 0;
 
-       ret = SSL_write(us->ssl, buf, len);
-       if (ret < 0) {
-               int err = SSL_get_error(us->ssl, ret);
-               if (err == SSL_ERROR_WANT_WRITE)
-                       return 0;
-       }
-
-       return ret;
+       return __ustream_ssl_write(us, buf, len);
 }
 
 static void ustream_ssl_set_read_blocked(struct ustream *s)
@@ -174,20 +134,26 @@ static void ustream_ssl_free(struct ustream *s)
        }
 
        uloop_timeout_cancel(&us->error_timer);
-       SSL_shutdown(us->ssl);
-       SSL_free(us->ssl);
+       __ustream_ssl_session_free(us->ssl);
+       free(us->peer_cn);
+
        us->ctx = NULL;
        us->ssl = NULL;
        us->conn = NULL;
+       us->peer_cn = NULL;
        us->connected = false;
        us->error = false;
+       us->valid_cert = false;
+       us->valid_cn = false;
 }
 
 static bool ustream_ssl_poll(struct ustream *s)
 {
        struct ustream_ssl *us = container_of(s, struct ustream_ssl, stream);
+       bool fd_poll;
 
-       return ustream_poll(us->conn);
+       fd_poll = ustream_poll(us->conn);
+       return __ustream_ssl_poll(us->conn) || fd_poll;
 }
 
 static void ustream_ssl_stream_init(struct ustream_ssl *us)
@@ -206,76 +172,14 @@ static void ustream_ssl_stream_init(struct ustream_ssl *us)
        ustream_init_defaults(s);
 }
 
-static void *_ustream_ssl_context_new(bool server)
-{
-       SSL_CTX *c;
-       const void *m;
-
-       ssl_init();
-
-#ifdef CYASSL_OPENSSL_H_
-       if (server)
-               m = SSLv23_server_method();
-       else
-               m = SSLv23_client_method();
-#else
-       if (server)
-               m = TLSv1_server_method();
-       else
-               m = TLSv1_client_method();
-#endif
-
-       c = SSL_CTX_new((void *) m);
-       if (!c)
-               return NULL;
-
-       if (server)
-               SSL_CTX_set_verify(c, SSL_VERIFY_NONE, NULL);
-
-       return c;
-}
-
-static int _ustream_ssl_context_set_crt_file(void *ctx, const char *file)
-{
-       int ret;
-
-       ret = SSL_CTX_use_certificate_file(ctx, file, SSL_FILETYPE_PEM);
-       if (ret < 1)
-               ret = SSL_CTX_use_certificate_file(ctx, file, SSL_FILETYPE_ASN1);
-
-       if (ret < 1)
-               return -1;
-
-       return 0;
-}
-
-static int _ustream_ssl_context_set_key_file(void *ctx, const char *file)
-{
-       int ret;
-
-       ret = SSL_CTX_use_PrivateKey_file(ctx, file, SSL_FILETYPE_PEM);
-       if (ret < 1)
-               ret = SSL_CTX_use_PrivateKey_file(ctx, file, SSL_FILETYPE_ASN1);
-
-       if (ret < 1)
-               return -1;
-
-       return 0;
-}
-
-static void _ustream_ssl_context_free(void *ctx)
-{
-       SSL_CTX_free(ctx);
-}
-
-static int _ustream_ssl_init(struct ustream_ssl *us, struct ustream *conn, void *ctx, bool server)
+static int _ustream_ssl_init(struct ustream_ssl *us, struct ustream *conn, struct ustream_ssl_ctx *ctx, bool server)
 {
        us->error_timer.cb = ustream_ssl_error_cb;
        us->server = server;
        us->conn = conn;
        us->ctx = ctx;
 
-       us->ssl = SSL_new(us->ctx);
+       us->ssl = __ustream_ssl_session_new(us->ctx);
        if (!us->ssl)
                return -ENOMEM;
 
@@ -283,13 +187,28 @@ static int _ustream_ssl_init(struct ustream_ssl *us, struct ustream *conn, void
        ustream_set_io(ctx, us->ssl, conn);
        ustream_ssl_stream_init(us);
 
+       if (us->server_name)
+               __ustream_ssl_set_server_name(us);
+
+       ustream_ssl_check_conn(us);
+
+       return 0;
+}
+
+static int _ustream_ssl_set_peer_cn(struct ustream_ssl *us, const char *name)
+{
+       us->peer_cn = strdup(name);
+       __ustream_ssl_update_peer_cn(us);
+
        return 0;
 }
 
 const struct ustream_ssl_ops ustream_ssl_ops = {
-       .context_new = _ustream_ssl_context_new,
-       .context_set_crt_file = _ustream_ssl_context_set_crt_file,
-       .context_set_key_file = _ustream_ssl_context_set_key_file,
-       .context_free = _ustream_ssl_context_free,
+       .context_new = __ustream_ssl_context_new,
+       .context_set_crt_file = __ustream_ssl_set_crt_file,
+       .context_set_key_file = __ustream_ssl_set_key_file,
+       .context_add_ca_crt_file = __ustream_ssl_add_ca_crt_file,
+       .context_free = __ustream_ssl_context_free,
        .init = _ustream_ssl_init,
+       .set_peer_cn = _ustream_ssl_set_peer_cn,
 };