Cicrumvent possible segfaults in axTLS
[project/luci.git] / libs / nixio / src / tls-socket.c
index a305518..0f504cc 100644 (file)
  *  limitations under the License.
  */
 
-#include "nixio.h"
-#include "string.h"
-
-#ifndef WITHOUT_OPENSSL
-#include <openssl/ssl.h>
-#endif
+#include "nixio-tls.h"
+#include <string.h>
+#include <stdlib.h>
 
 static int nixio__tls_sock_perror(lua_State *L, SSL *sock, int code) {
        lua_pushnil(L);
-       lua_pushinteger(L, code);
        lua_pushinteger(L, SSL_get_error(sock, code));
-       return 3;
+       return 2;
 }
 
 static int nixio__tls_sock_pstatus(lua_State *L, SSL *sock, int code) {
-       if (code == 1) {
+       if (code > 0) {
                lua_pushboolean(L, 1);
                return 1;
        } else {
@@ -40,20 +36,44 @@ static int nixio__tls_sock_pstatus(lua_State *L, SSL *sock, int code) {
 }
 
 static SSL* nixio__checktlssock(lua_State *L) {
-       SSL **sock = (SSL **)luaL_checkudata(L, 1, NIXIO_TLS_SOCK_META);
-       luaL_argcheck(L, *sock, 1, "invalid context");
-       return *sock;
+       if (lua_istable(L, 1)) {
+               lua_getfield(L, 1, "connection");
+               lua_replace(L, 1);
+       }
+       nixio_tls_sock *sock = luaL_checkudata(L, 1, NIXIO_TLS_SOCK_META);
+       luaL_argcheck(L, sock->socket, 1, "invalid context");
+       return sock->socket;
 }
 
+#ifndef WITH_AXTLS
+#define nixio_tls__check_connected(L) ;
+
+#define nixio_tls__set_connected(L, val) ;
+#else
+#define nixio_tls__check_connected(L) \
+       nixio_tls_sock *ctsock = luaL_checkudata(L, 1, NIXIO_TLS_SOCK_META);    \
+       if (!ctsock->connected) {       \
+               lua_pushnil(L);                 \
+               lua_pushinteger(L, 1);  \
+               return 2;                               \
+       }
+
+#define nixio_tls__set_connected(L, val) \
+((nixio_tls_sock*)luaL_checkudata(L, 1, NIXIO_TLS_SOCK_META))->connected = val;
+#endif /* WITH_AXTLS */
+
 static int nixio_tls_sock_recv(lua_State *L) {
        SSL *sock = nixio__checktlssock(L);
+       nixio_tls__check_connected(L);
        int req = luaL_checkinteger(L, 2);
 
        luaL_argcheck(L, req >= 0, 2, "out of range");
 
        /* We limit the readsize to NIXIO_BUFFERSIZE */
        req = (req > NIXIO_BUFFERSIZE) ? NIXIO_BUFFERSIZE : req;
+
 #ifndef WITH_AXTLS
+
        char buffer[NIXIO_BUFFERSIZE];
        int readc = SSL_read(sock, buffer, req);
 
@@ -63,73 +83,92 @@ static int nixio_tls_sock_recv(lua_State *L) {
                lua_pushlstring(L, buffer, readc);
                return 1;
        }
+
 #else
+
        if (!req) {
                lua_pushliteral(L, "");
                return 1;
        }
 
-       /* AXTLS doesn't handle buffering for us, so we have to hack around*/
-       int buflen = 0;
-       lua_getmetatable(L, 1);
-       lua_getfield(L, -1, "_axbuffer");
-
-       if (lua_isstring(L, -1)) {
-               buflen = lua_objlen(L, -1);
-       }
+       nixio_tls_sock *t = lua_touserdata(L, 1);
 
-       if (req < buflen) {
-               const char *axbuf = lua_tostring(L, -1);
-               lua_pushlstring(L, axbuf, req);
-               lua_pushlstring(L, axbuf + req, buflen - req);
-               lua_setfield(L, -4, "_axbuffer");
+       /* AXTLS doesn't handle buffering for us, so we have to hack around*/
+       if (req < t->pbufsiz) {
+               lua_pushlstring(L, t->pbufpos, req);
+               t->pbufpos += req;
+               t->pbufsiz -= req;
                return 1;
        } else {
-               if (!lua_isstring(L, -1)) {
-                       lua_pop(L, 1);
-                       lua_pushliteral(L, "");
-               }
-
-               char *axbuf;
+               uint8_t *axbuf;
                int axread;
 
                /* while handshake pending */
-               while ((axread = ssl_read(sock, (uint8_t**)&axbuf)) == SSL_OK);
+               while ((axread = ssl_read(sock, &axbuf)) == SSL_OK);
+
+               if (t->pbufsiz) {
+                       lua_pushlstring(L, t->pbufpos, t->pbufsiz);
+               }
 
                if (axread < 0) {
                        /* There is an error */
+                       free(t->pbuffer);
+                       t->pbuffer = t->pbufpos = NULL;
 
                        if (axread != SSL_ERROR_CONN_LOST) {
-                               lua_pushliteral(L, "");
-                               lua_setfield(L, -3, "_axbuffer");
+                               t->pbufsiz = 0;
                                return nixio__tls_sock_perror(L, sock, axread);
                        } else {
-                               lua_pushliteral(L, "");
+                               if (!t->pbufsiz) {
+                                       lua_pushliteral(L, "");
+                               } else {
+                                       t->pbufsiz = 0;
+                               }
                        }
                } else {
-                       int stillwant = req - buflen;
+                       int stillwant = req - t->pbufsiz;
                        if (stillwant < axread) {
                                /* we got more data than we need */
-                               lua_pushlstring(L, axbuf, stillwant);
-                               lua_concat(L, 2);
+                               lua_pushlstring(L, (char *)axbuf, stillwant);
+                               if(t->pbufsiz) {
+                                       lua_concat(L, 2);
+                               }
 
                                /* remaining data goes into the buffer */
-                               lua_pushlstring(L, axbuf + stillwant, axread - stillwant);
+                               t->pbufpos = t->pbuffer;
+                               t->pbufsiz = axread - stillwant;
+                               t->pbuffer = realloc(t->pbuffer, t->pbufsiz);
+                               if (!t->pbuffer) {
+                                       free(t->pbufpos);
+                                       t->pbufpos = NULL;
+                                       t->pbufsiz = 0;
+                                       return luaL_error(L, "out of memory");
+                               }
+
+                               t->pbufpos = t->pbuffer;
+                               memcpy(t->pbufpos, axbuf + stillwant, t->pbufsiz);
                        } else {
-                               lua_pushlstring(L, axbuf, axread);
-                               lua_concat(L, 2);
-                               lua_pushliteral(L, "");
+                               lua_pushlstring(L, (char *)axbuf, axread);
+                               if(t->pbufsiz) {
+                                       lua_concat(L, 2);
+                               }
+
+                               /* free buffer */
+                               free(t->pbuffer);
+                               t->pbuffer = t->pbufpos = NULL;
+                               t->pbufsiz = 0;
                        }
                }
-               lua_setfield(L, -3, "_axbuffer");
                return 1;
        }
 
-#endif
+#endif /* WITH_AXTLS */
+
 }
 
 static int nixio_tls_sock_send(lua_State *L) {
        SSL *sock = nixio__checktlssock(L);
+       nixio_tls__check_connected(L);
        size_t len;
        ssize_t sent;
        const char *data = luaL_checklstring(L, 2, &len);
@@ -138,37 +177,45 @@ static int nixio_tls_sock_send(lua_State *L) {
                lua_pushinteger(L, sent);
                return 1;
        } else {
-               return nixio__tls_sock_pstatus(L, sock, len);
+               return nixio__tls_sock_pstatus(L, sock, sent);
        }
 }
 
 static int nixio_tls_sock_accept(lua_State *L) {
        SSL *sock = nixio__checktlssock(L);
-       return nixio__tls_sock_pstatus(L, sock, SSL_accept(sock));
+       const int stat = SSL_accept(sock);
+       nixio_tls__set_connected(L, stat == 1);
+       return nixio__tls_sock_pstatus(L, sock, stat);
 }
 
 static int nixio_tls_sock_connect(lua_State *L) {
        SSL *sock = nixio__checktlssock(L);
-       return nixio__tls_sock_pstatus(L, sock, SSL_connect(sock));
+       const int stat = SSL_connect(sock);
+       nixio_tls__set_connected(L, stat == 1);
+       return nixio__tls_sock_pstatus(L, sock, stat);
 }
 
 static int nixio_tls_sock_shutdown(lua_State *L) {
        SSL *sock = nixio__checktlssock(L);
+       nixio_tls__set_connected(L, 0);
        return nixio__tls_sock_pstatus(L, sock, SSL_shutdown(sock));
 }
 
 static int nixio_tls_sock__gc(lua_State *L) {
-       SSL **sock = (SSL **)luaL_checkudata(L, 1, NIXIO_TLS_SOCK_META);
-       if (*sock) {
-               SSL_free(*sock);
-               *sock = NULL;
+       nixio_tls_sock *sock = luaL_checkudata(L, 1, NIXIO_TLS_SOCK_META);
+       if (sock->socket) {
+               SSL_free(sock->socket);
+               sock->socket = NULL;
+#ifdef WITH_AXTLS
+               free(sock->pbuffer);
+#endif
        }
        return 0;
 }
 
 static int nixio_tls_sock__tostring(lua_State *L) {
        SSL *sock = nixio__checktlssock(L);
-       lua_pushfstring(L, "nixio TLS socket: %p", sock);
+       lua_pushfstring(L, "nixio TLS connection: %p", sock);
        return 1;
 }
 
@@ -177,6 +224,8 @@ static int nixio_tls_sock__tostring(lua_State *L) {
 static const luaL_reg M[] = {
        {"recv",                nixio_tls_sock_recv},
        {"send",                nixio_tls_sock_send},
+       {"read",                nixio_tls_sock_recv},
+       {"write",               nixio_tls_sock_send},
        {"accept",              nixio_tls_sock_accept},
        {"connect",     nixio_tls_sock_connect},
        {"shutdown",    nixio_tls_sock_shutdown},
@@ -192,5 +241,5 @@ void nixio_open_tls_socket(lua_State *L) {
        luaL_register(L, NULL, M);
        lua_pushvalue(L, -1);
        lua_setfield(L, -2, "__index");
-       lua_setfield(L, -2, "tls_socket_meta");
+       lua_setfield(L, -2, "meta_tls_socket");
 }