35#if !defined(OPENSSL_NO_DEPRECATED)
36 #define OPENSSL_NO_DEPRECATED
39#include <openssl/err.h>
40#include <openssl/ssl.h>
43 #if defined(EWOULDBLOCK)
46 #if defined(SHUT_RDWR)
64 #define EWOULDBLOCK WSAEWOULDBLOCK
65 #define SHUT_RDWR SD_BOTH
66 #define pollfd WSAPOLLFD
67 #define connect(x, y, z) WSAConnect(x, y, z, nullptr, nullptr, nullptr, nullptr)
68 #define errno WSAGetLastError()
69 #define close closesocket
70 #define poll(x, y, z) WSAPoll(x, y, z)
71 #pragma comment(lib, "Ws2_32.lib")
74DCA_INLINE
bool isValidSocket(SOCKET s) {
75 return s != INVALID_SOCKET;
79using SOCKET = int32_t;
80DCA_INLINE
bool isValidSocket(SOCKET s) {
83 #include <netinet/tcp.h>
84 #include <netinet/in.h>
85 #include <sys/socket.h>
86 #include <sys/types.h>
87 #include <arpa/inet.h>
95#if !defined(SOCKET_ERROR)
96 #define SOCKET_ERROR SOCKET(-1)
99#if !defined(INVALID_SOCKET)
100 #define INVALID_SOCKET (-1)
105 namespace discord_core_internal {
107 enum class connection_status {
109 CONNECTION_Error = 1,
118 DCA_INLINE jsonifier::string reportSSLError(jsonifier::string_view errorPosition, int32_t errorValue = 0, SSL* SSL =
nullptr) {
119 std::stringstream
stream{};
120 stream << errorPosition <<
" error: ";
124 stream << ERR_error_string(ERR_get_error(),
nullptr);
126 return jsonifier::string{
stream.str() };
129 DCA_INLINE jsonifier::string reportError(jsonifier::string_view errorPosition) {
130 std::stringstream
stream{};
131 stream << errorPosition <<
" error: ";
138 FormatMessage(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
nullptr,
static_cast<DWORD
>(WSAGetLastError()), MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
139 static_cast<LPTSTR
>(
string.get()), 1024,
nullptr);
140 stream << WSAGetLastError() <<
", " << string;
142 stream << strerror(errno);
144 return jsonifier::string{
stream.str() };
148 struct wsadata_wrapper {
149 struct wsadata_deleter {
150 DCA_INLINE
void operator()(WSADATA* other) {
156 DCA_INLINE wsadata_wrapper() {
157 auto returnData = WSAStartup(MAKEWORD(2, 2), ptr.get());
159 message_printer::printError<print_message_type::general>(reportError(
"wsadata_wrapper::wsadata_wrapper()").data());
164 unique_ptr<WSADATA, wsadata_deleter> ptr{ makeUnique<WSADATA, wsadata_deleter>() };
168 struct poll_fd_wrapper {
169 jsonifier::vector<uint64_t> indices{};
170 jsonifier::vector<pollfd> polls{};
173 struct ssl_ctx_wrapper {
174 struct ssl_ctx_deleter {
175 DCA_INLINE
void operator()(SSL_CTX* other) {
183 DCA_INLINE ssl_ctx_wrapper& operator=(SSL_CTX* other) {
188 DCA_INLINE
operator SSL_CTX*() {
193 unique_ptr<SSL_CTX, ssl_ctx_deleter> ptr{};
199 DCA_INLINE
void operator()(SSL* other) {
208 DCA_INLINE ssl_wrapper() =
default;
210 DCA_INLINE ssl_wrapper& operator=(ssl_wrapper&& other)
noexcept {
211 ptr = std::move(other.ptr);
215 DCA_INLINE ssl_wrapper(ssl_wrapper&& other)
noexcept {
216 *
this = std::move(other);
219 DCA_INLINE ssl_wrapper& operator=(SSL* other) {
224 DCA_INLINE
explicit operator bool() {
225 return ptr.operator bool();
228 DCA_INLINE
operator SSL*() {
233 unique_ptr<SSL, ssl_deleter> ptr{};
236 class socket_wrapper {
238 struct socket_deleter {
239 DCA_INLINE
void operator()(SOCKET* ptrNew) {
240 if (ptrNew && *ptrNew != INVALID_SOCKET) {
241 shutdown(*ptrNew, SHUT_RDWR);
243 *ptrNew = INVALID_SOCKET;
250 DCA_INLINE socket_wrapper() =
default;
252 DCA_INLINE socket_wrapper& operator=(socket_wrapper&& other)
noexcept {
253 ptr = std::move(other.ptr);
257 DCA_INLINE socket_wrapper(socket_wrapper&& other)
noexcept {
258 *
this = std::move(other);
261 DCA_INLINE socket_wrapper& operator=(SOCKET other) {
262 ptr.reset(
new SOCKET{ other });
266 DCA_INLINE socket_wrapper(SOCKET other) {
270 DCA_INLINE
explicit operator bool() {
271 return ptr.operator bool();
274 DCA_INLINE
operator SOCKET() {
275 if (ptr.operator
bool()) {
278 return INVALID_SOCKET;
283 unique_ptr<SOCKET, socket_deleter> ptr{};
286 struct addrinfo_wrapper {
287 DCA_INLINE addrinfo* operator->() {
291 DCA_INLINE
operator addrinfo**() {
295 DCA_INLINE
operator addrinfo*() {
301 addrinfo* ptr{ &value };
304 class ssl_context_holder {
306 DCA_INLINE
static ssl_ctx_wrapper context{};
307 DCA_INLINE
static std::mutex accessMutex{};
309 DCA_INLINE
static bool initialize() {
310 if (ssl_context_holder::context = SSL_CTX_new(TLS_client_method()); !ssl_context_holder::context) {
314 if (!SSL_CTX_set_min_proto_version(ssl_context_holder::context, TLS1_2_VERSION)) {
318#if defined(SSL_OP_IGNORE_UNEXPECTED_EOF)
319 auto originalOptions{ SSL_CTX_get_options(ssl_context_holder::context) | SSL_OP_IGNORE_UNEXPECTED_EOF };
320 if (SSL_CTX_set_options(ssl_context_holder::context, SSL_OP_IGNORE_UNEXPECTED_EOF) != originalOptions) {
328 template<
typename value_type>
class ssl_data_interface {
330 template<
typename value_type2>
friend class tcp_connection;
331 friend class https_client;
333 ssl_data_interface& operator=(ssl_data_interface<value_type>&& other)
noexcept {
334 outputBuffer = std::move(other.outputBuffer);
335 inputBuffer = std::move(other.inputBuffer);
336 bytesRead = other.bytesRead;
340 ssl_data_interface(ssl_data_interface<value_type>&& other)
noexcept {
341 *
this = std::move(other);
344 template<
typename value_type_new> DCA_INLINE
void writeData(jsonifier::string_view_base<value_type_new> dataToWrite,
bool priority) {
345 if (
static_cast<value_type*
>(
this)->areWeStillConnected()) {
346 if (dataToWrite.size() > 0 &&
static_cast<value_type*
>(
this)->ssl) {
347 if (priority && dataToWrite.size() < maxBufferSize) {
348 outputBuffer.clear();
349 outputBuffer.writeData(dataToWrite.data(), dataToWrite.size());
350 static_cast<value_type*
>(
this)->processWriteData();
352 uint64_t remainingBytes{ dataToWrite.size() };
353 while (remainingBytes > 0) {
354 uint64_t amountToCollect{ dataToWrite.size() >= maxBufferSize ? maxBufferSize : dataToWrite.size() };
355 outputBuffer.writeData(dataToWrite.data(), amountToCollect);
356 dataToWrite = jsonifier::string_view_base{ dataToWrite.data() + amountToCollect, dataToWrite.size() - amountToCollect };
357 remainingBytes = dataToWrite.size();
367 DCA_INLINE
auto getInputBuffer() {
368 return inputBuffer.readData();
371 DCA_INLINE int64_t getBytesRead() {
375 DCA_INLINE
void reset() {
376 outputBuffer.clear();
382 const uint64_t maxBufferSize{ (1024 * 16) };
383 ring_buffer<uint8_t, 16> outputBuffer{};
384 ring_buffer<uint8_t, 64> inputBuffer{};
387 DCA_INLINE ssl_data_interface() =
default;
389 virtual ~ssl_data_interface() =
default;
392 template<
typename value_type>
class tcp_connection :
public ssl_data_interface<tcp_connection<value_type>> {
394 connection_status currentStatus{ connection_status::NO_Error };
395 socket_wrapper socket{};
396 bool writeWantWrite{};
397 bool writeWantRead{};
398 bool readWantWrite{};
402 tcp_connection& operator=(tcp_connection&& other) =
default;
403 tcp_connection(tcp_connection&& other) =
default;
404 tcp_connection& operator=(
const tcp_connection& other) =
default;
405 tcp_connection(
const tcp_connection& other) =
default;
407 DCA_INLINE tcp_connection(
const jsonifier::string& baseUrlNew,
const uint16_t portNew) {
408 jsonifier::string addressString{};
409 auto httpsFind = baseUrlNew.find(
"https://");
410 auto comFind = baseUrlNew.find(
".com");
411 auto orgFind = baseUrlNew.find(
".org");
412 if (httpsFind != jsonifier::string::npos && comFind != jsonifier::string::npos) {
413 addressString = baseUrlNew.substr(httpsFind + jsonifier::string_view{
"https://" }.size(),
414 comFind + jsonifier::string_view{
".com" }.size() - jsonifier::string_view{
"https://" }.size());
415 }
else if (httpsFind != jsonifier::string::npos && orgFind != jsonifier::string::npos) {
416 addressString = baseUrlNew.substr(httpsFind + jsonifier::string_view{
"https://" }.size(),
417 orgFind + jsonifier::string_view{
".org" }.size() - jsonifier::string_view{
"https://" }.size());
419 addressString = baseUrlNew;
421 addrinfo_wrapper hints{}, address{};
422 hints->ai_family = AF_INET;
423 hints->ai_socktype = SOCK_STREAM;
424 hints->ai_protocol = IPPROTO_TCP;
426 if (getaddrinfo(addressString.data(), jsonifier::toString(portNew).data(), hints, address)) {
428 currentStatus = connection_status::CONNECTION_Error;
429 socket = INVALID_SOCKET;
433 if (socket = ::socket(address->ai_family, address->ai_socktype, address->ai_protocol); !isValidSocket(socket.operator SOCKET())) {
435 currentStatus = connection_status::CONNECTION_Error;
436 socket = INVALID_SOCKET;
440 if (
::connect(socket, address->ai_addr,
static_cast<int32_t
>(address->ai_addrlen)) == SOCKET_ERROR) {
442 currentStatus = connection_status::CONNECTION_Error;
443 socket = INVALID_SOCKET;
447 std::unique_lock lock{ ssl_context_holder::accessMutex };
448 if (ssl = SSL_new(ssl_context_holder::context); !ssl) {
450 reportSSLError(
"Tcp_connection::connect::SSL_new(), to: " + baseUrlNew) +
"\n" + reportError(
"Tcp_connection::connect::SSL_new(), to: " + baseUrlNew));
451 currentStatus = connection_status::CONNECTION_Error;
452 socket = INVALID_SOCKET;
458 if (
auto result{ SSL_set_fd(ssl,
static_cast<int32_t
>(socket)) }; result != 1) {
460 reportError(
"Tcp_connection::connect::SSL_set_fd(), to: " + baseUrlNew));
461 currentStatus = connection_status::CONNECTION_Error;
462 socket = INVALID_SOCKET;
468 if (
auto result{ SSL_set_tlsext_host_name(ssl, addressString.data()) }; result != 1) {
470 reportError(
"Tcp_connection::connect::SSL_set_tlsext_host_name(), to: " + baseUrlNew));
471 currentStatus = connection_status::CONNECTION_Error;
472 socket = INVALID_SOCKET;
477 if (
auto result{ SSL_connect(ssl) }; result != 1) {
479 reportError(
"Tcp_connection::connect::SSL_connect(), to: " + baseUrlNew));
480 currentStatus = connection_status::CONNECTION_Error;
481 socket = INVALID_SOCKET;
488 if (
auto returnData{ ioctlsocket(socket, FIONBIO, &value02) }; returnData == SOCKET_ERROR) {
490 currentStatus = connection_status::CONNECTION_Error;
491 socket = INVALID_SOCKET;
496 if (
auto returnData{ fcntl(socket, F_SETFL, fcntl(socket, F_GETFL, 0) | O_NONBLOCK) }; returnData == SOCKET_ERROR) {
498 currentStatus = connection_status::CONNECTION_Error;
499 socket = INVALID_SOCKET;
504 currentStatus = connection_status::NO_Error;
507 DCA_INLINE connection_status processIO(int32_t waitTimeInMs) {
508 if (!areWeStillConnected()) {
509 return currentStatus;
511 pollfd readWriteSet{};
512 readWriteSet.fd =
static_cast<SOCKET
>(socket);
513 if (writeWantRead || readWantRead) {
514 readWriteSet.events = POLLIN;
515 }
else if (writeWantWrite || readWantWrite) {
516 readWriteSet.events = POLLOUT;
517 }
else if (
static_cast<value_type*
>(
this)->outputBuffer.getUsedSpace() > 0) {
518 readWriteSet.events = POLLIN | POLLOUT;
520 readWriteSet.events = POLLIN;
522 if (
auto returnValue = poll(&readWriteSet, 1, waitTimeInMs); returnValue == SOCKET_ERROR) {
524 reportSSLError(
"Tcp_connection::processIO() 00") +
"\n" + reportError(
"Tcp_connection::processIO() 00"));
525 socket = INVALID_SOCKET;
527 currentStatus = connection_status::SOCKET_Error;
528 return currentStatus;
529 }
else if (returnValue == 0) {
530 return currentStatus;
532 if (readWriteSet.revents & POLLOUT || (POLLIN && writeWantRead)) {
533 if (!processWriteData()) {
535 reportSSLError(
"Tcp_connection::processIO() 01") +
"\n" + reportError(
"Tcp_connection::processIO() 01"));
536 currentStatus = connection_status::WRITE_Error;
537 socket = INVALID_SOCKET;
539 return currentStatus;
542 if (readWriteSet.revents & POLLIN || (POLLOUT && readWantWrite)) {
543 if (!processReadData()) {
545 reportSSLError(
"Tcp_connection::processIO() 02") +
"\n" + reportError(
"Tcp_connection::processIO() 02"));
546 currentStatus = connection_status::READ_Error;
547 socket = INVALID_SOCKET;
549 return currentStatus;
552 if (readWriteSet.revents & POLLERR) {
554 reportSSLError(
"Tcp_connection::processIO() 03") +
"\n" + reportError(
"Tcp_connection::processIO() 03"));
555 currentStatus = connection_status::POLLERR_Error;
556 socket = INVALID_SOCKET;
559 if (readWriteSet.revents & POLLNVAL) {
561 reportSSLError(
"Tcp_connection::processIO() 04") +
"\n" + reportError(
"Tcp_connection::processIO() 04"));
562 currentStatus = connection_status::POLLNVAL_Error;
563 socket = INVALID_SOCKET;
566 if (readWriteSet.revents & POLLHUP) {
567 currentStatus = connection_status::POLLHUP_Error;
568 socket = INVALID_SOCKET;
572 return currentStatus;
575 DCA_INLINE
bool areWeStillConnected() {
576 if (socket.operator
bool() && socket.operator SOCKET() != INVALID_SOCKET && currentStatus == connection_status::NO_Error && ssl.operator
bool()) {
579 fdEvent.events = POLLOUT;
580 int32_t result = poll(&fdEvent, 1, 1);
581 if (result == SOCKET_ERROR || fdEvent.revents & POLLHUP || fdEvent.revents & POLLNVAL || fdEvent.revents & POLLERR) {
582 socket = INVALID_SOCKET;
592 DCA_INLINE
bool processWriteData() {
593 writeWantRead =
false;
594 writeWantWrite =
false;
595 if (
static_cast<value_type*
>(
this)->outputBuffer.getUsedSpace() > 0 && areWeStillConnected()) {
596 uint64_t bytesToWrite{
static_cast<value_type*
>(
this)->outputBuffer.getCurrentTail()->getUsedSpace() };
598 size_t writtenBytes{};
599 auto returnData{ SSL_write_ex(ssl,
static_cast<value_type*
>(
this)->outputBuffer.readData().data(), bytesToWrite, &writtenBytes) };
600 auto errorValue{ SSL_get_error(ssl, returnData) };
601 switch (errorValue) {
602 case SSL_ERROR_WANT_READ: {
603 writeWantRead =
true;
606 case SSL_ERROR_WANT_WRITE: {
607 writeWantWrite =
true;
610 case SSL_ERROR_NONE: {
613 case SSL_ERROR_ZERO_RETURN: {
614 socket = INVALID_SOCKET;
626 DCA_INLINE
bool processReadData() {
627 readWantRead =
false;
628 readWantWrite =
false;
629 if (!
static_cast<value_type*
>(
this)->inputBuffer.isItFull() && areWeStillConnected()) {
632 uint64_t bytesToRead{
static_cast<value_type*
>(
this)->maxBufferSize };
633 auto returnData{ SSL_read_ex(ssl,
static_cast<value_type*
>(
this)->inputBuffer.getCurrentHead()->getCurrentHead(), bytesToRead, &readBytes) };
634 auto errorValue{ SSL_get_error(ssl, returnData) };
635 if (
static_cast<int64_t
>(readBytes) > 0) {
636 static_cast<value_type*
>(
this)->inputBuffer.getCurrentHead()->modifyReadOrWritePosition(ring_buffer_access_type::write, readBytes);
637 static_cast<value_type*
>(
this)->inputBuffer.modifyReadOrWritePosition(ring_buffer_access_type::write, 1);
638 static_cast<value_type*
>(
this)->bytesRead += readBytes;
639 static_cast<value_type*
>(
this)->handleBuffer();
641 switch (errorValue) {
642 case SSL_ERROR_WANT_READ: {
646 case SSL_ERROR_WANT_WRITE: {
647 readWantWrite =
true;
650 case SSL_ERROR_NONE: {
653 case SSL_ERROR_ZERO_RETURN: {
654 socket = INVALID_SOCKET;
662 }
while (areWeStillConnected() && SSL_pending(ssl) && !
static_cast<value_type*
>(
this)->inputBuffer.isItFull() && !readWantRead);
667 template<
typename value_type2> DCA_INLINE
static unordered_map<uint64_t, value_type2*> processIO(unordered_map<uint64_t, value_type2*>& shardMap) {
668 unordered_map<uint64_t, value_type2*> returnData{};
669 poll_fd_wrapper readWriteSet{};
670 for (
auto& [key, value]: shardMap) {
671 if (value->areWeStillConnected()) {
673 fdSet.fd =
static_cast<SOCKET
>(value->socket);
674 if (value->writeWantRead || value->readWantRead) {
675 fdSet.events = POLLIN;
676 }
else if (value->writeWantWrite || value->readWantWrite) {
677 fdSet.events = POLLOUT;
678 }
else if (value->outputBuffer.getUsedSpace() > 0) {
679 fdSet.events = POLLIN | POLLOUT;
681 fdSet.events = POLLIN;
683 readWriteSet.indices.emplace_back(key);
684 readWriteSet.polls.emplace_back(fdSet);
686 returnData.emplace(key, value);
690 if (readWriteSet.polls.size() == 0) {
693 if (
auto returnDataNew = poll(readWriteSet.polls.data(),
static_cast<u_long
>(readWriteSet.polls.size()), 1); returnDataNew == SOCKET_ERROR) {
694 bool didWeFindTheSocket{};
695 for (uint64_t x = 0; x < readWriteSet.polls.size(); ++x) {
696 if (readWriteSet.polls.at(x).revents & POLLERR || readWriteSet.polls.at(x).revents & POLLHUP || readWriteSet.polls.at(x).revents & POLLNVAL) {
697 shardMap.at(readWriteSet.indices.at(x))->currentStatus = connection_status::SOCKET_Error;
698 returnData.emplace(readWriteSet.indices.at(x), shardMap.at(readWriteSet.indices.at(x)));
699 readWriteSet.indices.erase(readWriteSet.indices.begin() +
static_cast<int64_t
>(x));
700 readWriteSet.polls.erase(readWriteSet.polls.begin() +
static_cast<int64_t
>(x));
701 didWeFindTheSocket =
true;
704 if (!didWeFindTheSocket) {
705 for (uint64_t x = 0; x < readWriteSet.polls.size(); ++x) {
706 shardMap.at(readWriteSet.indices.at(x))->currentStatus = connection_status::SOCKET_Error;
707 returnData.emplace(readWriteSet.indices.at(x), shardMap.at(readWriteSet.indices.at(x)));
712 }
else if (returnDataNew == 0) {
715 for (uint64_t x = 0; x < readWriteSet.polls.size(); ++x) {
716 if (readWriteSet.polls.at(x).revents & POLLOUT || (POLLIN && shardMap.at(readWriteSet.indices.at(x))->writeWantRead)) {
717 if (!shardMap.at(readWriteSet.indices.at(x))->processWriteData()) {
718 shardMap.at(readWriteSet.indices.at(x))->currentStatus = connection_status::WRITE_Error;
719 returnData.emplace(readWriteSet.indices.at(x), shardMap.at(readWriteSet.indices.at(x)));
723 if (readWriteSet.polls.at(x).revents & POLLIN || (POLLOUT && shardMap.at(readWriteSet.indices.at(x))->readWantWrite)) {
724 if (!shardMap.at(readWriteSet.indices.at(x))->processReadData()) {
725 shardMap.at(readWriteSet.indices.at(x))->currentStatus = connection_status::READ_Error;
726 returnData.emplace(readWriteSet.indices.at(x), shardMap.at(readWriteSet.indices.at(x)));
730 if (readWriteSet.polls.at(x).revents & POLLERR) {
731 shardMap.at(readWriteSet.indices.at(x))->currentStatus = connection_status::POLLERR_Error;
732 shardMap.at(readWriteSet.indices.at(x))->socket = INVALID_SOCKET;
733 shardMap.at(readWriteSet.indices.at(x))->ssl =
nullptr;
734 returnData.emplace(readWriteSet.indices.at(x), shardMap.at(readWriteSet.indices.at(x)));
737 if (readWriteSet.polls.at(x).revents & POLLNVAL) {
738 shardMap.at(readWriteSet.indices.at(x))->currentStatus = connection_status::POLLNVAL_Error;
739 shardMap.at(readWriteSet.indices.at(x))->socket = INVALID_SOCKET;
740 shardMap.at(readWriteSet.indices.at(x))->ssl =
nullptr;
741 returnData.emplace(readWriteSet.indices.at(x), shardMap.at(readWriteSet.indices.at(x)));
744 if (readWriteSet.polls.at(x).revents & POLLHUP) {
745 shardMap.at(readWriteSet.indices.at(x))->currentStatus = connection_status::POLLHUP_Error;
746 shardMap.at(readWriteSet.indices.at(x))->socket = INVALID_SOCKET;
747 shardMap.at(readWriteSet.indices.at(x))->ssl =
nullptr;
748 returnData.emplace(readWriteSet.indices.at(x), shardMap.at(readWriteSet.indices.at(x)));
755 virtual DCA_INLINE
void handleBuffer() = 0;
757 DCA_INLINE
void disconnect() {
758 currentStatus = connection_status::CONNECTION_Error;
759 static_cast<value_type*
>(
this)->reset();
760 socket = INVALID_SOCKET;
764 virtual DCA_INLINE ~tcp_connection() =
default;
767 DCA_INLINE tcp_connection() =
default;
static DCA_INLINE void printError(const string_type &what, std::source_location where=std::source_location::current())
Print an error message of the specified type.
@ connect
Allows for joining of a voice channel.
@ stream
Allows the user to go live.
DCA_INLINE unique_ptr< value_type, deleter > makeUnique(arg_types &&... args)
Helper function to create a unique_ptr for a non-array object.
The main namespace for the forward-facing interfaces.