blob: 47bd90bb1bada3960d66eebf64369f636fc38d0b [file] [log] [blame]
Andrew Scull5e1ddfa2018-08-14 10:06:54 +01001//===------- RPCUTils.h - Utilities for building RPC APIs -------*- C++ -*-===//
2//
3// The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// Utilities to support construction of simple RPC APIs.
11//
12// The RPC utilities aim for ease of use (minimal conceptual overhead) for C++
13// programmers, high performance, low memory overhead, and efficient use of the
14// communications channel.
15//
16//===----------------------------------------------------------------------===//
17
18#ifndef LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H
19#define LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H
20
21#include <map>
22#include <thread>
23#include <vector>
24
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ExecutionEngine/Orc/OrcError.h"
27#include "llvm/ExecutionEngine/Orc/RPCSerialization.h"
28
29#include <future>
30
31namespace llvm {
32namespace orc {
33namespace rpc {
34
35/// Base class of all fatal RPC errors (those that necessarily result in the
36/// termination of the RPC session).
37class RPCFatalError : public ErrorInfo<RPCFatalError> {
38public:
39 static char ID;
40};
41
42/// RPCConnectionClosed is returned from RPC operations if the RPC connection
43/// has already been closed due to either an error or graceful disconnection.
44class ConnectionClosed : public ErrorInfo<ConnectionClosed> {
45public:
46 static char ID;
47 std::error_code convertToErrorCode() const override;
48 void log(raw_ostream &OS) const override;
49};
50
51/// BadFunctionCall is returned from handleOne when the remote makes a call with
52/// an unrecognized function id.
53///
54/// This error is fatal because Orc RPC needs to know how to parse a function
55/// call to know where the next call starts, and if it doesn't recognize the
56/// function id it cannot parse the call.
57template <typename FnIdT, typename SeqNoT>
58class BadFunctionCall
59 : public ErrorInfo<BadFunctionCall<FnIdT, SeqNoT>, RPCFatalError> {
60public:
61 static char ID;
62
63 BadFunctionCall(FnIdT FnId, SeqNoT SeqNo)
64 : FnId(std::move(FnId)), SeqNo(std::move(SeqNo)) {}
65
66 std::error_code convertToErrorCode() const override {
67 return orcError(OrcErrorCode::UnexpectedRPCCall);
68 }
69
70 void log(raw_ostream &OS) const override {
71 OS << "Call to invalid RPC function id '" << FnId << "' with "
72 "sequence number " << SeqNo;
73 }
74
75private:
76 FnIdT FnId;
77 SeqNoT SeqNo;
78};
79
80template <typename FnIdT, typename SeqNoT>
81char BadFunctionCall<FnIdT, SeqNoT>::ID = 0;
82
83/// InvalidSequenceNumberForResponse is returned from handleOne when a response
84/// call arrives with a sequence number that doesn't correspond to any in-flight
85/// function call.
86///
87/// This error is fatal because Orc RPC needs to know how to parse the rest of
88/// the response call to know where the next call starts, and if it doesn't have
89/// a result parser for this sequence number it can't do that.
90template <typename SeqNoT>
91class InvalidSequenceNumberForResponse
92 : public ErrorInfo<InvalidSequenceNumberForResponse<SeqNoT>, RPCFatalError> {
93public:
94 static char ID;
95
96 InvalidSequenceNumberForResponse(SeqNoT SeqNo)
97 : SeqNo(std::move(SeqNo)) {}
98
99 std::error_code convertToErrorCode() const override {
100 return orcError(OrcErrorCode::UnexpectedRPCCall);
101 };
102
103 void log(raw_ostream &OS) const override {
104 OS << "Response has unknown sequence number " << SeqNo;
105 }
106private:
107 SeqNoT SeqNo;
108};
109
110template <typename SeqNoT>
111char InvalidSequenceNumberForResponse<SeqNoT>::ID = 0;
112
113/// This non-fatal error will be passed to asynchronous result handlers in place
114/// of a result if the connection goes down before a result returns, or if the
115/// function to be called cannot be negotiated with the remote.
116class ResponseAbandoned : public ErrorInfo<ResponseAbandoned> {
117public:
118 static char ID;
119
120 std::error_code convertToErrorCode() const override;
121 void log(raw_ostream &OS) const override;
122};
123
124/// This error is returned if the remote does not have a handler installed for
125/// the given RPC function.
126class CouldNotNegotiate : public ErrorInfo<CouldNotNegotiate> {
127public:
128 static char ID;
129
130 CouldNotNegotiate(std::string Signature);
131 std::error_code convertToErrorCode() const override;
132 void log(raw_ostream &OS) const override;
133 const std::string &getSignature() const { return Signature; }
134private:
135 std::string Signature;
136};
137
138template <typename DerivedFunc, typename FnT> class Function;
139
140// RPC Function class.
141// DerivedFunc should be a user defined class with a static 'getName()' method
142// returning a const char* representing the function's name.
143template <typename DerivedFunc, typename RetT, typename... ArgTs>
144class Function<DerivedFunc, RetT(ArgTs...)> {
145public:
146 /// User defined function type.
147 using Type = RetT(ArgTs...);
148
149 /// Return type.
150 using ReturnType = RetT;
151
152 /// Returns the full function prototype as a string.
153 static const char *getPrototype() {
154 std::lock_guard<std::mutex> Lock(NameMutex);
155 if (Name.empty())
156 raw_string_ostream(Name)
157 << RPCTypeName<RetT>::getName() << " " << DerivedFunc::getName()
158 << "(" << llvm::orc::rpc::RPCTypeNameSequence<ArgTs...>() << ")";
159 return Name.data();
160 }
161
162private:
163 static std::mutex NameMutex;
164 static std::string Name;
165};
166
167template <typename DerivedFunc, typename RetT, typename... ArgTs>
168std::mutex Function<DerivedFunc, RetT(ArgTs...)>::NameMutex;
169
170template <typename DerivedFunc, typename RetT, typename... ArgTs>
171std::string Function<DerivedFunc, RetT(ArgTs...)>::Name;
172
173/// Allocates RPC function ids during autonegotiation.
174/// Specializations of this class must provide four members:
175///
176/// static T getInvalidId():
177/// Should return a reserved id that will be used to represent missing
178/// functions during autonegotiation.
179///
180/// static T getResponseId():
181/// Should return a reserved id that will be used to send function responses
182/// (return values).
183///
184/// static T getNegotiateId():
185/// Should return a reserved id for the negotiate function, which will be used
186/// to negotiate ids for user defined functions.
187///
188/// template <typename Func> T allocate():
189/// Allocate a unique id for function Func.
190template <typename T, typename = void> class RPCFunctionIdAllocator;
191
192/// This specialization of RPCFunctionIdAllocator provides a default
193/// implementation for integral types.
194template <typename T>
195class RPCFunctionIdAllocator<
196 T, typename std::enable_if<std::is_integral<T>::value>::type> {
197public:
198 static T getInvalidId() { return T(0); }
199 static T getResponseId() { return T(1); }
200 static T getNegotiateId() { return T(2); }
201
202 template <typename Func> T allocate() { return NextId++; }
203
204private:
205 T NextId = 3;
206};
207
208namespace detail {
209
210// FIXME: Remove MSVCPError/MSVCPExpected once MSVC's future implementation
211// supports classes without default constructors.
212#ifdef _MSC_VER
213
214namespace msvc_hacks {
215
216// Work around MSVC's future implementation's use of default constructors:
217// A default constructed value in the promise will be overwritten when the
218// real error is set - so the default constructed Error has to be checked
219// already.
220class MSVCPError : public Error {
221public:
222 MSVCPError() { (void)!!*this; }
223
224 MSVCPError(MSVCPError &&Other) : Error(std::move(Other)) {}
225
226 MSVCPError &operator=(MSVCPError Other) {
227 Error::operator=(std::move(Other));
228 return *this;
229 }
230
231 MSVCPError(Error Err) : Error(std::move(Err)) {}
232};
233
234// Work around MSVC's future implementation, similar to MSVCPError.
235template <typename T> class MSVCPExpected : public Expected<T> {
236public:
237 MSVCPExpected()
238 : Expected<T>(make_error<StringError>("", inconvertibleErrorCode())) {
239 consumeError(this->takeError());
240 }
241
242 MSVCPExpected(MSVCPExpected &&Other) : Expected<T>(std::move(Other)) {}
243
244 MSVCPExpected &operator=(MSVCPExpected &&Other) {
245 Expected<T>::operator=(std::move(Other));
246 return *this;
247 }
248
249 MSVCPExpected(Error Err) : Expected<T>(std::move(Err)) {}
250
251 template <typename OtherT>
252 MSVCPExpected(
253 OtherT &&Val,
254 typename std::enable_if<std::is_convertible<OtherT, T>::value>::type * =
255 nullptr)
256 : Expected<T>(std::move(Val)) {}
257
258 template <class OtherT>
259 MSVCPExpected(
260 Expected<OtherT> &&Other,
261 typename std::enable_if<std::is_convertible<OtherT, T>::value>::type * =
262 nullptr)
263 : Expected<T>(std::move(Other)) {}
264
265 template <class OtherT>
266 explicit MSVCPExpected(
267 Expected<OtherT> &&Other,
268 typename std::enable_if<!std::is_convertible<OtherT, T>::value>::type * =
269 nullptr)
270 : Expected<T>(std::move(Other)) {}
271};
272
273} // end namespace msvc_hacks
274
275#endif // _MSC_VER
276
277/// Provides a typedef for a tuple containing the decayed argument types.
278template <typename T> class FunctionArgsTuple;
279
280template <typename RetT, typename... ArgTs>
281class FunctionArgsTuple<RetT(ArgTs...)> {
282public:
283 using Type = std::tuple<typename std::decay<
284 typename std::remove_reference<ArgTs>::type>::type...>;
285};
286
287// ResultTraits provides typedefs and utilities specific to the return type
288// of functions.
289template <typename RetT> class ResultTraits {
290public:
291 // The return type wrapped in llvm::Expected.
292 using ErrorReturnType = Expected<RetT>;
293
294#ifdef _MSC_VER
295 // The ErrorReturnType wrapped in a std::promise.
296 using ReturnPromiseType = std::promise<msvc_hacks::MSVCPExpected<RetT>>;
297
298 // The ErrorReturnType wrapped in a std::future.
299 using ReturnFutureType = std::future<msvc_hacks::MSVCPExpected<RetT>>;
300#else
301 // The ErrorReturnType wrapped in a std::promise.
302 using ReturnPromiseType = std::promise<ErrorReturnType>;
303
304 // The ErrorReturnType wrapped in a std::future.
305 using ReturnFutureType = std::future<ErrorReturnType>;
306#endif
307
308 // Create a 'blank' value of the ErrorReturnType, ready and safe to
309 // overwrite.
310 static ErrorReturnType createBlankErrorReturnValue() {
311 return ErrorReturnType(RetT());
312 }
313
314 // Consume an abandoned ErrorReturnType.
315 static void consumeAbandoned(ErrorReturnType RetOrErr) {
316 consumeError(RetOrErr.takeError());
317 }
318};
319
320// ResultTraits specialization for void functions.
321template <> class ResultTraits<void> {
322public:
323 // For void functions, ErrorReturnType is llvm::Error.
324 using ErrorReturnType = Error;
325
326#ifdef _MSC_VER
327 // The ErrorReturnType wrapped in a std::promise.
328 using ReturnPromiseType = std::promise<msvc_hacks::MSVCPError>;
329
330 // The ErrorReturnType wrapped in a std::future.
331 using ReturnFutureType = std::future<msvc_hacks::MSVCPError>;
332#else
333 // The ErrorReturnType wrapped in a std::promise.
334 using ReturnPromiseType = std::promise<ErrorReturnType>;
335
336 // The ErrorReturnType wrapped in a std::future.
337 using ReturnFutureType = std::future<ErrorReturnType>;
338#endif
339
340 // Create a 'blank' value of the ErrorReturnType, ready and safe to
341 // overwrite.
342 static ErrorReturnType createBlankErrorReturnValue() {
343 return ErrorReturnType::success();
344 }
345
346 // Consume an abandoned ErrorReturnType.
347 static void consumeAbandoned(ErrorReturnType Err) {
348 consumeError(std::move(Err));
349 }
350};
351
352// ResultTraits<Error> is equivalent to ResultTraits<void>. This allows
353// handlers for void RPC functions to return either void (in which case they
354// implicitly succeed) or Error (in which case their error return is
355// propagated). See usage in HandlerTraits::runHandlerHelper.
356template <> class ResultTraits<Error> : public ResultTraits<void> {};
357
358// ResultTraits<Expected<T>> is equivalent to ResultTraits<T>. This allows
359// handlers for RPC functions returning a T to return either a T (in which
360// case they implicitly succeed) or Expected<T> (in which case their error
361// return is propagated). See usage in HandlerTraits::runHandlerHelper.
362template <typename RetT>
363class ResultTraits<Expected<RetT>> : public ResultTraits<RetT> {};
364
365// Determines whether an RPC function's defined error return type supports
366// error return value.
367template <typename T>
368class SupportsErrorReturn {
369public:
370 static const bool value = false;
371};
372
373template <>
374class SupportsErrorReturn<Error> {
375public:
376 static const bool value = true;
377};
378
379template <typename T>
380class SupportsErrorReturn<Expected<T>> {
381public:
382 static const bool value = true;
383};
384
385// RespondHelper packages return values based on whether or not the declared
386// RPC function return type supports error returns.
387template <bool FuncSupportsErrorReturn>
388class RespondHelper;
389
390// RespondHelper specialization for functions that support error returns.
391template <>
392class RespondHelper<true> {
393public:
394
395 // Send Expected<T>.
396 template <typename WireRetT, typename HandlerRetT, typename ChannelT,
397 typename FunctionIdT, typename SequenceNumberT>
398 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
399 SequenceNumberT SeqNo,
400 Expected<HandlerRetT> ResultOrErr) {
401 if (!ResultOrErr && ResultOrErr.template errorIsA<RPCFatalError>())
402 return ResultOrErr.takeError();
403
404 // Open the response message.
405 if (auto Err = C.startSendMessage(ResponseId, SeqNo))
406 return Err;
407
408 // Serialize the result.
409 if (auto Err =
410 SerializationTraits<ChannelT, WireRetT,
411 Expected<HandlerRetT>>::serialize(
412 C, std::move(ResultOrErr)))
413 return Err;
414
415 // Close the response message.
416 return C.endSendMessage();
417 }
418
419 template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT>
420 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
421 SequenceNumberT SeqNo, Error Err) {
422 if (Err && Err.isA<RPCFatalError>())
423 return Err;
424 if (auto Err2 = C.startSendMessage(ResponseId, SeqNo))
425 return Err2;
426 if (auto Err2 = serializeSeq(C, std::move(Err)))
427 return Err2;
428 return C.endSendMessage();
429 }
430
431};
432
433// RespondHelper specialization for functions that do not support error returns.
434template <>
435class RespondHelper<false> {
436public:
437
438 template <typename WireRetT, typename HandlerRetT, typename ChannelT,
439 typename FunctionIdT, typename SequenceNumberT>
440 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
441 SequenceNumberT SeqNo,
442 Expected<HandlerRetT> ResultOrErr) {
443 if (auto Err = ResultOrErr.takeError())
444 return Err;
445
446 // Open the response message.
447 if (auto Err = C.startSendMessage(ResponseId, SeqNo))
448 return Err;
449
450 // Serialize the result.
451 if (auto Err =
452 SerializationTraits<ChannelT, WireRetT, HandlerRetT>::serialize(
453 C, *ResultOrErr))
454 return Err;
455
456 // Close the response message.
457 return C.endSendMessage();
458 }
459
460 template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT>
461 static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
462 SequenceNumberT SeqNo, Error Err) {
463 if (Err)
464 return Err;
465 if (auto Err2 = C.startSendMessage(ResponseId, SeqNo))
466 return Err2;
467 return C.endSendMessage();
468 }
469
470};
471
472
473// Send a response of the given wire return type (WireRetT) over the
474// channel, with the given sequence number.
475template <typename WireRetT, typename HandlerRetT, typename ChannelT,
476 typename FunctionIdT, typename SequenceNumberT>
477Error respond(ChannelT &C, const FunctionIdT &ResponseId,
478 SequenceNumberT SeqNo, Expected<HandlerRetT> ResultOrErr) {
479 return RespondHelper<SupportsErrorReturn<WireRetT>::value>::
480 template sendResult<WireRetT>(C, ResponseId, SeqNo, std::move(ResultOrErr));
481}
482
483// Send an empty response message on the given channel to indicate that
484// the handler ran.
485template <typename WireRetT, typename ChannelT, typename FunctionIdT,
486 typename SequenceNumberT>
487Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo,
488 Error Err) {
489 return RespondHelper<SupportsErrorReturn<WireRetT>::value>::
490 sendResult(C, ResponseId, SeqNo, std::move(Err));
491}
492
493// Converts a given type to the equivalent error return type.
494template <typename T> class WrappedHandlerReturn {
495public:
496 using Type = Expected<T>;
497};
498
499template <typename T> class WrappedHandlerReturn<Expected<T>> {
500public:
501 using Type = Expected<T>;
502};
503
504template <> class WrappedHandlerReturn<void> {
505public:
506 using Type = Error;
507};
508
509template <> class WrappedHandlerReturn<Error> {
510public:
511 using Type = Error;
512};
513
514template <> class WrappedHandlerReturn<ErrorSuccess> {
515public:
516 using Type = Error;
517};
518
519// Traits class that strips the response function from the list of handler
520// arguments.
521template <typename FnT> class AsyncHandlerTraits;
522
523template <typename ResultT, typename... ArgTs>
524class AsyncHandlerTraits<Error(std::function<Error(Expected<ResultT>)>, ArgTs...)> {
525public:
526 using Type = Error(ArgTs...);
527 using ResultType = Expected<ResultT>;
528};
529
530template <typename... ArgTs>
531class AsyncHandlerTraits<Error(std::function<Error(Error)>, ArgTs...)> {
532public:
533 using Type = Error(ArgTs...);
534 using ResultType = Error;
535};
536
537template <typename... ArgTs>
538class AsyncHandlerTraits<ErrorSuccess(std::function<Error(Error)>, ArgTs...)> {
539public:
540 using Type = Error(ArgTs...);
541 using ResultType = Error;
542};
543
544template <typename... ArgTs>
545class AsyncHandlerTraits<void(std::function<Error(Error)>, ArgTs...)> {
546public:
547 using Type = Error(ArgTs...);
548 using ResultType = Error;
549};
550
551template <typename ResponseHandlerT, typename... ArgTs>
552class AsyncHandlerTraits<Error(ResponseHandlerT, ArgTs...)> :
553 public AsyncHandlerTraits<Error(typename std::decay<ResponseHandlerT>::type,
554 ArgTs...)> {};
555
556// This template class provides utilities related to RPC function handlers.
557// The base case applies to non-function types (the template class is
558// specialized for function types) and inherits from the appropriate
559// speciilization for the given non-function type's call operator.
560template <typename HandlerT>
561class HandlerTraits : public HandlerTraits<decltype(
562 &std::remove_reference<HandlerT>::type::operator())> {
563};
564
565// Traits for handlers with a given function type.
566template <typename RetT, typename... ArgTs>
567class HandlerTraits<RetT(ArgTs...)> {
568public:
569 // Function type of the handler.
570 using Type = RetT(ArgTs...);
571
572 // Return type of the handler.
573 using ReturnType = RetT;
574
575 // Call the given handler with the given arguments.
576 template <typename HandlerT, typename... TArgTs>
577 static typename WrappedHandlerReturn<RetT>::Type
578 unpackAndRun(HandlerT &Handler, std::tuple<TArgTs...> &Args) {
579 return unpackAndRunHelper(Handler, Args,
580 llvm::index_sequence_for<TArgTs...>());
581 }
582
583 // Call the given handler with the given arguments.
584 template <typename HandlerT, typename ResponderT, typename... TArgTs>
585 static Error unpackAndRunAsync(HandlerT &Handler, ResponderT &Responder,
586 std::tuple<TArgTs...> &Args) {
587 return unpackAndRunAsyncHelper(Handler, Responder, Args,
588 llvm::index_sequence_for<TArgTs...>());
589 }
590
591 // Call the given handler with the given arguments.
592 template <typename HandlerT>
593 static typename std::enable_if<
594 std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value,
595 Error>::type
596 run(HandlerT &Handler, ArgTs &&... Args) {
597 Handler(std::move(Args)...);
598 return Error::success();
599 }
600
601 template <typename HandlerT, typename... TArgTs>
602 static typename std::enable_if<
603 !std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value,
604 typename HandlerTraits<HandlerT>::ReturnType>::type
605 run(HandlerT &Handler, TArgTs... Args) {
606 return Handler(std::move(Args)...);
607 }
608
609 // Serialize arguments to the channel.
610 template <typename ChannelT, typename... CArgTs>
611 static Error serializeArgs(ChannelT &C, const CArgTs... CArgs) {
612 return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, CArgs...);
613 }
614
615 // Deserialize arguments from the channel.
616 template <typename ChannelT, typename... CArgTs>
617 static Error deserializeArgs(ChannelT &C, std::tuple<CArgTs...> &Args) {
618 return deserializeArgsHelper(C, Args,
619 llvm::index_sequence_for<CArgTs...>());
620 }
621
622private:
623 template <typename ChannelT, typename... CArgTs, size_t... Indexes>
624 static Error deserializeArgsHelper(ChannelT &C, std::tuple<CArgTs...> &Args,
625 llvm::index_sequence<Indexes...> _) {
626 return SequenceSerialization<ChannelT, ArgTs...>::deserialize(
627 C, std::get<Indexes>(Args)...);
628 }
629
630 template <typename HandlerT, typename ArgTuple, size_t... Indexes>
631 static typename WrappedHandlerReturn<
632 typename HandlerTraits<HandlerT>::ReturnType>::Type
633 unpackAndRunHelper(HandlerT &Handler, ArgTuple &Args,
634 llvm::index_sequence<Indexes...>) {
635 return run(Handler, std::move(std::get<Indexes>(Args))...);
636 }
637
638
639 template <typename HandlerT, typename ResponderT, typename ArgTuple,
640 size_t... Indexes>
641 static typename WrappedHandlerReturn<
642 typename HandlerTraits<HandlerT>::ReturnType>::Type
643 unpackAndRunAsyncHelper(HandlerT &Handler, ResponderT &Responder,
644 ArgTuple &Args,
645 llvm::index_sequence<Indexes...>) {
646 return run(Handler, Responder, std::move(std::get<Indexes>(Args))...);
647 }
648};
649
650// Handler traits for free functions.
651template <typename RetT, typename... ArgTs>
652class HandlerTraits<RetT(*)(ArgTs...)>
653 : public HandlerTraits<RetT(ArgTs...)> {};
654
655// Handler traits for class methods (especially call operators for lambdas).
656template <typename Class, typename RetT, typename... ArgTs>
657class HandlerTraits<RetT (Class::*)(ArgTs...)>
658 : public HandlerTraits<RetT(ArgTs...)> {};
659
660// Handler traits for const class methods (especially call operators for
661// lambdas).
662template <typename Class, typename RetT, typename... ArgTs>
663class HandlerTraits<RetT (Class::*)(ArgTs...) const>
664 : public HandlerTraits<RetT(ArgTs...)> {};
665
666// Utility to peel the Expected wrapper off a response handler error type.
667template <typename HandlerT> class ResponseHandlerArg;
668
669template <typename ArgT> class ResponseHandlerArg<Error(Expected<ArgT>)> {
670public:
671 using ArgType = Expected<ArgT>;
672 using UnwrappedArgType = ArgT;
673};
674
675template <typename ArgT>
676class ResponseHandlerArg<ErrorSuccess(Expected<ArgT>)> {
677public:
678 using ArgType = Expected<ArgT>;
679 using UnwrappedArgType = ArgT;
680};
681
682template <> class ResponseHandlerArg<Error(Error)> {
683public:
684 using ArgType = Error;
685};
686
687template <> class ResponseHandlerArg<ErrorSuccess(Error)> {
688public:
689 using ArgType = Error;
690};
691
692// ResponseHandler represents a handler for a not-yet-received function call
693// result.
694template <typename ChannelT> class ResponseHandler {
695public:
696 virtual ~ResponseHandler() {}
697
698 // Reads the function result off the wire and acts on it. The meaning of
699 // "act" will depend on how this method is implemented in any given
700 // ResponseHandler subclass but could, for example, mean running a
701 // user-specified handler or setting a promise value.
702 virtual Error handleResponse(ChannelT &C) = 0;
703
704 // Abandons this outstanding result.
705 virtual void abandon() = 0;
706
707 // Create an error instance representing an abandoned response.
708 static Error createAbandonedResponseError() {
709 return make_error<ResponseAbandoned>();
710 }
711};
712
713// ResponseHandler subclass for RPC functions with non-void returns.
714template <typename ChannelT, typename FuncRetT, typename HandlerT>
715class ResponseHandlerImpl : public ResponseHandler<ChannelT> {
716public:
717 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
718
719 // Handle the result by deserializing it from the channel then passing it
720 // to the user defined handler.
721 Error handleResponse(ChannelT &C) override {
722 using UnwrappedArgType = typename ResponseHandlerArg<
723 typename HandlerTraits<HandlerT>::Type>::UnwrappedArgType;
724 UnwrappedArgType Result;
725 if (auto Err =
726 SerializationTraits<ChannelT, FuncRetT,
727 UnwrappedArgType>::deserialize(C, Result))
728 return Err;
729 if (auto Err = C.endReceiveMessage())
730 return Err;
731 return Handler(std::move(Result));
732 }
733
734 // Abandon this response by calling the handler with an 'abandoned response'
735 // error.
736 void abandon() override {
737 if (auto Err = Handler(this->createAbandonedResponseError())) {
738 // Handlers should not fail when passed an abandoned response error.
739 report_fatal_error(std::move(Err));
740 }
741 }
742
743private:
744 HandlerT Handler;
745};
746
747// ResponseHandler subclass for RPC functions with void returns.
748template <typename ChannelT, typename HandlerT>
749class ResponseHandlerImpl<ChannelT, void, HandlerT>
750 : public ResponseHandler<ChannelT> {
751public:
752 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
753
754 // Handle the result (no actual value, just a notification that the function
755 // has completed on the remote end) by calling the user-defined handler with
756 // Error::success().
757 Error handleResponse(ChannelT &C) override {
758 if (auto Err = C.endReceiveMessage())
759 return Err;
760 return Handler(Error::success());
761 }
762
763 // Abandon this response by calling the handler with an 'abandoned response'
764 // error.
765 void abandon() override {
766 if (auto Err = Handler(this->createAbandonedResponseError())) {
767 // Handlers should not fail when passed an abandoned response error.
768 report_fatal_error(std::move(Err));
769 }
770 }
771
772private:
773 HandlerT Handler;
774};
775
776template <typename ChannelT, typename FuncRetT, typename HandlerT>
777class ResponseHandlerImpl<ChannelT, Expected<FuncRetT>, HandlerT>
778 : public ResponseHandler<ChannelT> {
779public:
780 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
781
782 // Handle the result by deserializing it from the channel then passing it
783 // to the user defined handler.
784 Error handleResponse(ChannelT &C) override {
785 using HandlerArgType = typename ResponseHandlerArg<
786 typename HandlerTraits<HandlerT>::Type>::ArgType;
787 HandlerArgType Result((typename HandlerArgType::value_type()));
788
789 if (auto Err =
790 SerializationTraits<ChannelT, Expected<FuncRetT>,
791 HandlerArgType>::deserialize(C, Result))
792 return Err;
793 if (auto Err = C.endReceiveMessage())
794 return Err;
795 return Handler(std::move(Result));
796 }
797
798 // Abandon this response by calling the handler with an 'abandoned response'
799 // error.
800 void abandon() override {
801 if (auto Err = Handler(this->createAbandonedResponseError())) {
802 // Handlers should not fail when passed an abandoned response error.
803 report_fatal_error(std::move(Err));
804 }
805 }
806
807private:
808 HandlerT Handler;
809};
810
811template <typename ChannelT, typename HandlerT>
812class ResponseHandlerImpl<ChannelT, Error, HandlerT>
813 : public ResponseHandler<ChannelT> {
814public:
815 ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
816
817 // Handle the result by deserializing it from the channel then passing it
818 // to the user defined handler.
819 Error handleResponse(ChannelT &C) override {
820 Error Result = Error::success();
821 if (auto Err =
822 SerializationTraits<ChannelT, Error, Error>::deserialize(C, Result))
823 return Err;
824 if (auto Err = C.endReceiveMessage())
825 return Err;
826 return Handler(std::move(Result));
827 }
828
829 // Abandon this response by calling the handler with an 'abandoned response'
830 // error.
831 void abandon() override {
832 if (auto Err = Handler(this->createAbandonedResponseError())) {
833 // Handlers should not fail when passed an abandoned response error.
834 report_fatal_error(std::move(Err));
835 }
836 }
837
838private:
839 HandlerT Handler;
840};
841
842// Create a ResponseHandler from a given user handler.
843template <typename ChannelT, typename FuncRetT, typename HandlerT>
844std::unique_ptr<ResponseHandler<ChannelT>> createResponseHandler(HandlerT H) {
845 return llvm::make_unique<ResponseHandlerImpl<ChannelT, FuncRetT, HandlerT>>(
846 std::move(H));
847}
848
849// Helper for wrapping member functions up as functors. This is useful for
850// installing methods as result handlers.
851template <typename ClassT, typename RetT, typename... ArgTs>
852class MemberFnWrapper {
853public:
854 using MethodT = RetT (ClassT::*)(ArgTs...);
855 MemberFnWrapper(ClassT &Instance, MethodT Method)
856 : Instance(Instance), Method(Method) {}
857 RetT operator()(ArgTs &&... Args) {
858 return (Instance.*Method)(std::move(Args)...);
859 }
860
861private:
862 ClassT &Instance;
863 MethodT Method;
864};
865
866// Helper that provides a Functor for deserializing arguments.
867template <typename... ArgTs> class ReadArgs {
868public:
869 Error operator()() { return Error::success(); }
870};
871
872template <typename ArgT, typename... ArgTs>
873class ReadArgs<ArgT, ArgTs...> : public ReadArgs<ArgTs...> {
874public:
875 ReadArgs(ArgT &Arg, ArgTs &... Args)
876 : ReadArgs<ArgTs...>(Args...), Arg(Arg) {}
877
878 Error operator()(ArgT &ArgVal, ArgTs &... ArgVals) {
879 this->Arg = std::move(ArgVal);
880 return ReadArgs<ArgTs...>::operator()(ArgVals...);
881 }
882
883private:
884 ArgT &Arg;
885};
886
887// Manage sequence numbers.
888template <typename SequenceNumberT> class SequenceNumberManager {
889public:
890 // Reset, making all sequence numbers available.
891 void reset() {
892 std::lock_guard<std::mutex> Lock(SeqNoLock);
893 NextSequenceNumber = 0;
894 FreeSequenceNumbers.clear();
895 }
896
897 // Get the next available sequence number. Will re-use numbers that have
898 // been released.
899 SequenceNumberT getSequenceNumber() {
900 std::lock_guard<std::mutex> Lock(SeqNoLock);
901 if (FreeSequenceNumbers.empty())
902 return NextSequenceNumber++;
903 auto SequenceNumber = FreeSequenceNumbers.back();
904 FreeSequenceNumbers.pop_back();
905 return SequenceNumber;
906 }
907
908 // Release a sequence number, making it available for re-use.
909 void releaseSequenceNumber(SequenceNumberT SequenceNumber) {
910 std::lock_guard<std::mutex> Lock(SeqNoLock);
911 FreeSequenceNumbers.push_back(SequenceNumber);
912 }
913
914private:
915 std::mutex SeqNoLock;
916 SequenceNumberT NextSequenceNumber = 0;
917 std::vector<SequenceNumberT> FreeSequenceNumbers;
918};
919
920// Checks that predicate P holds for each corresponding pair of type arguments
921// from T1 and T2 tuple.
922template <template <class, class> class P, typename T1Tuple, typename T2Tuple>
923class RPCArgTypeCheckHelper;
924
925template <template <class, class> class P>
926class RPCArgTypeCheckHelper<P, std::tuple<>, std::tuple<>> {
927public:
928 static const bool value = true;
929};
930
931template <template <class, class> class P, typename T, typename... Ts,
932 typename U, typename... Us>
933class RPCArgTypeCheckHelper<P, std::tuple<T, Ts...>, std::tuple<U, Us...>> {
934public:
935 static const bool value =
936 P<T, U>::value &&
937 RPCArgTypeCheckHelper<P, std::tuple<Ts...>, std::tuple<Us...>>::value;
938};
939
940template <template <class, class> class P, typename T1Sig, typename T2Sig>
941class RPCArgTypeCheck {
942public:
943 using T1Tuple = typename FunctionArgsTuple<T1Sig>::Type;
944 using T2Tuple = typename FunctionArgsTuple<T2Sig>::Type;
945
946 static_assert(std::tuple_size<T1Tuple>::value >=
947 std::tuple_size<T2Tuple>::value,
948 "Too many arguments to RPC call");
949 static_assert(std::tuple_size<T1Tuple>::value <=
950 std::tuple_size<T2Tuple>::value,
951 "Too few arguments to RPC call");
952
953 static const bool value = RPCArgTypeCheckHelper<P, T1Tuple, T2Tuple>::value;
954};
955
956template <typename ChannelT, typename WireT, typename ConcreteT>
957class CanSerialize {
958private:
959 using S = SerializationTraits<ChannelT, WireT, ConcreteT>;
960
961 template <typename T>
962 static std::true_type
963 check(typename std::enable_if<
964 std::is_same<decltype(T::serialize(std::declval<ChannelT &>(),
965 std::declval<const ConcreteT &>())),
966 Error>::value,
967 void *>::type);
968
969 template <typename> static std::false_type check(...);
970
971public:
972 static const bool value = decltype(check<S>(0))::value;
973};
974
975template <typename ChannelT, typename WireT, typename ConcreteT>
976class CanDeserialize {
977private:
978 using S = SerializationTraits<ChannelT, WireT, ConcreteT>;
979
980 template <typename T>
981 static std::true_type
982 check(typename std::enable_if<
983 std::is_same<decltype(T::deserialize(std::declval<ChannelT &>(),
984 std::declval<ConcreteT &>())),
985 Error>::value,
986 void *>::type);
987
988 template <typename> static std::false_type check(...);
989
990public:
991 static const bool value = decltype(check<S>(0))::value;
992};
993
994/// Contains primitive utilities for defining, calling and handling calls to
995/// remote procedures. ChannelT is a bidirectional stream conforming to the
996/// RPCChannel interface (see RPCChannel.h), FunctionIdT is a procedure
997/// identifier type that must be serializable on ChannelT, and SequenceNumberT
998/// is an integral type that will be used to number in-flight function calls.
999///
1000/// These utilities support the construction of very primitive RPC utilities.
1001/// Their intent is to ensure correct serialization and deserialization of
1002/// procedure arguments, and to keep the client and server's view of the API in
1003/// sync.
1004template <typename ImplT, typename ChannelT, typename FunctionIdT,
1005 typename SequenceNumberT>
1006class RPCEndpointBase {
1007protected:
1008 class OrcRPCInvalid : public Function<OrcRPCInvalid, void()> {
1009 public:
1010 static const char *getName() { return "__orc_rpc$invalid"; }
1011 };
1012
1013 class OrcRPCResponse : public Function<OrcRPCResponse, void()> {
1014 public:
1015 static const char *getName() { return "__orc_rpc$response"; }
1016 };
1017
1018 class OrcRPCNegotiate
1019 : public Function<OrcRPCNegotiate, FunctionIdT(std::string)> {
1020 public:
1021 static const char *getName() { return "__orc_rpc$negotiate"; }
1022 };
1023
1024 // Helper predicate for testing for the presence of SerializeTraits
1025 // serializers.
1026 template <typename WireT, typename ConcreteT>
1027 class CanSerializeCheck : detail::CanSerialize<ChannelT, WireT, ConcreteT> {
1028 public:
1029 using detail::CanSerialize<ChannelT, WireT, ConcreteT>::value;
1030
1031 static_assert(value, "Missing serializer for argument (Can't serialize the "
1032 "first template type argument of CanSerializeCheck "
1033 "from the second)");
1034 };
1035
1036 // Helper predicate for testing for the presence of SerializeTraits
1037 // deserializers.
1038 template <typename WireT, typename ConcreteT>
1039 class CanDeserializeCheck
1040 : detail::CanDeserialize<ChannelT, WireT, ConcreteT> {
1041 public:
1042 using detail::CanDeserialize<ChannelT, WireT, ConcreteT>::value;
1043
1044 static_assert(value, "Missing deserializer for argument (Can't deserialize "
1045 "the second template type argument of "
1046 "CanDeserializeCheck from the first)");
1047 };
1048
1049public:
1050 /// Construct an RPC instance on a channel.
1051 RPCEndpointBase(ChannelT &C, bool LazyAutoNegotiation)
1052 : C(C), LazyAutoNegotiation(LazyAutoNegotiation) {
1053 // Hold ResponseId in a special variable, since we expect Response to be
1054 // called relatively frequently, and want to avoid the map lookup.
1055 ResponseId = FnIdAllocator.getResponseId();
1056 RemoteFunctionIds[OrcRPCResponse::getPrototype()] = ResponseId;
1057
1058 // Register the negotiate function id and handler.
1059 auto NegotiateId = FnIdAllocator.getNegotiateId();
1060 RemoteFunctionIds[OrcRPCNegotiate::getPrototype()] = NegotiateId;
1061 Handlers[NegotiateId] = wrapHandler<OrcRPCNegotiate>(
1062 [this](const std::string &Name) { return handleNegotiate(Name); });
1063 }
1064
1065
1066 /// Negotiate a function id for Func with the other end of the channel.
1067 template <typename Func> Error negotiateFunction(bool Retry = false) {
1068 return getRemoteFunctionId<Func>(true, Retry).takeError();
1069 }
1070
1071 /// Append a call Func, does not call send on the channel.
1072 /// The first argument specifies a user-defined handler to be run when the
1073 /// function returns. The handler should take an Expected<Func::ReturnType>,
1074 /// or an Error (if Func::ReturnType is void). The handler will be called
1075 /// with an error if the return value is abandoned due to a channel error.
1076 template <typename Func, typename HandlerT, typename... ArgTs>
1077 Error appendCallAsync(HandlerT Handler, const ArgTs &... Args) {
1078
1079 static_assert(
1080 detail::RPCArgTypeCheck<CanSerializeCheck, typename Func::Type,
1081 void(ArgTs...)>::value,
1082 "");
1083
1084 // Look up the function ID.
1085 FunctionIdT FnId;
1086 if (auto FnIdOrErr = getRemoteFunctionId<Func>(LazyAutoNegotiation, false))
1087 FnId = *FnIdOrErr;
1088 else {
1089 // Negotiation failed. Notify the handler then return the negotiate-failed
1090 // error.
1091 cantFail(Handler(make_error<ResponseAbandoned>()));
1092 return FnIdOrErr.takeError();
1093 }
1094
1095 SequenceNumberT SeqNo; // initialized in locked scope below.
1096 {
1097 // Lock the pending responses map and sequence number manager.
1098 std::lock_guard<std::mutex> Lock(ResponsesMutex);
1099
1100 // Allocate a sequence number.
1101 SeqNo = SequenceNumberMgr.getSequenceNumber();
1102 assert(!PendingResponses.count(SeqNo) &&
1103 "Sequence number already allocated");
1104
1105 // Install the user handler.
1106 PendingResponses[SeqNo] =
1107 detail::createResponseHandler<ChannelT, typename Func::ReturnType>(
1108 std::move(Handler));
1109 }
1110
1111 // Open the function call message.
1112 if (auto Err = C.startSendMessage(FnId, SeqNo)) {
1113 abandonPendingResponses();
1114 return Err;
1115 }
1116
1117 // Serialize the call arguments.
1118 if (auto Err = detail::HandlerTraits<typename Func::Type>::serializeArgs(
1119 C, Args...)) {
1120 abandonPendingResponses();
1121 return Err;
1122 }
1123
1124 // Close the function call messagee.
1125 if (auto Err = C.endSendMessage()) {
1126 abandonPendingResponses();
1127 return Err;
1128 }
1129
1130 return Error::success();
1131 }
1132
1133 Error sendAppendedCalls() { return C.send(); };
1134
1135 template <typename Func, typename HandlerT, typename... ArgTs>
1136 Error callAsync(HandlerT Handler, const ArgTs &... Args) {
1137 if (auto Err = appendCallAsync<Func>(std::move(Handler), Args...))
1138 return Err;
1139 return C.send();
1140 }
1141
1142 /// Handle one incoming call.
1143 Error handleOne() {
1144 FunctionIdT FnId;
1145 SequenceNumberT SeqNo;
1146 if (auto Err = C.startReceiveMessage(FnId, SeqNo)) {
1147 abandonPendingResponses();
1148 return Err;
1149 }
1150 if (FnId == ResponseId)
1151 return handleResponse(SeqNo);
1152 auto I = Handlers.find(FnId);
1153 if (I != Handlers.end())
1154 return I->second(C, SeqNo);
1155
1156 // else: No handler found. Report error to client?
1157 return make_error<BadFunctionCall<FunctionIdT, SequenceNumberT>>(FnId,
1158 SeqNo);
1159 }
1160
1161 /// Helper for handling setter procedures - this method returns a functor that
1162 /// sets the variables referred to by Args... to values deserialized from the
1163 /// channel.
1164 /// E.g.
1165 ///
1166 /// typedef Function<0, bool, int> Func1;
1167 ///
1168 /// ...
1169 /// bool B;
1170 /// int I;
1171 /// if (auto Err = expect<Func1>(Channel, readArgs(B, I)))
1172 /// /* Handle Args */ ;
1173 ///
1174 template <typename... ArgTs>
1175 static detail::ReadArgs<ArgTs...> readArgs(ArgTs &... Args) {
1176 return detail::ReadArgs<ArgTs...>(Args...);
1177 }
1178
1179 /// Abandon all outstanding result handlers.
1180 ///
1181 /// This will call all currently registered result handlers to receive an
1182 /// "abandoned" error as their argument. This is used internally by the RPC
1183 /// in error situations, but can also be called directly by clients who are
1184 /// disconnecting from the remote and don't or can't expect responses to their
1185 /// outstanding calls. (Especially for outstanding blocking calls, calling
1186 /// this function may be necessary to avoid dead threads).
1187 void abandonPendingResponses() {
1188 // Lock the pending responses map and sequence number manager.
1189 std::lock_guard<std::mutex> Lock(ResponsesMutex);
1190
1191 for (auto &KV : PendingResponses)
1192 KV.second->abandon();
1193 PendingResponses.clear();
1194 SequenceNumberMgr.reset();
1195 }
1196
1197 /// Remove the handler for the given function.
1198 /// A handler must currently be registered for this function.
1199 template <typename Func>
1200 void removeHandler() {
1201 auto IdItr = LocalFunctionIds.find(Func::getPrototype());
1202 assert(IdItr != LocalFunctionIds.end() &&
1203 "Function does not have a registered handler");
1204 auto HandlerItr = Handlers.find(IdItr->second);
1205 assert(HandlerItr != Handlers.end() &&
1206 "Function does not have a registered handler");
1207 Handlers.erase(HandlerItr);
1208 }
1209
1210 /// Clear all handlers.
1211 void clearHandlers() {
1212 Handlers.clear();
1213 }
1214
1215protected:
1216
1217 FunctionIdT getInvalidFunctionId() const {
1218 return FnIdAllocator.getInvalidId();
1219 }
1220
1221 /// Add the given handler to the handler map and make it available for
1222 /// autonegotiation and execution.
1223 template <typename Func, typename HandlerT>
1224 void addHandlerImpl(HandlerT Handler) {
1225
1226 static_assert(detail::RPCArgTypeCheck<
1227 CanDeserializeCheck, typename Func::Type,
1228 typename detail::HandlerTraits<HandlerT>::Type>::value,
1229 "");
1230
1231 FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>();
1232 LocalFunctionIds[Func::getPrototype()] = NewFnId;
1233 Handlers[NewFnId] = wrapHandler<Func>(std::move(Handler));
1234 }
1235
1236 template <typename Func, typename HandlerT>
1237 void addAsyncHandlerImpl(HandlerT Handler) {
1238
1239 static_assert(detail::RPCArgTypeCheck<
1240 CanDeserializeCheck, typename Func::Type,
1241 typename detail::AsyncHandlerTraits<
1242 typename detail::HandlerTraits<HandlerT>::Type
1243 >::Type>::value,
1244 "");
1245
1246 FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>();
1247 LocalFunctionIds[Func::getPrototype()] = NewFnId;
1248 Handlers[NewFnId] = wrapAsyncHandler<Func>(std::move(Handler));
1249 }
1250
1251 Error handleResponse(SequenceNumberT SeqNo) {
1252 using Handler = typename decltype(PendingResponses)::mapped_type;
1253 Handler PRHandler;
1254
1255 {
1256 // Lock the pending responses map and sequence number manager.
1257 std::unique_lock<std::mutex> Lock(ResponsesMutex);
1258 auto I = PendingResponses.find(SeqNo);
1259
1260 if (I != PendingResponses.end()) {
1261 PRHandler = std::move(I->second);
1262 PendingResponses.erase(I);
1263 SequenceNumberMgr.releaseSequenceNumber(SeqNo);
1264 } else {
1265 // Unlock the pending results map to prevent recursive lock.
1266 Lock.unlock();
1267 abandonPendingResponses();
1268 return make_error<
1269 InvalidSequenceNumberForResponse<SequenceNumberT>>(SeqNo);
1270 }
1271 }
1272
1273 assert(PRHandler &&
1274 "If we didn't find a response handler we should have bailed out");
1275
1276 if (auto Err = PRHandler->handleResponse(C)) {
1277 abandonPendingResponses();
1278 return Err;
1279 }
1280
1281 return Error::success();
1282 }
1283
1284 FunctionIdT handleNegotiate(const std::string &Name) {
1285 auto I = LocalFunctionIds.find(Name);
1286 if (I == LocalFunctionIds.end())
1287 return getInvalidFunctionId();
1288 return I->second;
1289 }
1290
1291 // Find the remote FunctionId for the given function.
1292 template <typename Func>
1293 Expected<FunctionIdT> getRemoteFunctionId(bool NegotiateIfNotInMap,
1294 bool NegotiateIfInvalid) {
1295 bool DoNegotiate;
1296
1297 // Check if we already have a function id...
1298 auto I = RemoteFunctionIds.find(Func::getPrototype());
1299 if (I != RemoteFunctionIds.end()) {
1300 // If it's valid there's nothing left to do.
1301 if (I->second != getInvalidFunctionId())
1302 return I->second;
1303 DoNegotiate = NegotiateIfInvalid;
1304 } else
1305 DoNegotiate = NegotiateIfNotInMap;
1306
1307 // We don't have a function id for Func yet, but we're allowed to try to
1308 // negotiate one.
1309 if (DoNegotiate) {
1310 auto &Impl = static_cast<ImplT &>(*this);
1311 if (auto RemoteIdOrErr =
1312 Impl.template callB<OrcRPCNegotiate>(Func::getPrototype())) {
1313 RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr;
1314 if (*RemoteIdOrErr == getInvalidFunctionId())
1315 return make_error<CouldNotNegotiate>(Func::getPrototype());
1316 return *RemoteIdOrErr;
1317 } else
1318 return RemoteIdOrErr.takeError();
1319 }
1320
1321 // No key was available in the map and we weren't allowed to try to
1322 // negotiate one, so return an unknown function error.
1323 return make_error<CouldNotNegotiate>(Func::getPrototype());
1324 }
1325
1326 using WrappedHandlerFn = std::function<Error(ChannelT &, SequenceNumberT)>;
1327
1328 // Wrap the given user handler in the necessary argument-deserialization code,
1329 // result-serialization code, and call to the launch policy (if present).
1330 template <typename Func, typename HandlerT>
1331 WrappedHandlerFn wrapHandler(HandlerT Handler) {
1332 return [this, Handler](ChannelT &Channel,
1333 SequenceNumberT SeqNo) mutable -> Error {
1334 // Start by deserializing the arguments.
1335 using ArgsTuple =
1336 typename detail::FunctionArgsTuple<
1337 typename detail::HandlerTraits<HandlerT>::Type>::Type;
1338 auto Args = std::make_shared<ArgsTuple>();
1339
1340 if (auto Err =
1341 detail::HandlerTraits<typename Func::Type>::deserializeArgs(
1342 Channel, *Args))
1343 return Err;
1344
1345 // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning
1346 // for RPCArgs. Void cast RPCArgs to work around this for now.
1347 // FIXME: Remove this workaround once we can assume a working GCC version.
1348 (void)Args;
1349
1350 // End receieve message, unlocking the channel for reading.
1351 if (auto Err = Channel.endReceiveMessage())
1352 return Err;
1353
1354 using HTraits = detail::HandlerTraits<HandlerT>;
1355 using FuncReturn = typename Func::ReturnType;
1356 return detail::respond<FuncReturn>(Channel, ResponseId, SeqNo,
1357 HTraits::unpackAndRun(Handler, *Args));
1358 };
1359 }
1360
1361 // Wrap the given user handler in the necessary argument-deserialization code,
1362 // result-serialization code, and call to the launch policy (if present).
1363 template <typename Func, typename HandlerT>
1364 WrappedHandlerFn wrapAsyncHandler(HandlerT Handler) {
1365 return [this, Handler](ChannelT &Channel,
1366 SequenceNumberT SeqNo) mutable -> Error {
1367 // Start by deserializing the arguments.
1368 using AHTraits = detail::AsyncHandlerTraits<
1369 typename detail::HandlerTraits<HandlerT>::Type>;
1370 using ArgsTuple =
1371 typename detail::FunctionArgsTuple<typename AHTraits::Type>::Type;
1372 auto Args = std::make_shared<ArgsTuple>();
1373
1374 if (auto Err =
1375 detail::HandlerTraits<typename Func::Type>::deserializeArgs(
1376 Channel, *Args))
1377 return Err;
1378
1379 // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning
1380 // for RPCArgs. Void cast RPCArgs to work around this for now.
1381 // FIXME: Remove this workaround once we can assume a working GCC version.
1382 (void)Args;
1383
1384 // End receieve message, unlocking the channel for reading.
1385 if (auto Err = Channel.endReceiveMessage())
1386 return Err;
1387
1388 using HTraits = detail::HandlerTraits<HandlerT>;
1389 using FuncReturn = typename Func::ReturnType;
1390 auto Responder =
1391 [this, SeqNo](typename AHTraits::ResultType RetVal) -> Error {
1392 return detail::respond<FuncReturn>(C, ResponseId, SeqNo,
1393 std::move(RetVal));
1394 };
1395
1396 return HTraits::unpackAndRunAsync(Handler, Responder, *Args);
1397 };
1398 }
1399
1400 ChannelT &C;
1401
1402 bool LazyAutoNegotiation;
1403
1404 RPCFunctionIdAllocator<FunctionIdT> FnIdAllocator;
1405
1406 FunctionIdT ResponseId;
1407 std::map<std::string, FunctionIdT> LocalFunctionIds;
1408 std::map<const char *, FunctionIdT> RemoteFunctionIds;
1409
1410 std::map<FunctionIdT, WrappedHandlerFn> Handlers;
1411
1412 std::mutex ResponsesMutex;
1413 detail::SequenceNumberManager<SequenceNumberT> SequenceNumberMgr;
1414 std::map<SequenceNumberT, std::unique_ptr<detail::ResponseHandler<ChannelT>>>
1415 PendingResponses;
1416};
1417
1418} // end namespace detail
1419
1420template <typename ChannelT, typename FunctionIdT = uint32_t,
1421 typename SequenceNumberT = uint32_t>
1422class MultiThreadedRPCEndpoint
1423 : public detail::RPCEndpointBase<
1424 MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1425 ChannelT, FunctionIdT, SequenceNumberT> {
1426private:
1427 using BaseClass =
1428 detail::RPCEndpointBase<
1429 MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1430 ChannelT, FunctionIdT, SequenceNumberT>;
1431
1432public:
1433 MultiThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation)
1434 : BaseClass(C, LazyAutoNegotiation) {}
1435
1436 /// Add a handler for the given RPC function.
1437 /// This installs the given handler functor for the given RPC Function, and
1438 /// makes the RPC function available for negotiation/calling from the remote.
1439 template <typename Func, typename HandlerT>
1440 void addHandler(HandlerT Handler) {
1441 return this->template addHandlerImpl<Func>(std::move(Handler));
1442 }
1443
1444 /// Add a class-method as a handler.
1445 template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
1446 void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1447 addHandler<Func>(
1448 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1449 }
1450
1451 template <typename Func, typename HandlerT>
1452 void addAsyncHandler(HandlerT Handler) {
1453 return this->template addAsyncHandlerImpl<Func>(std::move(Handler));
1454 }
1455
1456 /// Add a class-method as a handler.
1457 template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
1458 void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1459 addAsyncHandler<Func>(
1460 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1461 }
1462
1463 /// Return type for non-blocking call primitives.
1464 template <typename Func>
1465 using NonBlockingCallResult = typename detail::ResultTraits<
1466 typename Func::ReturnType>::ReturnFutureType;
1467
1468 /// Call Func on Channel C. Does not block, does not call send. Returns a pair
1469 /// of a future result and the sequence number assigned to the result.
1470 ///
1471 /// This utility function is primarily used for single-threaded mode support,
1472 /// where the sequence number can be used to wait for the corresponding
1473 /// result. In multi-threaded mode the appendCallNB method, which does not
1474 /// return the sequence numeber, should be preferred.
1475 template <typename Func, typename... ArgTs>
1476 Expected<NonBlockingCallResult<Func>> appendCallNB(const ArgTs &... Args) {
1477 using RTraits = detail::ResultTraits<typename Func::ReturnType>;
1478 using ErrorReturn = typename RTraits::ErrorReturnType;
1479 using ErrorReturnPromise = typename RTraits::ReturnPromiseType;
1480
1481 // FIXME: Stack allocate and move this into the handler once LLVM builds
1482 // with C++14.
1483 auto Promise = std::make_shared<ErrorReturnPromise>();
1484 auto FutureResult = Promise->get_future();
1485
1486 if (auto Err = this->template appendCallAsync<Func>(
1487 [Promise](ErrorReturn RetOrErr) {
1488 Promise->set_value(std::move(RetOrErr));
1489 return Error::success();
1490 },
1491 Args...)) {
1492 RTraits::consumeAbandoned(FutureResult.get());
1493 return std::move(Err);
1494 }
1495 return std::move(FutureResult);
1496 }
1497
1498 /// The same as appendCallNBWithSeq, except that it calls C.send() to
1499 /// flush the channel after serializing the call.
1500 template <typename Func, typename... ArgTs>
1501 Expected<NonBlockingCallResult<Func>> callNB(const ArgTs &... Args) {
1502 auto Result = appendCallNB<Func>(Args...);
1503 if (!Result)
1504 return Result;
1505 if (auto Err = this->C.send()) {
1506 this->abandonPendingResponses();
1507 detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1508 std::move(Result->get()));
1509 return std::move(Err);
1510 }
1511 return Result;
1512 }
1513
1514 /// Call Func on Channel C. Blocks waiting for a result. Returns an Error
1515 /// for void functions or an Expected<T> for functions returning a T.
1516 ///
1517 /// This function is for use in threaded code where another thread is
1518 /// handling responses and incoming calls.
1519 template <typename Func, typename... ArgTs,
1520 typename AltRetT = typename Func::ReturnType>
1521 typename detail::ResultTraits<AltRetT>::ErrorReturnType
1522 callB(const ArgTs &... Args) {
1523 if (auto FutureResOrErr = callNB<Func>(Args...))
1524 return FutureResOrErr->get();
1525 else
1526 return FutureResOrErr.takeError();
1527 }
1528
1529 /// Handle incoming RPC calls.
1530 Error handlerLoop() {
1531 while (true)
1532 if (auto Err = this->handleOne())
1533 return Err;
1534 return Error::success();
1535 }
1536};
1537
1538template <typename ChannelT, typename FunctionIdT = uint32_t,
1539 typename SequenceNumberT = uint32_t>
1540class SingleThreadedRPCEndpoint
1541 : public detail::RPCEndpointBase<
1542 SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1543 ChannelT, FunctionIdT, SequenceNumberT> {
1544private:
1545 using BaseClass =
1546 detail::RPCEndpointBase<
1547 SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
1548 ChannelT, FunctionIdT, SequenceNumberT>;
1549
1550public:
1551 SingleThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation)
1552 : BaseClass(C, LazyAutoNegotiation) {}
1553
1554 template <typename Func, typename HandlerT>
1555 void addHandler(HandlerT Handler) {
1556 return this->template addHandlerImpl<Func>(std::move(Handler));
1557 }
1558
1559 template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
1560 void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1561 addHandler<Func>(
1562 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1563 }
1564
1565 template <typename Func, typename HandlerT>
1566 void addAsyncHandler(HandlerT Handler) {
1567 return this->template addAsyncHandlerImpl<Func>(std::move(Handler));
1568 }
1569
1570 /// Add a class-method as a handler.
1571 template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
1572 void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
1573 addAsyncHandler<Func>(
1574 detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
1575 }
1576
1577 template <typename Func, typename... ArgTs,
1578 typename AltRetT = typename Func::ReturnType>
1579 typename detail::ResultTraits<AltRetT>::ErrorReturnType
1580 callB(const ArgTs &... Args) {
1581 bool ReceivedResponse = false;
1582 using ResultType = typename detail::ResultTraits<AltRetT>::ErrorReturnType;
1583 auto Result = detail::ResultTraits<AltRetT>::createBlankErrorReturnValue();
1584
1585 // We have to 'Check' result (which we know is in a success state at this
1586 // point) so that it can be overwritten in the async handler.
1587 (void)!!Result;
1588
1589 if (auto Err = this->template appendCallAsync<Func>(
1590 [&](ResultType R) {
1591 Result = std::move(R);
1592 ReceivedResponse = true;
1593 return Error::success();
1594 },
1595 Args...)) {
1596 detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1597 std::move(Result));
1598 return std::move(Err);
1599 }
1600
1601 while (!ReceivedResponse) {
1602 if (auto Err = this->handleOne()) {
1603 detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
1604 std::move(Result));
1605 return std::move(Err);
1606 }
1607 }
1608
1609 return Result;
1610 }
1611};
1612
1613/// Asynchronous dispatch for a function on an RPC endpoint.
1614template <typename RPCClass, typename Func>
1615class RPCAsyncDispatch {
1616public:
1617 RPCAsyncDispatch(RPCClass &Endpoint) : Endpoint(Endpoint) {}
1618
1619 template <typename HandlerT, typename... ArgTs>
1620 Error operator()(HandlerT Handler, const ArgTs &... Args) const {
1621 return Endpoint.template appendCallAsync<Func>(std::move(Handler), Args...);
1622 }
1623
1624private:
1625 RPCClass &Endpoint;
1626};
1627
1628/// Construct an asynchronous dispatcher from an RPC endpoint and a Func.
1629template <typename Func, typename RPCEndpointT>
1630RPCAsyncDispatch<RPCEndpointT, Func> rpcAsyncDispatch(RPCEndpointT &Endpoint) {
1631 return RPCAsyncDispatch<RPCEndpointT, Func>(Endpoint);
1632}
1633
Andrew Scullcdfcccc2018-10-05 20:58:37 +01001634/// Allows a set of asynchrounous calls to be dispatched, and then
Andrew Scull5e1ddfa2018-08-14 10:06:54 +01001635/// waited on as a group.
1636class ParallelCallGroup {
1637public:
1638
1639 ParallelCallGroup() = default;
1640 ParallelCallGroup(const ParallelCallGroup &) = delete;
1641 ParallelCallGroup &operator=(const ParallelCallGroup &) = delete;
1642
Andrew Scullcdfcccc2018-10-05 20:58:37 +01001643 /// Make as asynchronous call.
Andrew Scull5e1ddfa2018-08-14 10:06:54 +01001644 template <typename AsyncDispatcher, typename HandlerT, typename... ArgTs>
1645 Error call(const AsyncDispatcher &AsyncDispatch, HandlerT Handler,
1646 const ArgTs &... Args) {
1647 // Increment the count of outstanding calls. This has to happen before
1648 // we invoke the call, as the handler may (depending on scheduling)
1649 // be run immediately on another thread, and we don't want the decrement
1650 // in the wrapped handler below to run before the increment.
1651 {
1652 std::unique_lock<std::mutex> Lock(M);
1653 ++NumOutstandingCalls;
1654 }
1655
1656 // Wrap the user handler in a lambda that will decrement the
1657 // outstanding calls count, then poke the condition variable.
1658 using ArgType = typename detail::ResponseHandlerArg<
1659 typename detail::HandlerTraits<HandlerT>::Type>::ArgType;
1660 // FIXME: Move handler into wrapped handler once we have C++14.
1661 auto WrappedHandler = [this, Handler](ArgType Arg) {
1662 auto Err = Handler(std::move(Arg));
1663 std::unique_lock<std::mutex> Lock(M);
1664 --NumOutstandingCalls;
1665 CV.notify_all();
1666 return Err;
1667 };
1668
1669 return AsyncDispatch(std::move(WrappedHandler), Args...);
1670 }
1671
Andrew Scullcdfcccc2018-10-05 20:58:37 +01001672 /// Blocks until all calls have been completed and their return value
Andrew Scull5e1ddfa2018-08-14 10:06:54 +01001673 /// handlers run.
1674 void wait() {
1675 std::unique_lock<std::mutex> Lock(M);
1676 while (NumOutstandingCalls > 0)
1677 CV.wait(Lock);
1678 }
1679
1680private:
1681 std::mutex M;
1682 std::condition_variable CV;
1683 uint32_t NumOutstandingCalls = 0;
1684};
1685
Andrew Scullcdfcccc2018-10-05 20:58:37 +01001686/// Convenience class for grouping RPC Functions into APIs that can be
Andrew Scull5e1ddfa2018-08-14 10:06:54 +01001687/// negotiated as a block.
1688///
1689template <typename... Funcs>
1690class APICalls {
1691public:
1692
Andrew Scullcdfcccc2018-10-05 20:58:37 +01001693 /// Test whether this API contains Function F.
Andrew Scull5e1ddfa2018-08-14 10:06:54 +01001694 template <typename F>
1695 class Contains {
1696 public:
1697 static const bool value = false;
1698 };
1699
Andrew Scullcdfcccc2018-10-05 20:58:37 +01001700 /// Negotiate all functions in this API.
Andrew Scull5e1ddfa2018-08-14 10:06:54 +01001701 template <typename RPCEndpoint>
1702 static Error negotiate(RPCEndpoint &R) {
1703 return Error::success();
1704 }
1705};
1706
1707template <typename Func, typename... Funcs>
1708class APICalls<Func, Funcs...> {
1709public:
1710
1711 template <typename F>
1712 class Contains {
1713 public:
1714 static const bool value = std::is_same<F, Func>::value |
1715 APICalls<Funcs...>::template Contains<F>::value;
1716 };
1717
1718 template <typename RPCEndpoint>
1719 static Error negotiate(RPCEndpoint &R) {
1720 if (auto Err = R.template negotiateFunction<Func>())
1721 return Err;
1722 return APICalls<Funcs...>::negotiate(R);
1723 }
1724
1725};
1726
1727template <typename... InnerFuncs, typename... Funcs>
1728class APICalls<APICalls<InnerFuncs...>, Funcs...> {
1729public:
1730
1731 template <typename F>
1732 class Contains {
1733 public:
1734 static const bool value =
1735 APICalls<InnerFuncs...>::template Contains<F>::value |
1736 APICalls<Funcs...>::template Contains<F>::value;
1737 };
1738
1739 template <typename RPCEndpoint>
1740 static Error negotiate(RPCEndpoint &R) {
1741 if (auto Err = APICalls<InnerFuncs...>::negotiate(R))
1742 return Err;
1743 return APICalls<Funcs...>::negotiate(R);
1744 }
1745
1746};
1747
1748} // end namespace rpc
1749} // end namespace orc
1750} // end namespace llvm
1751
1752#endif