add s->poll wrapper
[project/ustream-ssl.git] / ustream-ssl.c
1 /*
2  * ustream-ssl - library for SSL over ustream
3  *
4  * Copyright (C) 2012 Felix Fietkau <nbd@openwrt.org>
5  *
6  * Permission to use, copy, modify, and/or distribute this software for any
7  * purpose with or without fee is hereby granted, provided that the above
8  * copyright notice and this permission notice appear in all copies.
9  *
10  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
13  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
15  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
16  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17  */
18
19 #include <errno.h>
20
21 #include <openssl/ssl.h>
22 #include <openssl/err.h>
23
24 #include <libubox/ustream.h>
25 #include "ustream-io.h"
26 #include "ustream-ssl.h"
27
28 static void ssl_init(void)
29 {
30         static bool _init = false;
31
32         if (_init)
33                 return;
34
35         SSL_load_error_strings();
36         SSL_library_init();
37
38         _init = true;
39 }
40
41 static void ustream_ssl_error_cb(struct uloop_timeout *t)
42 {
43         struct ustream_ssl *us = container_of(t, struct ustream_ssl, error_timer);
44         static char buffer[128];
45         int error = us->error;
46
47         if (us->notify_error)
48                 us->notify_error(us, error, ERR_error_string(us->error, buffer));
49 }
50
51 static void ustream_ssl_error(struct ustream_ssl *us, int error)
52 {
53         us->error = error;
54         uloop_timeout_set(&us->error_timer, 0);
55 }
56
57 static void ustream_ssl_check_conn(struct ustream_ssl *us)
58 {
59         int ret;
60
61         if (us->connected || us->error)
62                 return;
63
64         if (us->server)
65                 ret = SSL_accept(us->ssl);
66         else
67                 ret = SSL_connect(us->ssl);
68
69         if (ret == 1) {
70                 us->connected = true;
71                 if (us->notify_connected)
72                         us->notify_connected(us);
73                 return;
74         }
75
76         ret = SSL_get_error(us->ssl, ret);
77         if (ret == SSL_ERROR_WANT_READ || ret == SSL_ERROR_WANT_WRITE)
78                 return;
79
80         ustream_ssl_error(us, ret);
81 }
82
83 static void ustream_ssl_notify_read(struct ustream *s, int bytes)
84 {
85         struct ustream_ssl *us = container_of(s->next, struct ustream_ssl, stream);
86         char *buf;
87         int wr = 0, len, ret;
88
89         ustream_ssl_check_conn(us);
90         if (!us->connected || us->error)
91                 return;
92
93         buf = ustream_reserve(&us->stream, 1, &len);
94         if (!len)
95                 return;
96
97         do {
98                 ret = SSL_read(us->ssl, buf, len);
99                 if (ret < 0) {
100                         ret = SSL_get_error(us->ssl, ret);
101
102                         if (ret == SSL_ERROR_WANT_READ)
103                                 break;
104
105                         ustream_ssl_error(us, ret);
106                         break;
107                 }
108                 if (ret == 0) {
109                         us->stream.eof = true;
110                         ustream_state_change(&us->stream);
111                         break;
112                 }
113
114                 buf += ret;
115                 len -= ret;
116                 wr += ret;
117         } while (1);
118
119         if (wr)
120                 ustream_fill_read(&us->stream, wr);
121 }
122
123 static void ustream_ssl_notify_write(struct ustream *s, int bytes)
124 {
125         struct ustream_ssl *us = container_of(s->next, struct ustream_ssl, stream);
126
127         ustream_ssl_check_conn(us);
128         ustream_write_pending(s->next);
129 }
130
131 static void ustream_ssl_notify_state(struct ustream *s)
132 {
133         s->next->write_error = true;
134         ustream_state_change(s->next);
135 }
136
137 static int ustream_ssl_write(struct ustream *s, const char *buf, int len, bool more)
138 {
139         struct ustream_ssl *us = container_of(s, struct ustream_ssl, stream);
140         int ret;
141
142         if (!us->connected || us->error)
143                 return 0;
144
145         if (us->conn->w.data_bytes)
146                 return 0;
147
148         ret = SSL_write(us->ssl, buf, len);
149         if (ret < 0) {
150                 int err = SSL_get_error(us->ssl, ret);
151                 if (err == SSL_ERROR_WANT_WRITE)
152                         return 0;
153         }
154
155         return ret;
156 }
157
158 static void ustream_ssl_set_read_blocked(struct ustream *s)
159 {
160         struct ustream_ssl *us = container_of(s, struct ustream_ssl, stream);
161
162         ustream_set_read_blocked(us->conn, !!s->read_blocked);
163 }
164
165 static void ustream_ssl_free(struct ustream *s)
166 {
167         struct ustream_ssl *us = container_of(s, struct ustream_ssl, stream);
168
169         if (us->conn) {
170                 us->conn->next = NULL;
171                 us->conn->notify_read = NULL;
172                 us->conn->notify_write = NULL;
173                 us->conn->notify_state = NULL;
174         }
175
176         uloop_timeout_cancel(&us->error_timer);
177         SSL_shutdown(us->ssl);
178         SSL_free(us->ssl);
179         us->ctx = NULL;
180         us->ssl = NULL;
181         us->conn = NULL;
182         us->connected = false;
183         us->error = false;
184 }
185
186 static bool ustream_ssl_poll(struct ustream *s)
187 {
188         struct ustream_ssl *us = container_of(s, struct ustream_ssl, stream);
189
190         return ustream_poll(us->conn);
191 }
192
193 static void ustream_ssl_stream_init(struct ustream_ssl *us)
194 {
195         struct ustream *conn = us->conn;
196         struct ustream *s = &us->stream;
197
198         conn->notify_read = ustream_ssl_notify_read;
199         conn->notify_write = ustream_ssl_notify_write;
200         conn->notify_state = ustream_ssl_notify_state;
201
202         s->free = ustream_ssl_free;
203         s->write = ustream_ssl_write;
204         s->poll = ustream_ssl_poll;
205         s->set_read_blocked = ustream_ssl_set_read_blocked;
206         ustream_init_defaults(s);
207 }
208
209 static void *_ustream_ssl_context_new(bool server)
210 {
211         SSL_CTX *c;
212         const void *m;
213
214         ssl_init();
215
216 #ifdef CYASSL_OPENSSL_H_
217         if (server)
218                 m = SSLv23_server_method();
219         else
220                 m = SSLv23_client_method();
221 #else
222         if (server)
223                 m = TLSv1_server_method();
224         else
225                 m = TLSv1_client_method();
226 #endif
227
228         c = SSL_CTX_new((void *) m);
229         if (!c)
230                 return NULL;
231
232         if (server)
233                 SSL_CTX_set_verify(c, SSL_VERIFY_NONE, NULL);
234
235         return c;
236 }
237
238 static int _ustream_ssl_context_set_crt_file(void *ctx, const char *file)
239 {
240         int ret;
241
242         ret = SSL_CTX_use_certificate_file(ctx, file, SSL_FILETYPE_PEM);
243         if (ret < 1)
244                 ret = SSL_CTX_use_certificate_file(ctx, file, SSL_FILETYPE_ASN1);
245
246         if (ret < 1)
247                 return -1;
248
249         return 0;
250 }
251
252 static int _ustream_ssl_context_set_key_file(void *ctx, const char *file)
253 {
254         int ret;
255
256         ret = SSL_CTX_use_PrivateKey_file(ctx, file, SSL_FILETYPE_PEM);
257         if (ret < 1)
258                 ret = SSL_CTX_use_PrivateKey_file(ctx, file, SSL_FILETYPE_ASN1);
259
260         if (ret < 1)
261                 return -1;
262
263         return 0;
264 }
265
266 static void _ustream_ssl_context_free(void *ctx)
267 {
268         SSL_CTX_free(ctx);
269 }
270
271 static int _ustream_ssl_init(struct ustream_ssl *us, struct ustream *conn, void *ctx, bool server)
272 {
273         us->error_timer.cb = ustream_ssl_error_cb;
274         us->server = server;
275         us->conn = conn;
276         us->ctx = ctx;
277
278         us->ssl = SSL_new(us->ctx);
279         if (!us->ssl)
280                 return -ENOMEM;
281
282         conn->next = &us->stream;
283         ustream_set_io(ctx, us->ssl, conn);
284         ustream_ssl_stream_init(us);
285
286         return 0;
287 }
288
289 const struct ustream_ssl_ops ustream_ssl_ops = {
290         .context_new = _ustream_ssl_context_new,
291         .context_set_crt_file = _ustream_ssl_context_set_crt_file,
292         .context_set_key_file = _ustream_ssl_context_set_key_file,
293         .context_free = _ustream_ssl_context_free,
294         .init = _ustream_ssl_init,
295 };