blob: 1ba73d611d0da53d5ec677105a6231e087498b7c [file] [log] [blame]
Jens Wiklander02389a92016-12-16 11:13:38 +01001/*
2 * Copyright (c) 2016, Linaro Limited
3 *
4 * This program is free software; you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License Version 2 as
6 * published by the Free Software Foundation.
7 *
8 * This program is distributed in the hope that it will be useful,
9 * but WITHOUT ANY WARRANTY; without even the implied warranty of
10 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 * GNU General Public License for more details.
12 */
13
14#include <sys/types.h>
15#include <stdbool.h>
16#include <arpa/inet.h>
17#include <err.h>
18#include <string.h>
19#include <stdlib.h>
20#include <errno.h>
21#include <netdb.h>
22#include <netinet/in.h>
23#include <poll.h>
Jens Wiklander02389a92016-12-16 11:13:38 +010024#include <sys/socket.h>
25#include <unistd.h>
26
27#include "sock_server.h"
28
29struct server_state {
30 struct sock_state *socks;
31 struct pollfd *fds;
32 nfds_t nfds;
33 bool got_quit;
34 struct sock_io_cb *cb;
35};
36
37#define SOCK_BUF_SIZE 512
38
39struct sock_state {
40 bool (*cb)(struct server_state *srvst, size_t idx);
41 struct sock_server_bind *serv;
42};
43
44static bool server_io_cb(struct server_state *srvst, size_t idx)
45{
46 short revents = srvst->fds[idx].revents;
47 short *events = &srvst->fds[idx].events;
48 struct sock_io_cb *cb = srvst->cb;
49 int fd;
50
51 fd = srvst->fds[idx].fd;
52 if (revents & POLLIN) {
53 if (!cb->read)
54 *events &= ~POLLIN;
55 else if (!cb->read(cb->ptr, fd, events))
56 goto close;
57 }
58
59 if (revents & POLLOUT) {
60 if (!cb->write)
61 *events &= ~POLLOUT;
62 else if (!cb->write(cb->ptr, fd, events))
63 goto close;
64 }
65
66 if (!(revents & ~(POLLIN | POLLOUT)))
67 return true;
68close:
69 if (close(fd)) {
70 warn("server_io_cb: close(%d)", fd);
71 return false;
72 }
73 srvst->fds[idx].fd = -1;
74 return true;
75}
76
77static bool server_add_state(struct server_state *srvst,
78 bool (*cb)(struct server_state *srvst, size_t idx),
79 struct sock_server_bind *serv, int fd,
80 short poll_events)
81{
82 void *p;
83 size_t n;
84
85 for (n = 0; n < srvst->nfds; n++) {
86 if (srvst->fds[n].fd == -1) {
87 srvst->socks[n].cb = cb;
88 srvst->socks[n].serv = serv;
89 srvst->fds[n].fd = fd;
90 srvst->fds[n].events = poll_events;
91 srvst->fds[n].revents = 0;
92 return true;
93 }
94 }
95
96 p = realloc(srvst->socks, sizeof(*srvst->socks) * (srvst->nfds + 1));
97 if (!p)
98 return false;
99 srvst->socks = p;
100 srvst->socks[srvst->nfds].cb = cb;
101 srvst->socks[srvst->nfds].serv = serv;
102
103 p = realloc(srvst->fds, sizeof(*srvst->fds) * (srvst->nfds + 1));
104 if (!p)
105 return false;
106 srvst->fds = p;
107 srvst->fds[srvst->nfds].fd = fd;
108 srvst->fds[srvst->nfds].events = poll_events;
109 srvst->fds[srvst->nfds].revents = 0;
110
111 srvst->nfds++;
112 return true;
113}
114
115static bool tcp_server_accept_cb(struct server_state *srvst, size_t idx)
116{
117 short revents = srvst->fds[idx].revents;
118 struct sockaddr_storage sass;
119 struct sockaddr *sa = (struct sockaddr *)&sass;
120 socklen_t len = sizeof(sass);
121 int fd;
122 short io_events = POLLIN | POLLOUT;
123
124 if (!(revents & POLLIN))
125 return false;
126
127 fd = accept(srvst->fds[idx].fd, sa, &len);
128 if (fd == -1) {
129 if (errno == EAGAIN || errno == EWOULDBLOCK ||
130 errno == ECONNABORTED)
131 return true;
132 return false;
133 }
134
135 if (srvst->cb->accept &&
136 !srvst->cb->accept(srvst->cb->ptr, fd, &io_events)) {
137 if (close(fd))
138 warn("server_accept_cb: close(%d)", fd);
139 return true;
140 }
141
142 return server_add_state(srvst, server_io_cb, srvst->socks[idx].serv,
143 fd, io_events);
144}
145
146static bool udp_server_cb(struct server_state *srvst, size_t idx)
147{
148 short revents = srvst->fds[idx].revents;
149
150 if (!(revents & POLLIN))
151 return false;
152
153 return srvst->cb->accept(srvst->cb->ptr, srvst->fds[idx].fd, NULL);
154}
155
156static bool server_quit_cb(struct server_state *srvst, size_t idx)
157{
158 (void)idx;
159 srvst->got_quit = true;
160 return true;
161}
162
163static void sock_server(struct sock_server *ts,
164 bool (*cb)(struct server_state *srvst, size_t idx))
165{
166 struct server_state srvst = { .cb = ts->cb };
167 int pres;
168 size_t n;
169 char b;
170
171 sock_server_lock(ts);
172
173 for (n = 0; n < ts->num_binds; n++) {
174 if (!server_add_state(&srvst, cb, ts->bind + n,
175 ts->bind[n].fd, POLLIN))
176 goto bad;
177 }
178
179 if (!server_add_state(&srvst, server_quit_cb, NULL,
180 ts->quit_fd, POLLIN))
181 goto bad;
182
183 while (true) {
184 sock_server_unlock(ts);
185 /*
186 * First sleep 5 ms to make it easier to test send timeouts
187 * due to this rate limit.
188 */
189 poll(NULL, 0, 5);
190 pres = poll(srvst.fds, srvst.nfds, -1);
191 sock_server_lock(ts);
192 if (pres < 0)
193 goto bad;
194
195 for (n = 0; pres && n < srvst.nfds; n++) {
196 if (srvst.fds[n].revents) {
197 pres--;
198 if (!srvst.socks[n].cb(&srvst, n))
199 goto bad;
200 }
201 }
202
203 if (srvst.got_quit)
204 goto out;
205 }
206
207bad:
208 ts->error = true;
209out:
210 for (n = 0; n < srvst.nfds; n++) {
211 /* Don't close accept and quit fds */
212 if (srvst.fds[n].fd != -1 && srvst.socks[n].serv &&
213 srvst.fds[n].fd != srvst.socks[n].serv->fd) {
214 if (close(srvst.fds[n].fd))
215 warn("sock_server: close(%d)", srvst.fds[n].fd);
216 }
217 }
218 free(srvst.socks);
219 free(srvst.fds);
220 if (read(ts->quit_fd, &b, 1) != 1)
221 ts->error = true;
222
223 sock_server_unlock(ts);
224}
225
226static void *sock_server_stream(void *arg)
227{
228 sock_server(arg, tcp_server_accept_cb);
229 return NULL;
230}
231
232static void *sock_server_dgram(void *arg)
233{
234 sock_server(arg, udp_server_cb);
235 return NULL;
236}
237
238static void sock_server_add_fd(struct sock_server *ts, struct addrinfo *ai)
239{
240 struct sock_server_bind serv;
241 struct sockaddr_storage sass;
242 struct sockaddr *sa = (struct sockaddr *)&sass;
243 struct sockaddr_in *sain = (struct sockaddr_in *)&sass;
244 struct sockaddr_in6 *sain6 = (struct sockaddr_in6 *)&sass;
245 void *src;
246 socklen_t len = sizeof(sass);
247 struct sock_server_bind *p;
248
249 memset(&serv, 0, sizeof(serv));
250
251 serv.fd = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
252 if (serv.fd < 0)
253 return;
254
255 if (bind(serv.fd, ai->ai_addr, ai->ai_addrlen))
256 goto bad;
257
258 if (ai->ai_socktype == SOCK_STREAM && listen(serv.fd, 5))
259 goto bad;
260
261 if (getsockname(serv.fd, sa, &len))
262 goto bad;
263
264 switch (sa->sa_family) {
265 case AF_INET:
266 src = &sain->sin_addr;
267 serv.port = ntohs(sain->sin_port);
268 break;
269 case AF_INET6:
270 src = &sain6->sin6_addr;
271 serv.port = ntohs(sain6->sin6_port);
272 default:
273 goto bad;
274 }
275
276 if (!inet_ntop(sa->sa_family, src, serv.host, sizeof(serv.host)))
277 goto bad;
278
279 p = realloc(ts->bind, sizeof(*p) * (ts->num_binds + 1));
280 if (!p)
281 goto bad;
282
283 ts->bind = p;
284 p[ts->num_binds] = serv;
285 ts->num_binds++;
286 return;
287bad:
288 if (close(serv.fd))
289 warn("sock_server_add_fd: close(%d)", serv.fd);
290}
291
292void sock_server_uninit(struct sock_server *ts)
293{
294 size_t n;
295 int e;
296
297 if (ts->stop_fd != -1) {
298 if (close(ts->stop_fd))
299 warn("sock_server_uninit: close(%d)", ts->stop_fd);
300 ts->stop_fd = -1;
301 e = pthread_join(ts->thr, NULL);
302 if (e)
303 warnx("sock_server_uninit: pthread_join: %s",
304 strerror(e));
305 }
306
307 e = pthread_mutex_destroy(&ts->mu);
308 if (e)
309 warnx("sock_server_uninit: pthread_mutex_destroy: %s",
310 strerror(e));
311
312 for (n = 0; n < ts->num_binds; n++)
313 if (close(ts->bind[n].fd))
314 warn("sock_server_uninit: close(%d)", ts->bind[n].fd);
315 free(ts->bind);
316 if (ts->quit_fd != -1 && close(ts->quit_fd))
317 warn("sock_server_uninit: close(%d)", ts->quit_fd);
318 memset(ts, 0, sizeof(*ts));
319 ts->quit_fd = -1;
320 ts->stop_fd = -1;
321}
322
323static bool sock_server_init(struct sock_server *ts, struct sock_io_cb *cb,
324 int socktype)
325{
326 struct addrinfo hints;
327 struct addrinfo *ai;
328 struct addrinfo *ai0;
329 int fd_pair[2];
330 int e;
331
332 memset(ts, 0, sizeof(*ts));
333 ts->quit_fd = -1;
334 ts->stop_fd = -1;
335 ts->cb = cb;
336
337 e = pthread_mutex_init(&ts->mu, NULL);
338 if (e) {
339 warnx("sock_server_init: pthread_mutex_init: %s", strerror(e));
340 return false;
341 }
342
343 memset(&hints, 0, sizeof(hints));
344
345 hints.ai_flags = AI_PASSIVE;
346 hints.ai_family = AF_UNSPEC;
347 hints.ai_socktype = socktype;
348
349 if (getaddrinfo(NULL, "0", &hints, &ai0))
350 return false;
351
352 for (ai = ai0; ai; ai = ai->ai_next)
353 sock_server_add_fd(ts, ai);
354
355 if (!ts->num_binds)
356 return false;
357
358 if (pipe(fd_pair)) {
359 sock_server_uninit(ts);
360 return false;
361 }
362
363 ts->quit_fd = fd_pair[0];
364
365 if (socktype == SOCK_STREAM)
366 e = pthread_create(&ts->thr, NULL, sock_server_stream, ts);
367 else
368 e = pthread_create(&ts->thr, NULL, sock_server_dgram, ts);
369 if (e) {
370 warnx("sock_server_init: pthread_create: %s", strerror(e));
371 if (close(fd_pair[1]))
372 warn("sock_server_init: close(%d)", fd_pair[1]);
373 sock_server_uninit(ts);
374 return false;
375 }
376
377 ts->stop_fd = fd_pair[1];
378 return true;
379}
380
381bool sock_server_init_tcp(struct sock_server *sock_serv, struct sock_io_cb *cb)
382{
383 return sock_server_init(sock_serv, cb, SOCK_STREAM);
384}
385
386bool sock_server_init_udp(struct sock_server *sock_serv, struct sock_io_cb *cb)
387{
388 return sock_server_init(sock_serv, cb, SOCK_DGRAM);
389}
390
391void sock_server_lock(struct sock_server *ts)
392{
393 int e = pthread_mutex_lock(&ts->mu);
394
395 if (e)
396 errx(1, "sock_server_lock: pthread_mutex_lock: %s", strerror(e));
397}
398
399void sock_server_unlock(struct sock_server *ts)
400{
401 int e = pthread_mutex_unlock(&ts->mu);
402
403 if (e)
404 errx(1, "sock_server_unlock: pthread_mutex_unlock: %s",
405 strerror(e));
406}