blob: db810f4ef2e53a4d12983a4647b5cab6fc702fb7 [file] [log] [blame]
Andrew Scull5e1ddfa2018-08-14 10:06:54 +01001//===- llvm/ExecutionEngine/Orc/RawByteChannel.h ----------------*- C++ -*-===//
2//
3// The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9
10#ifndef LLVM_EXECUTIONENGINE_ORC_RAWBYTECHANNEL_H
11#define LLVM_EXECUTIONENGINE_ORC_RAWBYTECHANNEL_H
12
13#include "llvm/ADT/StringRef.h"
14#include "llvm/ExecutionEngine/Orc/RPCSerialization.h"
15#include "llvm/Support/Endian.h"
16#include "llvm/Support/Error.h"
17#include <cstdint>
18#include <mutex>
19#include <string>
20#include <type_traits>
21
22namespace llvm {
23namespace orc {
24namespace rpc {
25
26/// Interface for byte-streams to be used with RPC.
27class RawByteChannel {
28public:
29 virtual ~RawByteChannel() = default;
30
31 /// Read Size bytes from the stream into *Dst.
32 virtual Error readBytes(char *Dst, unsigned Size) = 0;
33
34 /// Read size bytes from *Src and append them to the stream.
35 virtual Error appendBytes(const char *Src, unsigned Size) = 0;
36
37 /// Flush the stream if possible.
38 virtual Error send() = 0;
39
40 /// Notify the channel that we're starting a message send.
41 /// Locks the channel for writing.
42 template <typename FunctionIdT, typename SequenceIdT>
43 Error startSendMessage(const FunctionIdT &FnId, const SequenceIdT &SeqNo) {
44 writeLock.lock();
45 if (auto Err = serializeSeq(*this, FnId, SeqNo)) {
46 writeLock.unlock();
47 return Err;
48 }
49 return Error::success();
50 }
51
52 /// Notify the channel that we're ending a message send.
53 /// Unlocks the channel for writing.
54 Error endSendMessage() {
55 writeLock.unlock();
56 return Error::success();
57 }
58
59 /// Notify the channel that we're starting a message receive.
60 /// Locks the channel for reading.
61 template <typename FunctionIdT, typename SequenceNumberT>
62 Error startReceiveMessage(FunctionIdT &FnId, SequenceNumberT &SeqNo) {
63 readLock.lock();
64 if (auto Err = deserializeSeq(*this, FnId, SeqNo)) {
65 readLock.unlock();
66 return Err;
67 }
68 return Error::success();
69 }
70
71 /// Notify the channel that we're ending a message receive.
72 /// Unlocks the channel for reading.
73 Error endReceiveMessage() {
74 readLock.unlock();
75 return Error::success();
76 }
77
78 /// Get the lock for stream reading.
79 std::mutex &getReadLock() { return readLock; }
80
81 /// Get the lock for stream writing.
82 std::mutex &getWriteLock() { return writeLock; }
83
84private:
85 std::mutex readLock, writeLock;
86};
87
88template <typename ChannelT, typename T>
89class SerializationTraits<
90 ChannelT, T, T,
91 typename std::enable_if<
92 std::is_base_of<RawByteChannel, ChannelT>::value &&
93 (std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value ||
94 std::is_same<T, uint16_t>::value || std::is_same<T, int16_t>::value ||
95 std::is_same<T, uint32_t>::value || std::is_same<T, int32_t>::value ||
96 std::is_same<T, uint64_t>::value || std::is_same<T, int64_t>::value ||
97 std::is_same<T, char>::value)>::type> {
98public:
99 static Error serialize(ChannelT &C, T V) {
100 support::endian::byte_swap<T, support::big>(V);
101 return C.appendBytes(reinterpret_cast<const char *>(&V), sizeof(T));
102 };
103
104 static Error deserialize(ChannelT &C, T &V) {
105 if (auto Err = C.readBytes(reinterpret_cast<char *>(&V), sizeof(T)))
106 return Err;
107 support::endian::byte_swap<T, support::big>(V);
108 return Error::success();
109 };
110};
111
112template <typename ChannelT>
113class SerializationTraits<ChannelT, bool, bool,
114 typename std::enable_if<std::is_base_of<
115 RawByteChannel, ChannelT>::value>::type> {
116public:
117 static Error serialize(ChannelT &C, bool V) {
118 uint8_t Tmp = V ? 1 : 0;
119 if (auto Err =
120 C.appendBytes(reinterpret_cast<const char *>(&Tmp), 1))
121 return Err;
122 return Error::success();
123 }
124
125 static Error deserialize(ChannelT &C, bool &V) {
126 uint8_t Tmp = 0;
127 if (auto Err = C.readBytes(reinterpret_cast<char *>(&Tmp), 1))
128 return Err;
129 V = Tmp != 0;
130 return Error::success();
131 }
132};
133
134template <typename ChannelT>
135class SerializationTraits<ChannelT, std::string, StringRef,
136 typename std::enable_if<std::is_base_of<
137 RawByteChannel, ChannelT>::value>::type> {
138public:
139 /// RPC channel serialization for std::strings.
140 static Error serialize(RawByteChannel &C, StringRef S) {
141 if (auto Err = serializeSeq(C, static_cast<uint64_t>(S.size())))
142 return Err;
143 return C.appendBytes((const char *)S.data(), S.size());
144 }
145};
146
147template <typename ChannelT, typename T>
148class SerializationTraits<ChannelT, std::string, T,
149 typename std::enable_if<
150 std::is_base_of<RawByteChannel, ChannelT>::value &&
151 (std::is_same<T, const char*>::value ||
152 std::is_same<T, char*>::value)>::type> {
153public:
154 static Error serialize(RawByteChannel &C, const char *S) {
155 return SerializationTraits<ChannelT, std::string, StringRef>::serialize(C,
156 S);
157 }
158};
159
160template <typename ChannelT>
161class SerializationTraits<ChannelT, std::string, std::string,
162 typename std::enable_if<std::is_base_of<
163 RawByteChannel, ChannelT>::value>::type> {
164public:
165 /// RPC channel serialization for std::strings.
166 static Error serialize(RawByteChannel &C, const std::string &S) {
167 return SerializationTraits<ChannelT, std::string, StringRef>::serialize(C,
168 S);
169 }
170
171 /// RPC channel deserialization for std::strings.
172 static Error deserialize(RawByteChannel &C, std::string &S) {
173 uint64_t Count = 0;
174 if (auto Err = deserializeSeq(C, Count))
175 return Err;
176 S.resize(Count);
177 return C.readBytes(&S[0], Count);
178 }
179};
180
181} // end namespace rpc
182} // end namespace orc
183} // end namespace llvm
184
185#endif // LLVM_EXECUTIONENGINE_ORC_RAWBYTECHANNEL_H