diff --git a/gateway/httplib.h b/gateway/httplib.h index 7fe82e7..f90dc90 100644 --- a/gateway/httplib.h +++ b/gateway/httplib.h @@ -8,6 +8,20 @@ #ifndef CPPHTTPLIB_HTTPLIB_H #define CPPHTTPLIB_HTTPLIB_H +/* + * Configuration + */ +#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND 5 +#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND 0 +#define CPPHTTPLIB_KEEPALIVE_MAX_COUNT 5 +#define CPPHTTPLIB_READ_TIMEOUT_SECOND 5 +#define CPPHTTPLIB_READ_TIMEOUT_USECOND 0 +#define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 8192 +#define CPPHTTPLIB_REDIRECT_MAX_COUNT 20 +#define CPPHTTPLIB_PAYLOAD_MAX_LENGTH (std::numeric_limits::max)() +#define CPPHTTPLIB_RECV_BUFSIZ size_t(4096u) +#define CPPHTTPLIB_THREAD_POOL_COUNT 8 + #ifdef _WIN32 #ifndef _CRT_SECURE_NO_WARNINGS #define _CRT_SECURE_NO_WARNINGS @@ -17,8 +31,16 @@ #define _CRT_NONSTDC_NO_DEPRECATE #endif //_CRT_NONSTDC_NO_DEPRECATE -#if defined(_MSC_VER) && _MSC_VER < 1900 +#if defined(_MSC_VER) +#ifdef _WIN64 +typedef __int64 ssize_t; +#else +typedef int ssize_t; +#endif + +#if _MSC_VER < 1900 #define snprintf _snprintf_s +#endif #endif // _MSC_VER #ifndef S_ISREG @@ -37,18 +59,32 @@ #include #include +#ifndef WSA_FLAG_NO_HANDLE_INHERIT +#define WSA_FLAG_NO_HANDLE_INHERIT 0x80 +#endif + +#ifdef _MSC_VER #pragma comment(lib, "ws2_32.lib") +#endif #ifndef strcasecmp #define strcasecmp _stricmp #endif // strcasecmp typedef SOCKET socket_t; -#else +#ifndef CPPHTTPLIB_USE_SELECT +#define poll(fds, nfds, timeout) WSAPoll(fds, nfds, timeout) +#endif + +#else // not _WIN32 + #include #include #include #include +#ifndef CPPHTTPLIB_USE_SELECT +#include +#endif #include #include #include @@ -60,13 +96,18 @@ typedef int socket_t; #endif //_WIN32 #include +#include +#include +#include #include #include #include +#include #include #include #include -#include +#include +#include #include #include #include @@ -74,21 +115,25 @@ typedef int socket_t; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT #include #include +#include + +// #if OPENSSL_VERSION_NUMBER < 0x1010100fL +// #error Sorry, OpenSSL versions prior to 1.1.1 are not supported +// #endif + +#if OPENSSL_VERSION_NUMBER < 0x10100000L +#include +inline const unsigned char *ASN1_STRING_get0_data(const ASN1_STRING *asn1) +{ + return M_ASN1_STRING_data(asn1); +} +#endif #endif #ifdef CPPHTTPLIB_ZLIB_SUPPORT #include #endif -/* - * Configuration - */ -#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND 5 -#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND 0 -#define CPPHTTPLIB_KEEPALIVE_MAX_COUNT 5 -#define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 8192 -#define CPPHTTPLIB_PAYLOAD_MAX_LENGTH std::numeric_limits::max() - namespace httplib { @@ -114,13 +159,23 @@ enum class HttpVersion typedef std::multimap Headers; -template -std::pair make_range_header(uint64_t value, Args... args); - typedef std::multimap Params; -typedef boost::smatch Match; +typedef std::smatch Match; + +typedef std::function DataSink; + +typedef std::function Done; + +typedef std::function ContentProvider; + +typedef std::function + ContentReceiver; + typedef std::function Progress; +struct Response; +typedef std::function ResponseHandler; + struct MultipartFile { std::string filename; @@ -130,24 +185,48 @@ struct MultipartFile }; typedef std::multimap MultipartFiles; +struct MultipartFormData +{ + std::string name; + std::string content; + std::string filename; + std::string content_type; +}; +typedef std::vector MultipartFormDataItems; + +typedef std::pair Range; +typedef std::vector Ranges; + struct Request { - std::string version; std::string method; - std::string target; std::string path; Headers headers; std::string body; + + // for server + std::string version; + std::string target; Params params; MultipartFiles files; + Ranges ranges; Match matches; + // for client + size_t redirect_count = CPPHTTPLIB_REDIRECT_MAX_COUNT; + ResponseHandler response_handler; + ContentReceiver content_receiver; Progress progress; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + const SSL *ssl; +#endif + bool has_header(const char *key) const; std::string get_header_value(const char *key, size_t id = 0) const; size_t get_header_value_count(const char *key) const; void set_header(const char *key, const char *val); + void set_header(const char *key, const std::string &val); bool has_param(const char *key) const; std::string get_param_value(const char *key, size_t id = 0) const; @@ -163,20 +242,40 @@ struct Response int status; Headers headers; std::string body; - std::function streamcb; bool has_header(const char *key) const; std::string get_header_value(const char *key, size_t id = 0) const; size_t get_header_value_count(const char *key) const; void set_header(const char *key, const char *val); + void set_header(const char *key, const std::string &val); void set_redirect(const char *uri); void set_content(const char *s, size_t n, const char *content_type); void set_content(const std::string &s, const char *content_type); - Response() : status(-1) + void set_content_provider( + size_t length, std::function provider, + std::function resource_releaser = [] {}); + + void set_chunked_content_provider( + std::function provider, + std::function resource_releaser = [] {}); + + Response() : status(-1), content_provider_resource_length(0) { } + + ~Response() + { + if(content_provider_resource_releaser) + { + content_provider_resource_releaser(); + } + } + + size_t content_provider_resource_length; + ContentProvider content_provider; + std::function content_provider_resource_releaser; }; class Stream @@ -188,9 +287,10 @@ class Stream virtual int read(char *ptr, size_t size) = 0; virtual int write(const char *ptr, size_t size1) = 0; virtual int write(const char *ptr) = 0; + virtual int write(const std::string &s) = 0; virtual std::string get_remote_addr() const = 0; - template void write_format(const char *fmt, const Args &...args); + template int write_format(const char *fmt, const Args &...args); }; class SocketStream : public Stream @@ -202,6 +302,7 @@ class SocketStream : public Stream virtual int read(char *ptr, size_t size); virtual int write(const char *ptr, size_t size); virtual int write(const char *ptr); + virtual int write(const std::string &s); virtual std::string get_remote_addr() const; private: @@ -221,6 +322,7 @@ class BufferStream : public Stream virtual int read(char *ptr, size_t size); virtual int write(const char *ptr, size_t size); virtual int write(const char *ptr); + virtual int write(const std::string &s); virtual std::string get_remote_addr() const; const std::string &get_buffer() const; @@ -229,6 +331,152 @@ class BufferStream : public Stream std::string buffer; }; +class TaskQueue +{ + public: + TaskQueue() + { + } + virtual ~TaskQueue() + { + } + virtual void enqueue(std::function fn) = 0; + virtual void shutdown() = 0; +}; + +#if CPPHTTPLIB_THREAD_POOL_COUNT > 0 +class ThreadPool : public TaskQueue +{ + public: + ThreadPool(size_t n) : shutdown_(false) + { + while(n) + { + auto t = std::make_shared(worker(*this)); + threads_.push_back(t); + n--; + } + } + + ThreadPool(const ThreadPool &) = delete; + virtual ~ThreadPool() + { + } + + virtual void enqueue(std::function fn) override + { + std::unique_lock lock(mutex_); + jobs_.push_back(fn); + cond_.notify_one(); + } + + virtual void shutdown() override + { + // Stop all worker threads... + { + std::unique_lock lock(mutex_); + shutdown_ = true; + } + + cond_.notify_all(); + + // Join... + for(auto t : threads_) + { + t->join(); + } + } + + private: + struct worker + { + worker(ThreadPool &pool) : pool_(pool) + { + } + + void operator()() + { + for(;;) + { + std::function fn; + { + std::unique_lock lock(pool_.mutex_); + + pool_.cond_.wait(lock, [&] { return !pool_.jobs_.empty() || pool_.shutdown_; }); + + if(pool_.shutdown_ && pool_.jobs_.empty()) + { + break; + } + + fn = pool_.jobs_.front(); + pool_.jobs_.pop_front(); + } + + assert(true == static_cast(fn)); + fn(); + } + } + + ThreadPool &pool_; + }; + friend struct worker; + + std::vector> threads_; + std::list> jobs_; + + bool shutdown_; + + std::condition_variable cond_; + std::mutex mutex_; +}; +#else +class Threads : public TaskQueue +{ + public: + Threads() : running_threads_(0) + { + } + virtual ~Threads() + { + } + + virtual void enqueue(std::function fn) override + { + std::thread([=]() { + { + std::lock_guard guard(running_threads_mutex_); + running_threads_++; + } + + fn(); + + { + std::lock_guard guard(running_threads_mutex_); + running_threads_--; + } + }).detach(); + } + + virtual void shutdown() override + { + for(;;) + { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + std::lock_guard guard(running_threads_mutex_); + if(!running_threads_) + { + break; + } + } + } + + private: + std::mutex running_threads_mutex_; + int running_threads_; +}; +#endif + class Server { public: @@ -250,12 +498,13 @@ class Server Server &Options(const char *pattern, Handler handler); bool set_base_dir(const char *path); + void set_file_request_handler(Handler handler); void set_error_handler(Handler handler); void set_logger(Logger logger); void set_keep_alive_max_count(size_t count); - void set_payload_max_length(uint64_t length); + void set_payload_max_length(size_t length); int bind_to_any_port(const char *host, int socket_flags = 0); bool listen_after_bind(); @@ -265,14 +514,17 @@ class Server bool is_running() const; void stop(); + std::function new_task_queue; + protected: - bool process_request(Stream &strm, bool last_connection, bool &connection_close); + bool process_request(Stream &strm, bool last_connection, bool &connection_close, + std::function setup_request); size_t keep_alive_max_count_; size_t payload_max_length_; private: - typedef std::vector> Handlers; + typedef std::vector> Handlers; socket_t create_server_socket(const char *host, int port, int socket_flags) const; int bind_internal(const char *host, int port, int socket_flags); @@ -283,13 +535,16 @@ class Server bool dispatch_request(Request &req, Response &res, Handlers &handlers); bool parse_request_line(const char *s, Request &req); - void write_response(Stream &strm, bool last_connection, const Request &req, Response &res); + bool write_response(Stream &strm, bool last_connection, const Request &req, Response &res); + bool write_content_with_provider(Stream &strm, const Request &req, Response &res, const std::string &boundary, + const std::string &content_type); - virtual bool read_and_close_socket(socket_t sock); + virtual bool process_and_close_socket(socket_t sock); - bool is_running_; - socket_t svr_sock_; + std::atomic is_running_; + std::atomic svr_sock_; std::string base_dir_; + Handler file_request_handler_; Handlers get_handlers_; Handlers post_handlers_; Handlers put_handlers_; @@ -298,10 +553,6 @@ class Server Handlers options_handlers_; Handler error_handler_; Logger logger_; - - // TODO: Use thread pool... - std::mutex running_threads_mutex_; - int running_threads_; }; class Client @@ -313,52 +564,131 @@ class Client virtual bool is_valid() const; - std::shared_ptr Get(const char *path, Progress progress = nullptr); - std::shared_ptr Get(const char *path, const Headers &headers, Progress progress = nullptr); + std::shared_ptr Get(const char *path); + + std::shared_ptr Get(const char *path, const Headers &headers); + + std::shared_ptr Get(const char *path, Progress progress); + + std::shared_ptr Get(const char *path, const Headers &headers, Progress progress); + + std::shared_ptr Get(const char *path, ContentReceiver content_receiver); + + std::shared_ptr Get(const char *path, const Headers &headers, ContentReceiver content_receiver); + + std::shared_ptr Get(const char *path, ContentReceiver content_receiver, Progress progress); + + std::shared_ptr Get(const char *path, const Headers &headers, ContentReceiver content_receiver, + Progress progress); + + std::shared_ptr Get(const char *path, const Headers &headers, ResponseHandler response_handler, + ContentReceiver content_receiver); + + std::shared_ptr Get(const char *path, const Headers &headers, ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress); std::shared_ptr Head(const char *path); + std::shared_ptr Head(const char *path, const Headers &headers); std::shared_ptr Post(const char *path, const std::string &body, const char *content_type); + std::shared_ptr Post(const char *path, const Headers &headers, const std::string &body, const char *content_type); std::shared_ptr Post(const char *path, const Params ¶ms); + std::shared_ptr Post(const char *path, const Headers &headers, const Params ¶ms); + std::shared_ptr Post(const char *path, const MultipartFormDataItems &items); + + std::shared_ptr Post(const char *path, const Headers &headers, const MultipartFormDataItems &items); + std::shared_ptr Put(const char *path, const std::string &body, const char *content_type); + std::shared_ptr Put(const char *path, const Headers &headers, const std::string &body, const char *content_type); std::shared_ptr Patch(const char *path, const std::string &body, const char *content_type); + std::shared_ptr Patch(const char *path, const Headers &headers, const std::string &body, const char *content_type); std::shared_ptr Delete(const char *path); + + std::shared_ptr Delete(const char *path, const std::string &body, const char *content_type); + std::shared_ptr Delete(const char *path, const Headers &headers); + std::shared_ptr Delete(const char *path, const Headers &headers, const std::string &body, + const char *content_type); + std::shared_ptr Options(const char *path); + std::shared_ptr Options(const char *path, const Headers &headers); - bool send(Request &req, Response &res); + bool send(const Request &req, Response &res); + + bool send(const std::vector &requests, std::vector &responses); + + void set_keep_alive_max_count(size_t count); + + void follow_location(bool on); protected: - bool process_request(Stream &strm, Request &req, Response &res, bool &connection_close); + bool process_request(Stream &strm, const Request &req, Response &res, bool last_connection, bool &connection_close); const std::string host_; const int port_; time_t timeout_sec_; const std::string host_and_port_; + size_t keep_alive_max_count_; + size_t follow_location_; private: socket_t create_client_socket() const; bool read_response_line(Stream &strm, Response &res); - void write_request(Stream &strm, Request &req); + void write_request(Stream &strm, const Request &req, bool last_connection); + bool redirect(const Request &req, Response &res); + + virtual bool process_and_close_socket( + socket_t sock, size_t request_count, + std::function callback); - virtual bool read_and_close_socket(socket_t sock, Request &req, Response &res); virtual bool is_ssl() const; }; +inline void Get(std::vector &requests, const char *path, const Headers &headers) +{ + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + requests.emplace_back(std::move(req)); +} + +inline void Get(std::vector &requests, const char *path) +{ + Get(requests, path, Headers()); +} + +inline void Post(std::vector &requests, const char *path, const Headers &headers, const std::string &body, + const char *content_type) +{ + Request req; + req.method = "POST"; + req.path = path; + req.headers = headers; + req.headers.emplace("Content-Type", content_type); + req.body = body; + requests.emplace_back(std::move(req)); +} + +inline void Post(std::vector &requests, const char *path, const std::string &body, const char *content_type) +{ + Post(requests, path, Headers(), body, content_type); +} + #ifdef CPPHTTPLIB_OPENSSL_SUPPORT class SSLSocketStream : public Stream { @@ -369,6 +699,7 @@ class SSLSocketStream : public Stream virtual int read(char *ptr, size_t size); virtual int write(const char *ptr, size_t size); virtual int write(const char *ptr); + virtual int write(const std::string &s); virtual std::string get_remote_addr() const; private: @@ -379,14 +710,15 @@ class SSLSocketStream : public Stream class SSLServer : public Server { public: - SSLServer(const char *cert_path, const char *private_key_path); + SSLServer(const char *cert_path, const char *private_key_path, const char *client_ca_cert_file_path = nullptr, + const char *client_ca_cert_dir_path = nullptr); virtual ~SSLServer(); virtual bool is_valid() const; private: - virtual bool read_and_close_socket(socket_t sock); + virtual bool process_and_close_socket(socket_t sock); SSL_CTX *ctx_; std::mutex ctx_mutex_; @@ -395,18 +727,38 @@ class SSLServer : public Server class SSLClient : public Client { public: - SSLClient(const char *host, int port = 443, time_t timeout_sec = 300); + SSLClient(const char *host, int port = 443, time_t timeout_sec = 300, const char *client_cert_path = nullptr, + const char *client_key_path = nullptr); virtual ~SSLClient(); virtual bool is_valid() const; + void set_ca_cert_path(const char *ca_ceert_file_path, const char *ca_cert_dir_path = nullptr); + void enable_server_certificate_verification(bool enabled); + + long get_openssl_verify_result() const; + + SSL_CTX *ssl_context() const noexcept; + private: - virtual bool read_and_close_socket(socket_t sock, Request &req, Response &res); + virtual bool process_and_close_socket( + socket_t sock, size_t request_count, + std::function callback); virtual bool is_ssl() const; + bool verify_host(X509 *server_cert) const; + bool verify_host_with_subject_alt_name(X509 *server_cert) const; + bool verify_host_with_common_name(X509 *server_cert) const; + bool check_host_name(const char *pattern, size_t pattern_len) const; + SSL_CTX *ctx_; std::mutex ctx_mutex_; + std::vector host_components_; + std::string ca_cert_file_path_; + std::string ca_cert_dir_path_; + bool server_certificate_verification_ = false; + long verify_result_ = 0; }; #endif @@ -416,6 +768,228 @@ class SSLClient : public Client namespace detail { +inline bool is_hex(char c, int &v) +{ + if(0x20 <= c && isdigit(c)) + { + v = c - '0'; + return true; + } + else if('A' <= c && c <= 'F') + { + v = c - 'A' + 10; + return true; + } + else if('a' <= c && c <= 'f') + { + v = c - 'a' + 10; + return true; + } + return false; +} + +inline bool from_hex_to_i(const std::string &s, size_t i, size_t cnt, int &val) +{ + if(i >= s.size()) + { + return false; + } + + val = 0; + for(; cnt; i++, cnt--) + { + if(!s[i]) + { + return false; + } + int v = 0; + if(is_hex(s[i], v)) + { + val = val * 16 + v; + } + else + { + return false; + } + } + return true; +} + +inline std::string from_i_to_hex(size_t n) +{ + const char *charset = "0123456789abcdef"; + std::string ret; + do + { + ret = charset[n & 15] + ret; + n >>= 4; + } while(n > 0); + return ret; +} + +inline size_t to_utf8(int code, char *buff) +{ + if(code < 0x0080) + { + buff[0] = (code & 0x7F); + return 1; + } + else if(code < 0x0800) + { + buff[0] = (0xC0 | ((code >> 6) & 0x1F)); + buff[1] = (0x80 | (code & 0x3F)); + return 2; + } + else if(code < 0xD800) + { + buff[0] = (0xE0 | ((code >> 12) & 0xF)); + buff[1] = (0x80 | ((code >> 6) & 0x3F)); + buff[2] = (0x80 | (code & 0x3F)); + return 3; + } + else if(code < 0xE000) + { // D800 - DFFF is invalid... + return 0; + } + else if(code < 0x10000) + { + buff[0] = (0xE0 | ((code >> 12) & 0xF)); + buff[1] = (0x80 | ((code >> 6) & 0x3F)); + buff[2] = (0x80 | (code & 0x3F)); + return 3; + } + else if(code < 0x110000) + { + buff[0] = (0xF0 | ((code >> 18) & 0x7)); + buff[1] = (0x80 | ((code >> 12) & 0x3F)); + buff[2] = (0x80 | ((code >> 6) & 0x3F)); + buff[3] = (0x80 | (code & 0x3F)); + return 4; + } + + // NOTREACHED + return 0; +} + +// NOTE: This code came up with the following stackoverflow post: +// https://stackoverflow.com/questions/180947/base64-decode-snippet-in-c +inline std::string base64_encode(const std::string &in) +{ + static const auto lookup = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + + std::string out; + out.reserve(in.size()); + + int val = 0; + int valb = -6; + + for(uint8_t c : in) + { + val = (val << 8) + c; + valb += 8; + while(valb >= 0) + { + out.push_back(lookup[(val >> valb) & 0x3F]); + valb -= 6; + } + } + + if(valb > -6) + { + out.push_back(lookup[((val << 8) >> (valb + 8)) & 0x3F]); + } + + while(out.size() % 4) + { + out.push_back('='); + } + + return out; +} + +inline bool is_file(const std::string &path) +{ + struct stat st; + return stat(path.c_str(), &st) >= 0 && S_ISREG(st.st_mode); +} + +inline bool is_dir(const std::string &path) +{ + struct stat st; + return stat(path.c_str(), &st) >= 0 && S_ISDIR(st.st_mode); +} + +inline bool is_valid_path(const std::string &path) +{ + size_t level = 0; + size_t i = 0; + + // Skip slash + while(i < path.size() && path[i] == '/') + { + i++; + } + + while(i < path.size()) + { + // Read component + auto beg = i; + while(i < path.size() && path[i] != '/') + { + i++; + } + + auto len = i - beg; + assert(len > 0); + + if(!path.compare(beg, len, ".")) + { + ; + } + else if(!path.compare(beg, len, "..")) + { + if(level == 0) + { + return false; + } + level--; + } + else + { + level++; + } + + // Skip slash + while(i < path.size() && path[i] == '/') + { + i++; + } + } + + return true; +} + +inline void read_file(const std::string &path, std::string &out) +{ + std::ifstream fs(path, std::ios_base::binary); + fs.seekg(0, std::ios_base::end); + auto size = fs.tellg(); + fs.seekg(0); + out.resize(static_cast(size)); + fs.read(&out[0], size); +} + +inline std::string file_extension(const std::string &path) +{ + std::smatch m; + auto pat = std::regex("\\.([a-zA-Z0-9]+)$"); + if(std::regex_search(path, m, pat)) + { + return m[1].str(); + } + return std::string(); +} + template void split(const char *b, const char *e, char d, Fn fn) { int i = 0; @@ -545,6 +1119,7 @@ inline int close_socket(socket_t sock) inline int select_read(socket_t sock, time_t sec, time_t usec) { +#ifdef CPPHTTPLIB_USE_SELECT fd_set fds; FD_ZERO(&fds); FD_SET(sock, &fds); @@ -553,11 +1128,21 @@ inline int select_read(socket_t sock, time_t sec, time_t usec) tv.tv_sec = static_cast(sec); tv.tv_usec = static_cast(usec); - return select(static_cast(sock + 1), &fds, NULL, NULL, &tv); + return select(static_cast(sock + 1), &fds, nullptr, nullptr, &tv); +#else + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLIN; + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + return poll(&pfd_read, 1, timeout); +#endif } inline bool wait_until_socket_is_ready(socket_t sock, time_t sec, time_t usec) { +#ifdef CPPHTTPLIB_USE_SELECT fd_set fdsr; FD_ZERO(&fdsr); FD_SET(sock, &fdsr); @@ -569,36 +1154,43 @@ inline bool wait_until_socket_is_ready(socket_t sock, time_t sec, time_t usec) tv.tv_sec = static_cast(sec); tv.tv_usec = static_cast(usec); - if(select(static_cast(sock + 1), &fdsr, &fdsw, &fdse, &tv) < 0) - { - return false; - } - else if(FD_ISSET(sock, &fdsr) || FD_ISSET(sock, &fdsw)) + if(select(static_cast(sock + 1), &fdsr, &fdsw, &fdse, &tv) > 0 && + (FD_ISSET(sock, &fdsr) || FD_ISSET(sock, &fdsw))) { int error = 0; socklen_t len = sizeof(error); - if(getsockopt(sock, SOL_SOCKET, SO_ERROR, (char *)&error, &len) < 0 || error) - { - return false; - } - } - else - { - return false; + return getsockopt(sock, SOL_SOCKET, SO_ERROR, (char *)&error, &len) >= 0 && !error; } + return false; +#else + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLIN | POLLOUT; - return true; + auto timeout = static_cast(sec * 1000 + usec / 1000); + + if(poll(&pfd_read, 1, timeout) > 0 && pfd_read.revents & (POLLIN | POLLOUT)) + { + int error = 0; + socklen_t len = sizeof(error); + return getsockopt(sock, SOL_SOCKET, SO_ERROR, reinterpret_cast(&error), &len) >= 0 && !error; + } + return false; +#endif } -template inline bool read_and_close_socket(socket_t sock, size_t keep_alive_max_count, T callback) +template +inline bool process_and_close_socket(bool is_client_request, socket_t sock, size_t keep_alive_max_count, T callback) { + assert(keep_alive_max_count > 0); + bool ret = false; - if(keep_alive_max_count > 0) + if(keep_alive_max_count > 1) { auto count = keep_alive_max_count; - while(count > 0 && - detail::select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0) + while(count > 0 && (is_client_request || detail::select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, + CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) { SocketStream strm(sock); auto last_connection = count == 1; @@ -663,15 +1255,29 @@ template socket_t create_socket(const char *host, int port, Fn fn, for(auto rp = result; rp; rp = rp->ai_next) { // Create a socket +#ifdef _WIN32 + auto sock = WSASocketW(rp->ai_family, rp->ai_socktype, rp->ai_protocol, nullptr, 0, WSA_FLAG_NO_HANDLE_INHERIT); +#else auto sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); +#endif if(sock == INVALID_SOCKET) { continue; } +#ifndef _WIN32 + if(fcntl(sock, F_SETFD, FD_CLOEXEC) == -1) + { + continue; + } +#endif + // Make 'reuse address' option available int yes = 1; - setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, (char *)&yes, sizeof(yes)); + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&yes), sizeof(yes)); +#ifdef SO_REUSEPORT + setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, reinterpret_cast(&yes), sizeof(yes)); +#endif // bind or connect if(fn(sock, *rp)) @@ -712,11 +1318,12 @@ inline std::string get_remote_addr(socket_t sock) struct sockaddr_storage addr; socklen_t len = sizeof(addr); - if(!getpeername(sock, (struct sockaddr *)&addr, &len)) + if(!getpeername(sock, reinterpret_cast(&addr), &len)) { char ipstr[NI_MAXHOST]; - if(!getnameinfo((struct sockaddr *)&addr, len, ipstr, sizeof(ipstr), nullptr, 0, NI_NUMERICHOST)) + if(!getnameinfo(reinterpret_cast(&addr), len, ipstr, sizeof(ipstr), nullptr, 0, + NI_NUMERICHOST)) { return ipstr; } @@ -725,89 +1332,6 @@ inline std::string get_remote_addr(socket_t sock) return std::string(); } -inline bool is_file(const std::string &path) -{ - struct stat st; - return stat(path.c_str(), &st) >= 0 && S_ISREG(st.st_mode); -} - -inline bool is_dir(const std::string &path) -{ - struct stat st; - return stat(path.c_str(), &st) >= 0 && S_ISDIR(st.st_mode); -} - -inline bool is_valid_path(const std::string &path) -{ - size_t level = 0; - size_t i = 0; - - // Skip slash - while(i < path.size() && path[i] == '/') - { - i++; - } - - while(i < path.size()) - { - // Read component - auto beg = i; - while(i < path.size() && path[i] != '/') - { - i++; - } - - auto len = i - beg; - assert(len > 0); - - if(!path.compare(beg, len, ".")) - { - ; - } - else if(!path.compare(beg, len, "..")) - { - if(level == 0) - { - return false; - } - level--; - } - else - { - level++; - } - - // Skip slash - while(i < path.size() && path[i] == '/') - { - i++; - } - } - - return true; -} - -inline void read_file(const std::string &path, std::string &out) -{ - std::ifstream fs(path, std::ios_base::binary); - fs.seekg(0, std::ios_base::end); - auto size = fs.tellg(); - fs.seekg(0); - out.resize(static_cast(size)); - fs.read(&out[0], size); -} - -inline std::string file_extension(const std::string &path) -{ - boost::smatch m; - auto pat = boost::regex("\\.([a-zA-Z0-9]+)$"); - if(boost::regex_search(path, m, pat)) - { - return m[1].str(); - } - return std::string(); -} - inline const char *find_content_type(const std::string &path) { auto ext = file_extension(path); @@ -872,6 +1396,8 @@ inline const char *status_message(int status) { case 200: return "OK"; + case 206: + return "Partial Content"; case 301: return "Moved Permanently"; case 302: @@ -892,12 +1418,126 @@ inline const char *status_message(int status) return "Request-URI Too Long"; case 415: return "Unsupported Media Type"; + case 416: + return "Range Not Satisfiable"; + default: case 500: return "Internal Server Error"; } } +#ifdef CPPHTTPLIB_ZLIB_SUPPORT +inline bool can_compress(const std::string &content_type) +{ + return !content_type.find("text/") || content_type == "image/svg+xml" || content_type == "application/javascript" || + content_type == "application/json" || content_type == "application/xml" || + content_type == "application/xhtml+xml"; +} + +inline bool compress(std::string &content) +{ + z_stream strm; + strm.zalloc = Z_NULL; + strm.zfree = Z_NULL; + strm.opaque = Z_NULL; + + auto ret = deflateInit2(&strm, Z_DEFAULT_COMPRESSION, Z_DEFLATED, 31, 8, Z_DEFAULT_STRATEGY); + if(ret != Z_OK) + { + return false; + } + + strm.avail_in = content.size(); + strm.next_in = const_cast(reinterpret_cast(content.data())); + + std::string compressed; + + const auto bufsiz = 16384; + char buff[bufsiz]; + do + { + strm.avail_out = bufsiz; + strm.next_out = reinterpret_cast(buff); + ret = deflate(&strm, Z_FINISH); + assert(ret != Z_STREAM_ERROR); + compressed.append(buff, bufsiz - strm.avail_out); + } while(strm.avail_out == 0); + + assert(ret == Z_STREAM_END); + assert(strm.avail_in == 0); + + content.swap(compressed); + + deflateEnd(&strm); + return true; +} + +class decompressor +{ + public: + decompressor() + { + strm.zalloc = Z_NULL; + strm.zfree = Z_NULL; + strm.opaque = Z_NULL; + + // 15 is the value of wbits, which should be at the maximum possible value + // to ensure that any gzip stream can be decoded. The offset of 16 specifies + // that the stream to decompress will be formatted with a gzip wrapper. + is_valid_ = inflateInit2(&strm, 16 + 15) == Z_OK; + } + + ~decompressor() + { + inflateEnd(&strm); + } + + bool is_valid() const + { + return is_valid_; + } + + template bool decompress(const char *data, size_t data_length, T callback) + { + int ret = Z_OK; + + strm.avail_in = data_length; + strm.next_in = const_cast(reinterpret_cast(data)); + + const auto bufsiz = 16384; + char buff[bufsiz]; + do + { + strm.avail_out = bufsiz; + strm.next_out = reinterpret_cast(buff); + + ret = inflate(&strm, Z_NO_FLUSH); + assert(ret != Z_STREAM_ERROR); + switch(ret) + { + case Z_NEED_DICT: + case Z_DATA_ERROR: + case Z_MEM_ERROR: + inflateEnd(&strm); + return false; + } + + if(!callback(buff, bufsiz - strm.avail_out)) + { + return false; + } + } while(strm.avail_out == 0); + + return ret == Z_STREAM_END; + } + + private: + bool is_valid_; + z_stream strm; +}; +#endif + inline bool has_header(const Headers &headers, const char *key) { return headers.find(key) != headers.end(); @@ -926,7 +1566,7 @@ inline uint64_t get_header_value_uint64(const Headers &headers, const char *key, inline bool read_headers(Stream &strm, Headers &headers) { - static boost::regex re(R"((.+?):\s*(.+?)\s*\r\n)"); + static std::regex re(R"((.+?):\s*(.+?)\s*\r\n)"); const auto bufsiz = 2048; char buf[bufsiz]; @@ -943,8 +1583,8 @@ inline bool read_headers(Stream &strm, Headers &headers) { break; } - boost::cmatch m; - if(boost::regex_match(reader.ptr(), m, re)) + std::cmatch m; + if(std::regex_match(reader.ptr(), m, re)) { auto key = std::string(m[1]); auto val = std::string(m[2]); @@ -955,18 +1595,27 @@ inline bool read_headers(Stream &strm, Headers &headers) return true; } -inline bool read_content_with_length(Stream &strm, std::string &out, size_t len, Progress progress) +typedef std::function ContentReceiverCore; + +inline bool read_content_with_length(Stream &strm, uint64_t len, Progress progress, ContentReceiverCore out) { - out.assign(len, 0); - size_t r = 0; + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + + uint64_t r = 0; while(r < len) { - auto n = strm.read(&out[r], len - r); + auto read_len = static_cast(len - r); + auto n = strm.read(buf, std::min(read_len, CPPHTTPLIB_RECV_BUFSIZ)); if(n <= 0) { return false; } + if(!out(buf, n)) + { + return false; + } + r += n; if(progress) @@ -981,12 +1630,28 @@ inline bool read_content_with_length(Stream &strm, std::string &out, size_t len, return true; } -inline bool read_content_without_length(Stream &strm, std::string &out) +inline void skip_content_with_length(Stream &strm, uint64_t len) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + uint64_t r = 0; + while(r < len) + { + auto read_len = static_cast(len - r); + auto n = strm.read(buf, std::min(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if(n <= 0) + { + return; + } + r += n; + } +} + +inline bool read_content_without_length(Stream &strm, ContentReceiverCore out) +{ + char buf[CPPHTTPLIB_RECV_BUFSIZ]; for(;;) { - char byte; - auto n = strm.read(&byte, 1); + auto n = strm.read(buf, CPPHTTPLIB_RECV_BUFSIZ); if(n < 0) { return false; @@ -995,13 +1660,16 @@ inline bool read_content_without_length(Stream &strm, std::string &out) { return true; } - out += byte; + if(!out(buf, n)) + { + return false; + } } return true; } -inline bool read_content_chunked(Stream &strm, std::string &out) +inline bool read_content_chunked(Stream &strm, ContentReceiverCore out) { const auto bufsiz = 16; char buf[bufsiz]; @@ -1017,8 +1685,7 @@ inline bool read_content_chunked(Stream &strm, std::string &out) while(chunk_len > 0) { - std::string chunk; - if(!read_content_with_length(strm, chunk, chunk_len, nullptr)) + if(!read_content_with_length(strm, chunk_len, nullptr, out)) { return false; } @@ -1033,8 +1700,6 @@ inline bool read_content_chunked(Stream &strm, std::string &out) break; } - out += chunk; - if(!reader.getline()) { return false; @@ -1053,51 +1718,178 @@ inline bool read_content_chunked(Stream &strm, std::string &out) return true; } -template -bool read_content(Stream &strm, T &x, uint64_t payload_max_length, bool &exceed_payload_max_length, - Progress progress = Progress()) +inline bool is_chunked_transfer_encoding(const Headers &headers) { - if(has_header(x.headers, "Content-Length")) + return !strcasecmp(get_header_value(headers, "Transfer-Encoding", 0, ""), "chunked"); +} + +template +bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, Progress progress, + ContentReceiverCore receiver) +{ + + ContentReceiverCore out = [&](const char *buf, size_t n) { return receiver(buf, n); }; + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + detail::decompressor decompressor; + + if(!decompressor.is_valid()) { - auto len = get_header_value_uint64(x.headers, "Content-Length", 0); - if(len == 0) - { - const auto &encoding = get_header_value(x.headers, "Transfer-Encoding", 0, ""); - if(!strcasecmp(encoding, "chunked")) - { - return read_content_chunked(strm, x.body); - } - } + status = 500; + return false; + } - if((len > payload_max_length) || - // For 32-bit platform - (sizeof(size_t) < sizeof(uint64_t) && len > std::numeric_limits::max())) - { - exceed_payload_max_length = true; - return false; - } + if(x.get_header_value("Content-Encoding") == "gzip") + { + out = [&](const char *buf, size_t n) { + return decompressor.decompress(buf, n, [&](const char *buf, size_t n) { return receiver(buf, n); }); + }; + } +#else + if(x.get_header_value("Content-Encoding") == "gzip") + { + status = 415; + return false; + } +#endif - return read_content_with_length(strm, x.body, len, progress); + auto ret = true; + auto exceed_payload_max_length = false; + + if(is_chunked_transfer_encoding(x.headers)) + { + ret = read_content_chunked(strm, out); + } + else if(!has_header(x.headers, "Content-Length")) + { + ret = read_content_without_length(strm, out); } else { - const auto &encoding = get_header_value(x.headers, "Transfer-Encoding", 0, ""); - if(!strcasecmp(encoding, "chunked")) + auto len = get_header_value_uint64(x.headers, "Content-Length", 0); + if(len > payload_max_length) { - return read_content_chunked(strm, x.body); + exceed_payload_max_length = true; + skip_content_with_length(strm, len); + ret = false; + } + else if(len > 0) + { + ret = read_content_with_length(strm, len, progress, out); } - return read_content_without_length(strm, x.body); } - return true; + + if(!ret) + { + status = exceed_payload_max_length ? 413 : 400; + } + + return ret; } -template inline void write_headers(Stream &strm, const T &info) +template inline int write_headers(Stream &strm, const T &info, const Headers &headers) { + auto write_len = 0; for(const auto &x : info.headers) { - strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); + auto len = strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); + if(len < 0) + { + return len; + } + write_len += len; } - strm.write("\r\n"); + for(const auto &x : headers) + { + auto len = strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); + if(len < 0) + { + return len; + } + write_len += len; + } + auto len = strm.write("\r\n"); + if(len < 0) + { + return len; + } + write_len += len; + return write_len; +} + +inline ssize_t write_content(Stream &strm, ContentProvider content_provider, size_t offset, size_t length) +{ + size_t begin_offset = offset; + size_t end_offset = offset + length; + while(offset < end_offset) + { + ssize_t written_length = 0; + content_provider( + offset, end_offset - offset, + [&](const char *d, size_t l) { + offset += l; + written_length = strm.write(d, l); + }, + [&](void) { written_length = -1; }); + if(written_length < 0) + { + return written_length; + } + } + return static_cast(offset - begin_offset); +} + +inline ssize_t write_content_chunked(Stream &strm, ContentProvider content_provider) +{ + size_t offset = 0; + auto data_available = true; + ssize_t total_written_length = 0; + while(data_available) + { + ssize_t written_length = 0; + content_provider( + offset, 0, + [&](const char *d, size_t l) { + data_available = l > 0; + offset += l; + + // Emit chunked response header and footer for each chunk + auto chunk = from_i_to_hex(l) + "\r\n" + std::string(d, l) + "\r\n"; + written_length = strm.write(chunk); + }, + [&](void) { + data_available = false; + written_length = strm.write("0\r\n\r\n"); + }); + + if(written_length < 0) + { + return written_length; + } + total_written_length += written_length; + } + return total_written_length; +} + +template inline bool redirect(T &cli, const Request &req, Response &res, const std::string &path) +{ + Request new_req; + new_req.method = req.method; + new_req.path = path; + new_req.headers = req.headers; + new_req.body = req.body; + new_req.redirect_count = req.redirect_count - 1; + new_req.response_handler = req.response_handler; + new_req.content_receiver = req.content_receiver; + new_req.progress = req.progress; + + Response new_res; + auto ret = cli.send(new_req, new_res); + if(ret) + { + res = new_res; + } + return ret; } inline std::string encode_url(const std::string &s) @@ -1153,109 +1945,6 @@ inline std::string encode_url(const std::string &s) return result; } -inline bool is_hex(char c, int &v) -{ - if(0x20 <= c && isdigit(c)) - { - v = c - '0'; - return true; - } - else if('A' <= c && c <= 'F') - { - v = c - 'A' + 10; - return true; - } - else if('a' <= c && c <= 'f') - { - v = c - 'a' + 10; - return true; - } - return false; -} - -inline bool from_hex_to_i(const std::string &s, size_t i, size_t cnt, int &val) -{ - if(i >= s.size()) - { - return false; - } - - val = 0; - for(; cnt; i++, cnt--) - { - if(!s[i]) - { - return false; - } - int v = 0; - if(is_hex(s[i], v)) - { - val = val * 16 + v; - } - else - { - return false; - } - } - return true; -} - -inline std::string from_i_to_hex(uint64_t n) -{ - const char *charset = "0123456789abcdef"; - std::string ret; - do - { - ret = charset[n & 15] + ret; - n >>= 4; - } while(n > 0); - return ret; -} - -inline size_t to_utf8(int code, char *buff) -{ - if(code < 0x0080) - { - buff[0] = (code & 0x7F); - return 1; - } - else if(code < 0x0800) - { - buff[0] = (0xC0 | ((code >> 6) & 0x1F)); - buff[1] = (0x80 | (code & 0x3F)); - return 2; - } - else if(code < 0xD800) - { - buff[0] = (0xE0 | ((code >> 12) & 0xF)); - buff[1] = (0x80 | ((code >> 6) & 0x3F)); - buff[2] = (0x80 | (code & 0x3F)); - return 3; - } - else if(code < 0xE000) - { // D800 - DFFF is invalid... - return 0; - } - else if(code < 0x10000) - { - buff[0] = (0xE0 | ((code >> 12) & 0xF)); - buff[1] = (0x80 | ((code >> 6) & 0x3F)); - buff[2] = (0x80 | (code & 0x3F)); - return 3; - } - else if(code < 0x110000) - { - buff[0] = (0xF0 | ((code >> 18) & 0x7)); - buff[1] = (0x80 | ((code >> 12) & 0x3F)); - buff[2] = (0x80 | ((code >> 6) & 0x3F)); - buff[3] = (0x80 | (code & 0x3F)); - return 4; - } - - // NOTREACHED - return 0; -} - inline std::string decode_url(const std::string &s) { std::string result; @@ -1347,10 +2036,10 @@ inline bool parse_multipart_formdata(const std::string &boundary, const std::str static std::string dash = "--"; static std::string crlf = "\r\n"; - static boost::regex re_content_type("Content-Type: (.*?)", boost::regex_constants::icase); + static std::regex re_content_type("Content-Type: (.*?)", std::regex_constants::icase); - static boost::regex re_content_disposition( - "Content-Disposition: form-data; name=\"(.*?)\"(?:; filename=\"(.*?)\")?", boost::regex_constants::icase); + static std::regex re_content_disposition("Content-Disposition: form-data; name=\"(.*?)\"(?:; filename=\"(.*?)\")?", + std::regex_constants::icase); auto dash_boundary = dash + boundary; @@ -1385,12 +2074,12 @@ inline bool parse_multipart_formdata(const std::string &boundary, const std::str while(pos != next_pos) { - boost::smatch m; - if(boost::regex_match(header, m, re_content_type)) + std::smatch m; + if(std::regex_match(header, m, re_content_type)) { file.content_type = m[1]; } - else if(boost::regex_match(header, m, re_content_disposition)) + else if(std::regex_match(header, m, re_content_disposition)) { name = m[1]; file.filename = m[2]; @@ -1435,6 +2124,50 @@ inline bool parse_multipart_formdata(const std::string &boundary, const std::str return true; } +inline bool parse_range_header(const std::string &s, Ranges &ranges) +{ + try + { + static auto re = std::regex(R"(bytes=(\d*-\d*(?:,\s*\d*-\d*)*))"); + std::smatch m; + if(std::regex_match(s, m, re)) + { + auto pos = m.position(1); + auto len = m.length(1); + detail::split(&s[pos], &s[pos + len], ',', [&](const char *b, const char *e) { + static auto re = std::regex(R"(\s*(\d*)-(\d*))"); + std::cmatch m; + if(std::regex_match(b, e, m, re)) + { + ssize_t first = -1; + if(!m.str(1).empty()) + { + first = static_cast(std::stoll(m.str(1))); + } + + ssize_t last = -1; + if(!m.str(2).empty()) + { + last = static_cast(std::stoll(m.str(2))); + } + + if(first != -1 && last != -1 && first > last) + { + throw std::runtime_error("invalid range error"); + } + ranges.emplace_back(std::make_pair(first, last)); + } + }); + return true; + } + return false; + } + catch(...) + { + return false; + } +} + inline std::string to_lower(const char *beg, const char *end) { std::string out; @@ -1447,107 +2180,149 @@ inline std::string to_lower(const char *beg, const char *end) return out; } -inline void make_range_header_core(std::string &) +inline std::string make_multipart_data_boundary() { -} + static const char data[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; -template inline void make_range_header_core(std::string &field, uint64_t value) -{ - if(!field.empty()) + std::random_device seed_gen; + std::mt19937 engine(seed_gen()); + + std::string result = "--cpp-httplib-multipart-data-"; + + for(auto i = 0; i < 16; i++) { - field += ", "; - } - field += std::to_string(value) + "-"; -} - -template -inline void make_range_header_core(std::string &field, uint64_t value1, uint64_t value2, Args... args) -{ - if(!field.empty()) - { - field += ", "; - } - field += std::to_string(value1) + "-" + std::to_string(value2); - make_range_header_core(field, args...); -} - -#ifdef CPPHTTPLIB_ZLIB_SUPPORT -inline bool can_compress(const std::string &content_type) -{ - return !content_type.find("text/") || content_type == "image/svg+xml" || content_type == "application/javascript" || - content_type == "application/json" || content_type == "application/xml" || - content_type == "application/xhtml+xml"; -} - -inline void compress(std::string &content) -{ - z_stream strm; - strm.zalloc = Z_NULL; - strm.zfree = Z_NULL; - strm.opaque = Z_NULL; - - auto ret = deflateInit2(&strm, Z_DEFAULT_COMPRESSION, Z_DEFLATED, 31, 8, Z_DEFAULT_STRATEGY); - if(ret != Z_OK) - { - return; + result += data[engine() % (sizeof(data) - 1)]; } - strm.avail_in = content.size(); - strm.next_in = (Bytef *)content.data(); - - std::string compressed; - - const auto bufsiz = 16384; - char buff[bufsiz]; - do - { - strm.avail_out = bufsiz; - strm.next_out = (Bytef *)buff; - deflate(&strm, Z_FINISH); - compressed.append(buff, bufsiz - strm.avail_out); - } while(strm.avail_out == 0); - - content.swap(compressed); - - deflateEnd(&strm); + return result; } -inline void decompress(std::string &content) +inline std::pair get_range_offset_and_length(const Request &req, size_t content_length, size_t index) { - z_stream strm; - strm.zalloc = Z_NULL; - strm.zfree = Z_NULL; - strm.opaque = Z_NULL; + auto r = req.ranges[index]; - // 15 is the value of wbits, which should be at the maximum possible value to - // ensure that any gzip stream can be decoded. The offset of 16 specifies that - // the stream to decompress will be formatted with a gzip wrapper. - auto ret = inflateInit2(&strm, 16 + 15); - if(ret != Z_OK) + if(r.first == -1 && r.second == -1) { - return; + return std::make_pair(0, content_length); } - strm.avail_in = content.size(); - strm.next_in = (Bytef *)content.data(); - - std::string decompressed; - - const auto bufsiz = 16384; - char buff[bufsiz]; - do + if(r.first == -1) { - strm.avail_out = bufsiz; - strm.next_out = (Bytef *)buff; - inflate(&strm, Z_NO_FLUSH); - decompressed.append(buff, bufsiz - strm.avail_out); - } while(strm.avail_out == 0); + r.first = content_length - r.second; + r.second = content_length - 1; + } - content.swap(decompressed); + if(r.second == -1) + { + r.second = content_length - 1; + } - inflateEnd(&strm); + return std::make_pair(r.first, r.second - r.first + 1); +} + +inline std::string make_content_range_header_field(size_t offset, size_t length, size_t content_length) +{ + std::string field = "bytes "; + field += std::to_string(offset); + field += "-"; + field += std::to_string(offset + length - 1); + field += "/"; + field += std::to_string(content_length); + return field; +} + +template +bool process_multipart_ranges_data(const Request &req, Response &res, const std::string &boundary, + const std::string &content_type, SToken stoken, CToken ctoken, Content content) +{ + for(size_t i = 0; i < req.ranges.size(); i++) + { + ctoken("--"); + stoken(boundary); + ctoken("\r\n"); + if(!content_type.empty()) + { + ctoken("Content-Type: "); + stoken(content_type); + ctoken("\r\n"); + } + + auto offsets = detail::get_range_offset_and_length(req, res.body.size(), i); + auto offset = offsets.first; + auto length = offsets.second; + + ctoken("Content-Range: "); + stoken(make_content_range_header_field(offset, length, res.body.size())); + ctoken("\r\n"); + ctoken("\r\n"); + if(!content(offset, length)) + { + return false; + } + ctoken("\r\n"); + } + + ctoken("--"); + stoken(boundary); + ctoken("--\r\n"); + + return true; +} + +inline std::string make_multipart_ranges_data(const Request &req, Response &res, const std::string &boundary, + const std::string &content_type) +{ + std::string data; + + process_multipart_ranges_data( + req, res, boundary, content_type, [&](const std::string &token) { data += token; }, + [&](const char *token) { data += token; }, + [&](size_t offset, size_t length) { + data += res.body.substr(offset, length); + return true; + }); + + return data; +} + +inline size_t get_multipart_ranges_data_length(const Request &req, Response &res, const std::string &boundary, + const std::string &content_type) +{ + size_t data_length = 0; + + process_multipart_ranges_data( + req, res, boundary, content_type, [&](const std::string &token) { data_length += token.size(); }, + [&](const char *token) { data_length += strlen(token); }, + [&](size_t /*offset*/, size_t length) { + data_length += length; + return true; + }); + + return data_length; +} + +inline bool write_multipart_ranges_data(Stream &strm, const Request &req, Response &res, const std::string &boundary, + const std::string &content_type) +{ + return process_multipart_ranges_data( + req, res, boundary, content_type, [&](const std::string &token) { strm.write(token); }, + [&](const char *token) { strm.write(token); }, + [&](size_t offset, size_t length) { + return detail::write_content(strm, res.content_provider, offset, length) >= 0; + }); +} + +inline std::pair get_range_offset_and_length(const Request &req, const Response &res, size_t index) +{ + auto r = req.ranges[index]; + + if(r.second == -1) + { + r.second = res.content_provider_resource_length - 1; + } + + return std::make_pair(r.first, r.second - r.first + 1); } -#endif #ifdef _WIN32 class WSInit @@ -1571,15 +2346,37 @@ static WSInit wsinit_; } // namespace detail // Header utilities -template -inline std::pair make_range_header(uint64_t value, Args... args) +inline std::pair make_range_header(Ranges ranges) { - std::string field; - detail::make_range_header_core(field, value, args...); - field.insert(0, "bytes="); + std::string field = "bytes="; + auto i = 0; + for(auto r : ranges) + { + if(i != 0) + { + field += ", "; + } + if(r.first != -1) + { + field += std::to_string(r.first); + } + field += '-'; + if(r.second != -1) + { + field += std::to_string(r.second); + } + i++; + } return std::make_pair("Range", field); } +inline std::pair make_basic_authentication_header(const std::string &username, + const std::string &password) +{ + auto field = "Basic " + detail::base64_encode(username + ":" + password); + return std::make_pair("Authorization", field); +} + // Request implementation inline bool Request::has_header(const char *key) const { @@ -1602,6 +2399,11 @@ inline void Request::set_header(const char *key, const char *val) headers.emplace(key, val); } +inline void Request::set_header(const char *key, const std::string &val) +{ + headers.emplace(key, val); +} + inline bool Request::has_param(const char *key) const { return params.find(key) != params.end(); @@ -1661,6 +2463,11 @@ inline void Response::set_header(const char *key, const char *val) headers.emplace(key, val); } +inline void Response::set_header(const char *key, const std::string &val) +{ + headers.emplace(key, val); +} + inline void Response::set_redirect(const char *url) { set_header("Location", url); @@ -1679,8 +2486,28 @@ inline void Response::set_content(const std::string &s, const char *content_type set_header("Content-Type", content_type); } +inline void Response::set_content_provider(size_t length, + std::function provider, + std::function resource_releaser) +{ + assert(length > 0); + content_provider_resource_length = length; + content_provider = [provider](size_t offset, size_t length, DataSink sink, Done) { + provider(offset, length, sink); + }; + content_provider_resource_releaser = resource_releaser; +} + +inline void Response::set_chunked_content_provider( + std::function provider, std::function resource_releaser) +{ + content_provider_resource_length = 0; + content_provider = [provider](size_t offset, size_t, DataSink sink, Done done) { provider(offset, sink, done); }; + content_provider_resource_releaser = resource_releaser; +} + // Rstream implementation -template inline void Stream::write_format(const char *fmt, const Args &...args) +template inline int Stream::write_format(const char *fmt, const Args &...args) { const auto bufsiz = 2048; char buf[bufsiz]; @@ -1690,27 +2517,29 @@ template inline void Stream::write_format(const char *fmt, co #else auto n = snprintf(buf, bufsiz - 1, fmt, args...); #endif - if(n > 0) + if(n <= 0) { - if(n >= bufsiz - 1) - { - std::vector glowable_buf(bufsiz); + return n; + } - while(n >= static_cast(glowable_buf.size() - 1)) - { - glowable_buf.resize(glowable_buf.size() * 2); -#if defined(_MSC_VER) && _MSC_VER < 1900 - n = _snprintf_s(&glowable_buf[0], glowable_buf.size(), glowable_buf.size() - 1, fmt, args...); -#else - n = snprintf(&glowable_buf[0], glowable_buf.size() - 1, fmt, args...); -#endif - } - write(&glowable_buf[0], n); - } - else + if(n >= bufsiz - 1) + { + std::vector glowable_buf(bufsiz); + + while(n >= static_cast(glowable_buf.size() - 1)) { - write(buf, n); + glowable_buf.resize(glowable_buf.size() * 2); +#if defined(_MSC_VER) && _MSC_VER < 1900 + n = _snprintf_s(&glowable_buf[0], glowable_buf.size(), glowable_buf.size() - 1, fmt, args...); +#else + n = snprintf(&glowable_buf[0], glowable_buf.size() - 1, fmt, args...); +#endif } + return write(&glowable_buf[0], n); + } + else + { + return write(buf, n); } } @@ -1725,7 +2554,11 @@ inline SocketStream::~SocketStream() inline int SocketStream::read(char *ptr, size_t size) { - return recv(sock_, ptr, static_cast(size), 0); + if(detail::select_read(sock_, CPPHTTPLIB_READ_TIMEOUT_SECOND, CPPHTTPLIB_READ_TIMEOUT_USECOND) > 0) + { + return recv(sock_, ptr, static_cast(size), 0); + } + return -1; } inline int SocketStream::write(const char *ptr, size_t size) @@ -1738,6 +2571,11 @@ inline int SocketStream::write(const char *ptr) return write(ptr, strlen(ptr)); } +inline int SocketStream::write(const std::string &s) +{ + return write(s.data(), s.size()); +} + inline std::string SocketStream::get_remote_addr() const { return detail::get_remote_addr(sock_); @@ -1761,9 +2599,12 @@ inline int BufferStream::write(const char *ptr, size_t size) inline int BufferStream::write(const char *ptr) { - size_t size = strlen(ptr); - buffer.append(ptr, size); - return static_cast(size); + return write(ptr, strlen(ptr)); +} + +inline int BufferStream::write(const std::string &s) +{ + return write(s.data(), s.size()); } inline std::string BufferStream::get_remote_addr() const @@ -1779,11 +2620,18 @@ inline const std::string &BufferStream::get_buffer() const // HTTP server implementation inline Server::Server() : keep_alive_max_count_(CPPHTTPLIB_KEEPALIVE_MAX_COUNT), payload_max_length_(CPPHTTPLIB_PAYLOAD_MAX_LENGTH), - is_running_(false), svr_sock_(INVALID_SOCKET), running_threads_(0) + is_running_(false), svr_sock_(INVALID_SOCKET) { #ifndef _WIN32 signal(SIGPIPE, SIG_IGN); #endif + new_task_queue = [] { +#if CPPHTTPLIB_THREAD_POOL_COUNT > 0 + return new ThreadPool(CPPHTTPLIB_THREAD_POOL_COUNT); +#else + return new Threads(); +#endif + }; } inline Server::~Server() @@ -1792,37 +2640,37 @@ inline Server::~Server() inline Server &Server::Get(const char *pattern, Handler handler) { - get_handlers_.push_back(std::make_pair(boost::regex(pattern), handler)); + get_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); return *this; } inline Server &Server::Post(const char *pattern, Handler handler) { - post_handlers_.push_back(std::make_pair(boost::regex(pattern), handler)); + post_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); return *this; } inline Server &Server::Put(const char *pattern, Handler handler) { - put_handlers_.push_back(std::make_pair(boost::regex(pattern), handler)); + put_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); return *this; } inline Server &Server::Patch(const char *pattern, Handler handler) { - patch_handlers_.push_back(std::make_pair(boost::regex(pattern), handler)); + patch_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); return *this; } inline Server &Server::Delete(const char *pattern, Handler handler) { - delete_handlers_.push_back(std::make_pair(boost::regex(pattern), handler)); + delete_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); return *this; } inline Server &Server::Options(const char *pattern, Handler handler) { - options_handlers_.push_back(std::make_pair(boost::regex(pattern), handler)); + options_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); return *this; } @@ -1836,6 +2684,11 @@ inline bool Server::set_base_dir(const char *path) return false; } +inline void Server::set_file_request_handler(Handler handler) +{ + file_request_handler_ = handler; +} + inline void Server::set_error_handler(Handler handler) { error_handler_ = handler; @@ -1851,7 +2704,7 @@ inline void Server::set_keep_alive_max_count(size_t count) keep_alive_max_count_ = count; } -inline void Server::set_payload_max_length(uint64_t length) +inline void Server::set_payload_max_length(size_t length) { payload_max_length_ = length; } @@ -1883,8 +2736,7 @@ inline void Server::stop() if(is_running_) { assert(svr_sock_ != INVALID_SOCKET); - auto sock = svr_sock_; - svr_sock_ = INVALID_SOCKET; + std::atomic sock(svr_sock_.exchange(INVALID_SOCKET)); detail::shutdown_socket(sock); detail::close_socket(sock); } @@ -1892,11 +2744,11 @@ inline void Server::stop() inline bool Server::parse_request_line(const char *s, Request &req) { - static boost::regex re("(GET|HEAD|POST|PUT|PATCH|DELETE|OPTIONS) " - "(([^?]+)(?:\\?(.+?))?) (HTTP/1\\.[01])\r\n"); + static std::regex re("(GET|HEAD|POST|PUT|PATCH|DELETE|OPTIONS) " + "(([^?]+)(?:\\?(.+?))?) (HTTP/1\\.[01])\r\n"); - boost::cmatch m; - if(boost::regex_match(s, m, re)) + std::cmatch m; + if(std::regex_match(s, m, re)) { req.version = std::string(m[5]); req.method = std::string(m[1]); @@ -1916,7 +2768,7 @@ inline bool Server::parse_request_line(const char *s, Request &req) return false; } -inline void Server::write_response(Stream &strm, bool last_connection, const Request &req, Response &res) +inline bool Server::write_response(Stream &strm, bool last_connection, const Request &req, Response &res) { assert(res.status != -1); @@ -1926,7 +2778,10 @@ inline void Server::write_response(Stream &strm, bool last_connection, const Req } // Response line - strm.write_format("HTTP/1.1 %d %s\r\n", res.status, detail::status_message(res.status)); + if(!strm.write_format("HTTP/1.1 %d %s\r\n", res.status, detail::status_message(res.status))) + { + return false; + } // Headers if(last_connection || req.get_header_value("Connection") == "close") @@ -1939,13 +2794,61 @@ inline void Server::write_response(Stream &strm, bool last_connection, const Req res.set_header("Connection", "Keep-Alive"); } + if(!res.has_header("Content-Type")) + { + res.set_header("Content-Type", "text/plain"); + } + + if(!res.has_header("Accept-Ranges")) + { + res.set_header("Accept-Ranges", "bytes"); + } + + std::string content_type; + std::string boundary; + + if(req.ranges.size() > 1) + { + boundary = detail::make_multipart_data_boundary(); + + auto it = res.headers.find("Content-Type"); + if(it != res.headers.end()) + { + content_type = it->second; + res.headers.erase(it); + } + + res.headers.emplace("Content-Type", "multipart/byteranges; boundary=" + boundary); + } + if(res.body.empty()) { - if(!res.has_header("Content-Length")) + if(res.content_provider_resource_length > 0) { - if(res.streamcb) + size_t length = 0; + if(req.ranges.empty()) + { + length = res.content_provider_resource_length; + } + else if(req.ranges.size() == 1) + { + auto offsets = detail::get_range_offset_and_length(req, res.content_provider_resource_length, 0); + auto offset = offsets.first; + length = offsets.second; + auto content_range = + detail::make_content_range_header_field(offset, length, res.content_provider_resource_length); + res.set_header("Content-Range", content_range); + } + else + { + length = detail::get_multipart_ranges_data_length(req, res, boundary, content_type); + } + res.set_header("Content-Length", std::to_string(length)); + } + else + { + if(res.content_provider) { - // Streamed response res.set_header("Transfer-Encoding", "chunked"); } else @@ -1956,49 +2859,60 @@ inline void Server::write_response(Stream &strm, bool last_connection, const Req } else { + if(req.ranges.empty()) + { + ; + } + else if(req.ranges.size() == 1) + { + auto offsets = detail::get_range_offset_and_length(req, res.body.size(), 0); + auto offset = offsets.first; + auto length = offsets.second; + auto content_range = detail::make_content_range_header_field(offset, length, res.body.size()); + res.set_header("Content-Range", content_range); + res.body = res.body.substr(offset, length); + } + else + { + res.body = detail::make_multipart_ranges_data(req, res, boundary, content_type); + } + #ifdef CPPHTTPLIB_ZLIB_SUPPORT // TODO: 'Accpet-Encoding' has gzip, not gzip;q=0 const auto &encodings = req.get_header_value("Accept-Encoding"); if(encodings.find("gzip") != std::string::npos && detail::can_compress(res.get_header_value("Content-Type"))) { - detail::compress(res.body); - res.set_header("Content-Encoding", "gzip"); + if(detail::compress(res.body)) + { + res.set_header("Content-Encoding", "gzip"); + } } #endif - if(!res.has_header("Content-Type")) - { - res.set_header("Content-Type", "text/plain"); - } - auto length = std::to_string(res.body.size()); - res.set_header("Content-Length", length.c_str()); + res.set_header("Content-Length", length); } - detail::write_headers(strm, res); + if(!detail::write_headers(strm, res, Headers())) + { + return false; + } // Body if(req.method != "HEAD") { if(!res.body.empty()) { - strm.write(res.body.c_str(), res.body.size()); - } - else if(res.streamcb) - { - bool chunked_response = !res.has_header("Content-Length"); - uint64_t offset = 0; - bool data_available = true; - while(data_available) + if(!strm.write(res.body)) { - std::string chunk = res.streamcb(offset); - offset += chunk.size(); - data_available = !chunk.empty(); - // Emit chunked response header and footer for each chunk - if(chunked_response) - chunk = detail::from_i_to_hex(chunk.size()) + "\r\n" + chunk + "\r\n"; - if(strm.write(chunk.c_str(), chunk.size()) < 0) - break; // Stop on error + return false; + } + } + else if(res.content_provider) + { + if(!write_content_with_provider(strm, req, res, boundary, content_type)) + { + return false; } } } @@ -2008,6 +2922,48 @@ inline void Server::write_response(Stream &strm, bool last_connection, const Req { logger_(req, res); } + + return true; +} + +inline bool Server::write_content_with_provider(Stream &strm, const Request &req, Response &res, + const std::string &boundary, const std::string &content_type) +{ + if(res.content_provider_resource_length) + { + if(req.ranges.empty()) + { + if(detail::write_content(strm, res.content_provider, 0, res.content_provider_resource_length) < 0) + { + return false; + } + } + else if(req.ranges.size() == 1) + { + auto offsets = detail::get_range_offset_and_length(req, res.content_provider_resource_length, 0); + auto offset = offsets.first; + auto length = offsets.second; + if(detail::write_content(strm, res.content_provider, offset, length) < 0) + { + return false; + } + } + else + { + if(!detail::write_multipart_ranges_data(strm, req, res, boundary, content_type)) + { + return false; + } + } + } + else + { + if(detail::write_content_chunked(strm, res.content_provider) < 0) + { + return false; + } + } + return true; } inline bool Server::handle_file_request(Request &req, Response &res) @@ -2030,6 +2986,10 @@ inline bool Server::handle_file_request(Request &req, Response &res) res.set_header("Content-Type", type); } res.status = 200; + if(file_request_handler_) + { + file_request_handler_(req, res); + } return true; } } @@ -2098,68 +3058,56 @@ inline int Server::bind_internal(const char *host, int port, int socket_flags) inline bool Server::listen_internal() { auto ret = true; - is_running_ = true; - for(;;) { - auto val = detail::select_read(svr_sock_, 0, 100000); + std::unique_ptr task_queue(new_task_queue()); - if(val == 0) - { // Timeout + for(;;) + { if(svr_sock_ == INVALID_SOCKET) { // The server socket was closed by 'stop' method. break; } - continue; + + auto val = detail::select_read(svr_sock_, 0, 100000); + + if(val == 0) + { // Timeout + continue; + } + + socket_t sock = accept(svr_sock_, nullptr, nullptr); + + if(sock == INVALID_SOCKET) + { + if(errno == EMFILE) + { + // The per-process limit of open file descriptors has been reached. + // Try to accept new connections after a short sleep. + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + continue; + } + if(svr_sock_ != INVALID_SOCKET) + { + detail::close_socket(svr_sock_); + ret = false; + } + else + { + ; // The server socket was closed by user. + } + break; + } + + task_queue->enqueue([=]() { process_and_close_socket(sock); }); } - socket_t sock = accept(svr_sock_, NULL, NULL); - - if(sock == INVALID_SOCKET) - { - if(svr_sock_ != INVALID_SOCKET) - { - detail::close_socket(svr_sock_); - ret = false; - } - else - { - ; // The server socket was closed by user. - } - break; - } - - // TODO: Use thread pool... - std::thread([=]() { - { - std::lock_guard guard(running_threads_mutex_); - running_threads_++; - } - - read_and_close_socket(sock); - - { - std::lock_guard guard(running_threads_mutex_); - running_threads_--; - } - }).detach(); - } - - // TODO: Use thread pool... - for(;;) - { - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - std::lock_guard guard(running_threads_mutex_); - if(!running_threads_) - { - break; - } + task_queue->shutdown(); } is_running_ = false; - return ret; } @@ -2204,7 +3152,7 @@ inline bool Server::dispatch_request(Request &req, Response &res, Handlers &hand const auto &pattern = x.first; const auto &handler = x.second; - if(boost::regex_match(req.path, req.matches, pattern)) + if(std::regex_match(req.path, req.matches, pattern)) { handler(req, res); return true; @@ -2213,7 +3161,8 @@ inline bool Server::dispatch_request(Request &req, Response &res, Handlers &hand return false; } -inline bool Server::process_request(Stream &strm, bool last_connection, bool &connection_close) +inline bool Server::process_request(Stream &strm, bool last_connection, bool &connection_close, + std::function setup_request) { const auto bufsiz = 2048; char buf[bufsiz]; @@ -2234,17 +3183,17 @@ inline bool Server::process_request(Stream &strm, bool last_connection, bool &co // Check if the request URI doesn't exceed the limit if(reader.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { + Headers dummy; + detail::read_headers(strm, dummy); res.status = 414; - write_response(strm, last_connection, req, res); - return true; + return write_response(strm, last_connection, req, res); } // Request line and headers if(!parse_request_line(reader.ptr(), req) || !detail::read_headers(strm, req.headers)) { res.status = 400; - write_response(strm, last_connection, req, res); - return true; + return write_response(strm, last_connection, req, res); } if(req.get_header_value("Connection") == "close") @@ -2252,32 +3201,31 @@ inline bool Server::process_request(Stream &strm, bool last_connection, bool &co connection_close = true; } - req.set_header("REMOTE_ADDR", strm.get_remote_addr().c_str()); + if(req.version == "HTTP/1.0" && req.get_header_value("Connection") != "Keep-Alive") + { + connection_close = true; + } + + req.set_header("REMOTE_ADDR", strm.get_remote_addr()); // Body if(req.method == "POST" || req.method == "PUT" || req.method == "PATCH") { - bool exceed_payload_max_length = false; - if(!detail::read_content(strm, req, payload_max_length_, exceed_payload_max_length)) + if(!detail::read_content(strm, req, payload_max_length_, res.status, Progress(), + [&](const char *buf, size_t n) { + if(req.body.size() + n > req.body.max_size()) + { + return false; + } + req.body.append(buf, n); + return true; + })) { - res.status = exceed_payload_max_length ? 413 : 400; - write_response(strm, last_connection, req, res); - return !exceed_payload_max_length; + return write_response(strm, last_connection, req, res); } const auto &content_type = req.get_header_value("Content-Type"); - if(req.get_header_value("Content-Encoding") == "gzip") - { -#ifdef CPPHTTPLIB_ZLIB_SUPPORT - detail::decompress(req.body); -#else - res.status = 415; - write_response(strm, last_connection, req, res); - return true; -#endif - } - if(!content_type.find("application/x-www-form-urlencoded")) { detail::parse_query_text(req.body, req.params); @@ -2289,17 +3237,30 @@ inline bool Server::process_request(Stream &strm, bool last_connection, bool &co !detail::parse_multipart_formdata(boundary, req.body, req.files)) { res.status = 400; - write_response(strm, last_connection, req, res); - return true; + return write_response(strm, last_connection, req, res); } } } + if(req.has_header("Range")) + { + const auto &range_header_value = req.get_header_value("Range"); + if(!detail::parse_range_header(range_header_value, req.ranges)) + { + // TODO: error + } + } + + if(setup_request) + { + setup_request(req); + } + if(routing(req, res)) { if(res.status == -1) { - res.status = 200; + res.status = req.ranges.empty() ? 200 : 206; } } else @@ -2307,8 +3268,7 @@ inline bool Server::process_request(Stream &strm, bool last_connection, bool &co res.status = 404; } - write_response(strm, last_connection, req, res); - return true; + return write_response(strm, last_connection, req, res); } inline bool Server::is_valid() const @@ -2316,17 +3276,18 @@ inline bool Server::is_valid() const return true; } -inline bool Server::read_and_close_socket(socket_t sock) +inline bool Server::process_and_close_socket(socket_t sock) { - return detail::read_and_close_socket(sock, keep_alive_max_count_, - [this](Stream &strm, bool last_connection, bool &connection_close) { - return process_request(strm, last_connection, connection_close); - }); + return detail::process_and_close_socket( + false, sock, keep_alive_max_count_, [this](Stream &strm, bool last_connection, bool &connection_close) { + return process_request(strm, last_connection, connection_close, nullptr); + }); } // HTTP client implementation inline Client::Client(const char *host, int port, time_t timeout_sec) - : host_(host), port_(port), timeout_sec_(timeout_sec), host_and_port_(host_ + ":" + std::to_string(port_)) + : host_(host), port_(port), timeout_sec_(timeout_sec), host_and_port_(host_ + ":" + std::to_string(port_)), + keep_alive_max_count_(CPPHTTPLIB_KEEPALIVE_MAX_COUNT), follow_location_(false) { } @@ -2371,10 +3332,10 @@ inline bool Client::read_response_line(Stream &strm, Response &res) return false; } - const static boost::regex re("(HTTP/1\\.[01]) (\\d+?) .*\r\n"); + const static std::regex re("(HTTP/1\\.[01]) (\\d+?) .*\r\n"); - boost::cmatch m; - if(boost::regex_match(reader.ptr(), m, re)) + std::cmatch m; + if(std::regex_match(reader.ptr(), m, re)) { res.version = std::string(m[1]); res.status = std::stoi(std::string(m[2])); @@ -2383,7 +3344,7 @@ inline bool Client::read_response_line(Stream &strm, Response &res) return true; } -inline bool Client::send(Request &req, Response &res) +inline bool Client::send(const Request &req, Response &res) { if(req.path.empty()) { @@ -2396,10 +3357,121 @@ inline bool Client::send(Request &req, Response &res) return false; } - return read_and_close_socket(sock, req, res); + auto ret = process_and_close_socket(sock, 1, [&](Stream &strm, bool last_connection, bool &connection_close) { + return process_request(strm, req, res, last_connection, connection_close); + }); + + if(ret && follow_location_ && (300 < res.status && res.status < 400)) + { + ret = redirect(req, res); + } + + return ret; } -inline void Client::write_request(Stream &strm, Request &req) +inline bool Client::send(const std::vector &requests, std::vector &responses) +{ + size_t i = 0; + while(i < requests.size()) + { + auto sock = create_client_socket(); + if(sock == INVALID_SOCKET) + { + return false; + } + + if(!process_and_close_socket(sock, requests.size() - i, + [&](Stream &strm, bool last_connection, bool &connection_close) -> bool { + auto &req = requests[i]; + auto res = Response(); + i++; + + if(req.path.empty()) + { + return false; + } + auto ret = process_request(strm, req, res, last_connection, connection_close); + + if(ret && follow_location_ && (300 < res.status && res.status < 400)) + { + ret = redirect(req, res); + } + + if(ret) + { + responses.emplace_back(std::move(res)); + } + + return ret; + })) + { + return false; + } + } + + return true; +} + +inline bool Client::redirect(const Request &req, Response &res) +{ + if(req.redirect_count == 0) + { + return false; + } + + auto location = res.get_header_value("location"); + if(location.empty()) + { + return false; + } + + std::regex re(R"(^(?:([^:/?#]+):)?(?://([^/?#]*))?([^?#]*(?:\?[^#]*)?)(?:#.*)?)"); + + auto scheme = is_ssl() ? "https" : "http"; + + std::smatch m; + if(regex_match(location, m, re)) + { + auto next_scheme = m[1].str(); + auto next_host = m[2].str(); + auto next_path = m[3].str(); + if(next_host.empty()) + { + next_host = host_; + } + if(next_path.empty()) + { + next_path = "/"; + } + + if(next_scheme == scheme && next_host == host_) + { + return detail::redirect(*this, req, res, next_path); + } + else + { + if(next_scheme == "https") + { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + SSLClient cli(next_host.c_str()); + cli.follow_location(true); + return detail::redirect(cli, req, res, next_path); +#else + return false; +#endif + } + else + { + Client cli(next_host.c_str()); + cli.follow_location(true); + return detail::redirect(cli, req, res, next_path); + } + } + } + return false; +} + +inline void Client::write_request(Stream &strm, const Request &req, bool last_connection) { BufferStream bstrm; @@ -2408,75 +3480,76 @@ inline void Client::write_request(Stream &strm, Request &req) bstrm.write_format("%s %s HTTP/1.1\r\n", req.method.c_str(), path.c_str()); - // Headers + // Additonal headers + Headers headers; + if(last_connection) + { + headers.emplace("Connection", "close"); + } + if(!req.has_header("Host")) { if(is_ssl()) { if(port_ == 443) { - req.set_header("Host", host_.c_str()); + headers.emplace("Host", host_); } else { - req.set_header("Host", host_and_port_.c_str()); + headers.emplace("Host", host_and_port_); } } else { if(port_ == 80) { - req.set_header("Host", host_.c_str()); + headers.emplace("Host", host_); } else { - req.set_header("Host", host_and_port_.c_str()); + headers.emplace("Host", host_and_port_); } } } if(!req.has_header("Accept")) { - req.set_header("Accept", "*/*"); + headers.emplace("Accept", "*/*"); } if(!req.has_header("User-Agent")) { - req.set_header("User-Agent", "cpp-httplib/0.2"); + headers.emplace("User-Agent", "cpp-httplib/0.2"); } - // TODO: Support KeepAlive connection - // if (!req.has_header("Connection")) { - req.set_header("Connection", "close"); - // } - if(req.body.empty()) { if(req.method == "POST" || req.method == "PUT" || req.method == "PATCH") { - req.set_header("Content-Length", "0"); + headers.emplace("Content-Length", "0"); } } else { if(!req.has_header("Content-Type")) { - req.set_header("Content-Type", "text/plain"); + headers.emplace("Content-Type", "text/plain"); } if(!req.has_header("Content-Length")) { auto length = std::to_string(req.body.size()); - req.set_header("Content-Length", length.c_str()); + headers.emplace("Content-Length", length); } } - detail::write_headers(bstrm, req); + detail::write_headers(bstrm, req, headers); // Body if(!req.body.empty()) { - bstrm.write(req.body.c_str(), req.body.size()); + bstrm.write(req.body); } // Flush buffer @@ -2484,10 +3557,11 @@ inline void Client::write_request(Stream &strm, Request &req) strm.write(data.data(), data.size()); } -inline bool Client::process_request(Stream &strm, Request &req, Response &res, bool &connection_close) +inline bool Client::process_request(Stream &strm, const Request &req, Response &res, bool last_connection, + bool &connection_close) { // Send request - write_request(strm, req); + write_request(strm, req, last_connection); // Receive response and headers if(!read_response_line(strm, res) || !detail::read_headers(strm, res.headers)) @@ -2500,34 +3574,54 @@ inline bool Client::process_request(Stream &strm, Request &req, Response &res, b connection_close = true; } - // Body - if(req.method != "HEAD") + if(req.response_handler) { - bool exceed_payload_max_length = false; - if(!detail::read_content(strm, res, std::numeric_limits::max(), exceed_payload_max_length, - req.progress)) + if(!req.response_handler(res)) { return false; } + } - if(res.get_header_value("Content-Encoding") == "gzip") + // Body + if(req.method != "HEAD") + { + detail::ContentReceiverCore out = [&](const char *buf, size_t n) { + if(res.body.size() + n > res.body.max_size()) + { + return false; + } + res.body.append(buf, n); + return true; + }; + + if(req.content_receiver) + { + auto offset = std::make_shared(); + auto length = get_header_value_uint64(res.headers, "Content-Length", 0); + auto receiver = req.content_receiver; + out = [offset, length, receiver](const char *buf, size_t n) { + auto ret = receiver(buf, n, *offset, length); + (*offset) += n; + return ret; + }; + } + + int dummy_status; + if(!detail::read_content(strm, res, std::numeric_limits::max(), dummy_status, req.progress, out)) { -#ifdef CPPHTTPLIB_ZLIB_SUPPORT - detail::decompress(res.body); -#else return false; -#endif } } return true; } -inline bool Client::read_and_close_socket(socket_t sock, Request &req, Response &res) +inline bool Client::process_and_close_socket( + socket_t sock, size_t request_count, + std::function callback) { - return detail::read_and_close_socket(sock, 0, [&](Stream &strm, bool /*last_connection*/, bool &connection_close) { - return process_request(strm, req, res, connection_close); - }); + request_count = std::min(request_count, keep_alive_max_count_); + return detail::process_and_close_socket(true, sock, request_count, callback); } inline bool Client::is_ssl() const @@ -2535,11 +3629,23 @@ inline bool Client::is_ssl() const return false; } +inline std::shared_ptr Client::Get(const char *path) +{ + Progress dummy; + return Get(path, Headers(), dummy); +} + inline std::shared_ptr Client::Get(const char *path, Progress progress) { return Get(path, Headers(), progress); } +inline std::shared_ptr Client::Get(const char *path, const Headers &headers) +{ + Progress dummy; + return Get(path, headers, dummy); +} + inline std::shared_ptr Client::Get(const char *path, const Headers &headers, Progress progress) { Request req; @@ -2549,7 +3655,51 @@ inline std::shared_ptr Client::Get(const char *path, const Headers &he req.progress = progress; auto res = std::make_shared(); + return send(req, *res) ? res : nullptr; +} +inline std::shared_ptr Client::Get(const char *path, ContentReceiver content_receiver) +{ + Progress dummy; + return Get(path, Headers(), nullptr, content_receiver, dummy); +} + +inline std::shared_ptr Client::Get(const char *path, ContentReceiver content_receiver, Progress progress) +{ + return Get(path, Headers(), nullptr, content_receiver, progress); +} + +inline std::shared_ptr Client::Get(const char *path, const Headers &headers, ContentReceiver content_receiver) +{ + Progress dummy; + return Get(path, headers, nullptr, content_receiver, dummy); +} + +inline std::shared_ptr Client::Get(const char *path, const Headers &headers, ContentReceiver content_receiver, + Progress progress) +{ + return Get(path, headers, nullptr, content_receiver, progress); +} + +inline std::shared_ptr Client::Get(const char *path, const Headers &headers, ResponseHandler response_handler, + ContentReceiver content_receiver) +{ + Progress dummy; + return Get(path, headers, response_handler, content_receiver, dummy); +} + +inline std::shared_ptr Client::Get(const char *path, const Headers &headers, ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress) +{ + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + req.response_handler = response_handler; + req.content_receiver = content_receiver; + req.progress = progress; + + auto res = std::make_shared(); return send(req, *res) ? res : nullptr; } @@ -2607,12 +3757,53 @@ inline std::shared_ptr Client::Post(const char *path, const Headers &h } query += it->first; query += "="; - query += it->second; + query += detail::encode_url(it->second); } return Post(path, headers, query, "application/x-www-form-urlencoded"); } +inline std::shared_ptr Client::Post(const char *path, const MultipartFormDataItems &items) +{ + return Post(path, Headers(), items); +} + +inline std::shared_ptr Client::Post(const char *path, const Headers &headers, + const MultipartFormDataItems &items) +{ + Request req; + req.method = "POST"; + req.headers = headers; + req.path = path; + + auto boundary = detail::make_multipart_data_boundary(); + + req.headers.emplace("Content-Type", "multipart/form-data; boundary=" + boundary); + + for(const auto &item : items) + { + req.body += "--" + boundary + "\r\n"; + req.body += "Content-Disposition: form-data; name=\"" + item.name + "\""; + if(!item.filename.empty()) + { + req.body += "; filename=\"" + item.filename + "\""; + } + req.body += "\r\n"; + if(!item.content_type.empty()) + { + req.body += "Content-Type: " + item.content_type + "\r\n"; + } + req.body += "\r\n"; + req.body += item.content + "\r\n"; + } + + req.body += "--" + boundary + "--\r\n"; + + auto res = std::make_shared(); + + return send(req, *res) ? res : nullptr; +} + inline std::shared_ptr Client::Put(const char *path, const std::string &body, const char *content_type) { return Put(path, Headers(), body, content_type); @@ -2657,15 +3848,32 @@ inline std::shared_ptr Client::Patch(const char *path, const Headers & inline std::shared_ptr Client::Delete(const char *path) { - return Delete(path, Headers()); + return Delete(path, Headers(), std::string(), nullptr); +} + +inline std::shared_ptr Client::Delete(const char *path, const std::string &body, const char *content_type) +{ + return Delete(path, Headers(), body, content_type); } inline std::shared_ptr Client::Delete(const char *path, const Headers &headers) +{ + return Delete(path, headers, std::string(), nullptr); +} + +inline std::shared_ptr Client::Delete(const char *path, const Headers &headers, const std::string &body, + const char *content_type) { Request req; req.method = "DELETE"; - req.path = path; req.headers = headers; + req.path = path; + + if(content_type) + { + req.headers.emplace("Content-Type", content_type); + } + req.body = body; auto res = std::make_shared(); @@ -2689,6 +3897,16 @@ inline std::shared_ptr Client::Options(const char *path, const Headers return send(req, *res) ? res : nullptr; } +inline void Client::set_keep_alive_max_count(size_t count) +{ + keep_alive_max_count_ = count; +} + +inline void Client::follow_location(bool on) +{ + follow_location_ = on; +} + /* * SSL Implementation */ @@ -2697,59 +3915,71 @@ namespace detail { template -inline bool read_and_close_socket_ssl(socket_t sock, size_t keep_alive_max_count, - // TODO: OpenSSL 1.0.2 occasionally crashes... - // The upcoming 1.1.0 is going to be thread safe. - SSL_CTX *ctx, std::mutex &ctx_mutex, U SSL_connect_or_accept, V setup, T callback) +inline bool process_and_close_socket_ssl(bool is_client_request, socket_t sock, size_t keep_alive_max_count, + SSL_CTX *ctx, std::mutex &ctx_mutex, U SSL_connect_or_accept, V setup, + T callback) { + assert(keep_alive_max_count > 0); + SSL *ssl = nullptr; { std::lock_guard guard(ctx_mutex); - ssl = SSL_new(ctx); - if(!ssl) - { - return false; - } } - auto bio = BIO_new_socket(sock, BIO_NOCLOSE); + if(!ssl) + { + close_socket(sock); + return false; + } + + auto bio = BIO_new_socket(static_cast(sock), BIO_NOCLOSE); SSL_set_bio(ssl, bio, bio); - setup(ssl); + if(!setup(ssl)) + { + SSL_shutdown(ssl); + { + std::lock_guard guard(ctx_mutex); + SSL_free(ssl); + } - SSL_connect_or_accept(ssl); + close_socket(sock); + return false; + } bool ret = false; - if(keep_alive_max_count > 0) + if(SSL_connect_or_accept(ssl) == 1) { - auto count = keep_alive_max_count; - while(count > 0 && - detail::select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0) + if(keep_alive_max_count > 1) + { + auto count = keep_alive_max_count; + while(count > 0 && (is_client_request || detail::select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, + CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) + { + SSLSocketStream strm(sock, ssl); + auto last_connection = count == 1; + auto connection_close = false; + + ret = callback(ssl, strm, last_connection, connection_close); + if(!ret || connection_close) + { + break; + } + + count--; + } + } + else { SSLSocketStream strm(sock, ssl); - auto last_connection = count == 1; - auto connection_close = false; - - ret = callback(strm, last_connection, connection_close); - if(!ret || connection_close) - { - break; - } - - count--; + auto dummy_connection_close = false; + ret = callback(ssl, strm, true, dummy_connection_close); } } - else - { - SSLSocketStream strm(sock, ssl); - auto dummy_connection_close = false; - ret = callback(strm, true, dummy_connection_close); - } SSL_shutdown(ssl); - { std::lock_guard guard(ctx_mutex); SSL_free(ssl); @@ -2760,6 +3990,40 @@ inline bool read_and_close_socket_ssl(socket_t sock, size_t keep_alive_max_count return ret; } +#if OPENSSL_VERSION_NUMBER < 0x10100000L +static std::shared_ptr> openSSL_locks_; + +class SSLThreadLocks +{ + public: + SSLThreadLocks() + { + openSSL_locks_ = std::make_shared>(CRYPTO_num_locks()); + CRYPTO_set_locking_callback(locking_callback); + } + + ~SSLThreadLocks() + { + CRYPTO_set_locking_callback(nullptr); + } + + private: + static void locking_callback(int mode, int type, const char * /*file*/, int /*line*/) + { + auto &locks = *openSSL_locks_; + if(mode & CRYPTO_LOCK) + { + locks[type].lock(); + } + else + { + locks[type].unlock(); + } + } +}; + +#endif + class SSLInit { public: @@ -2773,6 +4037,11 @@ class SSLInit { ERR_free_strings(); } + + private: +#if OPENSSL_VERSION_NUMBER < 0x10100000L + SSLThreadLocks thread_init_; +#endif }; static SSLInit sslinit_; @@ -2790,12 +4059,17 @@ inline SSLSocketStream::~SSLSocketStream() inline int SSLSocketStream::read(char *ptr, size_t size) { - return SSL_read(ssl_, ptr, size); + if(SSL_pending(ssl_) > 0 || + detail::select_read(sock_, CPPHTTPLIB_READ_TIMEOUT_SECOND, CPPHTTPLIB_READ_TIMEOUT_USECOND) > 0) + { + return SSL_read(ssl_, ptr, static_cast(size)); + } + return -1; } inline int SSLSocketStream::write(const char *ptr, size_t size) { - return SSL_write(ssl_, ptr, size); + return SSL_write(ssl_, ptr, static_cast(size)); } inline int SSLSocketStream::write(const char *ptr) @@ -2803,13 +4077,19 @@ inline int SSLSocketStream::write(const char *ptr) return write(ptr, strlen(ptr)); } +inline int SSLSocketStream::write(const std::string &s) +{ + return write(s.data(), s.size()); +} + inline std::string SSLSocketStream::get_remote_addr() const { return detail::get_remote_addr(sock_); } // SSL HTTP server implementation -inline SSLServer::SSLServer(const char *cert_path, const char *private_key_path) +inline SSLServer::SSLServer(const char *cert_path, const char *private_key_path, const char *client_ca_cert_file_path, + const char *client_ca_cert_dir_path) { ctx_ = SSL_CTX_new(SSLv23_server_method()); @@ -2828,6 +4108,19 @@ inline SSLServer::SSLServer(const char *cert_path, const char *private_key_path) SSL_CTX_free(ctx_); ctx_ = nullptr; } + else if(client_ca_cert_file_path || client_ca_cert_dir_path) + { + // if (client_ca_cert_file_path) { + // auto list = SSL_load_client_CA_file(client_ca_cert_file_path); + // SSL_CTX_set_client_CA_list(ctx_, list); + // } + + SSL_CTX_load_verify_locations(ctx_, client_ca_cert_file_path, client_ca_cert_dir_path); + + SSL_CTX_set_verify(ctx_, + SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, // SSL_VERIFY_CLIENT_ONCE, + nullptr); + } } } @@ -2844,19 +4137,33 @@ inline bool SSLServer::is_valid() const return ctx_; } -inline bool SSLServer::read_and_close_socket(socket_t sock) +inline bool SSLServer::process_and_close_socket(socket_t sock) { - return detail::read_and_close_socket_ssl( - sock, keep_alive_max_count_, ctx_, ctx_mutex_, SSL_accept, [](SSL * /*ssl*/) {}, - [this](Stream &strm, bool last_connection, bool &connection_close) { - return process_request(strm, last_connection, connection_close); + return detail::process_and_close_socket_ssl( + false, sock, keep_alive_max_count_, ctx_, ctx_mutex_, SSL_accept, [](SSL * /*ssl*/) { return true; }, + [this](SSL *ssl, Stream &strm, bool last_connection, bool &connection_close) { + return process_request(strm, last_connection, connection_close, [&](Request &req) { req.ssl = ssl; }); }); } // SSL HTTP client implementation -inline SSLClient::SSLClient(const char *host, int port, time_t timeout_sec) : Client(host, port, timeout_sec) +inline SSLClient::SSLClient(const char *host, int port, time_t timeout_sec, const char *client_cert_path, + const char *client_key_path) + : Client(host, port, timeout_sec) { ctx_ = SSL_CTX_new(SSLv23_client_method()); + + detail::split(&host_[0], &host_[host_.size()], '.', + [&](const char *b, const char *e) { host_components_.emplace_back(std::string(b, e)); }); + if(client_cert_path && client_key_path) + { + if(SSL_CTX_use_certificate_file(ctx_, client_cert_path, SSL_FILETYPE_PEM) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, client_key_path, SSL_FILETYPE_PEM) != 1) + { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + } } inline SSLClient::~SSLClient() @@ -2872,24 +4179,253 @@ inline bool SSLClient::is_valid() const return ctx_; } -inline bool SSLClient::read_and_close_socket(socket_t sock, Request &req, Response &res) +inline void SSLClient::set_ca_cert_path(const char *ca_cert_file_path, const char *ca_cert_dir_path) { - return is_valid() && - detail::read_and_close_socket_ssl( - sock, 0, ctx_, ctx_mutex_, SSL_connect, [&](SSL *ssl) { SSL_set_tlsext_host_name(ssl, host_.c_str()); }, - [&](Stream &strm, bool /*last_connection*/, bool &connection_close) { - return process_request(strm, req, res, connection_close); - }); + if(ca_cert_file_path) + { + ca_cert_file_path_ = ca_cert_file_path; + } + if(ca_cert_dir_path) + { + ca_cert_dir_path_ = ca_cert_dir_path; + } +} + +inline void SSLClient::enable_server_certificate_verification(bool enabled) +{ + server_certificate_verification_ = enabled; +} + +inline long SSLClient::get_openssl_verify_result() const +{ + return verify_result_; +} + +inline SSL_CTX *SSLClient::ssl_context() const noexcept +{ + return ctx_; +} + +inline bool SSLClient::process_and_close_socket( + socket_t sock, size_t request_count, + std::function callback) +{ + + request_count = std::min(request_count, keep_alive_max_count_); + + return is_valid() && detail::process_and_close_socket_ssl( + true, sock, request_count, ctx_, ctx_mutex_, + [&](SSL *ssl) { + if(ca_cert_file_path_.empty()) + { + SSL_CTX_set_verify(ctx_, SSL_VERIFY_NONE, nullptr); + } + else + { + if(!SSL_CTX_load_verify_locations(ctx_, ca_cert_file_path_.c_str(), nullptr)) + { + return false; + } + SSL_CTX_set_verify(ctx_, SSL_VERIFY_PEER, nullptr); + } + + if(SSL_connect(ssl) != 1) + { + return false; + } + + if(server_certificate_verification_) + { + verify_result_ = SSL_get_verify_result(ssl); + + if(verify_result_ != X509_V_OK) + { + return false; + } + + auto server_cert = SSL_get_peer_certificate(ssl); + + if(server_cert == nullptr) + { + return false; + } + + if(!verify_host(server_cert)) + { + X509_free(server_cert); + return false; + } + X509_free(server_cert); + } + + return true; + }, + [&](SSL *ssl) { + SSL_set_tlsext_host_name(ssl, host_.c_str()); + return true; + }, + [&](SSL * /*ssl*/, Stream &strm, bool last_connection, bool &connection_close) { + return callback(strm, last_connection, connection_close); + }); } inline bool SSLClient::is_ssl() const { return true; } + +inline bool SSLClient::verify_host(X509 *server_cert) const +{ + /* Quote from RFC2818 section 3.1 "Server Identity" + + If a subjectAltName extension of type dNSName is present, that MUST + be used as the identity. Otherwise, the (most specific) Common Name + field in the Subject field of the certificate MUST be used. Although + the use of the Common Name is existing practice, it is deprecated and + Certification Authorities are encouraged to use the dNSName instead. + + Matching is performed using the matching rules specified by + [RFC2459]. If more than one identity of a given type is present in + the certificate (e.g., more than one dNSName name, a match in any one + of the set is considered acceptable.) Names may contain the wildcard + character * which is considered to match any single domain name + component or component fragment. E.g., *.a.com matches foo.a.com but + not bar.foo.a.com. f*.com matches foo.com but not bar.com. + + In some cases, the URI is specified as an IP address rather than a + hostname. In this case, the iPAddress subjectAltName must be present + in the certificate and must exactly match the IP in the URI. + + */ + return verify_host_with_subject_alt_name(server_cert) || verify_host_with_common_name(server_cert); +} + +inline bool SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const +{ + auto ret = false; + + auto type = GEN_DNS; + + struct in6_addr addr6; + struct in_addr addr; + size_t addr_len = 0; + +#ifndef __MINGW32__ + if(inet_pton(AF_INET6, host_.c_str(), &addr6)) + { + type = GEN_IPADD; + addr_len = sizeof(struct in6_addr); + } + else if(inet_pton(AF_INET, host_.c_str(), &addr)) + { + type = GEN_IPADD; + addr_len = sizeof(struct in_addr); + } +#endif + + auto alt_names = static_cast( + X509_get_ext_d2i(server_cert, NID_subject_alt_name, nullptr, nullptr)); + + if(alt_names) + { + auto dsn_matched = false; + auto ip_mached = false; + + auto count = sk_GENERAL_NAME_num(alt_names); + + for(auto i = 0; i < count && !dsn_matched; i++) + { + auto val = sk_GENERAL_NAME_value(alt_names, i); + if(val->type == type) + { + auto name = (const char *)ASN1_STRING_get0_data(val->d.ia5); + auto name_len = (size_t)ASN1_STRING_length(val->d.ia5); + + if(strlen(name) == name_len) + { + switch(type) + { + case GEN_DNS: + dsn_matched = check_host_name(name, name_len); + break; + + case GEN_IPADD: + if(!memcmp(&addr6, name, addr_len) || !memcmp(&addr, name, addr_len)) + { + ip_mached = true; + } + break; + } + } + } + } + + if(dsn_matched || ip_mached) + { + ret = true; + } + } + + GENERAL_NAMES_free((STACK_OF(GENERAL_NAME) *)alt_names); + + return ret; +} + +inline bool SSLClient::verify_host_with_common_name(X509 *server_cert) const +{ + const auto subject_name = X509_get_subject_name(server_cert); + + if(subject_name != nullptr) + { + char name[BUFSIZ]; + auto name_len = X509_NAME_get_text_by_NID(subject_name, NID_commonName, name, sizeof(name)); + + if(name_len != -1) + { + return check_host_name(name, name_len); + } + } + + return false; +} + +inline bool SSLClient::check_host_name(const char *pattern, size_t pattern_len) const +{ + if(host_.size() == pattern_len && host_ == pattern) + { + return true; + } + + // Wildcard match + // https://bugs.launchpad.net/ubuntu/+source/firefox-3.0/+bug/376484 + std::vector pattern_components; + detail::split(&pattern[0], &pattern[pattern_len], '.', + [&](const char *b, const char *e) { pattern_components.emplace_back(std::string(b, e)); }); + + if(host_components_.size() != pattern_components.size()) + { + return false; + } + + auto itr = pattern_components.begin(); + for(const auto &h : host_components_) + { + auto &p = *itr; + if(p != h && p != "*") + { + auto partial_match = (p.size() > 0 && p[p.size() - 1] == '*' && !p.compare(0, p.size() - 1, h)); + if(!partial_match) + { + return false; + } + } + ++itr; + } + + return true; +} #endif } // namespace httplib #endif // CPPHTTPLIB_HTTPLIB_H - -// vim: et ts=4 sw=4 cin cino={1s ff=unix