Fix ustream_ssl_poll
[project/ustream-ssl.git] / ustream-ssl.c
1 /*
2  * ustream-ssl - library for SSL over ustream
3  *
4  * Copyright (C) 2012 Felix Fietkau <nbd@openwrt.org>
5  *
6  * Permission to use, copy, modify, and/or distribute this software for any
7  * purpose with or without fee is hereby granted, provided that the above
8  * copyright notice and this permission notice appear in all copies.
9  *
10  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
13  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
15  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
16  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17  */
18
19 #include <errno.h>
20 #include <stdlib.h>
21 #include <string.h>
22 #include <libubox/ustream.h>
23
24 #include "ustream-ssl.h"
25 #include "ustream-internal.h"
26
27 static void ustream_ssl_error_cb(struct uloop_timeout *t)
28 {
29         struct ustream_ssl *us = container_of(t, struct ustream_ssl, error_timer);
30         static char buffer[128];
31         int error = us->error;
32
33         if (us->notify_error)
34                 us->notify_error(us, error, __ustream_ssl_strerror(us->error, buffer, sizeof(buffer)));
35 }
36
37 static void ustream_ssl_check_conn(struct ustream_ssl *us)
38 {
39         if (us->connected || us->error)
40                 return;
41
42         if (__ustream_ssl_connect(us) == U_SSL_OK) {
43                 us->connected = true;
44                 if (us->notify_connected)
45                         us->notify_connected(us);
46                 ustream_write_pending(&us->stream);
47         }
48 }
49
50 static bool __ustream_ssl_poll(struct ustream *s)
51 {
52         struct ustream_ssl *us = container_of(s->next, struct ustream_ssl, stream);
53         char *buf;
54         int len, ret;
55         bool more = false;
56
57         ustream_ssl_check_conn(us);
58         if (!us->connected || us->error)
59                 return false;
60
61         do {
62                 buf = ustream_reserve(&us->stream, 1, &len);
63                 if (!len)
64                         break;
65
66                 ret = __ustream_ssl_read(us, buf, len);
67                 switch (ret) {
68                 case U_SSL_PENDING:
69                         return more;
70                 case U_SSL_ERROR:
71                         return false;
72                 case 0:
73                         us->stream.eof = true;
74                         ustream_state_change(&us->stream);
75                         return false;
76                 default:
77                         ustream_fill_read(&us->stream, ret);
78                         more = true;
79                         continue;
80                 }
81         } while (1);
82
83         return more;
84 }
85
86 static void ustream_ssl_notify_read(struct ustream *s, int bytes)
87 {
88         __ustream_ssl_poll(s);
89 }
90
91 static void ustream_ssl_notify_write(struct ustream *s, int bytes)
92 {
93         struct ustream_ssl *us = container_of(s->next, struct ustream_ssl, stream);
94
95         ustream_ssl_check_conn(us);
96         ustream_write_pending(s->next);
97 }
98
99 static void ustream_ssl_notify_state(struct ustream *s)
100 {
101         s->next->write_error = true;
102         ustream_state_change(s->next);
103 }
104
105 static int ustream_ssl_write(struct ustream *s, const char *buf, int len, bool more)
106 {
107         struct ustream_ssl *us = container_of(s, struct ustream_ssl, stream);
108
109         if (!us->connected || us->error)
110                 return 0;
111
112         if (us->conn->w.data_bytes)
113                 return 0;
114
115         return __ustream_ssl_write(us, buf, len);
116 }
117
118 static void ustream_ssl_set_read_blocked(struct ustream *s)
119 {
120         struct ustream_ssl *us = container_of(s, struct ustream_ssl, stream);
121
122         ustream_set_read_blocked(us->conn, !!s->read_blocked);
123 }
124
125 static void ustream_ssl_free(struct ustream *s)
126 {
127         struct ustream_ssl *us = container_of(s, struct ustream_ssl, stream);
128
129         if (us->conn) {
130                 us->conn->next = NULL;
131                 us->conn->notify_read = NULL;
132                 us->conn->notify_write = NULL;
133                 us->conn->notify_state = NULL;
134         }
135
136         uloop_timeout_cancel(&us->error_timer);
137         __ustream_ssl_session_free(us->ssl);
138         free(us->peer_cn);
139
140         us->ctx = NULL;
141         us->ssl = NULL;
142         us->conn = NULL;
143         us->peer_cn = NULL;
144         us->connected = false;
145         us->error = false;
146         us->valid_cert = false;
147         us->valid_cn = false;
148 }
149
150 static bool ustream_ssl_poll(struct ustream *s)
151 {
152         struct ustream_ssl *us = container_of(s, struct ustream_ssl, stream);
153         bool fd_poll;
154
155         fd_poll = ustream_poll(us->conn);
156         return __ustream_ssl_poll(us->conn) || fd_poll;
157 }
158
159 static void ustream_ssl_stream_init(struct ustream_ssl *us)
160 {
161         struct ustream *conn = us->conn;
162         struct ustream *s = &us->stream;
163
164         conn->notify_read = ustream_ssl_notify_read;
165         conn->notify_write = ustream_ssl_notify_write;
166         conn->notify_state = ustream_ssl_notify_state;
167
168         s->free = ustream_ssl_free;
169         s->write = ustream_ssl_write;
170         s->poll = ustream_ssl_poll;
171         s->set_read_blocked = ustream_ssl_set_read_blocked;
172         ustream_init_defaults(s);
173 }
174
175 static int _ustream_ssl_init(struct ustream_ssl *us, struct ustream *conn, struct ustream_ssl_ctx *ctx, bool server)
176 {
177         us->error_timer.cb = ustream_ssl_error_cb;
178         us->server = server;
179         us->conn = conn;
180         us->ctx = ctx;
181
182         us->ssl = __ustream_ssl_session_new(us->ctx);
183         if (!us->ssl)
184                 return -ENOMEM;
185
186         conn->next = &us->stream;
187         ustream_set_io(ctx, us->ssl, conn);
188         ustream_ssl_stream_init(us);
189         ustream_ssl_check_conn(us);
190
191         return 0;
192 }
193
194 static int _ustream_ssl_set_peer_cn(struct ustream_ssl *us, const char *name)
195 {
196         us->peer_cn = strdup(name);
197         __ustream_ssl_update_peer_cn(us);
198
199         return 0;
200 }
201
202 const struct ustream_ssl_ops ustream_ssl_ops = {
203         .context_new = __ustream_ssl_context_new,
204         .context_set_crt_file = __ustream_ssl_set_crt_file,
205         .context_set_key_file = __ustream_ssl_set_key_file,
206         .context_add_ca_crt_file = __ustream_ssl_add_ca_crt_file,
207         .context_free = __ustream_ssl_context_free,
208         .init = _ustream_ssl_init,
209         .set_peer_cn = _ustream_ssl_set_peer_cn,
210 };