aboutsummaryrefslogtreecommitdiffhomepage
path: root/node_modules/uws/src/Socket.h
diff options
context:
space:
mode:
Diffstat (limited to 'node_modules/uws/src/Socket.h')
-rw-r--r--node_modules/uws/src/Socket.h507
1 files changed, 507 insertions, 0 deletions
diff --git a/node_modules/uws/src/Socket.h b/node_modules/uws/src/Socket.h
new file mode 100644
index 0000000..2179ff8
--- /dev/null
+++ b/node_modules/uws/src/Socket.h
@@ -0,0 +1,507 @@
+#ifndef SOCKET_UWS_H
+#define SOCKET_UWS_H
+
+#include "Networking.h"
+
+namespace uS {
+
+struct TransferData {
+ // Connection state
+ uv_os_sock_t fd;
+ SSL *ssl;
+
+ // Poll state
+ void (*pollCb)(Poll *, int, int);
+ int pollEvents;
+
+ // User state
+ void *userData;
+
+ // Destination
+ NodeData *destination;
+ void (*transferCb)(Poll *);
+};
+
+// perfectly 64 bytes (4 + 60)
+struct WIN32_EXPORT Socket : Poll {
+protected:
+ struct {
+ int poll : 4;
+ int shuttingDown : 4;
+ } state = {0, false};
+
+ SSL *ssl;
+ void *user = nullptr;
+ NodeData *nodeData;
+
+ // this is not needed by HttpSocket!
+ struct Queue {
+ struct Message {
+ const char *data;
+ size_t length;
+ Message *nextMessage = nullptr;
+ void (*callback)(void *socket, void *data, bool cancelled, void *reserved) = nullptr;
+ void *callbackData = nullptr, *reserved = nullptr;
+ };
+
+ Message *head = nullptr, *tail = nullptr;
+ void pop()
+ {
+ Message *nextMessage;
+ if ((nextMessage = head->nextMessage)) {
+ delete [] (char *) head;
+ head = nextMessage;
+ } else {
+ delete [] (char *) head;
+ head = tail = nullptr;
+ }
+ }
+
+ bool empty() {return head == nullptr;}
+ Message *front() {return head;}
+
+ void push(Message *message)
+ {
+ message->nextMessage = nullptr;
+ if (tail) {
+ tail->nextMessage = message;
+ tail = message;
+ } else {
+ head = message;
+ tail = message;
+ }
+ }
+ } messageQueue;
+
+ int getPoll() {
+ return state.poll;
+ }
+
+ int setPoll(int poll) {
+ state.poll = poll;
+ return poll;
+ }
+
+ void setShuttingDown(bool shuttingDown) {
+ state.shuttingDown = shuttingDown;
+ }
+
+ void transfer(NodeData *nodeData, void (*cb)(Poll *)) {
+ // userData is invalid from now on till onTransfer
+ setUserData(new TransferData({getFd(), ssl, getCb(), getPoll(), getUserData(), nodeData, cb}));
+ stop(this->nodeData->loop);
+ close(this->nodeData->loop, [](Poll *p) {
+ Socket *s = (Socket *) p;
+ TransferData *transferData = (TransferData *) s->getUserData();
+
+ transferData->destination->asyncMutex->lock();
+ bool wasEmpty = transferData->destination->transferQueue.empty();
+ transferData->destination->transferQueue.push_back(s);
+ transferData->destination->asyncMutex->unlock();
+
+ if (wasEmpty) {
+ transferData->destination->async->send();
+ }
+ });
+ }
+
+ void changePoll(Socket *socket) {
+ if (!threadSafeChange(nodeData->loop, this, socket->getPoll())) {
+ if (socket->nodeData->tid != pthread_self()) {
+ socket->nodeData->asyncMutex->lock();
+ socket->nodeData->changePollQueue.push_back(socket);
+ socket->nodeData->asyncMutex->unlock();
+ socket->nodeData->async->send();
+ } else {
+ change(socket->nodeData->loop, socket, socket->getPoll());
+ }
+ }
+ }
+
+ // clears user data!
+ template <void onTimeout(Socket *)>
+ void startTimeout(int timeoutMs = 15000) {
+ Timer *timer = new Timer(nodeData->loop);
+ timer->setData(this);
+ timer->start([](Timer *timer) {
+ Socket *s = (Socket *) timer->getData();
+ s->cancelTimeout();
+ onTimeout(s);
+ }, timeoutMs, 0);
+
+ user = timer;
+ }
+
+ void cancelTimeout() {
+ Timer *timer = (Timer *) getUserData();
+ if (timer) {
+ timer->stop();
+ timer->close();
+ user = nullptr;
+ }
+ }
+
+ template <class STATE>
+ static void sslIoHandler(Poll *p, int status, int events) {
+ Socket *socket = (Socket *) p;
+
+ if (status < 0) {
+ STATE::onEnd((Socket *) p);
+ return;
+ }
+
+ if (!socket->messageQueue.empty() && ((events & UV_WRITABLE) || SSL_want(socket->ssl) == SSL_READING)) {
+ socket->cork(true);
+ while (true) {
+ Queue::Message *messagePtr = socket->messageQueue.front();
+ int sent = SSL_write(socket->ssl, messagePtr->data, messagePtr->length);
+ if (sent == (ssize_t) messagePtr->length) {
+ if (messagePtr->callback) {
+ messagePtr->callback(p, messagePtr->callbackData, false, messagePtr->reserved);
+ }
+ socket->messageQueue.pop();
+ if (socket->messageQueue.empty()) {
+ if ((socket->state.poll & UV_WRITABLE) && SSL_want(socket->ssl) != SSL_WRITING) {
+ socket->change(socket->nodeData->loop, socket, socket->setPoll(UV_READABLE));
+ }
+ break;
+ }
+ } else if (sent <= 0) {
+ switch (SSL_get_error(socket->ssl, sent)) {
+ case SSL_ERROR_WANT_READ:
+ break;
+ case SSL_ERROR_WANT_WRITE:
+ if ((socket->getPoll() & UV_WRITABLE) == 0) {
+ socket->change(socket->nodeData->loop, socket, socket->setPoll(socket->getPoll() | UV_WRITABLE));
+ }
+ break;
+ default:
+ STATE::onEnd((Socket *) p);
+ return;
+ }
+ break;
+ }
+ }
+ socket->cork(false);
+ }
+
+ if (events & UV_READABLE) {
+ do {
+ int length = SSL_read(socket->ssl, socket->nodeData->recvBuffer, socket->nodeData->recvLength);
+ if (length <= 0) {
+ switch (SSL_get_error(socket->ssl, length)) {
+ case SSL_ERROR_WANT_READ:
+ break;
+ case SSL_ERROR_WANT_WRITE:
+ if ((socket->getPoll() & UV_WRITABLE) == 0) {
+ socket->change(socket->nodeData->loop, socket, socket->setPoll(socket->getPoll() | UV_WRITABLE));
+ }
+ break;
+ default:
+ STATE::onEnd((Socket *) p);
+ return;
+ }
+ break;
+ } else {
+ // Warning: onData can delete the socket! Happens when HttpSocket upgrades
+ socket = STATE::onData((Socket *) p, socket->nodeData->recvBuffer, length);
+ if (socket->isClosed() || socket->isShuttingDown()) {
+ return;
+ }
+ }
+ } while (SSL_pending(socket->ssl));
+ }
+ }
+
+ template <class STATE>
+ static void ioHandler(Poll *p, int status, int events) {
+ Socket *socket = (Socket *) p;
+ NodeData *nodeData = socket->nodeData;
+ Context *netContext = nodeData->netContext;
+
+ if (status < 0) {
+ STATE::onEnd((Socket *) p);
+ return;
+ }
+
+ if (events & UV_WRITABLE) {
+ if (!socket->messageQueue.empty() && (events & UV_WRITABLE)) {
+ socket->cork(true);
+ while (true) {
+ Queue::Message *messagePtr = socket->messageQueue.front();
+ ssize_t sent = ::send(socket->getFd(), messagePtr->data, messagePtr->length, MSG_NOSIGNAL);
+ if (sent == (ssize_t) messagePtr->length) {
+ if (messagePtr->callback) {
+ messagePtr->callback(p, messagePtr->callbackData, false, messagePtr->reserved);
+ }
+ socket->messageQueue.pop();
+ if (socket->messageQueue.empty()) {
+ // todo, remove bit, don't set directly
+ socket->change(socket->nodeData->loop, socket, socket->setPoll(UV_READABLE));
+ break;
+ }
+ } else if (sent == SOCKET_ERROR) {
+ if (!netContext->wouldBlock()) {
+ STATE::onEnd((Socket *) p);
+ return;
+ }
+ break;
+ } else {
+ messagePtr->length -= sent;
+ messagePtr->data += sent;
+ break;
+ }
+ }
+ socket->cork(false);
+ }
+ }
+
+ if (events & UV_READABLE) {
+ int length = recv(socket->getFd(), nodeData->recvBuffer, nodeData->recvLength, 0);
+ if (length > 0) {
+ STATE::onData((Socket *) p, nodeData->recvBuffer, length);
+ } else if (length <= 0 || (length == SOCKET_ERROR && !netContext->wouldBlock())) {
+ STATE::onEnd((Socket *) p);
+ }
+ }
+
+ }
+
+ template<class STATE>
+ void setState() {
+ if (ssl) {
+ setCb(sslIoHandler<STATE>);
+ } else {
+ setCb(ioHandler<STATE>);
+ }
+ }
+
+ bool hasEmptyQueue() {
+ return messageQueue.empty();
+ }
+
+ void enqueue(Queue::Message *message) {
+ messageQueue.push(message);
+ }
+
+ Queue::Message *allocMessage(size_t length, const char *data = 0) {
+ Queue::Message *messagePtr = (Queue::Message *) new char[sizeof(Queue::Message) + length];
+ messagePtr->length = length;
+ messagePtr->data = ((char *) messagePtr) + sizeof(Queue::Message);
+ messagePtr->nextMessage = nullptr;
+
+ if (data) {
+ memcpy((char *) messagePtr->data, data, messagePtr->length);
+ }
+
+ return messagePtr;
+ }
+
+ void freeMessage(Queue::Message *message) {
+ delete [] (char *) message;
+ }
+
+ bool write(Queue::Message *message, bool &wasTransferred) {
+ ssize_t sent = 0;
+ if (messageQueue.empty()) {
+
+ if (ssl) {
+ sent = SSL_write(ssl, message->data, message->length);
+ if (sent == (ssize_t) message->length) {
+ wasTransferred = false;
+ return true;
+ } else if (sent < 0) {
+ switch (SSL_get_error(ssl, sent)) {
+ case SSL_ERROR_WANT_READ:
+ break;
+ case SSL_ERROR_WANT_WRITE:
+ if ((getPoll() & UV_WRITABLE) == 0) {
+ setPoll(getPoll() | UV_WRITABLE);
+ changePoll(this);
+ }
+ break;
+ default:
+ return false;
+ }
+ }
+ } else {
+ sent = ::send(getFd(), message->data, message->length, MSG_NOSIGNAL);
+ if (sent == (ssize_t) message->length) {
+ wasTransferred = false;
+ return true;
+ } else if (sent == SOCKET_ERROR) {
+ if (!nodeData->netContext->wouldBlock()) {
+ return false;
+ }
+ } else {
+ message->length -= sent;
+ message->data += sent;
+ }
+
+ if ((getPoll() & UV_WRITABLE) == 0) {
+ setPoll(getPoll() | UV_WRITABLE);
+ changePoll(this);
+ }
+ }
+ }
+ messageQueue.push(message);
+ wasTransferred = true;
+ return true;
+ }
+
+ template <class T, class D>
+ void sendTransformed(const char *message, size_t length, void(*callback)(void *socket, void *data, bool cancelled, void *reserved), void *callbackData, D transformData) {
+ size_t estimatedLength = T::estimate(message, length) + sizeof(Queue::Message);
+
+ if (hasEmptyQueue()) {
+ if (estimatedLength <= uS::NodeData::preAllocMaxSize) {
+ int memoryLength = estimatedLength;
+ int memoryIndex = nodeData->getMemoryBlockIndex(memoryLength);
+
+ Queue::Message *messagePtr = (Queue::Message *) nodeData->getSmallMemoryBlock(memoryIndex);
+ messagePtr->data = ((char *) messagePtr) + sizeof(Queue::Message);
+ messagePtr->length = T::transform(message, (char *) messagePtr->data, length, transformData);
+
+ bool wasTransferred;
+ if (write(messagePtr, wasTransferred)) {
+ if (!wasTransferred) {
+ nodeData->freeSmallMemoryBlock((char *) messagePtr, memoryIndex);
+ if (callback) {
+ callback(this, callbackData, false, nullptr);
+ }
+ } else {
+ messagePtr->callback = callback;
+ messagePtr->callbackData = callbackData;
+ }
+ } else {
+ nodeData->freeSmallMemoryBlock((char *) messagePtr, memoryIndex);
+ if (callback) {
+ callback(this, callbackData, true, nullptr);
+ }
+ }
+ } else {
+ Queue::Message *messagePtr = allocMessage(estimatedLength - sizeof(Queue::Message));
+ messagePtr->length = T::transform(message, (char *) messagePtr->data, length, transformData);
+
+ bool wasTransferred;
+ if (write(messagePtr, wasTransferred)) {
+ if (!wasTransferred) {
+ freeMessage(messagePtr);
+ if (callback) {
+ callback(this, callbackData, false, nullptr);
+ }
+ } else {
+ messagePtr->callback = callback;
+ messagePtr->callbackData = callbackData;
+ }
+ } else {
+ freeMessage(messagePtr);
+ if (callback) {
+ callback(this, callbackData, true, nullptr);
+ }
+ }
+ }
+ } else {
+ Queue::Message *messagePtr = allocMessage(estimatedLength - sizeof(Queue::Message));
+ messagePtr->length = T::transform(message, (char *) messagePtr->data, length, transformData);
+ messagePtr->callback = callback;
+ messagePtr->callbackData = callbackData;
+ enqueue(messagePtr);
+ }
+ }
+
+public:
+ Socket(NodeData *nodeData, Loop *loop, uv_os_sock_t fd, SSL *ssl) : Poll(loop, fd), ssl(ssl), nodeData(nodeData) {
+ if (ssl) {
+ // OpenSSL treats SOCKETs as int
+ SSL_set_fd(ssl, (int) fd);
+ SSL_set_mode(ssl, SSL_MODE_RELEASE_BUFFERS);
+ }
+ }
+
+ NodeData *getNodeData() {
+ return nodeData;
+ }
+
+ Poll *next = nullptr, *prev = nullptr;
+
+ void *getUserData() {
+ return user;
+ }
+
+ void setUserData(void *user) {
+ this->user = user;
+ }
+
+ struct Address {
+ unsigned int port;
+ const char *address;
+ const char *family;
+ };
+
+ Address getAddress();
+
+ void setNoDelay(int enable) {
+ setsockopt(getFd(), IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int));
+ }
+
+ void cork(int enable) {
+#if defined(TCP_CORK)
+ // Linux & SmartOS have proper TCP_CORK
+ setsockopt(getFd(), IPPROTO_TCP, TCP_CORK, &enable, sizeof(int));
+#elif defined(TCP_NOPUSH)
+ // Mac OS X & FreeBSD have TCP_NOPUSH
+ setsockopt(getFd(), IPPROTO_TCP, TCP_NOPUSH, &enable, sizeof(int));
+ if (!enable) {
+ // Tested on OS X, FreeBSD situation is unclear
+ ::send(getFd(), "", 0, MSG_NOSIGNAL);
+ }
+#endif
+ }
+
+ void shutdown() {
+ if (ssl) {
+ //todo: poll in/out - have the io_cb recall shutdown if failed
+ SSL_shutdown(ssl);
+ } else {
+ ::shutdown(getFd(), SHUT_WR);
+ }
+ }
+
+ template <class T>
+ void closeSocket() {
+ uv_os_sock_t fd = getFd();
+ Context *netContext = nodeData->netContext;
+ stop(nodeData->loop);
+ netContext->closeSocket(fd);
+
+ if (ssl) {
+ SSL_free(ssl);
+ }
+
+ Poll::close(nodeData->loop, [](Poll *p) {
+ delete (T *) p;
+ });
+ }
+
+ bool isShuttingDown() {
+ return state.shuttingDown;
+ }
+
+ friend class Node;
+ friend struct NodeData;
+};
+
+struct ListenSocket : Socket {
+
+ ListenSocket(NodeData *nodeData, Loop *loop, uv_os_sock_t fd, SSL *ssl) : Socket(nodeData, loop, fd, ssl) {
+
+ }
+
+ Timer *timer = nullptr;
+ uS::TLS::Context sslContext;
+};
+
+}
+
+#endif // SOCKET_UWS_H