#ifndef SOCKET_UWS_H #define SOCKET_UWS_H #include "Networking.h" namespace uS { struct TransferData { // Connection state uv_os_sock_t fd; SSL *ssl; // Poll state void (*pollCb)(Poll *, int, int); int pollEvents; // User state void *userData; // Destination NodeData *destination; void (*transferCb)(Poll *); }; // perfectly 64 bytes (4 + 60) struct WIN32_EXPORT Socket : Poll { protected: struct { int poll : 4; int shuttingDown : 4; } state = {0, false}; SSL *ssl; void *user = nullptr; NodeData *nodeData; // this is not needed by HttpSocket! struct Queue { struct Message { const char *data; size_t length; Message *nextMessage = nullptr; void (*callback)(void *socket, void *data, bool cancelled, void *reserved) = nullptr; void *callbackData = nullptr, *reserved = nullptr; }; Message *head = nullptr, *tail = nullptr; void pop() { Message *nextMessage; if ((nextMessage = head->nextMessage)) { delete [] (char *) head; head = nextMessage; } else { delete [] (char *) head; head = tail = nullptr; } } bool empty() {return head == nullptr;} Message *front() {return head;} void push(Message *message) { message->nextMessage = nullptr; if (tail) { tail->nextMessage = message; tail = message; } else { head = message; tail = message; } } } messageQueue; int getPoll() { return state.poll; } int setPoll(int poll) { state.poll = poll; return poll; } void setShuttingDown(bool shuttingDown) { state.shuttingDown = shuttingDown; } void transfer(NodeData *nodeData, void (*cb)(Poll *)) { // userData is invalid from now on till onTransfer setUserData(new TransferData({getFd(), ssl, getCb(), getPoll(), getUserData(), nodeData, cb})); stop(this->nodeData->loop); close(this->nodeData->loop, [](Poll *p) { Socket *s = (Socket *) p; TransferData *transferData = (TransferData *) s->getUserData(); transferData->destination->asyncMutex->lock(); bool wasEmpty = transferData->destination->transferQueue.empty(); transferData->destination->transferQueue.push_back(s); transferData->destination->asyncMutex->unlock(); if (wasEmpty) { transferData->destination->async->send(); } }); } void changePoll(Socket *socket) { if (!threadSafeChange(nodeData->loop, this, socket->getPoll())) { if (socket->nodeData->tid != pthread_self()) { socket->nodeData->asyncMutex->lock(); socket->nodeData->changePollQueue.push_back(socket); socket->nodeData->asyncMutex->unlock(); socket->nodeData->async->send(); } else { change(socket->nodeData->loop, socket, socket->getPoll()); } } } // clears user data! template void startTimeout(int timeoutMs = 15000) { Timer *timer = new Timer(nodeData->loop); timer->setData(this); timer->start([](Timer *timer) { Socket *s = (Socket *) timer->getData(); s->cancelTimeout(); onTimeout(s); }, timeoutMs, 0); user = timer; } void cancelTimeout() { Timer *timer = (Timer *) getUserData(); if (timer) { timer->stop(); timer->close(); user = nullptr; } } template static void sslIoHandler(Poll *p, int status, int events) { Socket *socket = (Socket *) p; if (status < 0) { STATE::onEnd((Socket *) p); return; } if (!socket->messageQueue.empty() && ((events & UV_WRITABLE) || SSL_want(socket->ssl) == SSL_READING)) { socket->cork(true); while (true) { Queue::Message *messagePtr = socket->messageQueue.front(); int sent = SSL_write(socket->ssl, messagePtr->data, messagePtr->length); if (sent == (ssize_t) messagePtr->length) { if (messagePtr->callback) { messagePtr->callback(p, messagePtr->callbackData, false, messagePtr->reserved); } socket->messageQueue.pop(); if (socket->messageQueue.empty()) { if ((socket->state.poll & UV_WRITABLE) && SSL_want(socket->ssl) != SSL_WRITING) { socket->change(socket->nodeData->loop, socket, socket->setPoll(UV_READABLE)); } break; } } else if (sent <= 0) { switch (SSL_get_error(socket->ssl, sent)) { case SSL_ERROR_WANT_READ: break; case SSL_ERROR_WANT_WRITE: if ((socket->getPoll() & UV_WRITABLE) == 0) { socket->change(socket->nodeData->loop, socket, socket->setPoll(socket->getPoll() | UV_WRITABLE)); } break; default: STATE::onEnd((Socket *) p); return; } break; } } socket->cork(false); } if (events & UV_READABLE) { do { int length = SSL_read(socket->ssl, socket->nodeData->recvBuffer, socket->nodeData->recvLength); if (length <= 0) { switch (SSL_get_error(socket->ssl, length)) { case SSL_ERROR_WANT_READ: break; case SSL_ERROR_WANT_WRITE: if ((socket->getPoll() & UV_WRITABLE) == 0) { socket->change(socket->nodeData->loop, socket, socket->setPoll(socket->getPoll() | UV_WRITABLE)); } break; default: STATE::onEnd((Socket *) p); return; } break; } else { // Warning: onData can delete the socket! Happens when HttpSocket upgrades socket = STATE::onData((Socket *) p, socket->nodeData->recvBuffer, length); if (socket->isClosed() || socket->isShuttingDown()) { return; } } } while (SSL_pending(socket->ssl)); } } template static void ioHandler(Poll *p, int status, int events) { Socket *socket = (Socket *) p; NodeData *nodeData = socket->nodeData; Context *netContext = nodeData->netContext; if (status < 0) { STATE::onEnd((Socket *) p); return; } if (events & UV_WRITABLE) { if (!socket->messageQueue.empty() && (events & UV_WRITABLE)) { socket->cork(true); while (true) { Queue::Message *messagePtr = socket->messageQueue.front(); ssize_t sent = ::send(socket->getFd(), messagePtr->data, messagePtr->length, MSG_NOSIGNAL); if (sent == (ssize_t) messagePtr->length) { if (messagePtr->callback) { messagePtr->callback(p, messagePtr->callbackData, false, messagePtr->reserved); } socket->messageQueue.pop(); if (socket->messageQueue.empty()) { // todo, remove bit, don't set directly socket->change(socket->nodeData->loop, socket, socket->setPoll(UV_READABLE)); break; } } else if (sent == SOCKET_ERROR) { if (!netContext->wouldBlock()) { STATE::onEnd((Socket *) p); return; } break; } else { messagePtr->length -= sent; messagePtr->data += sent; break; } } socket->cork(false); } } if (events & UV_READABLE) { int length = recv(socket->getFd(), nodeData->recvBuffer, nodeData->recvLength, 0); if (length > 0) { STATE::onData((Socket *) p, nodeData->recvBuffer, length); } else if (length <= 0 || (length == SOCKET_ERROR && !netContext->wouldBlock())) { STATE::onEnd((Socket *) p); } } } template void setState() { if (ssl) { setCb(sslIoHandler); } else { setCb(ioHandler); } } bool hasEmptyQueue() { return messageQueue.empty(); } void enqueue(Queue::Message *message) { messageQueue.push(message); } Queue::Message *allocMessage(size_t length, const char *data = 0) { Queue::Message *messagePtr = (Queue::Message *) new char[sizeof(Queue::Message) + length]; messagePtr->length = length; messagePtr->data = ((char *) messagePtr) + sizeof(Queue::Message); messagePtr->nextMessage = nullptr; if (data) { memcpy((char *) messagePtr->data, data, messagePtr->length); } return messagePtr; } void freeMessage(Queue::Message *message) { delete [] (char *) message; } bool write(Queue::Message *message, bool &wasTransferred) { ssize_t sent = 0; if (messageQueue.empty()) { if (ssl) { sent = SSL_write(ssl, message->data, message->length); if (sent == (ssize_t) message->length) { wasTransferred = false; return true; } else if (sent < 0) { switch (SSL_get_error(ssl, sent)) { case SSL_ERROR_WANT_READ: break; case SSL_ERROR_WANT_WRITE: if ((getPoll() & UV_WRITABLE) == 0) { setPoll(getPoll() | UV_WRITABLE); changePoll(this); } break; default: return false; } } } else { sent = ::send(getFd(), message->data, message->length, MSG_NOSIGNAL); if (sent == (ssize_t) message->length) { wasTransferred = false; return true; } else if (sent == SOCKET_ERROR) { if (!nodeData->netContext->wouldBlock()) { return false; } } else { message->length -= sent; message->data += sent; } if ((getPoll() & UV_WRITABLE) == 0) { setPoll(getPoll() | UV_WRITABLE); changePoll(this); } } } messageQueue.push(message); wasTransferred = true; return true; } template void sendTransformed(const char *message, size_t length, void(*callback)(void *socket, void *data, bool cancelled, void *reserved), void *callbackData, D transformData) { size_t estimatedLength = T::estimate(message, length) + sizeof(Queue::Message); if (hasEmptyQueue()) { if (estimatedLength <= uS::NodeData::preAllocMaxSize) { int memoryLength = estimatedLength; int memoryIndex = nodeData->getMemoryBlockIndex(memoryLength); Queue::Message *messagePtr = (Queue::Message *) nodeData->getSmallMemoryBlock(memoryIndex); messagePtr->data = ((char *) messagePtr) + sizeof(Queue::Message); messagePtr->length = T::transform(message, (char *) messagePtr->data, length, transformData); bool wasTransferred; if (write(messagePtr, wasTransferred)) { if (!wasTransferred) { nodeData->freeSmallMemoryBlock((char *) messagePtr, memoryIndex); if (callback) { callback(this, callbackData, false, nullptr); } } else { messagePtr->callback = callback; messagePtr->callbackData = callbackData; } } else { nodeData->freeSmallMemoryBlock((char *) messagePtr, memoryIndex); if (callback) { callback(this, callbackData, true, nullptr); } } } else { Queue::Message *messagePtr = allocMessage(estimatedLength - sizeof(Queue::Message)); messagePtr->length = T::transform(message, (char *) messagePtr->data, length, transformData); bool wasTransferred; if (write(messagePtr, wasTransferred)) { if (!wasTransferred) { freeMessage(messagePtr); if (callback) { callback(this, callbackData, false, nullptr); } } else { messagePtr->callback = callback; messagePtr->callbackData = callbackData; } } else { freeMessage(messagePtr); if (callback) { callback(this, callbackData, true, nullptr); } } } } else { Queue::Message *messagePtr = allocMessage(estimatedLength - sizeof(Queue::Message)); messagePtr->length = T::transform(message, (char *) messagePtr->data, length, transformData); messagePtr->callback = callback; messagePtr->callbackData = callbackData; enqueue(messagePtr); } } public: Socket(NodeData *nodeData, Loop *loop, uv_os_sock_t fd, SSL *ssl) : Poll(loop, fd), ssl(ssl), nodeData(nodeData) { if (ssl) { // OpenSSL treats SOCKETs as int SSL_set_fd(ssl, (int) fd); SSL_set_mode(ssl, SSL_MODE_RELEASE_BUFFERS); } } NodeData *getNodeData() { return nodeData; } Poll *next = nullptr, *prev = nullptr; void *getUserData() { return user; } void setUserData(void *user) { this->user = user; } struct Address { unsigned int port; const char *address; const char *family; }; Address getAddress(); void setNoDelay(int enable) { setsockopt(getFd(), IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int)); } void cork(int enable) { #if defined(TCP_CORK) // Linux & SmartOS have proper TCP_CORK setsockopt(getFd(), IPPROTO_TCP, TCP_CORK, &enable, sizeof(int)); #elif defined(TCP_NOPUSH) // Mac OS X & FreeBSD have TCP_NOPUSH setsockopt(getFd(), IPPROTO_TCP, TCP_NOPUSH, &enable, sizeof(int)); if (!enable) { // Tested on OS X, FreeBSD situation is unclear ::send(getFd(), "", 0, MSG_NOSIGNAL); } #endif } void shutdown() { if (ssl) { //todo: poll in/out - have the io_cb recall shutdown if failed SSL_shutdown(ssl); } else { ::shutdown(getFd(), SHUT_WR); } } template void closeSocket() { uv_os_sock_t fd = getFd(); Context *netContext = nodeData->netContext; stop(nodeData->loop); netContext->closeSocket(fd); if (ssl) { SSL_free(ssl); } Poll::close(nodeData->loop, [](Poll *p) { delete (T *) p; }); } bool isShuttingDown() { return state.shuttingDown; } friend class Node; friend struct NodeData; }; struct ListenSocket : Socket { ListenSocket(NodeData *nodeData, Loop *loop, uv_os_sock_t fd, SSL *ssl) : Socket(nodeData, loop, fd, ssl) { } Timer *timer = nullptr; uS::TLS::Context sslContext; }; } #endif // SOCKET_UWS_H