blob: 953b73e10e43da94e43bb8aa6593554db3edae85 [file] [log] [blame]
Andrew Scull5e1ddfa2018-08-14 10:06:54 +01001//===------- RPCUTils.h - Utilities for building RPC APIs -------*- C++ -*-===//
2//
3// The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// Utilities to support construction of simple RPC APIs.
11//
12// The RPC utilities aim for ease of use (minimal conceptual overhead) for C++
13// programmers, high performance, low memory overhead, and efficient use of the
14// communications channel.
15//
16//===----------------------------------------------------------------------===//
17
18#ifndef LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H
19#define LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H
20
21#include <map>
22#include <thread>
23#include <vector>
24
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ExecutionEngine/Orc/OrcError.h"
27#include "llvm/ExecutionEngine/Orc/RPCSerialization.h"
Andrew Scull0372a572018-11-16 15:47:06 +000028#include "llvm/Support/MSVCErrorWorkarounds.h"
Andrew Scull5e1ddfa2018-08-14 10:06:54 +010029
30#include <future>
31
32namespace llvm {
33namespace orc {
34namespace rpc {
35
36/// Base class of all fatal RPC errors (those that necessarily result in the
37/// termination of the RPC session).
38class RPCFatalError : public ErrorInfo<RPCFatalError> {
39public:
40 static char ID;
41};
42
43/// RPCConnectionClosed is returned from RPC operations if the RPC connection
44/// has already been closed due to either an error or graceful disconnection.
45class ConnectionClosed : public ErrorInfo<ConnectionClosed> {
46public:
47 static char ID;
48 std::error_code convertToErrorCode() const override;
49 void log(raw_ostream &OS) const override;
50};
51
52/// BadFunctionCall is returned from handleOne when the remote makes a call with
53/// an unrecognized function id.
54///
55/// This error is fatal because Orc RPC needs to know how to parse a function
56/// call to know where the next call starts, and if it doesn't recognize the
57/// function id it cannot parse the call.
58template <typename FnIdT, typename SeqNoT>
59class BadFunctionCall
60 : public ErrorInfo<BadFunctionCall<FnIdT, SeqNoT>, RPCFatalError> {
61public:
62 static char ID;
63
64 BadFunctionCall(FnIdT FnId, SeqNoT SeqNo)
65 : FnId(std::move(FnId)), SeqNo(std::move(SeqNo)) {}
66
67 std::error_code convertToErrorCode() const override {
68 return orcError(OrcErrorCode::UnexpectedRPCCall);
69 }
70
71 void log(raw_ostream &OS) const override {
72 OS << "Call to invalid RPC function id '" << FnId << "' with "
73 "sequence number " << SeqNo;
74 }
75
76private:
77 FnIdT FnId;
78 SeqNoT SeqNo;
79};
80
81template <typename FnIdT, typename SeqNoT>
82char BadFunctionCall<FnIdT, SeqNoT>::ID = 0;
83
84/// InvalidSequenceNumberForResponse is returned from handleOne when a response
85/// call arrives with a sequence number that doesn't correspond to any in-flight
86/// function call.
87///
88/// This error is fatal because Orc RPC needs to know how to parse the rest of
89/// the response call to know where the next call starts, and if it doesn't have
90/// a result parser for this sequence number it can't do that.
91template <typename SeqNoT>
92class InvalidSequenceNumberForResponse
93 : public ErrorInfo<InvalidSequenceNumberForResponse<SeqNoT>, RPCFatalError> {
94public:
95 static char ID;
96
97 InvalidSequenceNumberForResponse(SeqNoT SeqNo)
98 : SeqNo(std::move(SeqNo)) {}
99
100 std::error_code convertToErrorCode() const override {
101 return orcError(OrcErrorCode::UnexpectedRPCCall);
102 };
103
104 void log(raw_ostream &OS) const override {
105 OS << "Response has unknown sequence number " << SeqNo;
106 }
107private:
108 SeqNoT SeqNo;
109};
110
111template <typename SeqNoT>
112char InvalidSequenceNumberForResponse<SeqNoT>::ID = 0;
113
114/// This non-fatal error will be passed to asynchronous result handlers in place
115/// of a result if the connection goes down before a result returns, or if the
116/// function to be called cannot be negotiated with the remote.
117class ResponseAbandoned : public ErrorInfo<ResponseAbandoned> {
118public:
119 static char ID;
120
121 std::error_code convertToErrorCode() const override;
122 void log(raw_ostream &OS) const override;
123};
124
125/// This error is returned if the remote does not have a handler installed for
126/// the given RPC function.
127class CouldNotNegotiate : public ErrorInfo<CouldNotNegotiate> {
128public:
129 static char ID;
130
131 CouldNotNegotiate(std::string Signature);
132 std::error_code convertToErrorCode() const override;
133 void log(raw_ostream &OS) const override;
134 const std::string &getSignature() const { return Signature; }
135private:
136 std::string Signature;
137};
138
139template <typename DerivedFunc, typename FnT> class Function;
140
141// RPC Function class.
142// DerivedFunc should be a user defined class with a static 'getName()' method
143// returning a const char* representing the function's name.
144template <typename DerivedFunc, typename RetT, typename... ArgTs>
145class Function<DerivedFunc, RetT(ArgTs...)> {
146public:
147 /// User defined function type.
148 using Type = RetT(ArgTs...);
149
150 /// Return type.
151 using ReturnType = RetT;
152
153 /// Returns the full function prototype as a string.
154 static const char *getPrototype() {
155 std::lock_guard<std::mutex> Lock(NameMutex);
156 if (Name.empty())
157 raw_string_ostream(Name)
158 << RPCTypeName<RetT>::getName() << " " << DerivedFunc::getName()
159 << "(" << llvm::orc::rpc::RPCTypeNameSequence<ArgTs...>() << ")";
160 return Name.data();
161 }
162
163private:
164 static std::mutex NameMutex;
165 static std::string Name;
166};
167
168template <typename DerivedFunc, typename RetT, typename... ArgTs>
169std::mutex Function<DerivedFunc, RetT(ArgTs...)>::NameMutex;
170
171template <typename DerivedFunc, typename RetT, typename... ArgTs>
172std::string Function<DerivedFunc, RetT(ArgTs...)>::Name;
173
174/// Allocates RPC function ids during autonegotiation.
175/// Specializations of this class must provide four members:
176///
177/// static T getInvalidId():
178/// Should return a reserved id that will be used to represent missing
179/// functions during autonegotiation.
180///
181/// static T getResponseId():
182/// Should return a reserved id that will be used to send function responses
183/// (return values).
184///
185/// static T getNegotiateId():
186/// Should return a reserved id for the negotiate function, which will be used
187/// to negotiate ids for user defined functions.
188///
189/// template <typename Func> T allocate():
190/// Allocate a unique id for function Func.
191template <typename T, typename = void> class RPCFunctionIdAllocator;
192
193/// This specialization of RPCFunctionIdAllocator provides a default
194/// implementation for integral types.
195template <typename T>
196class RPCFunctionIdAllocator<
197 T, typename std::enable_if<std::is_integral<T>::value>::type> {
198public:
199 static T getInvalidId() { return T(0); }
200 static T getResponseId() { return T(1); }
201 static T getNegotiateId() { return T(2); }
202
203 template <typename Func> T allocate() { return NextId++; }
204
205private:
206 T NextId = 3;
207};
208
209namespace detail {
210
Andrew Scull5e1ddfa2018-08-14 10:06:54 +0100211/// Provides a typedef for a tuple containing the decayed argument types.
212template <typename T> class FunctionArgsTuple;
213
214template <typename RetT, typename... ArgTs>
215class FunctionArgsTuple<RetT(ArgTs...)> {
216public:
217 using Type = std::tuple<typename std::decay<
218 typename std::remove_reference<ArgTs>::type>::type...>;
219};
220
221// ResultTraits provides typedefs and utilities specific to the return type
222// of functions.
223template <typename RetT> class ResultTraits {
224public:
225 // The return type wrapped in llvm::Expected.
226 using ErrorReturnType = Expected<RetT>;
227
228#ifdef _MSC_VER
229 // The ErrorReturnType wrapped in a std::promise.
Andrew Scull0372a572018-11-16 15:47:06 +0000230 using ReturnPromiseType = std::promise<MSVCPExpected<RetT>>;
Andrew Scull5e1ddfa2018-08-14 10:06:54 +0100231
232 // The ErrorReturnType wrapped in a std::future.
Andrew Scull0372a572018-11-16 15:47:06 +0000233 using ReturnFutureType = std::future<MSVCPExpected<RetT>>;
Andrew Scull5e1ddfa2018-08-14 10:06:54 +0100234#else
235 // The ErrorReturnType wrapped in a std::promise.
236 using ReturnPromiseType = std::promise<ErrorReturnType>;
237
238 // The ErrorReturnType wrapped in a std::future.
239 using ReturnFutureType = std::future<ErrorReturnType>;
240#endif
241
242 // Create a 'blank' value of the ErrorReturnType, ready and safe to
243 // overwrite.
244 static ErrorReturnType createBlankErrorReturnValue() {
245 return ErrorReturnType(RetT());
246 }
247
248 // Consume an abandoned ErrorReturnType.
249 static void consumeAbandoned(ErrorReturnType RetOrErr) {
250 consumeError(RetOrErr.takeError());
251 }
252};
253
254// ResultTraits specialization for void functions.
255template <> class ResultTraits<void> {
256public:
257 // For void functions, ErrorReturnType is llvm::Error.
258 using ErrorReturnType = Error;
259
260#ifdef _MSC_VER
261 // The ErrorReturnType wrapped in a std::promise.
Andrew Scull0372a572018-11-16 15:47:06 +0000262 using ReturnPromiseType = std::promise<MSVCPError>;
Andrew Scull5e1ddfa2018-08-14 10:06:54 +0100263
264 // The ErrorReturnType wrapped in a std::future.
Andrew Scull0372a572018-11-16 15:47:06 +0000265 using ReturnFutureType = std::future<MSVCPError>;
Andrew Scull5e1ddfa2018-08-14 10:06:54 +0100266#else
267 // The ErrorReturnType wrapped in a std::promise.
268 using ReturnPromiseType = std::promise<ErrorReturnType>;
269
270 // The ErrorReturnType wrapped in a std::future.
271 using ReturnFutureType = std::future<ErrorReturnType>;
272#endif
273
274 // Create a 'blank' value of the ErrorReturnType, ready and safe to
275 // overwrite.
276 static ErrorReturnType createBlankErrorReturnValue() {
277 return ErrorReturnType::success();
278 }
279
280 // Consume an abandoned ErrorReturnType.
281 static void consumeAbandoned(ErrorReturnType Err) {
282 consumeError(std::move(Err));
283 }
284};
285
286// ResultTraits<Error> is equivalent to ResultTraits<void>. This allows
287// handlers for void RPC functions to return either void (in which case they
288// implicitly succeed) or Error (in which case their error return is
289// propagated). See usage in HandlerTraits::runHandlerHelper.
290template <> class ResultTraits<Error> : public ResultTraits<void> {};
291
292// ResultTraits<Expected<T>> is equivalent to ResultTraits<T>. This allows
293// handlers for RPC functions returning a T to return either a T (in which
294// case they implicitly succeed) or Expected<T> (in which case their error
295// return is propagated). See usage in HandlerTraits::runHandlerHelper.
296template <typename RetT>
297class ResultTraits<Expected<RetT>> : public ResultTraits<RetT> {};
298
299// Determines whether an RPC function's defined error return type supports
300// error return value.
301template <typename T>
302class SupportsErrorReturn {
303public:
304 static const bool value = false;
305};
306
307template <>
308class SupportsErrorReturn<Error> {
309public:
310 static const bool value = true;
311};
312
313template <typename T>
314class SupportsErrorReturn<Expected<T>> {
315public:
316 static const bool value = true;
317};
318
319// RespondHelper packages return values based on whether or not the declared
320// RPC function return type supports error returns.
321template <bool FuncSupportsErrorReturn>
322class RespondHelper;
323
324// RespondHelper specialization for functions that support error returns.
325template <>
326class RespondHelper<true> {
327public:
328
329 // Send Expected<T>.
330 template <typename WireRetT, typename HandlerRetT, typename ChannelT,
331 typename FunctionIdT, typename SequenceNumberT>
332 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
333 SequenceNumberT SeqNo,
334 Expected<HandlerRetT> ResultOrErr) {
335 if (!ResultOrErr && ResultOrErr.template errorIsA<RPCFatalError>())
336 return ResultOrErr.takeError();
337
338 // Open the response message.
339 if (auto Err = C.startSendMessage(ResponseId, SeqNo))
340 return Err;
341
342 // Serialize the result.
343 if (auto Err =
344 SerializationTraits<ChannelT, WireRetT,
345 Expected<HandlerRetT>>::serialize(
346 C, std::move(ResultOrErr)))
347 return Err;
348
349 // Close the response message.
350 return C.endSendMessage();
351 }
352
353 template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT>
354 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
355 SequenceNumberT SeqNo, Error Err) {
356 if (Err && Err.isA<RPCFatalError>())
357 return Err;
358 if (auto Err2 = C.startSendMessage(ResponseId, SeqNo))
359 return Err2;
360 if (auto Err2 = serializeSeq(C, std::move(Err)))
361 return Err2;
362 return C.endSendMessage();
363 }
364
365};
366
367// RespondHelper specialization for functions that do not support error returns.
368template <>
369class RespondHelper<false> {
370public:
371
372 template <typename WireRetT, typename HandlerRetT, typename ChannelT,
373 typename FunctionIdT, typename SequenceNumberT>
374 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
375 SequenceNumberT SeqNo,
376 Expected<HandlerRetT> ResultOrErr) {
377 if (auto Err = ResultOrErr.takeError())
378 return Err;
379
380 // Open the response message.
381 if (auto Err = C.startSendMessage(ResponseId, SeqNo))
382 return Err;
383
384 // Serialize the result.
385 if (auto Err =
386 SerializationTraits<ChannelT, WireRetT, HandlerRetT>::serialize(
387 C, *ResultOrErr))
388 return Err;
389
390 // Close the response message.
391 return C.endSendMessage();
392 }
393
394 template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT>
395 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
396 SequenceNumberT SeqNo, Error Err) {
397 if (Err)
398 return Err;
399 if (auto Err2 = C.startSendMessage(ResponseId, SeqNo))
400 return Err2;
401 return C.endSendMessage();
402 }
403
404};
405
406
407// Send a response of the given wire return type (WireRetT) over the
408// channel, with the given sequence number.
409template <typename WireRetT, typename HandlerRetT, typename ChannelT,
410 typename FunctionIdT, typename SequenceNumberT>
411Error respond(ChannelT &C, const FunctionIdT &ResponseId,
412 SequenceNumberT SeqNo, Expected<HandlerRetT> ResultOrErr) {
413 return RespondHelper<SupportsErrorReturn<WireRetT>::value>::
414 template sendResult<WireRetT>(C, ResponseId, SeqNo, std::move(ResultOrErr));
415}
416
417// Send an empty response message on the given channel to indicate that
418// the handler ran.
419template <typename WireRetT, typename ChannelT, typename FunctionIdT,
420 typename SequenceNumberT>
421Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo,
422 Error Err) {
423 return RespondHelper<SupportsErrorReturn<WireRetT>::value>::
424 sendResult(C, ResponseId, SeqNo, std::move(Err));
425}
426
427// Converts a given type to the equivalent error return type.
428template <typename T> class WrappedHandlerReturn {
429public:
430 using Type = Expected<T>;
431};
432
433template <typename T> class WrappedHandlerReturn<Expected<T>> {
434public:
435 using Type = Expected<T>;
436};
437
438template <> class WrappedHandlerReturn<void> {
439public:
440 using Type = Error;
441};
442
443template <> class WrappedHandlerReturn<Error> {
444public:
445 using Type = Error;
446};
447
448template <> class WrappedHandlerReturn<ErrorSuccess> {
449public:
450 using Type = Error;
451};
452
453// Traits class that strips the response function from the list of handler
454// arguments.
455template <typename FnT> class AsyncHandlerTraits;
456
457template <typename ResultT, typename... ArgTs>
458class AsyncHandlerTraits<Error(std::function<Error(Expected<ResultT>)>, ArgTs...)> {
459public:
460 using Type = Error(ArgTs...);
461 using ResultType = Expected<ResultT>;
462};
463
464template <typename... ArgTs>
465class AsyncHandlerTraits<Error(std::function<Error(Error)>, ArgTs...)> {
466public:
467 using Type = Error(ArgTs...);
468 using ResultType = Error;
469};
470
471template <typename... ArgTs>
472class AsyncHandlerTraits<ErrorSuccess(std::function<Error(Error)>, ArgTs...)> {
473public:
474 using Type = Error(ArgTs...);
475 using ResultType = Error;
476};
477
478template <typename... ArgTs>
479class AsyncHandlerTraits<void(std::function<Error(Error)>, ArgTs...)> {
480public:
481 using Type = Error(ArgTs...);
482 using ResultType = Error;
483};
484
485template <typename ResponseHandlerT, typename... ArgTs>
486class AsyncHandlerTraits<Error(ResponseHandlerT, ArgTs...)> :
487 public AsyncHandlerTraits<Error(typename std::decay<ResponseHandlerT>::type,
488 ArgTs...)> {};
489
490// This template class provides utilities related to RPC function handlers.
491// The base case applies to non-function types (the template class is
492// specialized for function types) and inherits from the appropriate
493// speciilization for the given non-function type's call operator.
494template <typename HandlerT>
495class HandlerTraits : public HandlerTraits<decltype(
496 &std::remove_reference<HandlerT>::type::operator())> {
497};
498
499// Traits for handlers with a given function type.
500template <typename RetT, typename... ArgTs>
501class HandlerTraits<RetT(ArgTs...)> {
502public:
503 // Function type of the handler.
504 using Type = RetT(ArgTs...);
505
506 // Return type of the handler.
507 using ReturnType = RetT;
508
509 // Call the given handler with the given arguments.
510 template <typename HandlerT, typename... TArgTs>
511 static typename WrappedHandlerReturn<RetT>::Type
512 unpackAndRun(HandlerT &Handler, std::tuple<TArgTs...> &Args) {
513 return unpackAndRunHelper(Handler, Args,
514 llvm::index_sequence_for<TArgTs...>());
515 }
516
517 // Call the given handler with the given arguments.
518 template <typename HandlerT, typename ResponderT, typename... TArgTs>
519 static Error unpackAndRunAsync(HandlerT &Handler, ResponderT &Responder,
520 std::tuple<TArgTs...> &Args) {
521 return unpackAndRunAsyncHelper(Handler, Responder, Args,
522 llvm::index_sequence_for<TArgTs...>());
523 }
524
525 // Call the given handler with the given arguments.
526 template <typename HandlerT>
527 static typename std::enable_if<
528 std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value,
529 Error>::type
530 run(HandlerT &Handler, ArgTs &&... Args) {
531 Handler(std::move(Args)...);
532 return Error::success();
533 }
534
535 template <typename HandlerT, typename... TArgTs>
536 static typename std::enable_if<
537 !std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value,
538 typename HandlerTraits<HandlerT>::ReturnType>::type
539 run(HandlerT &Handler, TArgTs... Args) {
540 return Handler(std::move(Args)...);
541 }
542
543 // Serialize arguments to the channel.
544 template <typename ChannelT, typename... CArgTs>
545 static Error serializeArgs(ChannelT &C, const CArgTs... CArgs) {
546 return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, CArgs...);
547 }
548
549 // Deserialize arguments from the channel.
550 template <typename ChannelT, typename... CArgTs>
551 static Error deserializeArgs(ChannelT &C, std::tuple<CArgTs...> &Args) {
552 return deserializeArgsHelper(C, Args,
553 llvm::index_sequence_for<CArgTs...>());
554 }
555
556private:
557 template <typename ChannelT, typename... CArgTs, size_t... Indexes>
558 static Error deserializeArgsHelper(ChannelT &C, std::tuple<CArgTs...> &Args,
559 llvm::index_sequence<Indexes...> _) {
560 return SequenceSerialization<ChannelT, ArgTs...>::deserialize(
561 C, std::get<Indexes>(Args)...);
562 }
563
564 template <typename HandlerT, typename ArgTuple, size_t... Indexes>
565 static typename WrappedHandlerReturn<
566 typename HandlerTraits<HandlerT>::ReturnType>::Type
567 unpackAndRunHelper(HandlerT &Handler, ArgTuple &Args,
568 llvm::index_sequence<Indexes...>) {
569 return run(Handler, std::move(std::get<Indexes>(Args))...);
570 }
571
572
573 template <typename HandlerT, typename ResponderT, typename ArgTuple,
574 size_t... Indexes>
575 static typename WrappedHandlerReturn<
576 typename HandlerTraits<HandlerT>::ReturnType>::Type
577 unpackAndRunAsyncHelper(HandlerT &Handler, ResponderT &Responder,
578 ArgTuple &Args,
579 llvm::index_sequence<Indexes...>) {
580 return run(Handler, Responder, std::move(std::get<Indexes>(Args))...);
581 }
582};
583
584// Handler traits for free functions.
585template <typename RetT, typename... ArgTs>
586class HandlerTraits<RetT(*)(ArgTs...)>
587 : public HandlerTraits<RetT(ArgTs...)> {};
588
589// Handler traits for class methods (especially call operators for lambdas).
590template <typename Class, typename RetT, typename... ArgTs>
591class HandlerTraits<RetT (Class::*)(ArgTs...)>
592 : public HandlerTraits<RetT(ArgTs...)> {};
593
594// Handler traits for const class methods (especially call operators for
595// lambdas).
596template <typename Class, typename RetT, typename... ArgTs>
597class HandlerTraits<RetT (Class::*)(ArgTs...) const>
598 : public HandlerTraits<RetT(ArgTs...)> {};
599
600// Utility to peel the Expected wrapper off a response handler error type.
601template <typename HandlerT> class ResponseHandlerArg;
602
603template <typename ArgT> class ResponseHandlerArg<Error(Expected<ArgT>)> {
604public:
605 using ArgType = Expected<ArgT>;
606 using UnwrappedArgType = ArgT;
607};
608
609template <typename ArgT>
610class ResponseHandlerArg<ErrorSuccess(Expected<ArgT>)> {
611public:
612 using ArgType = Expected<ArgT>;
613 using UnwrappedArgType = ArgT;
614};
615
616template <> class ResponseHandlerArg<Error(Error)> {
617public:
618 using ArgType = Error;
619};
620
621template <> class ResponseHandlerArg<ErrorSuccess(Error)> {
622public:
623 using ArgType = Error;
624};
625
626// ResponseHandler represents a handler for a not-yet-received function call
627// result.
628template <typename ChannelT> class ResponseHandler {
629public:
630 virtual ~ResponseHandler() {}
631
632 // Reads the function result off the wire and acts on it. The meaning of
633 // "act" will depend on how this method is implemented in any given
634 // ResponseHandler subclass but could, for example, mean running a
635 // user-specified handler or setting a promise value.
636 virtual Error handleResponse(ChannelT &C) = 0;
637
638 // Abandons this outstanding result.
639 virtual void abandon() = 0;
640
641 // Create an error instance representing an abandoned response.
642 static Error createAbandonedResponseError() {
643 return make_error<ResponseAbandoned>();
644 }
645};
646
647// ResponseHandler subclass for RPC functions with non-void returns.
648template <typename ChannelT, typename FuncRetT, typename HandlerT>
649class ResponseHandlerImpl : public ResponseHandler<ChannelT> {
650public:
651 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
652
653 // Handle the result by deserializing it from the channel then passing it
654 // to the user defined handler.
655 Error handleResponse(ChannelT &C) override {
656 using UnwrappedArgType = typename ResponseHandlerArg<
657 typename HandlerTraits<HandlerT>::Type>::UnwrappedArgType;
658 UnwrappedArgType Result;
659 if (auto Err =
660 SerializationTraits<ChannelT, FuncRetT,
661 UnwrappedArgType>::deserialize(C, Result))
662 return Err;
663 if (auto Err = C.endReceiveMessage())
664 return Err;
665 return Handler(std::move(Result));
666 }
667
668 // Abandon this response by calling the handler with an 'abandoned response'
669 // error.
670 void abandon() override {
671 if (auto Err = Handler(this->createAbandonedResponseError())) {
672 // Handlers should not fail when passed an abandoned response error.
673 report_fatal_error(std::move(Err));
674 }
675 }
676
677private:
678 HandlerT Handler;
679};
680
681// ResponseHandler subclass for RPC functions with void returns.
682template <typename ChannelT, typename HandlerT>
683class ResponseHandlerImpl<ChannelT, void, HandlerT>
684 : public ResponseHandler<ChannelT> {
685public:
686 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
687
688 // Handle the result (no actual value, just a notification that the function
689 // has completed on the remote end) by calling the user-defined handler with
690 // Error::success().
691 Error handleResponse(ChannelT &C) override {
692 if (auto Err = C.endReceiveMessage())
693 return Err;
694 return Handler(Error::success());
695 }
696
697 // Abandon this response by calling the handler with an 'abandoned response'
698 // error.
699 void abandon() override {
700 if (auto Err = Handler(this->createAbandonedResponseError())) {
701 // Handlers should not fail when passed an abandoned response error.
702 report_fatal_error(std::move(Err));
703 }
704 }
705
706private:
707 HandlerT Handler;
708};
709
710template <typename ChannelT, typename FuncRetT, typename HandlerT>
711class ResponseHandlerImpl<ChannelT, Expected<FuncRetT>, HandlerT>
712 : public ResponseHandler<ChannelT> {
713public:
714 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
715
716 // Handle the result by deserializing it from the channel then passing it
717 // to the user defined handler.
718 Error handleResponse(ChannelT &C) override {
719 using HandlerArgType = typename ResponseHandlerArg<
720 typename HandlerTraits<HandlerT>::Type>::ArgType;
721 HandlerArgType Result((typename HandlerArgType::value_type()));
722
723 if (auto Err =
724 SerializationTraits<ChannelT, Expected<FuncRetT>,
725 HandlerArgType>::deserialize(C, Result))
726 return Err;
727 if (auto Err = C.endReceiveMessage())
728 return Err;
729 return Handler(std::move(Result));
730 }
731
732 // Abandon this response by calling the handler with an 'abandoned response'
733 // error.
734 void abandon() override {
735 if (auto Err = Handler(this->createAbandonedResponseError())) {
736 // Handlers should not fail when passed an abandoned response error.
737 report_fatal_error(std::move(Err));
738 }
739 }
740
741private:
742 HandlerT Handler;
743};
744
745template <typename ChannelT, typename HandlerT>
746class ResponseHandlerImpl<ChannelT, Error, HandlerT>
747 : public ResponseHandler<ChannelT> {
748public:
749 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
750
751 // Handle the result by deserializing it from the channel then passing it
752 // to the user defined handler.
753 Error handleResponse(ChannelT &C) override {
754 Error Result = Error::success();
755 if (auto Err =
756 SerializationTraits<ChannelT, Error, Error>::deserialize(C, Result))
757 return Err;
758 if (auto Err = C.endReceiveMessage())
759 return Err;
760 return Handler(std::move(Result));
761 }
762
763 // Abandon this response by calling the handler with an 'abandoned response'
764 // error.
765 void abandon() override {
766 if (auto Err = Handler(this->createAbandonedResponseError())) {
767 // Handlers should not fail when passed an abandoned response error.
768 report_fatal_error(std::move(Err));
769 }
770 }
771
772private:
773 HandlerT Handler;
774};
775
776// Create a ResponseHandler from a given user handler.
777template <typename ChannelT, typename FuncRetT, typename HandlerT>
778std::unique_ptr<ResponseHandler<ChannelT>> createResponseHandler(HandlerT H) {
779 return llvm::make_unique<ResponseHandlerImpl<ChannelT, FuncRetT, HandlerT>>(
780 std::move(H));
781}
782
783// Helper for wrapping member functions up as functors. This is useful for
784// installing methods as result handlers.
785template <typename ClassT, typename RetT, typename... ArgTs>
786class MemberFnWrapper {
787public:
788 using MethodT = RetT (ClassT::*)(ArgTs...);
789 MemberFnWrapper(ClassT &Instance, MethodT Method)
790 : Instance(Instance), Method(Method) {}
791 RetT operator()(ArgTs &&... Args) {
792 return (Instance.*Method)(std::move(Args)...);
793 }
794
795private:
796 ClassT &Instance;
797 MethodT Method;
798};
799
800// Helper that provides a Functor for deserializing arguments.
801template <typename... ArgTs> class ReadArgs {
802public:
803 Error operator()() { return Error::success(); }
804};
805
806template <typename ArgT, typename... ArgTs>
807class ReadArgs<ArgT, ArgTs...> : public ReadArgs<ArgTs...> {
808public:
809 ReadArgs(ArgT &Arg, ArgTs &... Args)
810 : ReadArgs<ArgTs...>(Args...), Arg(Arg) {}
811
812 Error operator()(ArgT &ArgVal, ArgTs &... ArgVals) {
813 this->Arg = std::move(ArgVal);
814 return ReadArgs<ArgTs...>::operator()(ArgVals...);
815 }
816
817private:
818 ArgT &Arg;
819};
820
821// Manage sequence numbers.
822template <typename SequenceNumberT> class SequenceNumberManager {
823public:
824 // Reset, making all sequence numbers available.
825 void reset() {
826 std::lock_guard<std::mutex> Lock(SeqNoLock);
827 NextSequenceNumber = 0;
828 FreeSequenceNumbers.clear();
829 }
830
831 // Get the next available sequence number. Will re-use numbers that have
832 // been released.
833 SequenceNumberT getSequenceNumber() {
834 std::lock_guard<std::mutex> Lock(SeqNoLock);
835 if (FreeSequenceNumbers.empty())
836 return NextSequenceNumber++;
837 auto SequenceNumber = FreeSequenceNumbers.back();
838 FreeSequenceNumbers.pop_back();
839 return SequenceNumber;
840 }
841
842 // Release a sequence number, making it available for re-use.
843 void releaseSequenceNumber(SequenceNumberT SequenceNumber) {
844 std::lock_guard<std::mutex> Lock(SeqNoLock);
845 FreeSequenceNumbers.push_back(SequenceNumber);
846 }
847
848private:
849 std::mutex SeqNoLock;
850 SequenceNumberT NextSequenceNumber = 0;
851 std::vector<SequenceNumberT> FreeSequenceNumbers;
852};
853
854// Checks that predicate P holds for each corresponding pair of type arguments
855// from T1 and T2 tuple.
856template <template <class, class> class P, typename T1Tuple, typename T2Tuple>
857class RPCArgTypeCheckHelper;
858
859template <template <class, class> class P>
860class RPCArgTypeCheckHelper<P, std::tuple<>, std::tuple<>> {
861public:
862 static const bool value = true;
863};
864
865template <template <class, class> class P, typename T, typename... Ts,
866 typename U, typename... Us>
867class RPCArgTypeCheckHelper<P, std::tuple<T, Ts...>, std::tuple<U, Us...>> {
868public:
869 static const bool value =
870 P<T, U>::value &&
871 RPCArgTypeCheckHelper<P, std::tuple<Ts...>, std::tuple<Us...>>::value;
872};
873
874template <template <class, class> class P, typename T1Sig, typename T2Sig>
875class RPCArgTypeCheck {
876public:
877 using T1Tuple = typename FunctionArgsTuple<T1Sig>::Type;
878 using T2Tuple = typename FunctionArgsTuple<T2Sig>::Type;
879
880 static_assert(std::tuple_size<T1Tuple>::value >=
881 std::tuple_size<T2Tuple>::value,
882 "Too many arguments to RPC call");
883 static_assert(std::tuple_size<T1Tuple>::value <=
884 std::tuple_size<T2Tuple>::value,
885 "Too few arguments to RPC call");
886
887 static const bool value = RPCArgTypeCheckHelper<P, T1Tuple, T2Tuple>::value;
888};
889
890template <typename ChannelT, typename WireT, typename ConcreteT>
891class CanSerialize {
892private:
893 using S = SerializationTraits<ChannelT, WireT, ConcreteT>;
894
895 template <typename T>
896 static std::true_type
897 check(typename std::enable_if<
898 std::is_same<decltype(T::serialize(std::declval<ChannelT &>(),
899 std::declval<const ConcreteT &>())),
900 Error>::value,
901 void *>::type);
902
903 template <typename> static std::false_type check(...);
904
905public:
906 static const bool value = decltype(check<S>(0))::value;
907};
908
909template <typename ChannelT, typename WireT, typename ConcreteT>
910class CanDeserialize {
911private:
912 using S = SerializationTraits<ChannelT, WireT, ConcreteT>;
913
914 template <typename T>
915 static std::true_type
916 check(typename std::enable_if<
917 std::is_same<decltype(T::deserialize(std::declval<ChannelT &>(),
918 std::declval<ConcreteT &>())),
919 Error>::value,
920 void *>::type);
921
922 template <typename> static std::false_type check(...);
923
924public:
925 static const bool value = decltype(check<S>(0))::value;
926};
927
928/// Contains primitive utilities for defining, calling and handling calls to
929/// remote procedures. ChannelT is a bidirectional stream conforming to the
930/// RPCChannel interface (see RPCChannel.h), FunctionIdT is a procedure
931/// identifier type that must be serializable on ChannelT, and SequenceNumberT
932/// is an integral type that will be used to number in-flight function calls.
933///
934/// These utilities support the construction of very primitive RPC utilities.
935/// Their intent is to ensure correct serialization and deserialization of
936/// procedure arguments, and to keep the client and server's view of the API in
937/// sync.
938template <typename ImplT, typename ChannelT, typename FunctionIdT,
939 typename SequenceNumberT>
940class RPCEndpointBase {
941protected:
942 class OrcRPCInvalid : public Function<OrcRPCInvalid, void()> {
943 public:
944 static const char *getName() { return "__orc_rpc$invalid"; }
945 };
946
947 class OrcRPCResponse : public Function<OrcRPCResponse, void()> {
948 public:
949 static const char *getName() { return "__orc_rpc$response"; }
950 };
951
952 class OrcRPCNegotiate
953 : public Function<OrcRPCNegotiate, FunctionIdT(std::string)> {
954 public:
955 static const char *getName() { return "__orc_rpc$negotiate"; }
956 };
957
958 // Helper predicate for testing for the presence of SerializeTraits
959 // serializers.
960 template <typename WireT, typename ConcreteT>
961 class CanSerializeCheck : detail::CanSerialize<ChannelT, WireT, ConcreteT> {
962 public:
963 using detail::CanSerialize<ChannelT, WireT, ConcreteT>::value;
964
965 static_assert(value, "Missing serializer for argument (Can't serialize the "
966 "first template type argument of CanSerializeCheck "
967 "from the second)");
968 };
969
970 // Helper predicate for testing for the presence of SerializeTraits
971 // deserializers.
972 template <typename WireT, typename ConcreteT>
973 class CanDeserializeCheck
974 : detail::CanDeserialize<ChannelT, WireT, ConcreteT> {
975 public:
976 using detail::CanDeserialize<ChannelT, WireT, ConcreteT>::value;
977
978 static_assert(value, "Missing deserializer for argument (Can't deserialize "
979 "the second template type argument of "
980 "CanDeserializeCheck from the first)");
981 };
982
983public:
984 /// Construct an RPC instance on a channel.
985 RPCEndpointBase(ChannelT &C, bool LazyAutoNegotiation)
986 : C(C), LazyAutoNegotiation(LazyAutoNegotiation) {
987 // Hold ResponseId in a special variable, since we expect Response to be
988 // called relatively frequently, and want to avoid the map lookup.
989 ResponseId = FnIdAllocator.getResponseId();
990 RemoteFunctionIds[OrcRPCResponse::getPrototype()] = ResponseId;
991
992 // Register the negotiate function id and handler.
993 auto NegotiateId = FnIdAllocator.getNegotiateId();
994 RemoteFunctionIds[OrcRPCNegotiate::getPrototype()] = NegotiateId;
995 Handlers[NegotiateId] = wrapHandler<OrcRPCNegotiate>(
996 [this](const std::string &Name) { return handleNegotiate(Name); });
997 }
998
999
1000 /// Negotiate a function id for Func with the other end of the channel.
1001 template <typename Func> Error negotiateFunction(bool Retry = false) {
1002 return getRemoteFunctionId<Func>(true, Retry).takeError();
1003 }
1004
1005 /// Append a call Func, does not call send on the channel.
1006 /// The first argument specifies a user-defined handler to be run when the
1007 /// function returns. The handler should take an Expected<Func::ReturnType>,
1008 /// or an Error (if Func::ReturnType is void). The handler will be called
1009 /// with an error if the return value is abandoned due to a channel error.
1010 template <typename Func, typename HandlerT, typename... ArgTs>
1011 Error appendCallAsync(HandlerT Handler, const ArgTs &... Args) {
1012
1013 static_assert(
1014 detail::RPCArgTypeCheck<CanSerializeCheck, typename Func::Type,
1015 void(ArgTs...)>::value,
1016 "");
1017
1018 // Look up the function ID.
1019 FunctionIdT FnId;
1020 if (auto FnIdOrErr = getRemoteFunctionId<Func>(LazyAutoNegotiation, false))
1021 FnId = *FnIdOrErr;
1022 else {
1023 // Negotiation failed. Notify the handler then return the negotiate-failed
1024 // error.
1025 cantFail(Handler(make_error<ResponseAbandoned>()));
1026 return FnIdOrErr.takeError();
1027 }
1028
1029 SequenceNumberT SeqNo; // initialized in locked scope below.
1030 {
1031 // Lock the pending responses map and sequence number manager.
1032 std::lock_guard<std::mutex> Lock(ResponsesMutex);
1033
1034 // Allocate a sequence number.
1035 SeqNo = SequenceNumberMgr.getSequenceNumber();
1036 assert(!PendingResponses.count(SeqNo) &&
1037 "Sequence number already allocated");
1038
1039 // Install the user handler.
1040 PendingResponses[SeqNo] =
1041 detail::createResponseHandler<ChannelT, typename Func::ReturnType>(
1042 std::move(Handler));
1043 }
1044
1045 // Open the function call message.
1046 if (auto Err = C.startSendMessage(FnId, SeqNo)) {
1047 abandonPendingResponses();
1048 return Err;
1049 }
1050
1051 // Serialize the call arguments.
1052 if (auto Err = detail::HandlerTraits<typename Func::Type>::serializeArgs(
1053 C, Args...)) {
1054 abandonPendingResponses();
1055 return Err;
1056 }
1057
1058 // Close the function call messagee.
1059 if (auto Err = C.endSendMessage()) {
1060 abandonPendingResponses();
1061 return Err;
1062 }
1063
1064 return Error::success();
1065 }
1066
1067 Error sendAppendedCalls() { return C.send(); };
1068
1069 template <typename Func, typename HandlerT, typename... ArgTs>
1070 Error callAsync(HandlerT Handler, const ArgTs &... Args) {
1071 if (auto Err = appendCallAsync<Func>(std::move(Handler), Args...))
1072 return Err;
1073 return C.send();
1074 }
1075
1076 /// Handle one incoming call.
1077 Error handleOne() {
1078 FunctionIdT FnId;
1079 SequenceNumberT SeqNo;
1080 if (auto Err = C.startReceiveMessage(FnId, SeqNo)) {
1081 abandonPendingResponses();
1082 return Err;
1083 }
1084 if (FnId == ResponseId)
1085 return handleResponse(SeqNo);
1086 auto I = Handlers.find(FnId);
1087 if (I != Handlers.end())
1088 return I->second(C, SeqNo);
1089
1090 // else: No handler found. Report error to client?
1091 return make_error<BadFunctionCall<FunctionIdT, SequenceNumberT>>(FnId,
1092 SeqNo);
1093 }
1094
1095 /// Helper for handling setter procedures - this method returns a functor that
1096 /// sets the variables referred to by Args... to values deserialized from the
1097 /// channel.
1098 /// E.g.
1099 ///
1100 /// typedef Function<0, bool, int> Func1;
1101 ///
1102 /// ...
1103 /// bool B;
1104 /// int I;
1105 /// if (auto Err = expect<Func1>(Channel, readArgs(B, I)))
1106 /// /* Handle Args */ ;
1107 ///
1108 template <typename... ArgTs>
1109 static detail::ReadArgs<ArgTs...> readArgs(ArgTs &... Args) {
1110 return detail::ReadArgs<ArgTs...>(Args...);
1111 }
1112
1113 /// Abandon all outstanding result handlers.
1114 ///
1115 /// This will call all currently registered result handlers to receive an
1116 /// "abandoned" error as their argument. This is used internally by the RPC
1117 /// in error situations, but can also be called directly by clients who are
1118 /// disconnecting from the remote and don't or can't expect responses to their
1119 /// outstanding calls. (Especially for outstanding blocking calls, calling
1120 /// this function may be necessary to avoid dead threads).
1121 void abandonPendingResponses() {
1122 // Lock the pending responses map and sequence number manager.
1123 std::lock_guard<std::mutex> Lock(ResponsesMutex);
1124
1125 for (auto &KV : PendingResponses)
1126 KV.second->abandon();
1127 PendingResponses.clear();
1128 SequenceNumberMgr.reset();
1129 }
1130
1131 /// Remove the handler for the given function.
1132 /// A handler must currently be registered for this function.
1133 template <typename Func>
1134 void removeHandler() {
1135 auto IdItr = LocalFunctionIds.find(Func::getPrototype());
1136 assert(IdItr != LocalFunctionIds.end() &&
1137 "Function does not have a registered handler");
1138 auto HandlerItr = Handlers.find(IdItr->second);
1139 assert(HandlerItr != Handlers.end() &&
1140 "Function does not have a registered handler");
1141 Handlers.erase(HandlerItr);
1142 }
1143
1144 /// Clear all handlers.
1145 void clearHandlers() {
1146 Handlers.clear();
1147 }
1148
1149protected:
1150
1151 FunctionIdT getInvalidFunctionId() const {
1152 return FnIdAllocator.getInvalidId();
1153 }
1154
1155 /// Add the given handler to the handler map and make it available for
1156 /// autonegotiation and execution.
1157 template <typename Func, typename HandlerT>
1158 void addHandlerImpl(HandlerT Handler) {
1159
1160 static_assert(detail::RPCArgTypeCheck<
1161 CanDeserializeCheck, typename Func::Type,
1162 typename detail::HandlerTraits<HandlerT>::Type>::value,
1163 "");
1164
1165 FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>();
1166 LocalFunctionIds[Func::getPrototype()] = NewFnId;
1167 Handlers[NewFnId] = wrapHandler<Func>(std::move(Handler));
1168 }
1169
1170 template <typename Func, typename HandlerT>
1171 void addAsyncHandlerImpl(HandlerT Handler) {
1172
1173 static_assert(detail::RPCArgTypeCheck<
1174 CanDeserializeCheck, typename Func::Type,
1175 typename detail::AsyncHandlerTraits<
1176 typename detail::HandlerTraits<HandlerT>::Type
1177 >::Type>::value,
1178 "");
1179
1180 FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>();
1181 LocalFunctionIds[Func::getPrototype()] = NewFnId;
1182 Handlers[NewFnId] = wrapAsyncHandler<Func>(std::move(Handler));
1183 }
1184
1185 Error handleResponse(SequenceNumberT SeqNo) {
1186 using Handler = typename decltype(PendingResponses)::mapped_type;
1187 Handler PRHandler;
1188
1189 {
1190 // Lock the pending responses map and sequence number manager.
1191 std::unique_lock<std::mutex> Lock(ResponsesMutex);
1192 auto I = PendingResponses.find(SeqNo);
1193
1194 if (I != PendingResponses.end()) {
1195 PRHandler = std::move(I->second);
1196 PendingResponses.erase(I);
1197 SequenceNumberMgr.releaseSequenceNumber(SeqNo);
1198 } else {
1199 // Unlock the pending results map to prevent recursive lock.
1200 Lock.unlock();
1201 abandonPendingResponses();
1202 return make_error<
1203 InvalidSequenceNumberForResponse<SequenceNumberT>>(SeqNo);
1204 }
1205 }
1206
1207 assert(PRHandler &&
1208 "If we didn't find a response handler we should have bailed out");
1209
1210 if (auto Err = PRHandler->handleResponse(C)) {
1211 abandonPendingResponses();
1212 return Err;
1213 }
1214
1215 return Error::success();
1216 }
1217
1218 FunctionIdT handleNegotiate(const std::string &Name) {
1219 auto I = LocalFunctionIds.find(Name);
1220 if (I == LocalFunctionIds.end())
1221 return getInvalidFunctionId();
1222 return I->second;
1223 }
1224
1225 // Find the remote FunctionId for the given function.
1226 template <typename Func>
1227 Expected<FunctionIdT> getRemoteFunctionId(bool NegotiateIfNotInMap,
1228 bool NegotiateIfInvalid) {
1229 bool DoNegotiate;
1230
1231 // Check if we already have a function id...
1232 auto I = RemoteFunctionIds.find(Func::getPrototype());
1233 if (I != RemoteFunctionIds.end()) {
1234 // If it's valid there's nothing left to do.
1235 if (I->second != getInvalidFunctionId())
1236 return I->second;
1237 DoNegotiate = NegotiateIfInvalid;
1238 } else
1239 DoNegotiate = NegotiateIfNotInMap;
1240
1241 // We don't have a function id for Func yet, but we're allowed to try to
1242 // negotiate one.
1243 if (DoNegotiate) {
1244 auto &Impl = static_cast<ImplT &>(*this);
1245 if (auto RemoteIdOrErr =
1246 Impl.template callB<OrcRPCNegotiate>(Func::getPrototype())) {
1247 RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr;
1248 if (*RemoteIdOrErr == getInvalidFunctionId())
1249 return make_error<CouldNotNegotiate>(Func::getPrototype());
1250 return *RemoteIdOrErr;
1251 } else
1252 return RemoteIdOrErr.takeError();
1253 }
1254
1255 // No key was available in the map and we weren't allowed to try to
1256 // negotiate one, so return an unknown function error.
1257 return make_error<CouldNotNegotiate>(Func::getPrototype());
1258 }
1259
1260 using WrappedHandlerFn = std::function<Error(ChannelT &, SequenceNumberT)>;
1261
1262 // Wrap the given user handler in the necessary argument-deserialization code,
1263 // result-serialization code, and call to the launch policy (if present).
1264 template <typename Func, typename HandlerT>
1265 WrappedHandlerFn wrapHandler(HandlerT Handler) {
1266 return [this, Handler](ChannelT &Channel,
1267 SequenceNumberT SeqNo) mutable -> Error {
1268 // Start by deserializing the arguments.
1269 using ArgsTuple =
1270 typename detail::FunctionArgsTuple<
1271 typename detail::HandlerTraits<HandlerT>::Type>::Type;
1272 auto Args = std::make_shared<ArgsTuple>();
1273
1274 if (auto Err =
1275 detail::HandlerTraits<typename Func::Type>::deserializeArgs(
1276 Channel, *Args))
1277 return Err;
1278
1279 // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning
1280 // for RPCArgs. Void cast RPCArgs to work around this for now.
1281 // FIXME: Remove this workaround once we can assume a working GCC version.
1282 (void)Args;
1283
1284 // End receieve message, unlocking the channel for reading.
1285 if (auto Err = Channel.endReceiveMessage())
1286 return Err;
1287
1288 using HTraits = detail::HandlerTraits<HandlerT>;
1289 using FuncReturn = typename Func::ReturnType;
1290 return detail::respond<FuncReturn>(Channel, ResponseId, SeqNo,
1291 HTraits::unpackAndRun(Handler, *Args));
1292 };
1293 }
1294
1295 // Wrap the given user handler in the necessary argument-deserialization code,
1296 // result-serialization code, and call to the launch policy (if present).
1297 template <typename Func, typename HandlerT>
1298 WrappedHandlerFn wrapAsyncHandler(HandlerT Handler) {
1299 return [this, Handler](ChannelT &Channel,
1300 SequenceNumberT SeqNo) mutable -> Error {
1301 // Start by deserializing the arguments.
1302 using AHTraits = detail::AsyncHandlerTraits<
1303 typename detail::HandlerTraits<HandlerT>::Type>;
1304 using ArgsTuple =
1305 typename detail::FunctionArgsTuple<typename AHTraits::Type>::Type;
1306 auto Args = std::make_shared<ArgsTuple>();
1307
1308 if (auto Err =
1309 detail::HandlerTraits<typename Func::Type>::deserializeArgs(
1310 Channel, *Args))
1311 return Err;
1312
1313 // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning
1314 // for RPCArgs. Void cast RPCArgs to work around this for now.
1315 // FIXME: Remove this workaround once we can assume a working GCC version.
1316 (void)Args;
1317
1318 // End receieve message, unlocking the channel for reading.
1319 if (auto Err = Channel.endReceiveMessage())
1320 return Err;
1321
1322 using HTraits = detail::HandlerTraits<HandlerT>;
1323 using FuncReturn = typename Func::ReturnType;
1324 auto Responder =
1325 [this, SeqNo](typename AHTraits::ResultType RetVal) -> Error {
1326 return detail::respond<FuncReturn>(C, ResponseId, SeqNo,
1327 std::move(RetVal));
1328 };
1329
1330 return HTraits::unpackAndRunAsync(Handler, Responder, *Args);
1331 };
1332 }
1333
1334 ChannelT &C;
1335
1336 bool LazyAutoNegotiation;
1337
1338 RPCFunctionIdAllocator<FunctionIdT> FnIdAllocator;
1339
1340 FunctionIdT ResponseId;
1341 std::map<std::string, FunctionIdT> LocalFunctionIds;
1342 std::map<const char *, FunctionIdT> RemoteFunctionIds;
1343
1344 std::map<FunctionIdT, WrappedHandlerFn> Handlers;
1345
1346 std::mutex ResponsesMutex;
1347 detail::SequenceNumberManager<SequenceNumberT> SequenceNumberMgr;
1348 std::map<SequenceNumberT, std::unique_ptr<detail::ResponseHandler<ChannelT>>>
1349 PendingResponses;
1350};
1351
1352} // end namespace detail
1353
1354template <typename ChannelT, typename FunctionIdT = uint32_t,
1355 typename SequenceNumberT = uint32_t>
1356class MultiThreadedRPCEndpoint
1357 : public detail::RPCEndpointBase<
1358 MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1359 ChannelT, FunctionIdT, SequenceNumberT> {
1360private:
1361 using BaseClass =
1362 detail::RPCEndpointBase<
1363 MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1364 ChannelT, FunctionIdT, SequenceNumberT>;
1365
1366public:
1367 MultiThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation)
1368 : BaseClass(C, LazyAutoNegotiation) {}
1369
1370 /// Add a handler for the given RPC function.
1371 /// This installs the given handler functor for the given RPC Function, and
1372 /// makes the RPC function available for negotiation/calling from the remote.
1373 template <typename Func, typename HandlerT>
1374 void addHandler(HandlerT Handler) {
1375 return this->template addHandlerImpl<Func>(std::move(Handler));
1376 }
1377
1378 /// Add a class-method as a handler.
1379 template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
1380 void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1381 addHandler<Func>(
1382 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1383 }
1384
1385 template <typename Func, typename HandlerT>
1386 void addAsyncHandler(HandlerT Handler) {
1387 return this->template addAsyncHandlerImpl<Func>(std::move(Handler));
1388 }
1389
1390 /// Add a class-method as a handler.
1391 template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
1392 void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1393 addAsyncHandler<Func>(
1394 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1395 }
1396
1397 /// Return type for non-blocking call primitives.
1398 template <typename Func>
1399 using NonBlockingCallResult = typename detail::ResultTraits<
1400 typename Func::ReturnType>::ReturnFutureType;
1401
1402 /// Call Func on Channel C. Does not block, does not call send. Returns a pair
1403 /// of a future result and the sequence number assigned to the result.
1404 ///
1405 /// This utility function is primarily used for single-threaded mode support,
1406 /// where the sequence number can be used to wait for the corresponding
1407 /// result. In multi-threaded mode the appendCallNB method, which does not
1408 /// return the sequence numeber, should be preferred.
1409 template <typename Func, typename... ArgTs>
1410 Expected<NonBlockingCallResult<Func>> appendCallNB(const ArgTs &... Args) {
1411 using RTraits = detail::ResultTraits<typename Func::ReturnType>;
1412 using ErrorReturn = typename RTraits::ErrorReturnType;
1413 using ErrorReturnPromise = typename RTraits::ReturnPromiseType;
1414
1415 // FIXME: Stack allocate and move this into the handler once LLVM builds
1416 // with C++14.
1417 auto Promise = std::make_shared<ErrorReturnPromise>();
1418 auto FutureResult = Promise->get_future();
1419
1420 if (auto Err = this->template appendCallAsync<Func>(
1421 [Promise](ErrorReturn RetOrErr) {
1422 Promise->set_value(std::move(RetOrErr));
1423 return Error::success();
1424 },
1425 Args...)) {
1426 RTraits::consumeAbandoned(FutureResult.get());
1427 return std::move(Err);
1428 }
1429 return std::move(FutureResult);
1430 }
1431
1432 /// The same as appendCallNBWithSeq, except that it calls C.send() to
1433 /// flush the channel after serializing the call.
1434 template <typename Func, typename... ArgTs>
1435 Expected<NonBlockingCallResult<Func>> callNB(const ArgTs &... Args) {
1436 auto Result = appendCallNB<Func>(Args...);
1437 if (!Result)
1438 return Result;
1439 if (auto Err = this->C.send()) {
1440 this->abandonPendingResponses();
1441 detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1442 std::move(Result->get()));
1443 return std::move(Err);
1444 }
1445 return Result;
1446 }
1447
1448 /// Call Func on Channel C. Blocks waiting for a result. Returns an Error
1449 /// for void functions or an Expected<T> for functions returning a T.
1450 ///
1451 /// This function is for use in threaded code where another thread is
1452 /// handling responses and incoming calls.
1453 template <typename Func, typename... ArgTs,
1454 typename AltRetT = typename Func::ReturnType>
1455 typename detail::ResultTraits<AltRetT>::ErrorReturnType
1456 callB(const ArgTs &... Args) {
1457 if (auto FutureResOrErr = callNB<Func>(Args...))
1458 return FutureResOrErr->get();
1459 else
1460 return FutureResOrErr.takeError();
1461 }
1462
1463 /// Handle incoming RPC calls.
1464 Error handlerLoop() {
1465 while (true)
1466 if (auto Err = this->handleOne())
1467 return Err;
1468 return Error::success();
1469 }
1470};
1471
1472template <typename ChannelT, typename FunctionIdT = uint32_t,
1473 typename SequenceNumberT = uint32_t>
1474class SingleThreadedRPCEndpoint
1475 : public detail::RPCEndpointBase<
1476 SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1477 ChannelT, FunctionIdT, SequenceNumberT> {
1478private:
1479 using BaseClass =
1480 detail::RPCEndpointBase<
1481 SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1482 ChannelT, FunctionIdT, SequenceNumberT>;
1483
1484public:
1485 SingleThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation)
1486 : BaseClass(C, LazyAutoNegotiation) {}
1487
1488 template <typename Func, typename HandlerT>
1489 void addHandler(HandlerT Handler) {
1490 return this->template addHandlerImpl<Func>(std::move(Handler));
1491 }
1492
1493 template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
1494 void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1495 addHandler<Func>(
1496 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1497 }
1498
1499 template <typename Func, typename HandlerT>
1500 void addAsyncHandler(HandlerT Handler) {
1501 return this->template addAsyncHandlerImpl<Func>(std::move(Handler));
1502 }
1503
1504 /// Add a class-method as a handler.
1505 template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
1506 void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1507 addAsyncHandler<Func>(
1508 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1509 }
1510
1511 template <typename Func, typename... ArgTs,
1512 typename AltRetT = typename Func::ReturnType>
1513 typename detail::ResultTraits<AltRetT>::ErrorReturnType
1514 callB(const ArgTs &... Args) {
1515 bool ReceivedResponse = false;
1516 using ResultType = typename detail::ResultTraits<AltRetT>::ErrorReturnType;
1517 auto Result = detail::ResultTraits<AltRetT>::createBlankErrorReturnValue();
1518
1519 // We have to 'Check' result (which we know is in a success state at this
1520 // point) so that it can be overwritten in the async handler.
1521 (void)!!Result;
1522
1523 if (auto Err = this->template appendCallAsync<Func>(
1524 [&](ResultType R) {
1525 Result = std::move(R);
1526 ReceivedResponse = true;
1527 return Error::success();
1528 },
1529 Args...)) {
1530 detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1531 std::move(Result));
1532 return std::move(Err);
1533 }
1534
1535 while (!ReceivedResponse) {
1536 if (auto Err = this->handleOne()) {
1537 detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1538 std::move(Result));
1539 return std::move(Err);
1540 }
1541 }
1542
1543 return Result;
1544 }
1545};
1546
1547/// Asynchronous dispatch for a function on an RPC endpoint.
1548template <typename RPCClass, typename Func>
1549class RPCAsyncDispatch {
1550public:
1551 RPCAsyncDispatch(RPCClass &Endpoint) : Endpoint(Endpoint) {}
1552
1553 template <typename HandlerT, typename... ArgTs>
1554 Error operator()(HandlerT Handler, const ArgTs &... Args) const {
1555 return Endpoint.template appendCallAsync<Func>(std::move(Handler), Args...);
1556 }
1557
1558private:
1559 RPCClass &Endpoint;
1560};
1561
1562/// Construct an asynchronous dispatcher from an RPC endpoint and a Func.
1563template <typename Func, typename RPCEndpointT>
1564RPCAsyncDispatch<RPCEndpointT, Func> rpcAsyncDispatch(RPCEndpointT &Endpoint) {
1565 return RPCAsyncDispatch<RPCEndpointT, Func>(Endpoint);
1566}
1567
Andrew Scullcdfcccc2018-10-05 20:58:37 +01001568/// Allows a set of asynchrounous calls to be dispatched, and then
Andrew Scull5e1ddfa2018-08-14 10:06:54 +01001569/// waited on as a group.
1570class ParallelCallGroup {
1571public:
1572
1573 ParallelCallGroup() = default;
1574 ParallelCallGroup(const ParallelCallGroup &) = delete;
1575 ParallelCallGroup &operator=(const ParallelCallGroup &) = delete;
1576
Andrew Scullcdfcccc2018-10-05 20:58:37 +01001577 /// Make as asynchronous call.
Andrew Scull5e1ddfa2018-08-14 10:06:54 +01001578 template <typename AsyncDispatcher, typename HandlerT, typename... ArgTs>
1579 Error call(const AsyncDispatcher &AsyncDispatch, HandlerT Handler,
1580 const ArgTs &... Args) {
1581 // Increment the count of outstanding calls. This has to happen before
1582 // we invoke the call, as the handler may (depending on scheduling)
1583 // be run immediately on another thread, and we don't want the decrement
1584 // in the wrapped handler below to run before the increment.
1585 {
1586 std::unique_lock<std::mutex> Lock(M);
1587 ++NumOutstandingCalls;
1588 }
1589
1590 // Wrap the user handler in a lambda that will decrement the
1591 // outstanding calls count, then poke the condition variable.
1592 using ArgType = typename detail::ResponseHandlerArg<
1593 typename detail::HandlerTraits<HandlerT>::Type>::ArgType;
1594 // FIXME: Move handler into wrapped handler once we have C++14.
1595 auto WrappedHandler = [this, Handler](ArgType Arg) {
1596 auto Err = Handler(std::move(Arg));
1597 std::unique_lock<std::mutex> Lock(M);
1598 --NumOutstandingCalls;
1599 CV.notify_all();
1600 return Err;
1601 };
1602
1603 return AsyncDispatch(std::move(WrappedHandler), Args...);
1604 }
1605
Andrew Scullcdfcccc2018-10-05 20:58:37 +01001606 /// Blocks until all calls have been completed and their return value
Andrew Scull5e1ddfa2018-08-14 10:06:54 +01001607 /// handlers run.
1608 void wait() {
1609 std::unique_lock<std::mutex> Lock(M);
1610 while (NumOutstandingCalls > 0)
1611 CV.wait(Lock);
1612 }
1613
1614private:
1615 std::mutex M;
1616 std::condition_variable CV;
1617 uint32_t NumOutstandingCalls = 0;
1618};
1619
Andrew Scullcdfcccc2018-10-05 20:58:37 +01001620/// Convenience class for grouping RPC Functions into APIs that can be
Andrew Scull5e1ddfa2018-08-14 10:06:54 +01001621/// negotiated as a block.
1622///
1623template <typename... Funcs>
1624class APICalls {
1625public:
1626
Andrew Scullcdfcccc2018-10-05 20:58:37 +01001627 /// Test whether this API contains Function F.
Andrew Scull5e1ddfa2018-08-14 10:06:54 +01001628 template <typename F>
1629 class Contains {
1630 public:
1631 static const bool value = false;
1632 };
1633
Andrew Scullcdfcccc2018-10-05 20:58:37 +01001634 /// Negotiate all functions in this API.
Andrew Scull5e1ddfa2018-08-14 10:06:54 +01001635 template <typename RPCEndpoint>
1636 static Error negotiate(RPCEndpoint &R) {
1637 return Error::success();
1638 }
1639};
1640
1641template <typename Func, typename... Funcs>
1642class APICalls<Func, Funcs...> {
1643public:
1644
1645 template <typename F>
1646 class Contains {
1647 public:
1648 static const bool value = std::is_same<F, Func>::value |
1649 APICalls<Funcs...>::template Contains<F>::value;
1650 };
1651
1652 template <typename RPCEndpoint>
1653 static Error negotiate(RPCEndpoint &R) {
1654 if (auto Err = R.template negotiateFunction<Func>())
1655 return Err;
1656 return APICalls<Funcs...>::negotiate(R);
1657 }
1658
1659};
1660
1661template <typename... InnerFuncs, typename... Funcs>
1662class APICalls<APICalls<InnerFuncs...>, Funcs...> {
1663public:
1664
1665 template <typename F>
1666 class Contains {
1667 public:
1668 static const bool value =
1669 APICalls<InnerFuncs...>::template Contains<F>::value |
1670 APICalls<Funcs...>::template Contains<F>::value;
1671 };
1672
1673 template <typename RPCEndpoint>
1674 static Error negotiate(RPCEndpoint &R) {
1675 if (auto Err = APICalls<InnerFuncs...>::negotiate(R))
1676 return Err;
1677 return APICalls<Funcs...>::negotiate(R);
1678 }
1679
1680};
1681
1682} // end namespace rpc
1683} // end namespace orc
1684} // end namespace llvm
1685
1686#endif