afb01f35f1d6b9a3cc16d9c85ca1568c67748a87
[project/ustream-ssl.git] / ustream-ssl.c
1 #include <errno.h>
2
3 #include <openssl/ssl.h>
4 #include <openssl/err.h>
5
6 #include <libubox/ustream.h>
7 #include "ustream-io.h"
8 #include "ustream-ssl.h"
9
10 static void ssl_init(void)
11 {
12         static bool _init = false;
13
14         if (_init)
15                 return;
16
17         SSL_load_error_strings();
18         SSL_library_init();
19
20         _init = true;
21 }
22
23 static void ustream_ssl_error_cb(struct uloop_timeout *t)
24 {
25         struct ustream_ssl *us = container_of(t, struct ustream_ssl, error_timer);
26         static char buffer[128];
27         int error = us->error;
28
29         if (us->notify_error)
30                 us->notify_error(us, error, ERR_error_string(us->error, buffer));
31 }
32
33 static void ustream_ssl_error(struct ustream_ssl *us, int error)
34 {
35         us->error = error;
36         uloop_timeout_set(&us->error_timer, 0);
37 }
38
39 static void ustream_ssl_check_conn(struct ustream_ssl *us)
40 {
41         int ret;
42
43         if (us->connected || us->error)
44                 return;
45
46         if (us->server)
47                 ret = SSL_accept(us->ssl);
48         else
49                 ret = SSL_connect(us->ssl);
50
51         if (ret == 1) {
52                 us->connected = true;
53                 if (us->notify_connected)
54                         us->notify_connected(us);
55                 return;
56         }
57
58         ret = SSL_get_error(us->ssl, ret);
59         if (ret == SSL_ERROR_WANT_READ || ret == SSL_ERROR_WANT_WRITE)
60                 return;
61
62         ustream_ssl_error(us, ret);
63 }
64
65 static void ustream_ssl_notify_read(struct ustream *s, int bytes)
66 {
67         struct ustream_ssl *us = container_of(s->next, struct ustream_ssl, stream);
68         char *buf;
69         int wr = 0, len, ret;
70
71         ustream_ssl_check_conn(us);
72         if (!us->connected || us->error)
73                 return;
74
75         buf = ustream_reserve(&us->stream, 1, &len);
76         if (!len)
77                 return;
78
79         do {
80                 ret = SSL_read(us->ssl, buf, len);
81                 if (ret < 0) {
82                         ret = SSL_get_error(us->ssl, ret);
83
84                         if (ret == SSL_ERROR_WANT_READ)
85                                 break;
86
87                         ustream_ssl_error(us, ret);
88                         break;
89                 }
90                 if (ret == 0) {
91                         us->stream.eof = true;
92                         ustream_state_change(&us->stream);
93                         break;
94                 }
95
96                 buf += ret;
97                 len -= ret;
98                 wr += ret;
99         } while (1);
100
101         if (wr)
102                 ustream_fill_read(&us->stream, wr);
103 }
104
105 static void ustream_ssl_notify_write(struct ustream *s, int bytes)
106 {
107         struct ustream_ssl *us = container_of(s->next, struct ustream_ssl, stream);
108
109         ustream_ssl_check_conn(us);
110         ustream_write_pending(s->next);
111 }
112
113 static void ustream_ssl_notify_state(struct ustream *s)
114 {
115         s->next->write_error = true;
116         ustream_state_change(s->next);
117 }
118
119 static int ustream_ssl_write(struct ustream *s, const char *buf, int len, bool more)
120 {
121         struct ustream_ssl *us = container_of(s, struct ustream_ssl, stream);
122         int ret;
123
124         if (!us->connected || us->error)
125                 return 0;
126
127         if (us->conn->w.data_bytes)
128                 return 0;
129
130         ret = SSL_write(us->ssl, buf, len);
131         if (ret < 0) {
132                 int err = SSL_get_error(us->ssl, ret);
133                 if (err == SSL_ERROR_WANT_WRITE)
134                         return 0;
135         }
136
137         return ret;
138 }
139
140 static void ustream_ssl_set_read_blocked(struct ustream *s)
141 {
142         struct ustream_ssl *us = container_of(s, struct ustream_ssl, stream);
143
144         ustream_set_read_blocked(us->conn, !!s->read_blocked);
145 }
146
147 static void ustream_ssl_free(struct ustream *s)
148 {
149         struct ustream_ssl *us = container_of(s, struct ustream_ssl, stream);
150
151         if (us->conn) {
152                 us->conn->next = NULL;
153                 us->conn->notify_read = NULL;
154                 us->conn->notify_write = NULL;
155                 us->conn->notify_state = NULL;
156         }
157
158         uloop_timeout_cancel(&us->error_timer);
159         SSL_shutdown(us->ssl);
160         SSL_free(us->ssl);
161         us->ctx = NULL;
162         us->ssl = NULL;
163         us->conn = NULL;
164         us->connected = false;
165         us->error = false;
166 }
167
168 static void ustream_ssl_stream_init(struct ustream_ssl *us)
169 {
170         struct ustream *conn = us->conn;
171         struct ustream *s = &us->stream;
172
173         conn->notify_read = ustream_ssl_notify_read;
174         conn->notify_write = ustream_ssl_notify_write;
175         conn->notify_state = ustream_ssl_notify_state;
176
177         s->free = ustream_ssl_free;
178         s->write = ustream_ssl_write;
179         s->set_read_blocked = ustream_ssl_set_read_blocked;
180         ustream_init_defaults(s);
181 }
182
183 static void *_ustream_ssl_context_new(bool server)
184 {
185         SSL_CTX *c;
186         const void *m;
187
188         ssl_init();
189
190 #ifdef CYASSL_OPENSSL_H_
191         if (server)
192                 m = SSLv23_server_method();
193         else
194                 m = SSLv23_client_method();
195 #else
196         if (server)
197                 m = TLSv1_server_method();
198         else
199                 m = TLSv1_client_method();
200 #endif
201
202         c = SSL_CTX_new((void *) m);
203         if (!c)
204                 return NULL;
205
206         if (server)
207                 SSL_CTX_set_verify(c, SSL_VERIFY_NONE, NULL);
208
209         return c;
210 }
211
212 static int _ustream_ssl_context_set_crt_file(void *ctx, const char *file)
213 {
214         int ret;
215
216         ret = SSL_CTX_use_certificate_file(ctx, file, SSL_FILETYPE_PEM);
217         if (ret < 1)
218                 ret = SSL_CTX_use_certificate_file(ctx, file, SSL_FILETYPE_ASN1);
219
220         if (ret < 1)
221                 return -1;
222
223         return 0;
224 }
225
226 static int _ustream_ssl_context_set_key_file(void *ctx, const char *file)
227 {
228         int ret;
229
230         ret = SSL_CTX_use_PrivateKey_file(ctx, file, SSL_FILETYPE_PEM);
231         if (ret < 1)
232                 ret = SSL_CTX_use_PrivateKey_file(ctx, file, SSL_FILETYPE_ASN1);
233
234         if (ret < 1)
235                 return -1;
236
237         return 0;
238 }
239
240 static void _ustream_ssl_context_free(void *ctx)
241 {
242         SSL_CTX_free(ctx);
243 }
244
245 static int _ustream_ssl_init(struct ustream_ssl *us, struct ustream *conn, void *ctx, bool server)
246 {
247         us->error_timer.cb = ustream_ssl_error_cb;
248         us->server = server;
249         us->conn = conn;
250         us->ctx = ctx;
251
252         us->ssl = SSL_new(us->ctx);
253         if (!us->ssl)
254                 return -ENOMEM;
255
256         conn->next = &us->stream;
257         ustream_set_io(ctx, us->ssl, conn);
258         ustream_ssl_stream_init(us);
259
260         return 0;
261 }
262
263 const struct ustream_ssl_ops ustream_ssl_ops = {
264         .context_new = _ustream_ssl_context_new,
265         .context_set_crt_file = _ustream_ssl_context_set_crt_file,
266         .context_set_key_file = _ustream_ssl_context_set_key_file,
267         .context_free = _ustream_ssl_context_free,
268         .init = _ustream_ssl_init,
269 };