blob: 3b11e1b283de7973887b683485c0ba77318820e6 [file] [log] [blame]
Andrew Walbran16937d02019-10-22 13:54:20 +01001//===- RPCUtils.h - Utilities for building RPC APIs -------------*- C++ -*-===//
Andrew Scull5e1ddfa2018-08-14 10:06:54 +01002//
Andrew Walbran16937d02019-10-22 13:54:20 +01003// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
Andrew Scull5e1ddfa2018-08-14 10:06:54 +01006//
7//===----------------------------------------------------------------------===//
8//
9// Utilities to support construction of simple RPC APIs.
10//
11// The RPC utilities aim for ease of use (minimal conceptual overhead) for C++
12// programmers, high performance, low memory overhead, and efficient use of the
13// communications channel.
14//
15//===----------------------------------------------------------------------===//
16
17#ifndef LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H
18#define LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H
19
20#include <map>
21#include <thread>
22#include <vector>
23
24#include "llvm/ADT/STLExtras.h"
25#include "llvm/ExecutionEngine/Orc/OrcError.h"
26#include "llvm/ExecutionEngine/Orc/RPCSerialization.h"
Andrew Scull0372a572018-11-16 15:47:06 +000027#include "llvm/Support/MSVCErrorWorkarounds.h"
Andrew Scull5e1ddfa2018-08-14 10:06:54 +010028
29#include <future>
30
31namespace llvm {
32namespace orc {
33namespace rpc {
34
35/// Base class of all fatal RPC errors (those that necessarily result in the
36/// termination of the RPC session).
37class RPCFatalError : public ErrorInfo<RPCFatalError> {
38public:
39 static char ID;
40};
41
42/// RPCConnectionClosed is returned from RPC operations if the RPC connection
43/// has already been closed due to either an error or graceful disconnection.
44class ConnectionClosed : public ErrorInfo<ConnectionClosed> {
45public:
46 static char ID;
47 std::error_code convertToErrorCode() const override;
48 void log(raw_ostream &OS) const override;
49};
50
51/// BadFunctionCall is returned from handleOne when the remote makes a call with
52/// an unrecognized function id.
53///
54/// This error is fatal because Orc RPC needs to know how to parse a function
55/// call to know where the next call starts, and if it doesn't recognize the
56/// function id it cannot parse the call.
57template <typename FnIdT, typename SeqNoT>
58class BadFunctionCall
59 : public ErrorInfo<BadFunctionCall<FnIdT, SeqNoT>, RPCFatalError> {
60public:
61 static char ID;
62
63 BadFunctionCall(FnIdT FnId, SeqNoT SeqNo)
64 : FnId(std::move(FnId)), SeqNo(std::move(SeqNo)) {}
65
66 std::error_code convertToErrorCode() const override {
67 return orcError(OrcErrorCode::UnexpectedRPCCall);
68 }
69
70 void log(raw_ostream &OS) const override {
71 OS << "Call to invalid RPC function id '" << FnId << "' with "
72 "sequence number " << SeqNo;
73 }
74
75private:
76 FnIdT FnId;
77 SeqNoT SeqNo;
78};
79
80template <typename FnIdT, typename SeqNoT>
81char BadFunctionCall<FnIdT, SeqNoT>::ID = 0;
82
83/// InvalidSequenceNumberForResponse is returned from handleOne when a response
84/// call arrives with a sequence number that doesn't correspond to any in-flight
85/// function call.
86///
87/// This error is fatal because Orc RPC needs to know how to parse the rest of
88/// the response call to know where the next call starts, and if it doesn't have
89/// a result parser for this sequence number it can't do that.
90template <typename SeqNoT>
91class InvalidSequenceNumberForResponse
92 : public ErrorInfo<InvalidSequenceNumberForResponse<SeqNoT>, RPCFatalError> {
93public:
94 static char ID;
95
96 InvalidSequenceNumberForResponse(SeqNoT SeqNo)
97 : SeqNo(std::move(SeqNo)) {}
98
99 std::error_code convertToErrorCode() const override {
100 return orcError(OrcErrorCode::UnexpectedRPCCall);
101 };
102
103 void log(raw_ostream &OS) const override {
104 OS << "Response has unknown sequence number " << SeqNo;
105 }
106private:
107 SeqNoT SeqNo;
108};
109
110template <typename SeqNoT>
111char InvalidSequenceNumberForResponse<SeqNoT>::ID = 0;
112
113/// This non-fatal error will be passed to asynchronous result handlers in place
114/// of a result if the connection goes down before a result returns, or if the
115/// function to be called cannot be negotiated with the remote.
116class ResponseAbandoned : public ErrorInfo<ResponseAbandoned> {
117public:
118 static char ID;
119
120 std::error_code convertToErrorCode() const override;
121 void log(raw_ostream &OS) const override;
122};
123
124/// This error is returned if the remote does not have a handler installed for
125/// the given RPC function.
126class CouldNotNegotiate : public ErrorInfo<CouldNotNegotiate> {
127public:
128 static char ID;
129
130 CouldNotNegotiate(std::string Signature);
131 std::error_code convertToErrorCode() const override;
132 void log(raw_ostream &OS) const override;
133 const std::string &getSignature() const { return Signature; }
134private:
135 std::string Signature;
136};
137
138template <typename DerivedFunc, typename FnT> class Function;
139
140// RPC Function class.
141// DerivedFunc should be a user defined class with a static 'getName()' method
142// returning a const char* representing the function's name.
143template <typename DerivedFunc, typename RetT, typename... ArgTs>
144class Function<DerivedFunc, RetT(ArgTs...)> {
145public:
146 /// User defined function type.
147 using Type = RetT(ArgTs...);
148
149 /// Return type.
150 using ReturnType = RetT;
151
152 /// Returns the full function prototype as a string.
153 static const char *getPrototype() {
Andrew Walbran16937d02019-10-22 13:54:20 +0100154 static std::string Name = [] {
155 std::string Name;
Andrew Scull5e1ddfa2018-08-14 10:06:54 +0100156 raw_string_ostream(Name)
157 << RPCTypeName<RetT>::getName() << " " << DerivedFunc::getName()
158 << "(" << llvm::orc::rpc::RPCTypeNameSequence<ArgTs...>() << ")";
Andrew Walbran16937d02019-10-22 13:54:20 +0100159 return Name;
160 }();
Andrew Scull5e1ddfa2018-08-14 10:06:54 +0100161 return Name.data();
162 }
Andrew Scull5e1ddfa2018-08-14 10:06:54 +0100163};
164
Andrew Scull5e1ddfa2018-08-14 10:06:54 +0100165/// Allocates RPC function ids during autonegotiation.
166/// Specializations of this class must provide four members:
167///
168/// static T getInvalidId():
169/// Should return a reserved id that will be used to represent missing
170/// functions during autonegotiation.
171///
172/// static T getResponseId():
173/// Should return a reserved id that will be used to send function responses
174/// (return values).
175///
176/// static T getNegotiateId():
177/// Should return a reserved id for the negotiate function, which will be used
178/// to negotiate ids for user defined functions.
179///
180/// template <typename Func> T allocate():
181/// Allocate a unique id for function Func.
182template <typename T, typename = void> class RPCFunctionIdAllocator;
183
184/// This specialization of RPCFunctionIdAllocator provides a default
185/// implementation for integral types.
186template <typename T>
187class RPCFunctionIdAllocator<
188 T, typename std::enable_if<std::is_integral<T>::value>::type> {
189public:
190 static T getInvalidId() { return T(0); }
191 static T getResponseId() { return T(1); }
192 static T getNegotiateId() { return T(2); }
193
194 template <typename Func> T allocate() { return NextId++; }
195
196private:
197 T NextId = 3;
198};
199
200namespace detail {
201
Andrew Scull5e1ddfa2018-08-14 10:06:54 +0100202/// Provides a typedef for a tuple containing the decayed argument types.
203template <typename T> class FunctionArgsTuple;
204
205template <typename RetT, typename... ArgTs>
206class FunctionArgsTuple<RetT(ArgTs...)> {
207public:
208 using Type = std::tuple<typename std::decay<
209 typename std::remove_reference<ArgTs>::type>::type...>;
210};
211
212// ResultTraits provides typedefs and utilities specific to the return type
213// of functions.
214template <typename RetT> class ResultTraits {
215public:
216 // The return type wrapped in llvm::Expected.
217 using ErrorReturnType = Expected<RetT>;
218
219#ifdef _MSC_VER
220 // The ErrorReturnType wrapped in a std::promise.
Andrew Scull0372a572018-11-16 15:47:06 +0000221 using ReturnPromiseType = std::promise<MSVCPExpected<RetT>>;
Andrew Scull5e1ddfa2018-08-14 10:06:54 +0100222
223 // The ErrorReturnType wrapped in a std::future.
Andrew Scull0372a572018-11-16 15:47:06 +0000224 using ReturnFutureType = std::future<MSVCPExpected<RetT>>;
Andrew Scull5e1ddfa2018-08-14 10:06:54 +0100225#else
226 // The ErrorReturnType wrapped in a std::promise.
227 using ReturnPromiseType = std::promise<ErrorReturnType>;
228
229 // The ErrorReturnType wrapped in a std::future.
230 using ReturnFutureType = std::future<ErrorReturnType>;
231#endif
232
233 // Create a 'blank' value of the ErrorReturnType, ready and safe to
234 // overwrite.
235 static ErrorReturnType createBlankErrorReturnValue() {
236 return ErrorReturnType(RetT());
237 }
238
239 // Consume an abandoned ErrorReturnType.
240 static void consumeAbandoned(ErrorReturnType RetOrErr) {
241 consumeError(RetOrErr.takeError());
242 }
243};
244
245// ResultTraits specialization for void functions.
246template <> class ResultTraits<void> {
247public:
248 // For void functions, ErrorReturnType is llvm::Error.
249 using ErrorReturnType = Error;
250
251#ifdef _MSC_VER
252 // The ErrorReturnType wrapped in a std::promise.
Andrew Scull0372a572018-11-16 15:47:06 +0000253 using ReturnPromiseType = std::promise<MSVCPError>;
Andrew Scull5e1ddfa2018-08-14 10:06:54 +0100254
255 // The ErrorReturnType wrapped in a std::future.
Andrew Scull0372a572018-11-16 15:47:06 +0000256 using ReturnFutureType = std::future<MSVCPError>;
Andrew Scull5e1ddfa2018-08-14 10:06:54 +0100257#else
258 // The ErrorReturnType wrapped in a std::promise.
259 using ReturnPromiseType = std::promise<ErrorReturnType>;
260
261 // The ErrorReturnType wrapped in a std::future.
262 using ReturnFutureType = std::future<ErrorReturnType>;
263#endif
264
265 // Create a 'blank' value of the ErrorReturnType, ready and safe to
266 // overwrite.
267 static ErrorReturnType createBlankErrorReturnValue() {
268 return ErrorReturnType::success();
269 }
270
271 // Consume an abandoned ErrorReturnType.
272 static void consumeAbandoned(ErrorReturnType Err) {
273 consumeError(std::move(Err));
274 }
275};
276
277// ResultTraits<Error> is equivalent to ResultTraits<void>. This allows
278// handlers for void RPC functions to return either void (in which case they
279// implicitly succeed) or Error (in which case their error return is
280// propagated). See usage in HandlerTraits::runHandlerHelper.
281template <> class ResultTraits<Error> : public ResultTraits<void> {};
282
283// ResultTraits<Expected<T>> is equivalent to ResultTraits<T>. This allows
284// handlers for RPC functions returning a T to return either a T (in which
285// case they implicitly succeed) or Expected<T> (in which case their error
286// return is propagated). See usage in HandlerTraits::runHandlerHelper.
287template <typename RetT>
288class ResultTraits<Expected<RetT>> : public ResultTraits<RetT> {};
289
290// Determines whether an RPC function's defined error return type supports
291// error return value.
292template <typename T>
293class SupportsErrorReturn {
294public:
295 static const bool value = false;
296};
297
298template <>
299class SupportsErrorReturn<Error> {
300public:
301 static const bool value = true;
302};
303
304template <typename T>
305class SupportsErrorReturn<Expected<T>> {
306public:
307 static const bool value = true;
308};
309
310// RespondHelper packages return values based on whether or not the declared
311// RPC function return type supports error returns.
312template <bool FuncSupportsErrorReturn>
313class RespondHelper;
314
315// RespondHelper specialization for functions that support error returns.
316template <>
317class RespondHelper<true> {
318public:
319
320 // Send Expected<T>.
321 template <typename WireRetT, typename HandlerRetT, typename ChannelT,
322 typename FunctionIdT, typename SequenceNumberT>
323 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
324 SequenceNumberT SeqNo,
325 Expected<HandlerRetT> ResultOrErr) {
326 if (!ResultOrErr && ResultOrErr.template errorIsA<RPCFatalError>())
327 return ResultOrErr.takeError();
328
329 // Open the response message.
330 if (auto Err = C.startSendMessage(ResponseId, SeqNo))
331 return Err;
332
333 // Serialize the result.
334 if (auto Err =
335 SerializationTraits<ChannelT, WireRetT,
336 Expected<HandlerRetT>>::serialize(
337 C, std::move(ResultOrErr)))
338 return Err;
339
340 // Close the response message.
341 return C.endSendMessage();
342 }
343
344 template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT>
345 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
346 SequenceNumberT SeqNo, Error Err) {
347 if (Err && Err.isA<RPCFatalError>())
348 return Err;
349 if (auto Err2 = C.startSendMessage(ResponseId, SeqNo))
350 return Err2;
351 if (auto Err2 = serializeSeq(C, std::move(Err)))
352 return Err2;
353 return C.endSendMessage();
354 }
355
356};
357
358// RespondHelper specialization for functions that do not support error returns.
359template <>
360class RespondHelper<false> {
361public:
362
363 template <typename WireRetT, typename HandlerRetT, typename ChannelT,
364 typename FunctionIdT, typename SequenceNumberT>
365 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
366 SequenceNumberT SeqNo,
367 Expected<HandlerRetT> ResultOrErr) {
368 if (auto Err = ResultOrErr.takeError())
369 return Err;
370
371 // Open the response message.
372 if (auto Err = C.startSendMessage(ResponseId, SeqNo))
373 return Err;
374
375 // Serialize the result.
376 if (auto Err =
377 SerializationTraits<ChannelT, WireRetT, HandlerRetT>::serialize(
378 C, *ResultOrErr))
379 return Err;
380
381 // Close the response message.
382 return C.endSendMessage();
383 }
384
385 template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT>
386 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
387 SequenceNumberT SeqNo, Error Err) {
388 if (Err)
389 return Err;
390 if (auto Err2 = C.startSendMessage(ResponseId, SeqNo))
391 return Err2;
392 return C.endSendMessage();
393 }
394
395};
396
397
398// Send a response of the given wire return type (WireRetT) over the
399// channel, with the given sequence number.
400template <typename WireRetT, typename HandlerRetT, typename ChannelT,
401 typename FunctionIdT, typename SequenceNumberT>
402Error respond(ChannelT &C, const FunctionIdT &ResponseId,
403 SequenceNumberT SeqNo, Expected<HandlerRetT> ResultOrErr) {
404 return RespondHelper<SupportsErrorReturn<WireRetT>::value>::
405 template sendResult<WireRetT>(C, ResponseId, SeqNo, std::move(ResultOrErr));
406}
407
408// Send an empty response message on the given channel to indicate that
409// the handler ran.
410template <typename WireRetT, typename ChannelT, typename FunctionIdT,
411 typename SequenceNumberT>
412Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo,
413 Error Err) {
414 return RespondHelper<SupportsErrorReturn<WireRetT>::value>::
415 sendResult(C, ResponseId, SeqNo, std::move(Err));
416}
417
418// Converts a given type to the equivalent error return type.
419template <typename T> class WrappedHandlerReturn {
420public:
421 using Type = Expected<T>;
422};
423
424template <typename T> class WrappedHandlerReturn<Expected<T>> {
425public:
426 using Type = Expected<T>;
427};
428
429template <> class WrappedHandlerReturn<void> {
430public:
431 using Type = Error;
432};
433
434template <> class WrappedHandlerReturn<Error> {
435public:
436 using Type = Error;
437};
438
439template <> class WrappedHandlerReturn<ErrorSuccess> {
440public:
441 using Type = Error;
442};
443
444// Traits class that strips the response function from the list of handler
445// arguments.
446template <typename FnT> class AsyncHandlerTraits;
447
448template <typename ResultT, typename... ArgTs>
449class AsyncHandlerTraits<Error(std::function<Error(Expected<ResultT>)>, ArgTs...)> {
450public:
451 using Type = Error(ArgTs...);
452 using ResultType = Expected<ResultT>;
453};
454
455template <typename... ArgTs>
456class AsyncHandlerTraits<Error(std::function<Error(Error)>, ArgTs...)> {
457public:
458 using Type = Error(ArgTs...);
459 using ResultType = Error;
460};
461
462template <typename... ArgTs>
463class AsyncHandlerTraits<ErrorSuccess(std::function<Error(Error)>, ArgTs...)> {
464public:
465 using Type = Error(ArgTs...);
466 using ResultType = Error;
467};
468
469template <typename... ArgTs>
470class AsyncHandlerTraits<void(std::function<Error(Error)>, ArgTs...)> {
471public:
472 using Type = Error(ArgTs...);
473 using ResultType = Error;
474};
475
476template <typename ResponseHandlerT, typename... ArgTs>
477class AsyncHandlerTraits<Error(ResponseHandlerT, ArgTs...)> :
478 public AsyncHandlerTraits<Error(typename std::decay<ResponseHandlerT>::type,
479 ArgTs...)> {};
480
481// This template class provides utilities related to RPC function handlers.
482// The base case applies to non-function types (the template class is
483// specialized for function types) and inherits from the appropriate
484// speciilization for the given non-function type's call operator.
485template <typename HandlerT>
486class HandlerTraits : public HandlerTraits<decltype(
487 &std::remove_reference<HandlerT>::type::operator())> {
488};
489
490// Traits for handlers with a given function type.
491template <typename RetT, typename... ArgTs>
492class HandlerTraits<RetT(ArgTs...)> {
493public:
494 // Function type of the handler.
495 using Type = RetT(ArgTs...);
496
497 // Return type of the handler.
498 using ReturnType = RetT;
499
500 // Call the given handler with the given arguments.
501 template <typename HandlerT, typename... TArgTs>
502 static typename WrappedHandlerReturn<RetT>::Type
503 unpackAndRun(HandlerT &Handler, std::tuple<TArgTs...> &Args) {
504 return unpackAndRunHelper(Handler, Args,
505 llvm::index_sequence_for<TArgTs...>());
506 }
507
508 // Call the given handler with the given arguments.
509 template <typename HandlerT, typename ResponderT, typename... TArgTs>
510 static Error unpackAndRunAsync(HandlerT &Handler, ResponderT &Responder,
511 std::tuple<TArgTs...> &Args) {
512 return unpackAndRunAsyncHelper(Handler, Responder, Args,
513 llvm::index_sequence_for<TArgTs...>());
514 }
515
516 // Call the given handler with the given arguments.
517 template <typename HandlerT>
518 static typename std::enable_if<
519 std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value,
520 Error>::type
521 run(HandlerT &Handler, ArgTs &&... Args) {
522 Handler(std::move(Args)...);
523 return Error::success();
524 }
525
526 template <typename HandlerT, typename... TArgTs>
527 static typename std::enable_if<
528 !std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value,
529 typename HandlerTraits<HandlerT>::ReturnType>::type
530 run(HandlerT &Handler, TArgTs... Args) {
531 return Handler(std::move(Args)...);
532 }
533
534 // Serialize arguments to the channel.
535 template <typename ChannelT, typename... CArgTs>
536 static Error serializeArgs(ChannelT &C, const CArgTs... CArgs) {
537 return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, CArgs...);
538 }
539
540 // Deserialize arguments from the channel.
541 template <typename ChannelT, typename... CArgTs>
542 static Error deserializeArgs(ChannelT &C, std::tuple<CArgTs...> &Args) {
543 return deserializeArgsHelper(C, Args,
544 llvm::index_sequence_for<CArgTs...>());
545 }
546
547private:
548 template <typename ChannelT, typename... CArgTs, size_t... Indexes>
549 static Error deserializeArgsHelper(ChannelT &C, std::tuple<CArgTs...> &Args,
550 llvm::index_sequence<Indexes...> _) {
551 return SequenceSerialization<ChannelT, ArgTs...>::deserialize(
552 C, std::get<Indexes>(Args)...);
553 }
554
555 template <typename HandlerT, typename ArgTuple, size_t... Indexes>
556 static typename WrappedHandlerReturn<
557 typename HandlerTraits<HandlerT>::ReturnType>::Type
558 unpackAndRunHelper(HandlerT &Handler, ArgTuple &Args,
559 llvm::index_sequence<Indexes...>) {
560 return run(Handler, std::move(std::get<Indexes>(Args))...);
561 }
562
563
564 template <typename HandlerT, typename ResponderT, typename ArgTuple,
565 size_t... Indexes>
566 static typename WrappedHandlerReturn<
567 typename HandlerTraits<HandlerT>::ReturnType>::Type
568 unpackAndRunAsyncHelper(HandlerT &Handler, ResponderT &Responder,
569 ArgTuple &Args,
570 llvm::index_sequence<Indexes...>) {
571 return run(Handler, Responder, std::move(std::get<Indexes>(Args))...);
572 }
573};
574
575// Handler traits for free functions.
576template <typename RetT, typename... ArgTs>
577class HandlerTraits<RetT(*)(ArgTs...)>
578 : public HandlerTraits<RetT(ArgTs...)> {};
579
580// Handler traits for class methods (especially call operators for lambdas).
581template <typename Class, typename RetT, typename... ArgTs>
582class HandlerTraits<RetT (Class::*)(ArgTs...)>
583 : public HandlerTraits<RetT(ArgTs...)> {};
584
585// Handler traits for const class methods (especially call operators for
586// lambdas).
587template <typename Class, typename RetT, typename... ArgTs>
588class HandlerTraits<RetT (Class::*)(ArgTs...) const>
589 : public HandlerTraits<RetT(ArgTs...)> {};
590
591// Utility to peel the Expected wrapper off a response handler error type.
592template <typename HandlerT> class ResponseHandlerArg;
593
594template <typename ArgT> class ResponseHandlerArg<Error(Expected<ArgT>)> {
595public:
596 using ArgType = Expected<ArgT>;
597 using UnwrappedArgType = ArgT;
598};
599
600template <typename ArgT>
601class ResponseHandlerArg<ErrorSuccess(Expected<ArgT>)> {
602public:
603 using ArgType = Expected<ArgT>;
604 using UnwrappedArgType = ArgT;
605};
606
607template <> class ResponseHandlerArg<Error(Error)> {
608public:
609 using ArgType = Error;
610};
611
612template <> class ResponseHandlerArg<ErrorSuccess(Error)> {
613public:
614 using ArgType = Error;
615};
616
617// ResponseHandler represents a handler for a not-yet-received function call
618// result.
619template <typename ChannelT> class ResponseHandler {
620public:
621 virtual ~ResponseHandler() {}
622
623 // Reads the function result off the wire and acts on it. The meaning of
624 // "act" will depend on how this method is implemented in any given
625 // ResponseHandler subclass but could, for example, mean running a
626 // user-specified handler or setting a promise value.
627 virtual Error handleResponse(ChannelT &C) = 0;
628
629 // Abandons this outstanding result.
630 virtual void abandon() = 0;
631
632 // Create an error instance representing an abandoned response.
633 static Error createAbandonedResponseError() {
634 return make_error<ResponseAbandoned>();
635 }
636};
637
638// ResponseHandler subclass for RPC functions with non-void returns.
639template <typename ChannelT, typename FuncRetT, typename HandlerT>
640class ResponseHandlerImpl : public ResponseHandler<ChannelT> {
641public:
642 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
643
644 // Handle the result by deserializing it from the channel then passing it
645 // to the user defined handler.
646 Error handleResponse(ChannelT &C) override {
647 using UnwrappedArgType = typename ResponseHandlerArg<
648 typename HandlerTraits<HandlerT>::Type>::UnwrappedArgType;
649 UnwrappedArgType Result;
650 if (auto Err =
651 SerializationTraits<ChannelT, FuncRetT,
652 UnwrappedArgType>::deserialize(C, Result))
653 return Err;
654 if (auto Err = C.endReceiveMessage())
655 return Err;
656 return Handler(std::move(Result));
657 }
658
659 // Abandon this response by calling the handler with an 'abandoned response'
660 // error.
661 void abandon() override {
662 if (auto Err = Handler(this->createAbandonedResponseError())) {
663 // Handlers should not fail when passed an abandoned response error.
664 report_fatal_error(std::move(Err));
665 }
666 }
667
668private:
669 HandlerT Handler;
670};
671
672// ResponseHandler subclass for RPC functions with void returns.
673template <typename ChannelT, typename HandlerT>
674class ResponseHandlerImpl<ChannelT, void, HandlerT>
675 : public ResponseHandler<ChannelT> {
676public:
677 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
678
679 // Handle the result (no actual value, just a notification that the function
680 // has completed on the remote end) by calling the user-defined handler with
681 // Error::success().
682 Error handleResponse(ChannelT &C) override {
683 if (auto Err = C.endReceiveMessage())
684 return Err;
685 return Handler(Error::success());
686 }
687
688 // Abandon this response by calling the handler with an 'abandoned response'
689 // error.
690 void abandon() override {
691 if (auto Err = Handler(this->createAbandonedResponseError())) {
692 // Handlers should not fail when passed an abandoned response error.
693 report_fatal_error(std::move(Err));
694 }
695 }
696
697private:
698 HandlerT Handler;
699};
700
701template <typename ChannelT, typename FuncRetT, typename HandlerT>
702class ResponseHandlerImpl<ChannelT, Expected<FuncRetT>, HandlerT>
703 : public ResponseHandler<ChannelT> {
704public:
705 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
706
707 // Handle the result by deserializing it from the channel then passing it
708 // to the user defined handler.
709 Error handleResponse(ChannelT &C) override {
710 using HandlerArgType = typename ResponseHandlerArg<
711 typename HandlerTraits<HandlerT>::Type>::ArgType;
712 HandlerArgType Result((typename HandlerArgType::value_type()));
713
714 if (auto Err =
715 SerializationTraits<ChannelT, Expected<FuncRetT>,
716 HandlerArgType>::deserialize(C, Result))
717 return Err;
718 if (auto Err = C.endReceiveMessage())
719 return Err;
720 return Handler(std::move(Result));
721 }
722
723 // Abandon this response by calling the handler with an 'abandoned response'
724 // error.
725 void abandon() override {
726 if (auto Err = Handler(this->createAbandonedResponseError())) {
727 // Handlers should not fail when passed an abandoned response error.
728 report_fatal_error(std::move(Err));
729 }
730 }
731
732private:
733 HandlerT Handler;
734};
735
736template <typename ChannelT, typename HandlerT>
737class ResponseHandlerImpl<ChannelT, Error, HandlerT>
738 : public ResponseHandler<ChannelT> {
739public:
740 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
741
742 // Handle the result by deserializing it from the channel then passing it
743 // to the user defined handler.
744 Error handleResponse(ChannelT &C) override {
745 Error Result = Error::success();
746 if (auto Err =
747 SerializationTraits<ChannelT, Error, Error>::deserialize(C, Result))
748 return Err;
749 if (auto Err = C.endReceiveMessage())
750 return Err;
751 return Handler(std::move(Result));
752 }
753
754 // Abandon this response by calling the handler with an 'abandoned response'
755 // error.
756 void abandon() override {
757 if (auto Err = Handler(this->createAbandonedResponseError())) {
758 // Handlers should not fail when passed an abandoned response error.
759 report_fatal_error(std::move(Err));
760 }
761 }
762
763private:
764 HandlerT Handler;
765};
766
767// Create a ResponseHandler from a given user handler.
768template <typename ChannelT, typename FuncRetT, typename HandlerT>
769std::unique_ptr<ResponseHandler<ChannelT>> createResponseHandler(HandlerT H) {
770 return llvm::make_unique<ResponseHandlerImpl<ChannelT, FuncRetT, HandlerT>>(
771 std::move(H));
772}
773
774// Helper for wrapping member functions up as functors. This is useful for
775// installing methods as result handlers.
776template <typename ClassT, typename RetT, typename... ArgTs>
777class MemberFnWrapper {
778public:
779 using MethodT = RetT (ClassT::*)(ArgTs...);
780 MemberFnWrapper(ClassT &Instance, MethodT Method)
781 : Instance(Instance), Method(Method) {}
782 RetT operator()(ArgTs &&... Args) {
783 return (Instance.*Method)(std::move(Args)...);
784 }
785
786private:
787 ClassT &Instance;
788 MethodT Method;
789};
790
791// Helper that provides a Functor for deserializing arguments.
792template <typename... ArgTs> class ReadArgs {
793public:
794 Error operator()() { return Error::success(); }
795};
796
797template <typename ArgT, typename... ArgTs>
798class ReadArgs<ArgT, ArgTs...> : public ReadArgs<ArgTs...> {
799public:
800 ReadArgs(ArgT &Arg, ArgTs &... Args)
801 : ReadArgs<ArgTs...>(Args...), Arg(Arg) {}
802
803 Error operator()(ArgT &ArgVal, ArgTs &... ArgVals) {
804 this->Arg = std::move(ArgVal);
805 return ReadArgs<ArgTs...>::operator()(ArgVals...);
806 }
807
808private:
809 ArgT &Arg;
810};
811
812// Manage sequence numbers.
813template <typename SequenceNumberT> class SequenceNumberManager {
814public:
815 // Reset, making all sequence numbers available.
816 void reset() {
817 std::lock_guard<std::mutex> Lock(SeqNoLock);
818 NextSequenceNumber = 0;
819 FreeSequenceNumbers.clear();
820 }
821
822 // Get the next available sequence number. Will re-use numbers that have
823 // been released.
824 SequenceNumberT getSequenceNumber() {
825 std::lock_guard<std::mutex> Lock(SeqNoLock);
826 if (FreeSequenceNumbers.empty())
827 return NextSequenceNumber++;
828 auto SequenceNumber = FreeSequenceNumbers.back();
829 FreeSequenceNumbers.pop_back();
830 return SequenceNumber;
831 }
832
833 // Release a sequence number, making it available for re-use.
834 void releaseSequenceNumber(SequenceNumberT SequenceNumber) {
835 std::lock_guard<std::mutex> Lock(SeqNoLock);
836 FreeSequenceNumbers.push_back(SequenceNumber);
837 }
838
839private:
840 std::mutex SeqNoLock;
841 SequenceNumberT NextSequenceNumber = 0;
842 std::vector<SequenceNumberT> FreeSequenceNumbers;
843};
844
845// Checks that predicate P holds for each corresponding pair of type arguments
846// from T1 and T2 tuple.
847template <template <class, class> class P, typename T1Tuple, typename T2Tuple>
848class RPCArgTypeCheckHelper;
849
850template <template <class, class> class P>
851class RPCArgTypeCheckHelper<P, std::tuple<>, std::tuple<>> {
852public:
853 static const bool value = true;
854};
855
856template <template <class, class> class P, typename T, typename... Ts,
857 typename U, typename... Us>
858class RPCArgTypeCheckHelper<P, std::tuple<T, Ts...>, std::tuple<U, Us...>> {
859public:
860 static const bool value =
861 P<T, U>::value &&
862 RPCArgTypeCheckHelper<P, std::tuple<Ts...>, std::tuple<Us...>>::value;
863};
864
865template <template <class, class> class P, typename T1Sig, typename T2Sig>
866class RPCArgTypeCheck {
867public:
868 using T1Tuple = typename FunctionArgsTuple<T1Sig>::Type;
869 using T2Tuple = typename FunctionArgsTuple<T2Sig>::Type;
870
871 static_assert(std::tuple_size<T1Tuple>::value >=
872 std::tuple_size<T2Tuple>::value,
873 "Too many arguments to RPC call");
874 static_assert(std::tuple_size<T1Tuple>::value <=
875 std::tuple_size<T2Tuple>::value,
876 "Too few arguments to RPC call");
877
878 static const bool value = RPCArgTypeCheckHelper<P, T1Tuple, T2Tuple>::value;
879};
880
881template <typename ChannelT, typename WireT, typename ConcreteT>
882class CanSerialize {
883private:
884 using S = SerializationTraits<ChannelT, WireT, ConcreteT>;
885
886 template <typename T>
887 static std::true_type
888 check(typename std::enable_if<
889 std::is_same<decltype(T::serialize(std::declval<ChannelT &>(),
890 std::declval<const ConcreteT &>())),
891 Error>::value,
892 void *>::type);
893
894 template <typename> static std::false_type check(...);
895
896public:
897 static const bool value = decltype(check<S>(0))::value;
898};
899
900template <typename ChannelT, typename WireT, typename ConcreteT>
901class CanDeserialize {
902private:
903 using S = SerializationTraits<ChannelT, WireT, ConcreteT>;
904
905 template <typename T>
906 static std::true_type
907 check(typename std::enable_if<
908 std::is_same<decltype(T::deserialize(std::declval<ChannelT &>(),
909 std::declval<ConcreteT &>())),
910 Error>::value,
911 void *>::type);
912
913 template <typename> static std::false_type check(...);
914
915public:
916 static const bool value = decltype(check<S>(0))::value;
917};
918
919/// Contains primitive utilities for defining, calling and handling calls to
920/// remote procedures. ChannelT is a bidirectional stream conforming to the
921/// RPCChannel interface (see RPCChannel.h), FunctionIdT is a procedure
922/// identifier type that must be serializable on ChannelT, and SequenceNumberT
923/// is an integral type that will be used to number in-flight function calls.
924///
925/// These utilities support the construction of very primitive RPC utilities.
926/// Their intent is to ensure correct serialization and deserialization of
927/// procedure arguments, and to keep the client and server's view of the API in
928/// sync.
929template <typename ImplT, typename ChannelT, typename FunctionIdT,
930 typename SequenceNumberT>
931class RPCEndpointBase {
932protected:
933 class OrcRPCInvalid : public Function<OrcRPCInvalid, void()> {
934 public:
935 static const char *getName() { return "__orc_rpc$invalid"; }
936 };
937
938 class OrcRPCResponse : public Function<OrcRPCResponse, void()> {
939 public:
940 static const char *getName() { return "__orc_rpc$response"; }
941 };
942
943 class OrcRPCNegotiate
944 : public Function<OrcRPCNegotiate, FunctionIdT(std::string)> {
945 public:
946 static const char *getName() { return "__orc_rpc$negotiate"; }
947 };
948
949 // Helper predicate for testing for the presence of SerializeTraits
950 // serializers.
951 template <typename WireT, typename ConcreteT>
952 class CanSerializeCheck : detail::CanSerialize<ChannelT, WireT, ConcreteT> {
953 public:
954 using detail::CanSerialize<ChannelT, WireT, ConcreteT>::value;
955
956 static_assert(value, "Missing serializer for argument (Can't serialize the "
957 "first template type argument of CanSerializeCheck "
958 "from the second)");
959 };
960
961 // Helper predicate for testing for the presence of SerializeTraits
962 // deserializers.
963 template <typename WireT, typename ConcreteT>
964 class CanDeserializeCheck
965 : detail::CanDeserialize<ChannelT, WireT, ConcreteT> {
966 public:
967 using detail::CanDeserialize<ChannelT, WireT, ConcreteT>::value;
968
969 static_assert(value, "Missing deserializer for argument (Can't deserialize "
970 "the second template type argument of "
971 "CanDeserializeCheck from the first)");
972 };
973
974public:
975 /// Construct an RPC instance on a channel.
976 RPCEndpointBase(ChannelT &C, bool LazyAutoNegotiation)
977 : C(C), LazyAutoNegotiation(LazyAutoNegotiation) {
978 // Hold ResponseId in a special variable, since we expect Response to be
979 // called relatively frequently, and want to avoid the map lookup.
980 ResponseId = FnIdAllocator.getResponseId();
981 RemoteFunctionIds[OrcRPCResponse::getPrototype()] = ResponseId;
982
983 // Register the negotiate function id and handler.
984 auto NegotiateId = FnIdAllocator.getNegotiateId();
985 RemoteFunctionIds[OrcRPCNegotiate::getPrototype()] = NegotiateId;
986 Handlers[NegotiateId] = wrapHandler<OrcRPCNegotiate>(
987 [this](const std::string &Name) { return handleNegotiate(Name); });
988 }
989
990
991 /// Negotiate a function id for Func with the other end of the channel.
992 template <typename Func> Error negotiateFunction(bool Retry = false) {
993 return getRemoteFunctionId<Func>(true, Retry).takeError();
994 }
995
996 /// Append a call Func, does not call send on the channel.
997 /// The first argument specifies a user-defined handler to be run when the
998 /// function returns. The handler should take an Expected<Func::ReturnType>,
999 /// or an Error (if Func::ReturnType is void). The handler will be called
1000 /// with an error if the return value is abandoned due to a channel error.
1001 template <typename Func, typename HandlerT, typename... ArgTs>
1002 Error appendCallAsync(HandlerT Handler, const ArgTs &... Args) {
1003
1004 static_assert(
1005 detail::RPCArgTypeCheck<CanSerializeCheck, typename Func::Type,
1006 void(ArgTs...)>::value,
1007 "");
1008
1009 // Look up the function ID.
1010 FunctionIdT FnId;
1011 if (auto FnIdOrErr = getRemoteFunctionId<Func>(LazyAutoNegotiation, false))
1012 FnId = *FnIdOrErr;
1013 else {
1014 // Negotiation failed. Notify the handler then return the negotiate-failed
1015 // error.
1016 cantFail(Handler(make_error<ResponseAbandoned>()));
1017 return FnIdOrErr.takeError();
1018 }
1019
1020 SequenceNumberT SeqNo; // initialized in locked scope below.
1021 {
1022 // Lock the pending responses map and sequence number manager.
1023 std::lock_guard<std::mutex> Lock(ResponsesMutex);
1024
1025 // Allocate a sequence number.
1026 SeqNo = SequenceNumberMgr.getSequenceNumber();
1027 assert(!PendingResponses.count(SeqNo) &&
1028 "Sequence number already allocated");
1029
1030 // Install the user handler.
1031 PendingResponses[SeqNo] =
1032 detail::createResponseHandler<ChannelT, typename Func::ReturnType>(
1033 std::move(Handler));
1034 }
1035
1036 // Open the function call message.
1037 if (auto Err = C.startSendMessage(FnId, SeqNo)) {
1038 abandonPendingResponses();
1039 return Err;
1040 }
1041
1042 // Serialize the call arguments.
1043 if (auto Err = detail::HandlerTraits<typename Func::Type>::serializeArgs(
1044 C, Args...)) {
1045 abandonPendingResponses();
1046 return Err;
1047 }
1048
1049 // Close the function call messagee.
1050 if (auto Err = C.endSendMessage()) {
1051 abandonPendingResponses();
1052 return Err;
1053 }
1054
1055 return Error::success();
1056 }
1057
1058 Error sendAppendedCalls() { return C.send(); };
1059
1060 template <typename Func, typename HandlerT, typename... ArgTs>
1061 Error callAsync(HandlerT Handler, const ArgTs &... Args) {
1062 if (auto Err = appendCallAsync<Func>(std::move(Handler), Args...))
1063 return Err;
1064 return C.send();
1065 }
1066
1067 /// Handle one incoming call.
1068 Error handleOne() {
1069 FunctionIdT FnId;
1070 SequenceNumberT SeqNo;
1071 if (auto Err = C.startReceiveMessage(FnId, SeqNo)) {
1072 abandonPendingResponses();
1073 return Err;
1074 }
1075 if (FnId == ResponseId)
1076 return handleResponse(SeqNo);
1077 auto I = Handlers.find(FnId);
1078 if (I != Handlers.end())
1079 return I->second(C, SeqNo);
1080
1081 // else: No handler found. Report error to client?
1082 return make_error<BadFunctionCall<FunctionIdT, SequenceNumberT>>(FnId,
1083 SeqNo);
1084 }
1085
1086 /// Helper for handling setter procedures - this method returns a functor that
1087 /// sets the variables referred to by Args... to values deserialized from the
1088 /// channel.
1089 /// E.g.
1090 ///
1091 /// typedef Function<0, bool, int> Func1;
1092 ///
1093 /// ...
1094 /// bool B;
1095 /// int I;
1096 /// if (auto Err = expect<Func1>(Channel, readArgs(B, I)))
1097 /// /* Handle Args */ ;
1098 ///
1099 template <typename... ArgTs>
1100 static detail::ReadArgs<ArgTs...> readArgs(ArgTs &... Args) {
1101 return detail::ReadArgs<ArgTs...>(Args...);
1102 }
1103
1104 /// Abandon all outstanding result handlers.
1105 ///
1106 /// This will call all currently registered result handlers to receive an
1107 /// "abandoned" error as their argument. This is used internally by the RPC
1108 /// in error situations, but can also be called directly by clients who are
1109 /// disconnecting from the remote and don't or can't expect responses to their
1110 /// outstanding calls. (Especially for outstanding blocking calls, calling
1111 /// this function may be necessary to avoid dead threads).
1112 void abandonPendingResponses() {
1113 // Lock the pending responses map and sequence number manager.
1114 std::lock_guard<std::mutex> Lock(ResponsesMutex);
1115
1116 for (auto &KV : PendingResponses)
1117 KV.second->abandon();
1118 PendingResponses.clear();
1119 SequenceNumberMgr.reset();
1120 }
1121
1122 /// Remove the handler for the given function.
1123 /// A handler must currently be registered for this function.
1124 template <typename Func>
1125 void removeHandler() {
1126 auto IdItr = LocalFunctionIds.find(Func::getPrototype());
1127 assert(IdItr != LocalFunctionIds.end() &&
1128 "Function does not have a registered handler");
1129 auto HandlerItr = Handlers.find(IdItr->second);
1130 assert(HandlerItr != Handlers.end() &&
1131 "Function does not have a registered handler");
1132 Handlers.erase(HandlerItr);
1133 }
1134
1135 /// Clear all handlers.
1136 void clearHandlers() {
1137 Handlers.clear();
1138 }
1139
1140protected:
1141
1142 FunctionIdT getInvalidFunctionId() const {
1143 return FnIdAllocator.getInvalidId();
1144 }
1145
1146 /// Add the given handler to the handler map and make it available for
1147 /// autonegotiation and execution.
1148 template <typename Func, typename HandlerT>
1149 void addHandlerImpl(HandlerT Handler) {
1150
1151 static_assert(detail::RPCArgTypeCheck<
1152 CanDeserializeCheck, typename Func::Type,
1153 typename detail::HandlerTraits<HandlerT>::Type>::value,
1154 "");
1155
1156 FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>();
1157 LocalFunctionIds[Func::getPrototype()] = NewFnId;
1158 Handlers[NewFnId] = wrapHandler<Func>(std::move(Handler));
1159 }
1160
1161 template <typename Func, typename HandlerT>
1162 void addAsyncHandlerImpl(HandlerT Handler) {
1163
1164 static_assert(detail::RPCArgTypeCheck<
1165 CanDeserializeCheck, typename Func::Type,
1166 typename detail::AsyncHandlerTraits<
1167 typename detail::HandlerTraits<HandlerT>::Type
1168 >::Type>::value,
1169 "");
1170
1171 FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>();
1172 LocalFunctionIds[Func::getPrototype()] = NewFnId;
1173 Handlers[NewFnId] = wrapAsyncHandler<Func>(std::move(Handler));
1174 }
1175
1176 Error handleResponse(SequenceNumberT SeqNo) {
1177 using Handler = typename decltype(PendingResponses)::mapped_type;
1178 Handler PRHandler;
1179
1180 {
1181 // Lock the pending responses map and sequence number manager.
1182 std::unique_lock<std::mutex> Lock(ResponsesMutex);
1183 auto I = PendingResponses.find(SeqNo);
1184
1185 if (I != PendingResponses.end()) {
1186 PRHandler = std::move(I->second);
1187 PendingResponses.erase(I);
1188 SequenceNumberMgr.releaseSequenceNumber(SeqNo);
1189 } else {
1190 // Unlock the pending results map to prevent recursive lock.
1191 Lock.unlock();
1192 abandonPendingResponses();
1193 return make_error<
1194 InvalidSequenceNumberForResponse<SequenceNumberT>>(SeqNo);
1195 }
1196 }
1197
1198 assert(PRHandler &&
1199 "If we didn't find a response handler we should have bailed out");
1200
1201 if (auto Err = PRHandler->handleResponse(C)) {
1202 abandonPendingResponses();
1203 return Err;
1204 }
1205
1206 return Error::success();
1207 }
1208
1209 FunctionIdT handleNegotiate(const std::string &Name) {
1210 auto I = LocalFunctionIds.find(Name);
1211 if (I == LocalFunctionIds.end())
1212 return getInvalidFunctionId();
1213 return I->second;
1214 }
1215
1216 // Find the remote FunctionId for the given function.
1217 template <typename Func>
1218 Expected<FunctionIdT> getRemoteFunctionId(bool NegotiateIfNotInMap,
1219 bool NegotiateIfInvalid) {
1220 bool DoNegotiate;
1221
1222 // Check if we already have a function id...
1223 auto I = RemoteFunctionIds.find(Func::getPrototype());
1224 if (I != RemoteFunctionIds.end()) {
1225 // If it's valid there's nothing left to do.
1226 if (I->second != getInvalidFunctionId())
1227 return I->second;
1228 DoNegotiate = NegotiateIfInvalid;
1229 } else
1230 DoNegotiate = NegotiateIfNotInMap;
1231
1232 // We don't have a function id for Func yet, but we're allowed to try to
1233 // negotiate one.
1234 if (DoNegotiate) {
1235 auto &Impl = static_cast<ImplT &>(*this);
1236 if (auto RemoteIdOrErr =
1237 Impl.template callB<OrcRPCNegotiate>(Func::getPrototype())) {
1238 RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr;
1239 if (*RemoteIdOrErr == getInvalidFunctionId())
1240 return make_error<CouldNotNegotiate>(Func::getPrototype());
1241 return *RemoteIdOrErr;
1242 } else
1243 return RemoteIdOrErr.takeError();
1244 }
1245
1246 // No key was available in the map and we weren't allowed to try to
1247 // negotiate one, so return an unknown function error.
1248 return make_error<CouldNotNegotiate>(Func::getPrototype());
1249 }
1250
1251 using WrappedHandlerFn = std::function<Error(ChannelT &, SequenceNumberT)>;
1252
1253 // Wrap the given user handler in the necessary argument-deserialization code,
1254 // result-serialization code, and call to the launch policy (if present).
1255 template <typename Func, typename HandlerT>
1256 WrappedHandlerFn wrapHandler(HandlerT Handler) {
1257 return [this, Handler](ChannelT &Channel,
1258 SequenceNumberT SeqNo) mutable -> Error {
1259 // Start by deserializing the arguments.
1260 using ArgsTuple =
1261 typename detail::FunctionArgsTuple<
1262 typename detail::HandlerTraits<HandlerT>::Type>::Type;
1263 auto Args = std::make_shared<ArgsTuple>();
1264
1265 if (auto Err =
1266 detail::HandlerTraits<typename Func::Type>::deserializeArgs(
1267 Channel, *Args))
1268 return Err;
1269
1270 // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning
1271 // for RPCArgs. Void cast RPCArgs to work around this for now.
1272 // FIXME: Remove this workaround once we can assume a working GCC version.
1273 (void)Args;
1274
1275 // End receieve message, unlocking the channel for reading.
1276 if (auto Err = Channel.endReceiveMessage())
1277 return Err;
1278
1279 using HTraits = detail::HandlerTraits<HandlerT>;
1280 using FuncReturn = typename Func::ReturnType;
1281 return detail::respond<FuncReturn>(Channel, ResponseId, SeqNo,
1282 HTraits::unpackAndRun(Handler, *Args));
1283 };
1284 }
1285
1286 // Wrap the given user handler in the necessary argument-deserialization code,
1287 // result-serialization code, and call to the launch policy (if present).
1288 template <typename Func, typename HandlerT>
1289 WrappedHandlerFn wrapAsyncHandler(HandlerT Handler) {
1290 return [this, Handler](ChannelT &Channel,
1291 SequenceNumberT SeqNo) mutable -> Error {
1292 // Start by deserializing the arguments.
1293 using AHTraits = detail::AsyncHandlerTraits<
1294 typename detail::HandlerTraits<HandlerT>::Type>;
1295 using ArgsTuple =
1296 typename detail::FunctionArgsTuple<typename AHTraits::Type>::Type;
1297 auto Args = std::make_shared<ArgsTuple>();
1298
1299 if (auto Err =
1300 detail::HandlerTraits<typename Func::Type>::deserializeArgs(
1301 Channel, *Args))
1302 return Err;
1303
1304 // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning
1305 // for RPCArgs. Void cast RPCArgs to work around this for now.
1306 // FIXME: Remove this workaround once we can assume a working GCC version.
1307 (void)Args;
1308
1309 // End receieve message, unlocking the channel for reading.
1310 if (auto Err = Channel.endReceiveMessage())
1311 return Err;
1312
1313 using HTraits = detail::HandlerTraits<HandlerT>;
1314 using FuncReturn = typename Func::ReturnType;
1315 auto Responder =
1316 [this, SeqNo](typename AHTraits::ResultType RetVal) -> Error {
1317 return detail::respond<FuncReturn>(C, ResponseId, SeqNo,
1318 std::move(RetVal));
1319 };
1320
1321 return HTraits::unpackAndRunAsync(Handler, Responder, *Args);
1322 };
1323 }
1324
1325 ChannelT &C;
1326
1327 bool LazyAutoNegotiation;
1328
1329 RPCFunctionIdAllocator<FunctionIdT> FnIdAllocator;
1330
1331 FunctionIdT ResponseId;
1332 std::map<std::string, FunctionIdT> LocalFunctionIds;
1333 std::map<const char *, FunctionIdT> RemoteFunctionIds;
1334
1335 std::map<FunctionIdT, WrappedHandlerFn> Handlers;
1336
1337 std::mutex ResponsesMutex;
1338 detail::SequenceNumberManager<SequenceNumberT> SequenceNumberMgr;
1339 std::map<SequenceNumberT, std::unique_ptr<detail::ResponseHandler<ChannelT>>>
1340 PendingResponses;
1341};
1342
1343} // end namespace detail
1344
1345template <typename ChannelT, typename FunctionIdT = uint32_t,
1346 typename SequenceNumberT = uint32_t>
1347class MultiThreadedRPCEndpoint
1348 : public detail::RPCEndpointBase<
1349 MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1350 ChannelT, FunctionIdT, SequenceNumberT> {
1351private:
1352 using BaseClass =
1353 detail::RPCEndpointBase<
1354 MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1355 ChannelT, FunctionIdT, SequenceNumberT>;
1356
1357public:
1358 MultiThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation)
1359 : BaseClass(C, LazyAutoNegotiation) {}
1360
1361 /// Add a handler for the given RPC function.
1362 /// This installs the given handler functor for the given RPC Function, and
1363 /// makes the RPC function available for negotiation/calling from the remote.
1364 template <typename Func, typename HandlerT>
1365 void addHandler(HandlerT Handler) {
1366 return this->template addHandlerImpl<Func>(std::move(Handler));
1367 }
1368
1369 /// Add a class-method as a handler.
1370 template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
1371 void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1372 addHandler<Func>(
1373 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1374 }
1375
1376 template <typename Func, typename HandlerT>
1377 void addAsyncHandler(HandlerT Handler) {
1378 return this->template addAsyncHandlerImpl<Func>(std::move(Handler));
1379 }
1380
1381 /// Add a class-method as a handler.
1382 template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
1383 void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1384 addAsyncHandler<Func>(
1385 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1386 }
1387
1388 /// Return type for non-blocking call primitives.
1389 template <typename Func>
1390 using NonBlockingCallResult = typename detail::ResultTraits<
1391 typename Func::ReturnType>::ReturnFutureType;
1392
1393 /// Call Func on Channel C. Does not block, does not call send. Returns a pair
1394 /// of a future result and the sequence number assigned to the result.
1395 ///
1396 /// This utility function is primarily used for single-threaded mode support,
1397 /// where the sequence number can be used to wait for the corresponding
1398 /// result. In multi-threaded mode the appendCallNB method, which does not
1399 /// return the sequence numeber, should be preferred.
1400 template <typename Func, typename... ArgTs>
1401 Expected<NonBlockingCallResult<Func>> appendCallNB(const ArgTs &... Args) {
1402 using RTraits = detail::ResultTraits<typename Func::ReturnType>;
1403 using ErrorReturn = typename RTraits::ErrorReturnType;
1404 using ErrorReturnPromise = typename RTraits::ReturnPromiseType;
1405
1406 // FIXME: Stack allocate and move this into the handler once LLVM builds
1407 // with C++14.
1408 auto Promise = std::make_shared<ErrorReturnPromise>();
1409 auto FutureResult = Promise->get_future();
1410
1411 if (auto Err = this->template appendCallAsync<Func>(
1412 [Promise](ErrorReturn RetOrErr) {
1413 Promise->set_value(std::move(RetOrErr));
1414 return Error::success();
1415 },
1416 Args...)) {
1417 RTraits::consumeAbandoned(FutureResult.get());
1418 return std::move(Err);
1419 }
1420 return std::move(FutureResult);
1421 }
1422
1423 /// The same as appendCallNBWithSeq, except that it calls C.send() to
1424 /// flush the channel after serializing the call.
1425 template <typename Func, typename... ArgTs>
1426 Expected<NonBlockingCallResult<Func>> callNB(const ArgTs &... Args) {
1427 auto Result = appendCallNB<Func>(Args...);
1428 if (!Result)
1429 return Result;
1430 if (auto Err = this->C.send()) {
1431 this->abandonPendingResponses();
1432 detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1433 std::move(Result->get()));
1434 return std::move(Err);
1435 }
1436 return Result;
1437 }
1438
1439 /// Call Func on Channel C. Blocks waiting for a result. Returns an Error
1440 /// for void functions or an Expected<T> for functions returning a T.
1441 ///
1442 /// This function is for use in threaded code where another thread is
1443 /// handling responses and incoming calls.
1444 template <typename Func, typename... ArgTs,
1445 typename AltRetT = typename Func::ReturnType>
1446 typename detail::ResultTraits<AltRetT>::ErrorReturnType
1447 callB(const ArgTs &... Args) {
1448 if (auto FutureResOrErr = callNB<Func>(Args...))
1449 return FutureResOrErr->get();
1450 else
1451 return FutureResOrErr.takeError();
1452 }
1453
1454 /// Handle incoming RPC calls.
1455 Error handlerLoop() {
1456 while (true)
1457 if (auto Err = this->handleOne())
1458 return Err;
1459 return Error::success();
1460 }
1461};
1462
1463template <typename ChannelT, typename FunctionIdT = uint32_t,
1464 typename SequenceNumberT = uint32_t>
1465class SingleThreadedRPCEndpoint
1466 : public detail::RPCEndpointBase<
1467 SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1468 ChannelT, FunctionIdT, SequenceNumberT> {
1469private:
1470 using BaseClass =
1471 detail::RPCEndpointBase<
1472 SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1473 ChannelT, FunctionIdT, SequenceNumberT>;
1474
1475public:
1476 SingleThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation)
1477 : BaseClass(C, LazyAutoNegotiation) {}
1478
1479 template <typename Func, typename HandlerT>
1480 void addHandler(HandlerT Handler) {
1481 return this->template addHandlerImpl<Func>(std::move(Handler));
1482 }
1483
1484 template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
1485 void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1486 addHandler<Func>(
1487 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1488 }
1489
1490 template <typename Func, typename HandlerT>
1491 void addAsyncHandler(HandlerT Handler) {
1492 return this->template addAsyncHandlerImpl<Func>(std::move(Handler));
1493 }
1494
1495 /// Add a class-method as a handler.
1496 template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
1497 void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1498 addAsyncHandler<Func>(
1499 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1500 }
1501
1502 template <typename Func, typename... ArgTs,
1503 typename AltRetT = typename Func::ReturnType>
1504 typename detail::ResultTraits<AltRetT>::ErrorReturnType
1505 callB(const ArgTs &... Args) {
1506 bool ReceivedResponse = false;
1507 using ResultType = typename detail::ResultTraits<AltRetT>::ErrorReturnType;
1508 auto Result = detail::ResultTraits<AltRetT>::createBlankErrorReturnValue();
1509
1510 // We have to 'Check' result (which we know is in a success state at this
1511 // point) so that it can be overwritten in the async handler.
1512 (void)!!Result;
1513
1514 if (auto Err = this->template appendCallAsync<Func>(
1515 [&](ResultType R) {
1516 Result = std::move(R);
1517 ReceivedResponse = true;
1518 return Error::success();
1519 },
1520 Args...)) {
1521 detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1522 std::move(Result));
1523 return std::move(Err);
1524 }
1525
1526 while (!ReceivedResponse) {
1527 if (auto Err = this->handleOne()) {
1528 detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1529 std::move(Result));
1530 return std::move(Err);
1531 }
1532 }
1533
1534 return Result;
1535 }
1536};
1537
1538/// Asynchronous dispatch for a function on an RPC endpoint.
1539template <typename RPCClass, typename Func>
1540class RPCAsyncDispatch {
1541public:
1542 RPCAsyncDispatch(RPCClass &Endpoint) : Endpoint(Endpoint) {}
1543
1544 template <typename HandlerT, typename... ArgTs>
1545 Error operator()(HandlerT Handler, const ArgTs &... Args) const {
1546 return Endpoint.template appendCallAsync<Func>(std::move(Handler), Args...);
1547 }
1548
1549private:
1550 RPCClass &Endpoint;
1551};
1552
1553/// Construct an asynchronous dispatcher from an RPC endpoint and a Func.
1554template <typename Func, typename RPCEndpointT>
1555RPCAsyncDispatch<RPCEndpointT, Func> rpcAsyncDispatch(RPCEndpointT &Endpoint) {
1556 return RPCAsyncDispatch<RPCEndpointT, Func>(Endpoint);
1557}
1558
Andrew Scullcdfcccc2018-10-05 20:58:37 +01001559/// Allows a set of asynchrounous calls to be dispatched, and then
Andrew Scull5e1ddfa2018-08-14 10:06:54 +01001560/// waited on as a group.
1561class ParallelCallGroup {
1562public:
1563
1564 ParallelCallGroup() = default;
1565 ParallelCallGroup(const ParallelCallGroup &) = delete;
1566 ParallelCallGroup &operator=(const ParallelCallGroup &) = delete;
1567
Andrew Scullcdfcccc2018-10-05 20:58:37 +01001568 /// Make as asynchronous call.
Andrew Scull5e1ddfa2018-08-14 10:06:54 +01001569 template <typename AsyncDispatcher, typename HandlerT, typename... ArgTs>
1570 Error call(const AsyncDispatcher &AsyncDispatch, HandlerT Handler,
1571 const ArgTs &... Args) {
1572 // Increment the count of outstanding calls. This has to happen before
1573 // we invoke the call, as the handler may (depending on scheduling)
1574 // be run immediately on another thread, and we don't want the decrement
1575 // in the wrapped handler below to run before the increment.
1576 {
1577 std::unique_lock<std::mutex> Lock(M);
1578 ++NumOutstandingCalls;
1579 }
1580
1581 // Wrap the user handler in a lambda that will decrement the
1582 // outstanding calls count, then poke the condition variable.
1583 using ArgType = typename detail::ResponseHandlerArg<
1584 typename detail::HandlerTraits<HandlerT>::Type>::ArgType;
1585 // FIXME: Move handler into wrapped handler once we have C++14.
1586 auto WrappedHandler = [this, Handler](ArgType Arg) {
1587 auto Err = Handler(std::move(Arg));
1588 std::unique_lock<std::mutex> Lock(M);
1589 --NumOutstandingCalls;
1590 CV.notify_all();
1591 return Err;
1592 };
1593
1594 return AsyncDispatch(std::move(WrappedHandler), Args...);
1595 }
1596
Andrew Scullcdfcccc2018-10-05 20:58:37 +01001597 /// Blocks until all calls have been completed and their return value
Andrew Scull5e1ddfa2018-08-14 10:06:54 +01001598 /// handlers run.
1599 void wait() {
1600 std::unique_lock<std::mutex> Lock(M);
1601 while (NumOutstandingCalls > 0)
1602 CV.wait(Lock);
1603 }
1604
1605private:
1606 std::mutex M;
1607 std::condition_variable CV;
1608 uint32_t NumOutstandingCalls = 0;
1609};
1610
Andrew Scullcdfcccc2018-10-05 20:58:37 +01001611/// Convenience class for grouping RPC Functions into APIs that can be
Andrew Scull5e1ddfa2018-08-14 10:06:54 +01001612/// negotiated as a block.
1613///
1614template <typename... Funcs>
1615class APICalls {
1616public:
1617
Andrew Scullcdfcccc2018-10-05 20:58:37 +01001618 /// Test whether this API contains Function F.
Andrew Scull5e1ddfa2018-08-14 10:06:54 +01001619 template <typename F>
1620 class Contains {
1621 public:
1622 static const bool value = false;
1623 };
1624
Andrew Scullcdfcccc2018-10-05 20:58:37 +01001625 /// Negotiate all functions in this API.
Andrew Scull5e1ddfa2018-08-14 10:06:54 +01001626 template <typename RPCEndpoint>
1627 static Error negotiate(RPCEndpoint &R) {
1628 return Error::success();
1629 }
1630};
1631
1632template <typename Func, typename... Funcs>
1633class APICalls<Func, Funcs...> {
1634public:
1635
1636 template <typename F>
1637 class Contains {
1638 public:
1639 static const bool value = std::is_same<F, Func>::value |
1640 APICalls<Funcs...>::template Contains<F>::value;
1641 };
1642
1643 template <typename RPCEndpoint>
1644 static Error negotiate(RPCEndpoint &R) {
1645 if (auto Err = R.template negotiateFunction<Func>())
1646 return Err;
1647 return APICalls<Funcs...>::negotiate(R);
1648 }
1649
1650};
1651
1652template <typename... InnerFuncs, typename... Funcs>
1653class APICalls<APICalls<InnerFuncs...>, Funcs...> {
1654public:
1655
1656 template <typename F>
1657 class Contains {
1658 public:
1659 static const bool value =
1660 APICalls<InnerFuncs...>::template Contains<F>::value |
1661 APICalls<Funcs...>::template Contains<F>::value;
1662 };
1663
1664 template <typename RPCEndpoint>
1665 static Error negotiate(RPCEndpoint &R) {
1666 if (auto Err = APICalls<InnerFuncs...>::negotiate(R))
1667 return Err;
1668 return APICalls<Funcs...>::negotiate(R);
1669 }
1670
1671};
1672
1673} // end namespace rpc
1674} // end namespace orc
1675} // end namespace llvm
1676
1677#endif