blob: 9bfe20f78a905f2acf1ed37bd3aebf07ce405f5f [file] [log] [blame]
Minos Galanakis2c824b42025-03-20 09:28:45 +00001/* PSA Firmware Framework service API */
2
3/*
4 * Copyright The Mbed TLS Contributors
5 * SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
6 */
7
8#include <psa/service.h>
9#include <sys/types.h>
10#include <sys/ipc.h>
11#include <sys/msg.h>
12#include <stdlib.h>
13#include <stdio.h>
14#include <string.h>
15#include <strings.h>
16#include <unistd.h>
17#include <time.h>
18#include <assert.h>
19#include <psasim/init.h>
20#include "common.h"
21
22#define MAX_CLIENTS 128
23#define MAX_MESSAGES 32
24
25struct connection {
26 uint32_t client;
27 void *rhandle;
28 int client_to_server_q; // this should be called client to server
29};
30
31static psa_msg_t messages[MAX_MESSAGES]; /* Message slots */
32static uint8_t pending_message[MAX_MESSAGES] = { 0 }; /* Booleans indicating active message slots */
33static uint32_t message_client[MAX_MESSAGES] = { 0 }; /* Each client's response queue */
34static int nsacl[32];
35static int strict_policy[32] = { 0 };
36static uint32_t rot_svc_versions[32];
37static int rot_svc_incoming_queue[32] = { -1 };
38static struct connection connections[MAX_CLIENTS] = { { 0 } };
39
40static uint32_t exposed_signals = 0;
41
42void print_vectors(vector_sizes_t *sizes)
43{
44 INFO("Printing iovec sizes");
45 for (int j = 0; j < PSA_MAX_IOVEC; j++) {
46 INFO("Invec %d: %lu", j, sizes->invec_sizes[j]);
47 }
48
49 for (int j = 0; j < PSA_MAX_IOVEC; j++) {
50 INFO("Outvec %d: %lu", j, sizes->outvec_sizes[j]);
51 }
52}
53
54int find_connection(uint32_t client)
55{
56 for (int i = 1; i < MAX_CLIENTS; i++) {
57 if (client == connections[i].client) {
58 return i;
59 }
60 }
61 return -1;
62}
63
64void destroy_connection(uint32_t client)
65{
66 int idx = find_connection(client);
67 if (idx >= 0) {
68 connections[idx].client = 0;
69 connections[idx].rhandle = 0;
70 INFO("Destroying connection");
71 } else {
72 INFO("Couldn't destroy connection for %u", client);
73 }
74}
75
76int find_free_connection()
77{
78 INFO("Allocating connection");
79 return find_connection(0);
80}
81
82static void reply(psa_handle_t msg_handle, psa_status_t status)
83{
84 pending_message[msg_handle] = 1;
85 psa_reply(msg_handle, status);
86 pending_message[msg_handle] = 0;
87}
88
89psa_signal_t psa_wait(psa_signal_t signal_mask, uint32_t timeout)
90{
91
92 psa_signal_t mask;
93 struct message msg;
94 vector_sizes_t sizes;
95 struct msqid_ds qinfo;
96 uint32_t requested_version;
97 ssize_t len;
98 int idx;
99
100 if (timeout == PSA_POLL) {
101 INFO("psa_wait: Called in polling mode");
102 }
103
104 do {
105 mask = signal_mask;
106
107 /* Check the status of each queue */
108 for (int i = 0; i < 32; i++) {
109 if (mask & 0x1) {
110 if (i < 3) {
111 // do nothing (reserved)
112 } else if (i == 3) {
113 // this must be psa doorbell
114 } else {
115
116 /* Check if this signal corresponds to a queue */
117 if (rot_svc_incoming_queue[i] >= 0 && (pending_message[i] == 0)) {
118
119 /* AFAIK there is no "peek" method in SysV, so try to get a message */
120 len = msgrcv(rot_svc_incoming_queue[i],
121 &msg,
122 sizeof(struct message_text),
123 0,
124 IPC_NOWAIT);
125 if (len > 0) {
126
127 INFO("Storing that QID in message_client[%d]\n", i);
128 INFO("The message handle will be %d\n", i);
129
130 msgctl(rot_svc_incoming_queue[i], IPC_STAT, &qinfo);
131 messages[i].client_id = qinfo.msg_lspid; /* PID of last msgsnd(2) call */
132 message_client[i] = msg.message_text.qid;
133 idx = find_connection(msg.message_text.qid);
134
135 if (msg.message_type & NON_SECURE) {
136 /* This is a non-secure message */
137
138 /* Check if NS client is allowed for this RoT service */
139 if (nsacl[i] <= 0) {
140#if 0
141 INFO(
142 "Rejecting non-secure client due to manifest security policy");
143 reply(i, PSA_ERROR_CONNECTION_REFUSED);
144 continue; /* Skip to next signal */
145#endif
146 }
147
148 msg.message_type &= ~(NON_SECURE); /* clear */
149 messages[i].client_id = messages[i].client_id * -1;
150 }
151
152 INFO("Got a message from client ID %d\n", messages[i].client_id);
153 INFO("Message type is %lu\n", msg.message_type);
154 INFO("PSA message type is %d\n", msg.message_text.psa_type);
155
156 messages[i].handle = i;
157
158 switch (msg.message_text.psa_type) {
159 case PSA_IPC_CONNECT:
160
161 if (len >= 16) {
162 memcpy(&requested_version, msg.message_text.buf,
163 sizeof(requested_version));
164 INFO("Requesting version %u\n", requested_version);
165 INFO("Implemented version %u\n", rot_svc_versions[i]);
166 /* TODO: need to check whether the policy is strict,
167 * and if so, then reject the client if the number doesn't match */
168
169 if (requested_version > rot_svc_versions[i]) {
170 INFO(
171 "Rejecting client because requested version that was too high");
172 reply(i, PSA_ERROR_CONNECTION_REFUSED);
173 continue; /* Skip to next signal */
174 }
175
176 if (strict_policy[i] == 1 &&
177 (requested_version != rot_svc_versions[i])) {
178 INFO(
179 "Rejecting client because enforcing a STRICT version policy");
180 reply(i, PSA_ERROR_CONNECTION_REFUSED);
181 continue; /* Skip to next signal */
182 } else {
183 INFO("Not rejecting client");
184 }
185 }
186
187 messages[i].type = PSA_IPC_CONNECT;
188
189 if (idx < 0) {
190 idx = find_free_connection();
191 }
192
193 if (idx >= 0) {
194 connections[idx].client = msg.message_text.qid;
195 } else {
196 /* We've run out of system wide connections */
197 reply(i, PSA_ERROR_CONNECTION_BUSY);
198 INFO("Ran out of free connections");
199 continue;
200 }
201
202 break;
203 case PSA_IPC_DISCONNECT:
204 messages[i].type = PSA_IPC_DISCONNECT;
205 break;
206 case VERSION_REQUEST:
207 INFO("Got a version request");
208 reply(i, rot_svc_versions[i]);
209 continue; /* Skip to next signal */
210 break;
211
212 default:
213
214 /* PSA CALL */
215 if (msg.message_text.psa_type >= 0) {
216 messages[i].type = msg.message_text.psa_type;
217 memcpy(&sizes, msg.message_text.buf, sizeof(sizes));
218 print_vectors(&sizes);
219 memcpy(&messages[i].in_size, &sizes.invec_sizes,
220 (sizeof(size_t) * PSA_MAX_IOVEC));
221 memcpy(&messages[i].out_size, &sizes.outvec_sizes,
222 (sizeof(size_t) * PSA_MAX_IOVEC));
223 } else {
224 FATAL("UNKNOWN MESSAGE TYPE RECEIVED %li\n",
225 msg.message_type);
226 }
227 break;
228 }
229 messages[i].handle = i;
230
231 /* Check if the client has a connection */
232 if (idx >= 0) {
233 messages[i].rhandle = connections[idx].rhandle;
234 } else {
235 /* Client is begging for a programmer error */
236 reply(i, PSA_ERROR_PROGRAMMER_ERROR);
237 continue;
238 }
239
240 /* House keeping */
241 pending_message[i] = 1; /* set message as pending */
242 exposed_signals |= (0x1 << i); /* assert the signal */
243 }
244 }
245 }
246 mask = mask >> 1;
247 }
248 }
249
250 if ((timeout == PSA_BLOCK) && (exposed_signals > 0)) {
251 break;
252 } else {
253 /* There is no 'select' function in SysV to block on multiple queues, so busy-wait :( */
254 usleep(50000);
255 }
256 } while (timeout == PSA_BLOCK);
257
258 INFO("\n");
259
260 /* Assert signals */
261 return signal_mask & exposed_signals;
262}
263
264static int signal_to_index(psa_signal_t signal)
265{
266
267 int i;
268 int count = 0;
269 int ret = -1;
270
271 for (i = 0; i < 32; i++) {
272 if (signal & 0x1) {
273 ret = i;
274 count++;
275 }
276 signal = signal >> 1;
277 }
278
279 if (count > 1) {
280 INFO("ERROR: Too many signals");
281 return -1; /* Too many signals */
282 }
283 return ret;
284}
285
286static void clear_signal(psa_signal_t signal)
287{
288 exposed_signals = exposed_signals & ~signal;
289}
290
291void raise_signal(psa_signal_t signal)
292{
293 exposed_signals |= signal;
294}
295
296psa_status_t psa_get(psa_signal_t signal, psa_msg_t *msg)
297{
298 int index = signal_to_index(signal);
299 if (index < 0) {
300 PROGRAMMER_ERROR("Bad signal\n");
301 }
302
303 clear_signal(signal);
304
305 assert(messages[index].handle != 0);
306
307 if (pending_message[index] == 1) {
308 INFO("There is a pending message!");
309 memcpy(msg, &messages[index], sizeof(struct psa_msg_t));
310 assert(msg->handle != 0);
311 return PSA_SUCCESS;
312 } else {
313 INFO("no pending message");
314 }
315
316 return PSA_ERROR_DOES_NOT_EXIST;
317}
318
319static int is_valid_msg_handle(psa_handle_t h)
320{
321 if (h > 0 && h < MAX_MESSAGES) {
322 return 1;
323 }
324 PROGRAMMER_ERROR("Not a valid message handle");
325}
326
327static inline int is_call_msg(psa_handle_t h)
328{
329 assert(messages[h].type >= PSA_IPC_CALL);
330 return 1;
331}
332
333void psa_set_rhandle(psa_handle_t msg_handle, void *rhandle)
334{
335 is_valid_msg_handle(msg_handle);
336 int idx = find_connection(message_client[msg_handle]);
337 INFO("Setting rhandle to %p", rhandle);
338 assert(idx >= 0);
339 connections[idx].rhandle = rhandle;
340}
341
342/* Sends a message from the server to the client. Does not wait for a response */
343static void send_msg(psa_handle_t msg_handle,
344 int ctrl_msg,
345 psa_status_t status,
346 size_t amount,
347 const void *data,
348 size_t data_amount)
349{
350
351 struct message response;
352 int flags = 0;
353
354 assert(ctrl_msg > 0); /* According to System V, it must be greater than 0 */
355
356 response.message_type = ctrl_msg;
357 if (ctrl_msg == PSA_REPLY) {
358 memcpy(response.message_text.buf, &status, sizeof(psa_status_t));
359 } else if (ctrl_msg == READ_REQUEST || ctrl_msg == WRITE_REQUEST || ctrl_msg == SKIP_REQUEST) {
360 memcpy(response.message_text.buf, &status, sizeof(psa_status_t));
361 memcpy(response.message_text.buf+sizeof(size_t), &amount, sizeof(size_t));
362 if (ctrl_msg == WRITE_REQUEST) {
363 /* TODO: Check if too big */
364 memcpy(response.message_text.buf + (sizeof(size_t) * 2), data, data_amount);
365 }
366 }
367
368 /* TODO: sizeof doesn't need to be so big here for small responses */
369 if (msgsnd(message_client[msg_handle], &response, sizeof(response.message_text), flags) == -1) {
370 INFO("Failed to reply");
371 }
372}
373
374static size_t skip(psa_handle_t msg_handle, uint32_t invec_idx, size_t num_bytes)
375{
376 if (num_bytes < (messages[msg_handle].in_size[invec_idx] - num_bytes)) {
377 messages[msg_handle].in_size[invec_idx] = messages[msg_handle].in_size[invec_idx] -
378 num_bytes;
379 return num_bytes;
380 } else {
381 if (num_bytes >= messages[msg_handle].in_size[invec_idx]) {
382 size_t ret = messages[msg_handle].in_size[invec_idx];
383 messages[msg_handle].in_size[invec_idx] = 0;
384 return ret;
385 } else {
386 return num_bytes;
387 }
388 }
389}
390
391size_t psa_read(psa_handle_t msg_handle, uint32_t invec_idx,
392 void *buffer, size_t num_bytes)
393{
394 size_t sofar = 0;
395 struct message msg = { 0 };
396 int idx;
397 ssize_t len;
398
399 is_valid_msg_handle(msg_handle);
400 is_call_msg(msg_handle);
401
402 if (invec_idx >= PSA_MAX_IOVEC) {
403 PROGRAMMER_ERROR("Invalid iovec number");
404 }
405
406 /* If user wants more data than what's available, truncate their request */
407 if (num_bytes > messages[msg_handle].in_size[invec_idx]) {
408 num_bytes = messages[msg_handle].in_size[invec_idx];
409 }
410
411 while (sofar < num_bytes) {
412 INFO("Server: requesting %lu bytes from client\n", (num_bytes - sofar));
413 send_msg(msg_handle, READ_REQUEST, invec_idx, (num_bytes - sofar), NULL, 0);
414
415 idx = find_connection(message_client[msg_handle]);
416 assert(idx >= 0);
417
418 len = msgrcv(connections[idx].client_to_server_q, &msg, sizeof(struct message_text), 0, 0);
419 len = (len - sizeof(msg.message_text.qid));
420
421 if (len < 0) {
422 FATAL("Internal error: failed to dispatch read request to the client");
423 }
424
425 if (len > (num_bytes - sofar)) {
426 if ((num_bytes - sofar) > 0) {
427 memcpy(buffer+sofar, msg.message_text.buf, (num_bytes - sofar));
428 }
429 } else {
430 memcpy(buffer + sofar, msg.message_text.buf, len);
431 }
432
433 INFO("Printing what i got so far: %s\n", msg.message_text.buf);
434
435 sofar = sofar + len;
436 }
437
438 /* Update the seek count */
439 skip(msg_handle, invec_idx, num_bytes);
440 INFO("Finished psa_read");
441 return sofar;
442}
443
444void psa_write(psa_handle_t msg_handle, uint32_t outvec_idx,
445 const void *buffer, size_t num_bytes)
446{
447
448 size_t sofar = 0;
449 struct message msg = { 0 };
450 int idx;
451 ssize_t len;
452
453 is_valid_msg_handle(msg_handle);
454 is_call_msg(msg_handle);
455
456 if (outvec_idx >= PSA_MAX_IOVEC) {
457 PROGRAMMER_ERROR("Invalid iovec number");
458 }
459
460 if (num_bytes > messages[msg_handle].out_size[outvec_idx]) {
461 PROGRAMMER_ERROR("Program tried to write too much data %lu/%lu", num_bytes,
462 messages[msg_handle].out_size[outvec_idx]);
463 }
464
465 while (sofar < num_bytes) {
466 size_t sending = (num_bytes - sofar);
467 if (sending >= MAX_FRAGMENT_SIZE) {
468 sending = MAX_FRAGMENT_SIZE - (sizeof(size_t) * 2);
469 }
470
471 INFO("Server: sending %lu bytes to client\n", sending);
472
473 send_msg(msg_handle, WRITE_REQUEST, outvec_idx, sending, buffer, sending);
474
475 idx = find_connection(message_client[msg_handle]);
476 assert(idx >= 0);
477
478 len = msgrcv(connections[idx].client_to_server_q, &msg, sizeof(struct message_text), 0, 0);
479 if (len < 1) {
480 FATAL("Client didn't give me a full response");
481 }
482 sofar = sofar + len;
483 }
484
485 /* Update the seek count */
486 messages[msg_handle].out_size[outvec_idx] -= num_bytes;
487}
488
489size_t psa_skip(psa_handle_t msg_handle, uint32_t invec_idx, size_t num_bytes)
490{
491
492 is_valid_msg_handle(msg_handle);
493 is_call_msg(msg_handle);
494
495 size_t ret = skip(msg_handle, invec_idx, num_bytes);
496
497 /* notify client to skip */
498 send_msg(msg_handle, SKIP_REQUEST, invec_idx, num_bytes, NULL, 0);
499 return ret;
500}
501
502static void destroy_temporary_queue(int myqid)
503{
504
505 if (msgctl(myqid, IPC_RMID, NULL) != 0) {
506 INFO("ERROR: Failed to delete msg queue %d", myqid);
507 }
508}
509
510static int make_temporary_queue()
511{
512 int myqid;
513 if ((myqid = msgget(IPC_PRIVATE, 0660)) == -1) {
514 INFO("msgget: myqid");
515 return -1;
516 }
517 return myqid;
518}
519
520/**
521 * Assumes msg_handle is the index into the message array
522 */
523void psa_reply(psa_handle_t msg_handle, psa_status_t status)
524{
525 int idx, q;
526 is_valid_msg_handle(msg_handle);
527
528 if (pending_message[msg_handle] != 1) {
529 PROGRAMMER_ERROR("Not a valid message handle");
530 }
531
532 if (messages[msg_handle].type == PSA_IPC_CONNECT) {
533 switch (status) {
534 case PSA_SUCCESS:
535 idx = find_connection(message_client[msg_handle]);
536 q = make_temporary_queue();
537 if (q > 0 && idx >= 0) {
538 connections[idx].client_to_server_q = q;
539 status = q;
540 } else {
541 FATAL("What happened?");
542 }
543 break;
544 case PSA_ERROR_CONNECTION_REFUSED:
545 destroy_connection(message_client[msg_handle]);
546 break;
547 case PSA_ERROR_CONNECTION_BUSY:
548 destroy_connection(message_client[msg_handle]);
549 break;
550 case PSA_ERROR_PROGRAMMER_ERROR:
551 destroy_connection(message_client[msg_handle]);
552 break;
553 default:
554 PROGRAMMER_ERROR("Not a valid reply %d\n", status);
555 }
556 } else if (messages[msg_handle].type == PSA_IPC_DISCONNECT) {
557 idx = find_connection(message_client[msg_handle]);
558 if (idx >= 0) {
559 destroy_temporary_queue(connections[idx].client_to_server_q);
560 }
561 destroy_connection(message_client[msg_handle]);
562 }
563
564 send_msg(msg_handle, PSA_REPLY, status, 0, NULL, 0);
565
566 pending_message[msg_handle] = 0;
567 message_client[msg_handle] = 0;
568}
569
570/* TODO: make sure you only clear interrupt signals, and not others */
571void psa_eoi(psa_signal_t signal)
572{
573 int index = signal_to_index(signal);
574 if (index >= 0 && (rot_svc_incoming_queue[index] >= 0)) {
575 clear_signal(signal);
576 } else {
577 PROGRAMMER_ERROR("Tried to EOI a signal that isn't an interrupt");
578 }
579}
580
581void psa_notify(int32_t partition_id)
582{
583 char pathname[PATHNAMESIZE] = { 0 };
584
585 if (partition_id < 0) {
586 PROGRAMMER_ERROR("Not a valid secure partition");
587 }
588
589 snprintf(pathname, PATHNAMESIZE, "/tmp/psa_notify_%u", partition_id);
590 INFO("psa_notify: notifying partition %u using %s",
591 partition_id, pathname);
592 INFO("psa_notify is unimplemented");
593}
594
595void psa_clear(void)
596{
597 clear_signal(PSA_DOORBELL);
598}
599
600void __init_psasim(const char **array,
601 int size,
602 const int allow_ns_clients_array[32],
603 const uint32_t versions[32],
604 const int strict_policy_array[32])
605{
606
607 static uint8_t library_initialised = 0;
608 key_t key;
609 int qid;
610 FILE *fp;
611 char doorbell_path[PATHNAMESIZE] = { 0 };
612 snprintf(doorbell_path, PATHNAMESIZE, "/tmp/psa_notify_%u", getpid());
613
614 if (library_initialised > 0) {
615 return;
616 } else {
617 library_initialised = 1;
618 }
619
620 if (size != 32) {
621 FATAL("Unsupported value. Aborting.");
622 }
623
624 array[3] = doorbell_path;
625
626 for (int i = 0; i < 32; i++) {
627 if (strncmp(array[i], "", 1) != 0) {
628 INFO("Setting up %s", array[i]);
629
630 /* Create file if doesn't exist */
631 fp = fopen(array[i], "ab+");
632 if (fp) {
633 fclose(fp);
634 }
635
636 if ((key = ftok(array[i], PROJECT_ID)) == -1) {
637 FATAL("Error finding message queue during initialisation");
638 }
639
640 /* TODO: Investigate. Permissions are likely to be too relaxed */
641 if ((qid = msgget(key, IPC_CREAT | 0660)) == -1) {
642 FATAL("Error opening message queue during initialisation");
643 } else {
644 rot_svc_incoming_queue[i] = qid;
645 }
646 }
647 }
648
649 memcpy(nsacl, allow_ns_clients_array, sizeof(int) * 32);
650 memcpy(strict_policy, strict_policy_array, sizeof(int) * 32);
651 memcpy(rot_svc_versions, versions, sizeof(uint32_t) * 32);
652 bzero(&connections, sizeof(struct connection) * MAX_CLIENTS);
653
654 __psa_ff_client_security_state = 0; /* Set the client status to SECURE */
655}