Line data Source code
1 : // Copyright (c) 2020-2021 The Bitcoin Core developers
2 : // Distributed under the MIT software license, see the accompanying
3 : // file COPYING or http://www.opensource.org/licenses/mit-license.php.
4 :
5 : #include <compat/compat.h>
6 : #include <logging.h>
7 : #include <tinyformat.h>
8 : #include <util/sock.h>
9 : #include <util/syserror.h>
10 : #include <util/system.h>
11 : #include <util/threadinterrupt.h>
12 : #include <util/time.h>
13 :
14 : #include <memory>
15 : #include <stdexcept>
16 : #include <string>
17 :
18 : #ifdef USE_EPOLL
19 : #include <sys/epoll.h>
20 : #endif
21 :
22 : #ifdef USE_KQUEUE
23 : #include <sys/event.h>
24 : #endif
25 :
26 : #ifdef USE_POLL
27 : #include <poll.h>
28 : #endif
29 :
30 : SocketEventsMode g_socket_events_mode{SocketEventsMode::Unknown};
31 :
32 0 : static inline bool IOErrorIsPermanent(int err)
33 : {
34 0 : return err != WSAEAGAIN && err != WSAEINTR && err != WSAEWOULDBLOCK && err != WSAEINPROGRESS;
35 : }
36 :
37 24471 : static inline bool IsSelectableSocket(const SOCKET& s, bool is_select)
38 : {
39 : #if defined(WIN32)
40 : return true;
41 : #else
42 24471 : return is_select ? (s < FD_SETSIZE) : true;
43 : #endif
44 : }
45 :
46 20 : Sock::Sock() : m_socket(INVALID_SOCKET) {}
47 :
48 38298 : Sock::Sock(SOCKET s) : m_socket(s) {}
49 :
50 4 : Sock::Sock(Sock&& other)
51 4 : {
52 2 : m_socket = other.m_socket;
53 2 : other.m_socket = INVALID_SOCKET;
54 4 : }
55 :
56 57471 : Sock::~Sock() { Close(); }
57 :
58 2 : Sock& Sock::operator=(Sock&& other)
59 : {
60 2 : Close();
61 2 : m_socket = other.m_socket;
62 2 : other.m_socket = INVALID_SOCKET;
63 2 : return *this;
64 : }
65 :
66 68724285 : SOCKET Sock::Get() const { return m_socket; }
67 :
68 1564636 : ssize_t Sock::Send(const void* data, size_t len, int flags) const
69 : {
70 1564636 : return send(m_socket, static_cast<const char*>(data), len, flags);
71 : }
72 :
73 227 : ssize_t Sock::Recv(void* buf, size_t len, int flags) const
74 : {
75 227 : return recv(m_socket, static_cast<char*>(buf), len, flags);
76 : }
77 :
78 6038 : int Sock::Connect(const sockaddr* addr, socklen_t addr_len) const
79 : {
80 6038 : return connect(m_socket, addr, addr_len);
81 : }
82 :
83 5614 : int Sock::Bind(const sockaddr* addr, socklen_t addr_len) const
84 : {
85 5614 : return bind(m_socket, addr, addr_len);
86 : }
87 :
88 3036 : int Sock::Listen(int backlog) const
89 : {
90 3036 : return listen(m_socket, backlog);
91 : }
92 :
93 6638 : std::unique_ptr<Sock> Sock::Accept(sockaddr* addr, socklen_t* addr_len) const
94 : {
95 : #ifdef WIN32
96 : static constexpr auto ERR = INVALID_SOCKET;
97 : #else
98 : static constexpr auto ERR = SOCKET_ERROR;
99 : #endif
100 :
101 6638 : std::unique_ptr<Sock> sock;
102 :
103 6638 : const auto socket = accept(m_socket, addr, addr_len);
104 6638 : if (socket != ERR) {
105 : try {
106 6638 : sock = std::make_unique<Sock>(socket);
107 6638 : } catch (const std::exception&) {
108 : #ifdef WIN32
109 : closesocket(socket);
110 : #else
111 0 : close(socket);
112 : #endif
113 0 : }
114 6638 : }
115 :
116 6638 : return sock;
117 6638 : }
118 :
119 5992 : int Sock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const
120 : {
121 5992 : return getsockopt(m_socket, level, opt_name, static_cast<char*>(opt_val), opt_len);
122 : }
123 :
124 34804 : int Sock::SetSockOpt(int level, int opt_name, const void* opt_val, socklen_t opt_len) const
125 : {
126 34804 : return setsockopt(m_socket, level, opt_name, static_cast<const char*>(opt_val), opt_len);
127 : }
128 :
129 11427 : int Sock::GetSockName(sockaddr* name, socklen_t* name_len) const
130 : {
131 11427 : return getsockname(m_socket, name, name_len);
132 : }
133 :
134 11652 : bool Sock::SetNonBlocking() const
135 : {
136 : #ifdef WIN32
137 : u_long on{1};
138 : if (ioctlsocket(m_socket, FIONBIO, &on) == SOCKET_ERROR) {
139 : return false;
140 : }
141 : #else
142 11652 : const int flags{fcntl(m_socket, F_GETFL, 0)};
143 11652 : if (flags == SOCKET_ERROR) {
144 0 : return false;
145 : }
146 11652 : if (fcntl(m_socket, F_SETFL, flags | O_NONBLOCK) == SOCKET_ERROR) {
147 0 : return false;
148 : }
149 : #endif
150 11652 : return true;
151 11652 : }
152 :
153 18377 : bool Sock::IsSelectable(bool is_select) const
154 : {
155 18377 : return IsSelectableSocket(m_socket, is_select);
156 : }
157 :
158 6094 : bool Sock::Wait(std::chrono::milliseconds timeout, Event requested, SocketEventsParams event_params, Event* occurred) const
159 : {
160 6094 : EventsPerSock events_per_sock{std::make_pair(m_socket, Events{requested})};
161 :
162 : // We need to ensure we are only using a level-triggered mode because we are expecting
163 : // a direct correlation between the events reported and the one socket we are querying
164 12188 : if (auto [sem, _, __] = event_params; sem != SocketEventsMode::Poll && sem != SocketEventsMode::Select) {
165 : // We will use a compatible fallback events mode if we didn't specify a valid option
166 6094 : event_params = SocketEventsParams{
167 : #ifdef USE_POLL
168 : SocketEventsMode::Poll
169 : #else
170 : SocketEventsMode::Select
171 : #endif /* USE_POLL */
172 : };
173 6094 : }
174 6094 : if (!WaitMany(timeout, events_per_sock, event_params)) {
175 0 : return false;
176 : }
177 :
178 6094 : if (occurred != nullptr) {
179 6021 : *occurred = events_per_sock.begin()->second.occurred;
180 6021 : }
181 :
182 6094 : return true;
183 6094 : }
184 :
185 6094 : bool Sock::WaitMany(std::chrono::milliseconds timeout, EventsPerSock& events_per_sock, SocketEventsParams event_params) const
186 : {
187 6094 : return WaitManyInternal(timeout, events_per_sock, event_params);
188 0 : }
189 :
190 29790657 : bool Sock::WaitManyInternal(std::chrono::milliseconds timeout, EventsPerSock& events_per_sock, SocketEventsParams event_params)
191 : {
192 29790657 : switch (event_params.m_event_mode)
193 : {
194 : #ifdef USE_POLL
195 : case SocketEventsMode::Poll:
196 : return WaitManyPoll(timeout, events_per_sock, event_params.m_wrap_func);
197 : #endif /* USE_POLL */
198 : case SocketEventsMode::Select:
199 6094 : return WaitManySelect(timeout, events_per_sock, event_params.m_wrap_func);
200 : #ifdef USE_EPOLL
201 : case SocketEventsMode::EPoll:
202 : assert(event_params.m_event_fd != INVALID_SOCKET);
203 : return WaitManyEPoll(timeout, events_per_sock, event_params.m_event_fd, event_params.m_wrap_func);
204 : #endif /* USE_EPOLL */
205 : #ifdef USE_KQUEUE
206 : case SocketEventsMode::KQueue:
207 29784563 : assert(event_params.m_event_fd != INVALID_SOCKET);
208 29784563 : return WaitManyKQueue(timeout, events_per_sock, event_params.m_event_fd, event_params.m_wrap_func);
209 : #endif /* USE_KQUEUE */
210 : default:
211 0 : assert(false);
212 : }
213 29790657 : }
214 :
215 : #ifdef USE_EPOLL
216 : bool Sock::WaitManyEPoll(std::chrono::milliseconds timeout,
217 : EventsPerSock& events_per_sock,
218 : SOCKET epoll_fd,
219 : SocketEventsParams::wrap_fn wrap_func)
220 : {
221 : std::array<epoll_event, MAX_EVENTS> events{};
222 :
223 : int ret{SOCKET_ERROR};
224 : wrap_func([&](){
225 : ret = epoll_wait(epoll_fd, events.data(), events.size(), count_milliseconds(timeout));
226 : });
227 : if (ret == SOCKET_ERROR) {
228 : return false;
229 : }
230 :
231 : // Events reported do not correspond to sockets requested in edge-triggered modes, we will clear the
232 : // entire map before populating it with our events data.
233 : events_per_sock.clear();
234 :
235 : for (int idx = 0; idx < ret; idx++) {
236 : auto& ev = events[idx];
237 : Event occurred = 0;
238 : if (ev.events & (EPOLLERR | EPOLLHUP)) {
239 : occurred |= ERR;
240 : } else {
241 : if (ev.events & EPOLLIN) {
242 : occurred |= RECV;
243 : }
244 : if (ev.events & EPOLLOUT) {
245 : occurred |= SEND;
246 : }
247 : }
248 : events_per_sock.emplace(static_cast<SOCKET>(ev.data.fd), Sock::Events{/*req=*/RECV | SEND, occurred});
249 : }
250 :
251 : return true;
252 : }
253 : #endif /* USE_EPOLL */
254 :
255 : #ifdef USE_KQUEUE
256 29784563 : bool Sock::WaitManyKQueue(std::chrono::milliseconds timeout,
257 : EventsPerSock& events_per_sock,
258 : SOCKET kqueue_fd,
259 : SocketEventsParams::wrap_fn wrap_func)
260 : {
261 29784563 : std::array<struct kevent, MAX_EVENTS> events{};
262 29784563 : struct timespec ts = MillisToTimespec(timeout);
263 :
264 29784563 : int ret{SOCKET_ERROR};
265 59569126 : wrap_func([&](){
266 29784563 : ret = kevent(kqueue_fd, nullptr, 0, events.data(), events.size(), &ts);
267 29784563 : });
268 29784563 : if (ret == SOCKET_ERROR) {
269 0 : return false;
270 : }
271 :
272 : // Events reported do not correspond to sockets requested in edge-triggered modes, we will clear the
273 : // entire map before populating it with our events data.
274 29784563 : events_per_sock.clear();
275 :
276 59734339 : for (int idx = 0; idx < ret; idx++) {
277 29949776 : auto& ev = events[idx];
278 29949776 : Event occurred = 0;
279 29949776 : if (ev.flags & (EV_ERROR | EV_EOF)) {
280 9099 : occurred |= ERR;
281 9099 : } else {
282 29940677 : if (ev.filter == EVFILT_READ) {
283 29143413 : occurred |= RECV;
284 29143413 : }
285 29940677 : if (ev.filter == EVFILT_WRITE) {
286 797264 : occurred |= SEND;
287 797264 : }
288 : }
289 29949776 : if (auto it = events_per_sock.find(static_cast<SOCKET>(ev.ident)); it != events_per_sock.end()) {
290 35622 : it->second.occurred |= occurred;
291 35622 : } else {
292 29914154 : events_per_sock.emplace(static_cast<SOCKET>(ev.ident), Sock::Events{/*req=*/RECV | SEND, occurred});
293 : }
294 29949776 : }
295 :
296 29784563 : return true;
297 29784563 : }
298 : #endif /* USE_KQUEUE */
299 :
300 : #ifdef USE_POLL
301 : bool Sock::WaitManyPoll(std::chrono::milliseconds timeout,
302 : EventsPerSock& events_per_sock,
303 : SocketEventsParams::wrap_fn wrap_func)
304 : {
305 : if (events_per_sock.empty()) return true;
306 :
307 : std::vector<pollfd> pfds;
308 : for (const auto& [socket, events] : events_per_sock) {
309 : pfds.emplace_back();
310 : auto& pfd = pfds.back();
311 : pfd.fd = socket;
312 : if (events.requested & RECV) {
313 : pfd.events |= POLLIN;
314 : }
315 : if (events.requested & SEND) {
316 : pfd.events |= POLLOUT;
317 : }
318 : }
319 :
320 : int ret{SOCKET_ERROR};
321 : wrap_func([&](){
322 : ret = poll(pfds.data(), pfds.size(), count_milliseconds(timeout));
323 : });
324 : if (ret == SOCKET_ERROR) {
325 : return false;
326 : }
327 :
328 : assert(pfds.size() == events_per_sock.size());
329 : size_t i{0};
330 : for (auto& [socket, events] : events_per_sock) {
331 : assert(socket == static_cast<SOCKET>(pfds[i].fd));
332 : events.occurred = 0;
333 : if (pfds[i].revents & POLLIN) {
334 : events.occurred |= RECV;
335 : }
336 : if (pfds[i].revents & POLLOUT) {
337 : events.occurred |= SEND;
338 : }
339 : if (pfds[i].revents & (POLLERR | POLLHUP)) {
340 : events.occurred |= ERR;
341 : }
342 : ++i;
343 : }
344 :
345 : return true;
346 : }
347 : #endif /* USE_POLL */
348 :
349 6094 : bool Sock::WaitManySelect(std::chrono::milliseconds timeout,
350 : EventsPerSock& events_per_sock,
351 : SocketEventsParams::wrap_fn wrap_func)
352 : {
353 6094 : if (events_per_sock.empty()) return true;
354 :
355 : fd_set recv;
356 : fd_set send;
357 : fd_set err;
358 6094 : FD_ZERO(&recv);
359 6094 : FD_ZERO(&send);
360 6094 : FD_ZERO(&err);
361 6094 : SOCKET socket_max{0};
362 :
363 18209 : for (const auto& [sock, events] : events_per_sock) {
364 6094 : if (!IsSelectableSocket(sock, /*is_select=*/true)) {
365 0 : return false;
366 : }
367 6094 : const auto& s = sock;
368 6094 : if (events.requested & RECV) {
369 6094 : FD_SET(s, &recv);
370 6094 : }
371 6094 : if (events.requested & SEND) {
372 6021 : FD_SET(s, &send);
373 6021 : }
374 6094 : FD_SET(s, &err);
375 6094 : socket_max = std::max(socket_max, s);
376 : }
377 :
378 6094 : timeval tv = MillisToTimeval(timeout);
379 :
380 6094 : int ret{SOCKET_ERROR};
381 12188 : wrap_func([&](){
382 6094 : ret = select(socket_max + 1, &recv, &send, &err, &tv);
383 6094 : });
384 6094 : if (ret == SOCKET_ERROR) {
385 0 : return false;
386 : }
387 :
388 18180 : for (auto& [sock, events] : events_per_sock) {
389 6094 : const auto& s = sock;
390 6094 : events.occurred = 0;
391 6094 : if (FD_ISSET(s, &recv)) {
392 453 : events.occurred |= RECV;
393 453 : }
394 6094 : if (FD_ISSET(s, &send)) {
395 5992 : events.occurred |= SEND;
396 5992 : }
397 6094 : if (FD_ISSET(s, &err)) {
398 0 : events.occurred |= ERR;
399 0 : }
400 : }
401 :
402 6094 : return true;
403 6094 : }
404 :
405 34 : void Sock::SendComplete(const std::string& data,
406 : std::chrono::milliseconds timeout,
407 : CThreadInterrupt& interrupt) const
408 : {
409 34 : const auto deadline = GetTime<std::chrono::milliseconds>() + timeout;
410 34 : size_t sent{0};
411 :
412 34 : for (;;) {
413 34 : const ssize_t ret{Send(data.data() + sent, data.size() - sent, MSG_NOSIGNAL)};
414 :
415 34 : if (ret > 0) {
416 34 : sent += static_cast<size_t>(ret);
417 34 : if (sent == data.size()) {
418 34 : break;
419 : }
420 0 : } else {
421 0 : const int err{WSAGetLastError()};
422 0 : if (IOErrorIsPermanent(err)) {
423 0 : throw std::runtime_error(strprintf("send(): %s", NetworkErrorString(err)));
424 : }
425 : }
426 :
427 0 : const auto now = GetTime<std::chrono::milliseconds>();
428 :
429 0 : if (now >= deadline) {
430 0 : throw std::runtime_error(strprintf(
431 0 : "Send timeout (sent only %u of %u bytes before that)", sent, data.size()));
432 : }
433 :
434 0 : if (interrupt) {
435 0 : throw std::runtime_error(strprintf(
436 0 : "Send interrupted (sent only %u of %u bytes before that)", sent, data.size()));
437 : }
438 :
439 : // Wait for a short while (or the socket to become ready for sending) before retrying
440 : // if nothing was sent.
441 0 : const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO});
442 0 : (void)Wait(wait_time, SEND, SocketEventsParams{::g_socket_events_mode});
443 : }
444 34 : }
445 :
446 38 : std::string Sock::RecvUntilTerminator(uint8_t terminator,
447 : std::chrono::milliseconds timeout,
448 : CThreadInterrupt& interrupt,
449 : size_t max_data) const
450 : {
451 38 : const auto deadline = GetTime<std::chrono::milliseconds>() + timeout;
452 38 : std::string data;
453 38 : bool terminator_found{false};
454 :
455 : // We must not consume any bytes past the terminator from the socket.
456 : // One option is to read one byte at a time and check if we have read a terminator.
457 : // However that is very slow. Instead, we peek at what is in the socket and only read
458 : // as many bytes as possible without crossing the terminator.
459 : // Reading 64 MiB of random data with 262526 terminator chars takes 37 seconds to read
460 : // one byte at a time VS 0.71 seconds with the "peek" solution below. Reading one byte
461 : // at a time is about 50 times slower.
462 :
463 169 : for (;;) {
464 169 : if (data.size() >= max_data) {
465 2 : throw std::runtime_error(
466 2 : strprintf("Received too many bytes without a terminator (%u)", data.size()));
467 : }
468 :
469 : char buf[512];
470 :
471 167 : const ssize_t peek_ret{Recv(buf, std::min(sizeof(buf), max_data - data.size()), MSG_PEEK)};
472 :
473 167 : switch (peek_ret) {
474 : case -1: {
475 0 : const int err{WSAGetLastError()};
476 0 : if (IOErrorIsPermanent(err)) {
477 0 : throw std::runtime_error(strprintf("recv(): %s", NetworkErrorString(err)));
478 : }
479 0 : break;
480 : }
481 : case 0:
482 0 : throw std::runtime_error("Connection unexpectedly closed by peer");
483 : default:
484 167 : auto end = buf + peek_ret;
485 167 : auto terminator_pos = std::find(buf, end, terminator);
486 167 : terminator_found = terminator_pos != end;
487 :
488 167 : const size_t try_len{terminator_found ? terminator_pos - buf + 1 :
489 131 : static_cast<size_t>(peek_ret)};
490 :
491 167 : const ssize_t read_ret{Recv(buf, try_len, 0)};
492 :
493 167 : if (read_ret < 0 || static_cast<size_t>(read_ret) != try_len) {
494 0 : throw std::runtime_error(
495 0 : strprintf("recv() returned %u bytes on attempt to read %u bytes but previous "
496 : "peek claimed %u bytes are available",
497 : read_ret, try_len, peek_ret));
498 : }
499 :
500 : // Don't include the terminator in the output.
501 167 : const size_t append_len{terminator_found ? try_len - 1 : try_len};
502 :
503 167 : data.append(buf, buf + append_len);
504 :
505 167 : if (terminator_found) {
506 36 : return data;
507 : }
508 131 : }
509 :
510 131 : const auto now = GetTime<std::chrono::milliseconds>();
511 :
512 131 : if (now >= deadline) {
513 0 : throw std::runtime_error(strprintf(
514 0 : "Receive timeout (received %u bytes without terminator before that)", data.size()));
515 : }
516 :
517 131 : if (interrupt) {
518 0 : throw std::runtime_error(strprintf(
519 : "Receive interrupted (received %u bytes without terminator before that)",
520 0 : data.size()));
521 : }
522 :
523 : // Wait for a short while (or the socket to become ready for reading) before retrying.
524 131 : const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO});
525 131 : (void)Wait(wait_time, RECV, SocketEventsParams{::g_socket_events_mode});
526 : }
527 40 : }
528 :
529 0 : bool Sock::IsConnected(std::string& errmsg) const
530 : {
531 0 : if (m_socket == INVALID_SOCKET) {
532 0 : errmsg = "not connected";
533 0 : return false;
534 : }
535 :
536 : char c;
537 0 : switch (Recv(&c, sizeof(c), MSG_PEEK)) {
538 : case -1: {
539 0 : const int err = WSAGetLastError();
540 0 : if (IOErrorIsPermanent(err)) {
541 0 : errmsg = NetworkErrorString(err);
542 0 : return false;
543 : }
544 0 : return true;
545 : }
546 : case 0:
547 0 : errmsg = "closed";
548 0 : return false;
549 : default:
550 0 : return true;
551 : }
552 0 : }
553 :
554 19171 : void Sock::Close()
555 : {
556 19171 : if (m_socket == INVALID_SOCKET) {
557 866 : return;
558 : }
559 : #ifdef WIN32
560 : int ret = closesocket(m_socket);
561 : #else
562 18305 : int ret = close(m_socket);
563 : #endif
564 18305 : if (ret) {
565 0 : LogPrintf("Error closing socket %d: %s\n", m_socket, NetworkErrorString(WSAGetLastError()));
566 0 : }
567 18305 : m_socket = INVALID_SOCKET;
568 19171 : }
569 :
570 493 : std::string NetworkErrorString(int err)
571 : {
572 : #if defined(WIN32)
573 : return Win32ErrorString(err);
574 : #else
575 : // On BSD sockets implementations, NetworkErrorString is the same as SysErrorString.
576 493 : return SysErrorString(err);
577 : #endif
578 : }
|