blob: 0d2ff066d75593ae00eb9ab0f2a54b5b262ef90c [file] [log] [blame]
Jens Wiklander02389a92016-12-16 11:13:38 +01001/*
2 * Copyright (c) 2016, Linaro Limited
3 *
4 * This program is free software; you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License Version 2 as
6 * published by the Free Software Foundation.
7 *
8 * This program is distributed in the hope that it will be useful,
9 * but WITHOUT ANY WARRANTY; without even the implied warranty of
10 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 * GNU General Public License for more details.
12 */
13
14#include <sys/types.h>
15#include <stdbool.h>
16#include <arpa/inet.h>
17#include <err.h>
18#include <string.h>
19#include <stdlib.h>
20#include <errno.h>
21#include <netdb.h>
22#include <netinet/in.h>
23#include <poll.h>
24#include <sys/errno.h>
25#include <sys/socket.h>
26#include <unistd.h>
27
28#include "sock_server.h"
29
30struct server_state {
31 struct sock_state *socks;
32 struct pollfd *fds;
33 nfds_t nfds;
34 bool got_quit;
35 struct sock_io_cb *cb;
36};
37
38#define SOCK_BUF_SIZE 512
39
40struct sock_state {
41 bool (*cb)(struct server_state *srvst, size_t idx);
42 struct sock_server_bind *serv;
43};
44
45static bool server_io_cb(struct server_state *srvst, size_t idx)
46{
47 short revents = srvst->fds[idx].revents;
48 short *events = &srvst->fds[idx].events;
49 struct sock_io_cb *cb = srvst->cb;
50 int fd;
51
52 fd = srvst->fds[idx].fd;
53 if (revents & POLLIN) {
54 if (!cb->read)
55 *events &= ~POLLIN;
56 else if (!cb->read(cb->ptr, fd, events))
57 goto close;
58 }
59
60 if (revents & POLLOUT) {
61 if (!cb->write)
62 *events &= ~POLLOUT;
63 else if (!cb->write(cb->ptr, fd, events))
64 goto close;
65 }
66
67 if (!(revents & ~(POLLIN | POLLOUT)))
68 return true;
69close:
70 if (close(fd)) {
71 warn("server_io_cb: close(%d)", fd);
72 return false;
73 }
74 srvst->fds[idx].fd = -1;
75 return true;
76}
77
78static bool server_add_state(struct server_state *srvst,
79 bool (*cb)(struct server_state *srvst, size_t idx),
80 struct sock_server_bind *serv, int fd,
81 short poll_events)
82{
83 void *p;
84 size_t n;
85
86 for (n = 0; n < srvst->nfds; n++) {
87 if (srvst->fds[n].fd == -1) {
88 srvst->socks[n].cb = cb;
89 srvst->socks[n].serv = serv;
90 srvst->fds[n].fd = fd;
91 srvst->fds[n].events = poll_events;
92 srvst->fds[n].revents = 0;
93 return true;
94 }
95 }
96
97 p = realloc(srvst->socks, sizeof(*srvst->socks) * (srvst->nfds + 1));
98 if (!p)
99 return false;
100 srvst->socks = p;
101 srvst->socks[srvst->nfds].cb = cb;
102 srvst->socks[srvst->nfds].serv = serv;
103
104 p = realloc(srvst->fds, sizeof(*srvst->fds) * (srvst->nfds + 1));
105 if (!p)
106 return false;
107 srvst->fds = p;
108 srvst->fds[srvst->nfds].fd = fd;
109 srvst->fds[srvst->nfds].events = poll_events;
110 srvst->fds[srvst->nfds].revents = 0;
111
112 srvst->nfds++;
113 return true;
114}
115
116static bool tcp_server_accept_cb(struct server_state *srvst, size_t idx)
117{
118 short revents = srvst->fds[idx].revents;
119 struct sockaddr_storage sass;
120 struct sockaddr *sa = (struct sockaddr *)&sass;
121 socklen_t len = sizeof(sass);
122 int fd;
123 short io_events = POLLIN | POLLOUT;
124
125 if (!(revents & POLLIN))
126 return false;
127
128 fd = accept(srvst->fds[idx].fd, sa, &len);
129 if (fd == -1) {
130 if (errno == EAGAIN || errno == EWOULDBLOCK ||
131 errno == ECONNABORTED)
132 return true;
133 return false;
134 }
135
136 if (srvst->cb->accept &&
137 !srvst->cb->accept(srvst->cb->ptr, fd, &io_events)) {
138 if (close(fd))
139 warn("server_accept_cb: close(%d)", fd);
140 return true;
141 }
142
143 return server_add_state(srvst, server_io_cb, srvst->socks[idx].serv,
144 fd, io_events);
145}
146
147static bool udp_server_cb(struct server_state *srvst, size_t idx)
148{
149 short revents = srvst->fds[idx].revents;
150
151 if (!(revents & POLLIN))
152 return false;
153
154 return srvst->cb->accept(srvst->cb->ptr, srvst->fds[idx].fd, NULL);
155}
156
157static bool server_quit_cb(struct server_state *srvst, size_t idx)
158{
159 (void)idx;
160 srvst->got_quit = true;
161 return true;
162}
163
164static void sock_server(struct sock_server *ts,
165 bool (*cb)(struct server_state *srvst, size_t idx))
166{
167 struct server_state srvst = { .cb = ts->cb };
168 int pres;
169 size_t n;
170 char b;
171
172 sock_server_lock(ts);
173
174 for (n = 0; n < ts->num_binds; n++) {
175 if (!server_add_state(&srvst, cb, ts->bind + n,
176 ts->bind[n].fd, POLLIN))
177 goto bad;
178 }
179
180 if (!server_add_state(&srvst, server_quit_cb, NULL,
181 ts->quit_fd, POLLIN))
182 goto bad;
183
184 while (true) {
185 sock_server_unlock(ts);
186 /*
187 * First sleep 5 ms to make it easier to test send timeouts
188 * due to this rate limit.
189 */
190 poll(NULL, 0, 5);
191 pres = poll(srvst.fds, srvst.nfds, -1);
192 sock_server_lock(ts);
193 if (pres < 0)
194 goto bad;
195
196 for (n = 0; pres && n < srvst.nfds; n++) {
197 if (srvst.fds[n].revents) {
198 pres--;
199 if (!srvst.socks[n].cb(&srvst, n))
200 goto bad;
201 }
202 }
203
204 if (srvst.got_quit)
205 goto out;
206 }
207
208bad:
209 ts->error = true;
210out:
211 for (n = 0; n < srvst.nfds; n++) {
212 /* Don't close accept and quit fds */
213 if (srvst.fds[n].fd != -1 && srvst.socks[n].serv &&
214 srvst.fds[n].fd != srvst.socks[n].serv->fd) {
215 if (close(srvst.fds[n].fd))
216 warn("sock_server: close(%d)", srvst.fds[n].fd);
217 }
218 }
219 free(srvst.socks);
220 free(srvst.fds);
221 if (read(ts->quit_fd, &b, 1) != 1)
222 ts->error = true;
223
224 sock_server_unlock(ts);
225}
226
227static void *sock_server_stream(void *arg)
228{
229 sock_server(arg, tcp_server_accept_cb);
230 return NULL;
231}
232
233static void *sock_server_dgram(void *arg)
234{
235 sock_server(arg, udp_server_cb);
236 return NULL;
237}
238
239static void sock_server_add_fd(struct sock_server *ts, struct addrinfo *ai)
240{
241 struct sock_server_bind serv;
242 struct sockaddr_storage sass;
243 struct sockaddr *sa = (struct sockaddr *)&sass;
244 struct sockaddr_in *sain = (struct sockaddr_in *)&sass;
245 struct sockaddr_in6 *sain6 = (struct sockaddr_in6 *)&sass;
246 void *src;
247 socklen_t len = sizeof(sass);
248 struct sock_server_bind *p;
249
250 memset(&serv, 0, sizeof(serv));
251
252 serv.fd = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
253 if (serv.fd < 0)
254 return;
255
256 if (bind(serv.fd, ai->ai_addr, ai->ai_addrlen))
257 goto bad;
258
259 if (ai->ai_socktype == SOCK_STREAM && listen(serv.fd, 5))
260 goto bad;
261
262 if (getsockname(serv.fd, sa, &len))
263 goto bad;
264
265 switch (sa->sa_family) {
266 case AF_INET:
267 src = &sain->sin_addr;
268 serv.port = ntohs(sain->sin_port);
269 break;
270 case AF_INET6:
271 src = &sain6->sin6_addr;
272 serv.port = ntohs(sain6->sin6_port);
273 default:
274 goto bad;
275 }
276
277 if (!inet_ntop(sa->sa_family, src, serv.host, sizeof(serv.host)))
278 goto bad;
279
280 p = realloc(ts->bind, sizeof(*p) * (ts->num_binds + 1));
281 if (!p)
282 goto bad;
283
284 ts->bind = p;
285 p[ts->num_binds] = serv;
286 ts->num_binds++;
287 return;
288bad:
289 if (close(serv.fd))
290 warn("sock_server_add_fd: close(%d)", serv.fd);
291}
292
293void sock_server_uninit(struct sock_server *ts)
294{
295 size_t n;
296 int e;
297
298 if (ts->stop_fd != -1) {
299 if (close(ts->stop_fd))
300 warn("sock_server_uninit: close(%d)", ts->stop_fd);
301 ts->stop_fd = -1;
302 e = pthread_join(ts->thr, NULL);
303 if (e)
304 warnx("sock_server_uninit: pthread_join: %s",
305 strerror(e));
306 }
307
308 e = pthread_mutex_destroy(&ts->mu);
309 if (e)
310 warnx("sock_server_uninit: pthread_mutex_destroy: %s",
311 strerror(e));
312
313 for (n = 0; n < ts->num_binds; n++)
314 if (close(ts->bind[n].fd))
315 warn("sock_server_uninit: close(%d)", ts->bind[n].fd);
316 free(ts->bind);
317 if (ts->quit_fd != -1 && close(ts->quit_fd))
318 warn("sock_server_uninit: close(%d)", ts->quit_fd);
319 memset(ts, 0, sizeof(*ts));
320 ts->quit_fd = -1;
321 ts->stop_fd = -1;
322}
323
324static bool sock_server_init(struct sock_server *ts, struct sock_io_cb *cb,
325 int socktype)
326{
327 struct addrinfo hints;
328 struct addrinfo *ai;
329 struct addrinfo *ai0;
330 int fd_pair[2];
331 int e;
332
333 memset(ts, 0, sizeof(*ts));
334 ts->quit_fd = -1;
335 ts->stop_fd = -1;
336 ts->cb = cb;
337
338 e = pthread_mutex_init(&ts->mu, NULL);
339 if (e) {
340 warnx("sock_server_init: pthread_mutex_init: %s", strerror(e));
341 return false;
342 }
343
344 memset(&hints, 0, sizeof(hints));
345
346 hints.ai_flags = AI_PASSIVE;
347 hints.ai_family = AF_UNSPEC;
348 hints.ai_socktype = socktype;
349
350 if (getaddrinfo(NULL, "0", &hints, &ai0))
351 return false;
352
353 for (ai = ai0; ai; ai = ai->ai_next)
354 sock_server_add_fd(ts, ai);
355
356 if (!ts->num_binds)
357 return false;
358
359 if (pipe(fd_pair)) {
360 sock_server_uninit(ts);
361 return false;
362 }
363
364 ts->quit_fd = fd_pair[0];
365
366 if (socktype == SOCK_STREAM)
367 e = pthread_create(&ts->thr, NULL, sock_server_stream, ts);
368 else
369 e = pthread_create(&ts->thr, NULL, sock_server_dgram, ts);
370 if (e) {
371 warnx("sock_server_init: pthread_create: %s", strerror(e));
372 if (close(fd_pair[1]))
373 warn("sock_server_init: close(%d)", fd_pair[1]);
374 sock_server_uninit(ts);
375 return false;
376 }
377
378 ts->stop_fd = fd_pair[1];
379 return true;
380}
381
382bool sock_server_init_tcp(struct sock_server *sock_serv, struct sock_io_cb *cb)
383{
384 return sock_server_init(sock_serv, cb, SOCK_STREAM);
385}
386
387bool sock_server_init_udp(struct sock_server *sock_serv, struct sock_io_cb *cb)
388{
389 return sock_server_init(sock_serv, cb, SOCK_DGRAM);
390}
391
392void sock_server_lock(struct sock_server *ts)
393{
394 int e = pthread_mutex_lock(&ts->mu);
395
396 if (e)
397 errx(1, "sock_server_lock: pthread_mutex_lock: %s", strerror(e));
398}
399
400void sock_server_unlock(struct sock_server *ts)
401{
402 int e = pthread_mutex_unlock(&ts->mu);
403
404 if (e)
405 errx(1, "sock_server_unlock: pthread_mutex_unlock: %s",
406 strerror(e));
407}