blob: b2b6a08f54324f4590158c1075ebf23e8a3f3312 [file] [log] [blame]
Valerio Setti4f4ade92024-05-03 17:28:04 +02001/* PSA Firmware Framework service API */
2
3/*
4 * Copyright The Mbed TLS Contributors
5 * SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
6 */
7
8#include <sys/types.h>
9#include <sys/ipc.h>
10#include <sys/msg.h>
11#include <stdlib.h>
12#include <stdio.h>
13#include <string.h>
14#include <strings.h>
15#include <unistd.h>
16#include <time.h>
17#include <assert.h>
18
19#include "psa/service.h"
20#include "psasim/init.h"
21#include "psa/error.h"
22#include "psa/common.h"
23#include "psa/util.h"
24
25#define MAX_CLIENTS 128
26#define MAX_MESSAGES 32
27
28#define SLEEP_MS 50
29
30struct connection {
31 uint32_t client;
32 void *rhandle;
33 int client_to_server_q;
34};
35
36/* Note that this implementation is functional and not secure. */
37extern int __psa_ff_client_security_state;
38
39static psa_msg_t messages[MAX_MESSAGES]; /* Message slots */
40static uint8_t pending_message[MAX_MESSAGES] = { 0 }; /* Booleans indicating active message slots */
41static uint32_t message_client[MAX_MESSAGES] = { 0 }; /* Each client's response queue */
42static int nsacl[32];
43static int strict_policy[32] = { 0 };
44static uint32_t rot_svc_versions[32];
45static int rot_svc_incoming_queue[32] = { -1 };
46static struct connection connections[MAX_CLIENTS] = { { 0 } };
47
48static uint32_t exposed_signals = 0;
49
50void print_vectors(vector_sizes_t *sizes)
51{
52 INFO("Printing iovec sizes");
53 for (int j = 0; j < PSA_MAX_IOVEC; j++) {
54 INFO("Invec %d: %lu", j, sizes->invec_sizes[j]);
55 }
56
57 for (int j = 0; j < PSA_MAX_IOVEC; j++) {
58 INFO("Outvec %d: %lu", j, sizes->outvec_sizes[j]);
59 }
60}
61
62int find_connection(uint32_t client)
63{
64 for (int i = 1; i < MAX_CLIENTS; i++) {
65 if (client == connections[i].client) {
66 return i;
67 }
68 }
69 return -1;
70}
71
72void destroy_connection(uint32_t client)
73{
74 int idx = find_connection(client);
75 if (idx >= 0) {
76 connections[idx].client = 0;
77 connections[idx].rhandle = 0;
78 INFO("Destroying connection");
79 } else {
80 ERROR("Couldn't destroy connection for %u", client);
81 }
82}
83
84int find_free_connection()
85{
86 INFO("Allocating connection");
87 return find_connection(0);
88}
89
90static void reply(psa_handle_t msg_handle, psa_status_t status)
91{
92 pending_message[msg_handle] = 1;
93 psa_reply(msg_handle, status);
94 pending_message[msg_handle] = 0;
95}
96
97psa_signal_t psa_wait(psa_signal_t signal_mask, uint32_t timeout)
98{
99 psa_signal_t mask;
100 struct message msg;
101 vector_sizes_t sizes;
102 struct msqid_ds qinfo;
103 uint32_t requested_version;
104 ssize_t len;
105 int idx;
106#if !defined(PSASIM_USE_USLEEP)
107 const struct timespec ts_delay = { .tv_sec = 0, .tv_nsec = SLEEP_MS * 1000000 };
108#endif
109
110 if (timeout == PSA_POLL) {
111 INFO("psa_wait: Called in polling mode");
112 }
113
114 do {
115 mask = signal_mask;
116
117 /* Check the status of each queue */
118 for (int i = 0; i < 32; i++) {
119 if (mask & 0x1) {
120 if (i < 3) {
121 // do nothing (reserved)
122 } else if (i == 3) {
123 // this must be psa doorbell
124 } else {
125 /* Check if this signal corresponds to a queue */
126 if (rot_svc_incoming_queue[i] >= 0 && (pending_message[i] == 0)) {
127
128 /* AFAIK there is no "peek" method in SysV, so try to get a message */
129 len = msgrcv(rot_svc_incoming_queue[i],
130 &msg,
131 sizeof(struct message_text),
132 0,
133 IPC_NOWAIT);
134 if (len > 0) {
135
136 INFO("Storing that QID in message_client[%d]", i);
137 INFO("The message handle will be %d", i);
138
139 msgctl(rot_svc_incoming_queue[i], IPC_STAT, &qinfo);
140 messages[i].client_id = qinfo.msg_lspid; /* PID of last msgsnd(2) call */
141 message_client[i] = msg.message_text.qid;
142 idx = find_connection(msg.message_text.qid);
143
144 if (msg.message_type & NON_SECURE) {
145 /* This is a non-secure message */
146
147 /* Check if NS client is allowed for this RoT service */
148 if (nsacl[i] <= 0) {
149#if 0
150 INFO(
151 "Rejecting non-secure client due to manifest security policy");
152 reply(i, PSA_ERROR_CONNECTION_REFUSED);
153 continue; /* Skip to next signal */
154#endif
155 }
156
157 msg.message_type &= ~(NON_SECURE); /* clear */
158 messages[i].client_id = messages[i].client_id * -1;
159 }
160
161 INFO("Got a message from client ID %d", messages[i].client_id);
162 INFO("Message type is %lu", msg.message_type);
163 INFO("PSA message type is %d", msg.message_text.psa_type);
164
165 messages[i].handle = i;
166
167 switch (msg.message_text.psa_type) {
168 case PSA_IPC_CONNECT:
169
170 if (len >= 16) {
171 memcpy(&requested_version, msg.message_text.buf,
172 sizeof(requested_version));
173 INFO("Requesting version %u", requested_version);
174 INFO("Implemented version %u", rot_svc_versions[i]);
175 /* TODO: need to check whether the policy is strict,
176 * and if so, then reject the client if the number doesn't match */
177
178 if (requested_version > rot_svc_versions[i]) {
179 INFO(
180 "Rejecting client because requested version that was too high");
181 reply(i, PSA_ERROR_CONNECTION_REFUSED);
182 continue; /* Skip to next signal */
183 }
184
185 if (strict_policy[i] == 1 &&
186 (requested_version != rot_svc_versions[i])) {
187 INFO(
188 "Rejecting client because enforcing a STRICT version policy");
189 reply(i, PSA_ERROR_CONNECTION_REFUSED);
190 continue; /* Skip to next signal */
191 } else {
192 INFO("Not rejecting client");
193 }
194 }
195
196 messages[i].type = PSA_IPC_CONNECT;
197
198 if (idx < 0) {
199 idx = find_free_connection();
200 }
201
202 if (idx >= 0) {
203 connections[idx].client = msg.message_text.qid;
204 } else {
205 /* We've run out of system wide connections */
206 reply(i, PSA_ERROR_CONNECTION_BUSY);
207 ERROR("Ran out of free connections");
208 continue;
209 }
210
211 break;
212 case PSA_IPC_DISCONNECT:
213 messages[i].type = PSA_IPC_DISCONNECT;
214 break;
215 case VERSION_REQUEST:
216 INFO("Got a version request");
217 reply(i, rot_svc_versions[i]);
218 continue; /* Skip to next signal */
219 break;
220
221 default:
222
223 /* PSA CALL */
224 if (msg.message_text.psa_type >= 0) {
225 messages[i].type = msg.message_text.psa_type;
226 memcpy(&sizes, msg.message_text.buf, sizeof(sizes));
227 print_vectors(&sizes);
228 memcpy(&messages[i].in_size, &sizes.invec_sizes,
229 (sizeof(size_t) * PSA_MAX_IOVEC));
230 memcpy(&messages[i].out_size, &sizes.outvec_sizes,
231 (sizeof(size_t) * PSA_MAX_IOVEC));
232 } else {
233 FATAL("UNKNOWN MESSAGE TYPE RECEIVED %li",
234 msg.message_type);
235 }
236 break;
237 }
238 messages[i].handle = i;
239
240 /* Check if the client has a connection */
241 if (idx >= 0) {
242 messages[i].rhandle = connections[idx].rhandle;
243 } else {
244 /* Client is begging for a programmer error */
245 reply(i, PSA_ERROR_PROGRAMMER_ERROR);
246 continue;
247 }
248
249 /* House keeping */
250 pending_message[i] = 1; /* set message as pending */
251 exposed_signals |= (0x1 << i); /* assert the signal */
252 }
253 }
254 }
255 mask = mask >> 1;
256 }
257 }
258
259 if ((timeout == PSA_BLOCK) && (exposed_signals > 0)) {
260 break;
261 } else {
262 /* There is no 'select' function in SysV to block on multiple queues, so busy-wait :( */
263#if defined(PSASIM_USE_USLEEP)
264 usleep(SLEEP_MS * 1000);
265#else /* PSASIM_USE_USLEEP */
266 nanosleep(&ts_delay, NULL);
267#endif /* PSASIM_USE_USLEEP */
268 }
269 } while (timeout == PSA_BLOCK);
270
271 /* Assert signals */
272 return signal_mask & exposed_signals;
273}
274
275static int signal_to_index(psa_signal_t signal)
276{
277 int i;
278 int count = 0;
279 int ret = -1;
280
281 for (i = 0; i < 32; i++) {
282 if (signal & 0x1) {
283 ret = i;
284 count++;
285 }
286 signal = signal >> 1;
287 }
288
289 if (count > 1) {
290 ERROR("ERROR: Too many signals");
291 return -1; /* Too many signals */
292 }
293 return ret;
294}
295
296static void clear_signal(psa_signal_t signal)
297{
298 exposed_signals = exposed_signals & ~signal;
299}
300
301void raise_signal(psa_signal_t signal)
302{
303 exposed_signals |= signal;
304}
305
306psa_status_t psa_get(psa_signal_t signal, psa_msg_t *msg)
307{
308 int index = signal_to_index(signal);
309 if (index < 0) {
310 ERROR("Bad signal");
311 }
312
313 clear_signal(signal);
314
315 assert(messages[index].handle != 0);
316
317 if (pending_message[index] == 1) {
318 INFO("There is a pending message!");
319 memcpy(msg, &messages[index], sizeof(struct psa_msg_t));
320 assert(msg->handle != 0);
321 return PSA_SUCCESS;
322 } else {
323 INFO("no pending message");
324 }
325
326 return PSA_ERROR_DOES_NOT_EXIST;
327}
328
329static inline int is_valid_msg_handle(psa_handle_t h)
330{
331 if (h > 0 && h < MAX_MESSAGES) {
332 return 1;
333 }
334 ERROR("Not a valid message handle");
335 return 0;
336}
337
338static inline int is_call_msg(psa_handle_t h)
339{
340 assert(messages[h].type >= PSA_IPC_CALL);
341 return 1;
342}
343
344void psa_set_rhandle(psa_handle_t msg_handle, void *rhandle)
345{
346 is_valid_msg_handle(msg_handle);
347 int idx = find_connection(message_client[msg_handle]);
348 INFO("Setting rhandle to %p", rhandle);
349 assert(idx >= 0);
350 connections[idx].rhandle = rhandle;
351}
352
353/* Sends a message from the server to the client. Does not wait for a response */
354static void send_msg(psa_handle_t msg_handle,
355 int ctrl_msg,
356 psa_status_t status,
357 size_t amount,
358 const void *data,
359 size_t data_amount)
360{
361 struct message response;
362 int flags = 0;
363
364 assert(ctrl_msg > 0); /* According to System V, it must be greater than 0 */
365
366 response.message_type = ctrl_msg;
367 if (ctrl_msg == PSA_REPLY) {
368 memcpy(response.message_text.buf, &status, sizeof(psa_status_t));
369 } else if (ctrl_msg == READ_REQUEST || ctrl_msg == WRITE_REQUEST || ctrl_msg == SKIP_REQUEST) {
370 memcpy(response.message_text.buf, &status, sizeof(psa_status_t));
371 memcpy(response.message_text.buf+sizeof(size_t), &amount, sizeof(size_t));
372 if (ctrl_msg == WRITE_REQUEST) {
373 /* TODO: Check if too big */
374 memcpy(response.message_text.buf + (sizeof(size_t) * 2), data, data_amount);
375 }
376 }
377
378 /* TODO: sizeof doesn't need to be so big here for small responses */
379 if (msgsnd(message_client[msg_handle], &response, sizeof(response.message_text), flags) == -1) {
380 ERROR("Failed to reply");
381 }
382}
383
384static size_t skip(psa_handle_t msg_handle, uint32_t invec_idx, size_t num_bytes)
385{
386 if (num_bytes < (messages[msg_handle].in_size[invec_idx] - num_bytes)) {
387 messages[msg_handle].in_size[invec_idx] = messages[msg_handle].in_size[invec_idx] -
388 num_bytes;
389 return num_bytes;
390 } else {
391 if (num_bytes >= messages[msg_handle].in_size[invec_idx]) {
392 size_t ret = messages[msg_handle].in_size[invec_idx];
393 messages[msg_handle].in_size[invec_idx] = 0;
394 return ret;
395 } else {
396 return num_bytes;
397 }
398 }
399}
400
401size_t psa_read(psa_handle_t msg_handle, uint32_t invec_idx,
402 void *buffer, size_t num_bytes)
403{
404 size_t sofar = 0;
405 struct message msg = { 0 };
406 int idx;
407 ssize_t len;
408
409 is_valid_msg_handle(msg_handle);
410 is_call_msg(msg_handle);
411
412 if (invec_idx >= PSA_MAX_IOVEC) {
413 ERROR("Invalid iovec number");
414 }
415
416 /* If user wants more data than what's available, truncate their request */
417 if (num_bytes > messages[msg_handle].in_size[invec_idx]) {
418 num_bytes = messages[msg_handle].in_size[invec_idx];
419 }
420
421 while (sofar < num_bytes) {
422 INFO("Server: requesting %lu bytes from client", (num_bytes - sofar));
423 send_msg(msg_handle, READ_REQUEST, invec_idx, (num_bytes - sofar), NULL, 0);
424
425 idx = find_connection(message_client[msg_handle]);
426 assert(idx >= 0);
427
428 len = msgrcv(connections[idx].client_to_server_q, &msg, sizeof(struct message_text), 0, 0);
429 len = (len - sizeof(msg.message_text.qid));
430
431 if (len < 0) {
432 FATAL("Internal error: failed to dispatch read request to the client");
433 }
434
435 if (len > (num_bytes - sofar)) {
436 if ((num_bytes - sofar) > 0) {
437 memcpy(buffer+sofar, msg.message_text.buf, (num_bytes - sofar));
438 }
439 } else {
440 memcpy(buffer + sofar, msg.message_text.buf, len);
441 }
442
443 INFO("Printing what i got so far: %s", msg.message_text.buf);
444
445 sofar = sofar + len;
446 }
447
448 /* Update the seek count */
449 skip(msg_handle, invec_idx, num_bytes);
450 INFO("Finished psa_read");
451 return sofar;
452}
453
454void psa_write(psa_handle_t msg_handle, uint32_t outvec_idx,
455 const void *buffer, size_t num_bytes)
456{
457
458 size_t sofar = 0;
459 struct message msg = { 0 };
460 int idx;
461 ssize_t len;
462
463 is_valid_msg_handle(msg_handle);
464 is_call_msg(msg_handle);
465
466 if (outvec_idx >= PSA_MAX_IOVEC) {
467 ERROR("Invalid iovec number");
468 }
469
470 if (num_bytes > messages[msg_handle].out_size[outvec_idx]) {
471 ERROR("Program tried to write too much data %lu/%lu", num_bytes,
472 messages[msg_handle].out_size[outvec_idx]);
473 }
474
475 while (sofar < num_bytes) {
476 size_t sending = (num_bytes - sofar);
477 if (sending >= MAX_FRAGMENT_SIZE) {
478 sending = MAX_FRAGMENT_SIZE - (sizeof(size_t) * 2);
479 }
480
481 INFO("Server: sending %lu bytes to client", sending);
482
483 send_msg(msg_handle, WRITE_REQUEST, outvec_idx, sending, buffer, sending);
484
485 idx = find_connection(message_client[msg_handle]);
486 assert(idx >= 0);
487
488 len = msgrcv(connections[idx].client_to_server_q, &msg, sizeof(struct message_text), 0, 0);
489 if (len < 1) {
490 FATAL("Client didn't give me a full response");
491 }
492 sofar = sofar + len;
493 }
494
495 /* Update the seek count */
496 messages[msg_handle].out_size[outvec_idx] -= num_bytes;
497}
498
499size_t psa_skip(psa_handle_t msg_handle, uint32_t invec_idx, size_t num_bytes)
500{
501
502 is_valid_msg_handle(msg_handle);
503 is_call_msg(msg_handle);
504
505 size_t ret = skip(msg_handle, invec_idx, num_bytes);
506
507 /* notify client to skip */
508 send_msg(msg_handle, SKIP_REQUEST, invec_idx, num_bytes, NULL, 0);
509 return ret;
510}
511
512static void destroy_temporary_queue(int myqid)
513{
514
515 if (msgctl(myqid, IPC_RMID, NULL) != 0) {
516 INFO("ERROR: Failed to delete msg queue %d", myqid);
517 }
518}
519
520static int make_temporary_queue()
521{
522 int myqid;
523 if ((myqid = msgget(IPC_PRIVATE, 0660)) == -1) {
524 INFO("msgget: myqid");
525 return -1;
526 }
527 return myqid;
528}
529
530/**
531 * Assumes msg_handle is the index into the message array
532 */
533void psa_reply(psa_handle_t msg_handle, psa_status_t status)
534{
535 int idx, q;
536 is_valid_msg_handle(msg_handle);
537
538 if (pending_message[msg_handle] != 1) {
539 ERROR("Not a valid message handle");
540 }
541
542 if (messages[msg_handle].type == PSA_IPC_CONNECT) {
543 switch (status) {
544 case PSA_SUCCESS:
545 idx = find_connection(message_client[msg_handle]);
546 q = make_temporary_queue();
547 if (q > 0 && idx >= 0) {
548 connections[idx].client_to_server_q = q;
549 status = q;
550 } else {
551 FATAL("What happened?");
552 }
553 break;
554 case PSA_ERROR_CONNECTION_REFUSED:
555 destroy_connection(message_client[msg_handle]);
556 break;
557 case PSA_ERROR_CONNECTION_BUSY:
558 destroy_connection(message_client[msg_handle]);
559 break;
560 case PSA_ERROR_PROGRAMMER_ERROR:
561 destroy_connection(message_client[msg_handle]);
562 break;
563 default:
564 ERROR("Not a valid reply %d", status);
565 }
566 } else if (messages[msg_handle].type == PSA_IPC_DISCONNECT) {
567 idx = find_connection(message_client[msg_handle]);
568 if (idx >= 0) {
569 destroy_temporary_queue(connections[idx].client_to_server_q);
570 }
571 destroy_connection(message_client[msg_handle]);
572 }
573
574 send_msg(msg_handle, PSA_REPLY, status, 0, NULL, 0);
575
576 pending_message[msg_handle] = 0;
577 message_client[msg_handle] = 0;
578}
579
580/* TODO: make sure you only clear interrupt signals, and not others */
581void psa_eoi(psa_signal_t signal)
582{
583 int index = signal_to_index(signal);
584 if (index >= 0 && (rot_svc_incoming_queue[index] >= 0)) {
585 clear_signal(signal);
586 } else {
587 ERROR("Tried to EOI a signal that isn't an interrupt");
588 }
589}
590
591void psa_notify(int32_t partition_id)
592{
593 char pathname[PATHNAMESIZE] = { 0 };
594
595 if (partition_id < 0) {
596 ERROR("Not a valid secure partition");
597 }
598
599 snprintf(pathname, PATHNAMESIZE, "/tmp/psa_notify_%u", partition_id);
600 INFO("psa_notify: notifying partition %u using %s",
601 partition_id, pathname);
602 INFO("psa_notify is unimplemented");
603}
604
605void psa_clear(void)
606{
607 clear_signal(PSA_DOORBELL);
608}
609
610void __init_psasim(const char **array,
611 int size,
612 const int allow_ns_clients_array[32],
613 const uint32_t versions[32],
614 const int strict_policy_array[32])
615{
616
617 static uint8_t library_initialised = 0;
618 key_t key;
619 int qid;
620 FILE *fp;
621 char doorbell_path[PATHNAMESIZE] = { 0 };
622 char queue_path[PATHNAMESIZE];
623 snprintf(doorbell_path, PATHNAMESIZE, TMP_FILE_BASE_PATH "psa_notify_%u", getpid());
624
625 if (library_initialised > 0) {
626 return;
627 } else {
628 library_initialised = 1;
629 }
630
631 if (size != 32) {
632 FATAL("Unsupported value. Aborting.");
633 }
634
635 array[3] = doorbell_path;
636
637 for (int i = 0; i < 32; i++) {
638 if (strncmp(array[i], "", 1) != 0) {
639 INFO("Setting up %s", array[i]);
640 memset(queue_path, 0, sizeof(queue_path));
641 sprintf(queue_path, "%s%s", TMP_FILE_BASE_PATH, array[i]);
642
643 /* Create file if doesn't exist */
644 fp = fopen(queue_path, "ab+");
645 if (fp) {
646 fclose(fp);
647 }
648
649 if ((key = ftok(queue_path, PROJECT_ID)) == -1) {
650 FATAL("Error finding message queue during initialisation");
651 }
652
653 /* TODO: Investigate. Permissions are likely to be too relaxed */
654 if ((qid = msgget(key, IPC_CREAT | 0660)) == -1) {
655 FATAL("Error opening message queue during initialisation");
656 } else {
657 rot_svc_incoming_queue[i] = qid;
658 }
659 }
660 }
661
662 memcpy(nsacl, allow_ns_clients_array, sizeof(int) * 32);
663 memcpy(strict_policy, strict_policy_array, sizeof(int) * 32);
664 memcpy(rot_svc_versions, versions, sizeof(uint32_t) * 32);
665 memset(&connections, 0, sizeof(struct connection) * MAX_CLIENTS);
666
667 __psa_ff_client_security_state = 0; /* Set the client status to SECURE */
668}