#include <cstdio>
#include <cstdlib>
#include <cstring>

#include <iostream>

#include <signal.h>
#include <unistd.h>
#include <liburing.h>

typedef enum {
	EVENT_END,
	EVENT_ACCEPT,
	EVENT_REQUEST,
	EVENT_ANSWER
} event_t;

struct __attribute__((aligned(128))) request_t {
	event_t type;
	int fd;
	size_t length;
	char * buff;
	char data[0];

	static struct request_t * create(event_t type, size_t extra) {
		auto ret = (struct request_t *)malloc(sizeof(struct request_t) + extra);
		ret->type = type;
		ret->length = extra;
		ret->buff = ret->data;
		return ret;
	}

	static struct request_t * create(event_t type) {
		return create(type, 0);
	}
};

struct __attribute__((aligned(128))) options_t {
	struct {
		int sockfd;
		struct sockaddr *addr;
		socklen_t *addrlen;
		int flags;
	} acpt;

	int endfd;
	struct io_uring * ring;

	struct {
		size_t subs = 0;
		size_t cnts = 0;
	} result;
};

//=========================================================
static struct io_uring_sqe * get_sqe(struct io_uring * ring) {
	struct io_uring_sqe * sqe = io_uring_get_sqe(ring);
	if(!sqe) {
		std::cerr << "Insufficient entries in ring" << std::endl;
		exit(EXIT_FAILURE);
	}
	return sqe;
}

static void submit(struct io_uring * ) {
	// io_uring_submit(ring);
}

//=========================================================
static void ring_end(struct io_uring * ring, int fd, char * buffer, size_t len) {
	struct io_uring_sqe * sqe = get_sqe(ring);
	io_uring_prep_read(sqe, fd, buffer, len, 0);
	io_uring_sqe_set_data(sqe, request_t::create(EVENT_END));
	submit(ring);
}

static void ring_accept(struct io_uring * ring, int sockfd, struct sockaddr *addr, socklen_t *addrlen, int flags) {
	auto req = request_t::create(EVENT_ACCEPT);
	struct io_uring_sqe * sqe = get_sqe(ring);
	io_uring_prep_accept(sqe, sockfd, addr, addrlen, flags);
	io_uring_sqe_set_data(sqe, req);
	submit(ring);
	// std::cout << "Submitted accept: " << req << std::endl;
}

static void ring_request(struct io_uring * ring, int fd) {
	size_t size = 1024;
	auto req = request_t::create(EVENT_REQUEST, size);
	req->fd = fd;

	struct io_uring_sqe * sqe = get_sqe(ring);
	io_uring_prep_read(sqe, fd, req->buff, size, 0);
	io_uring_sqe_set_data(sqe, req);
	submit(ring);
	// std::cout << "Submitted request: " << req << " (" << (void*)req->buffer << ")"<<std::endl;
}

//=========================================================
enum HttpCode {
	OK200 = 0,
	E400,
	E404,
	E405,
	E408,
	E413,
	E414,
	KNOWN_CODES
};

const char * http_msgs[] = {
	"HTTP/1.1 200 OK\nServer: HttoForall\nDate: %s \nContent-Type: text/plain\nContent-Length: %zu \n\n%s",
	"HTTP/1.1 400 Bad Request\nServer: HttoForall\nDate: %s \nContent-Type: text/plain\nContent-Length: 0 \n\n",
	"HTTP/1.1 404 Not Found\nServer: HttoForall\nDate: %s \nContent-Type: text/plain\nContent-Length: 0 \n\n",
	"HTTP/1.1 405 Method Not Allowed\nServer: HttoForall\nDate: %s \nContent-Type: text/plain\nContent-Length: 0 \n\n",
	"HTTP/1.1 408 Request Timeout\nServer: HttoForall\nDate: %s \nContent-Type: text/plain\nContent-Length: 0 \n\n",
	"HTTP/1.1 413 Payload Too Large\nServer: HttoForall\nDate: %s \nContent-Type: text/plain\nContent-Length: 0 \n\n",
	"HTTP/1.1 414 URI Too Long\nServer: HttoForall\nDate: %s \nContent-Type: text/plain\nContent-Length: 0 \n\n",
};

static_assert( KNOWN_CODES == (sizeof(http_msgs ) / sizeof(http_msgs [0])));

const int http_codes[] = {
	200,
	400,
	404,
	405,
	408,
	413,
	414,
};

static_assert( KNOWN_CODES == (sizeof(http_codes) / sizeof(http_codes[0])));

int code_val(HttpCode code) {
	return http_codes[code];
}

static void ring_answer(struct io_uring * ring, int fd, HttpCode code) {
	size_t size = 256;
	auto req = request_t::create(EVENT_ANSWER, size);
	req->fd = fd;

	const char * fmt = http_msgs[code];
	const char * date = "";
	size = snprintf(req->buff, size, fmt, date, size);

	struct io_uring_sqe * sqe = get_sqe(ring);
	io_uring_prep_write(sqe, fd, req->buff, size, 0);
	io_uring_sqe_set_data(sqe, req);
	submit(ring);
	// std::cout << "Submitted good answer: " << req << " (" << (void*)req->buffer << ")"<<std::endl;
}

static void ring_answer(struct io_uring * ring, int fd, const std::string &) {
	// size_t size = 256;
	// auto req = request_t::create(EVENT_ANSWER, size);
	// req->fd = fd;

	// const char * fmt = http_msgs[OK200];
	// const char * date = "";
	// size_t len = snprintf(req->buffer, size, fmt, date, ans.size(), ans.c_str());
	// req->length = len;

	// struct io_uring_sqe * sqe = get_sqe(ring);
	// io_uring_prep_write(sqe, fd, req->buffer, len, 0);
	// io_uring_sqe_set_data(sqe, req);
	// submit(ring);
	// std::cout << "Submitted good answer: " << req << " (" << (void*)req->buffer << ")"<<std::endl;


	static const char* RESPONSE = "HTTP/1.1 200 OK\r\n" \
						"Content-Length: 15\r\n" \
						"Content-Type: text/html\r\n" \
						"Connection: keep-alive\r\n" \
						"Server: testserver\r\n" \
						"\r\n" \
						"Hello, World!\r\n";

	static const size_t RLEN = strlen(RESPONSE);

	size_t size = 256;
	auto req = request_t::create(EVENT_ANSWER, size);
	req->fd = fd;
	req->buff = (char*)RESPONSE;
	req->length = RLEN;

	// const char * fmt = http_msgs[OK200];
	// const char * date = "";
	// size_t len = snprintf(req->buffer, size, fmt, date, ans.size(), ans.c_str());
	// req->length = len;

	struct io_uring_sqe * sqe = get_sqe(ring);
	io_uring_prep_write(sqe, fd, RESPONSE, RLEN, 0);
	io_uring_sqe_set_data(sqe, req);
	submit(ring);
}

//=========================================================
static void handle_new_conn(struct io_uring * ring, int fd) {
	if( fd < 0 ) {
		int err = -fd;
		if( err == ECONNABORTED ) return;
		std::cerr << "accept error: (" << errno << ") " << strerror(errno) << std::endl;
		exit(EXIT_FAILURE);
	}

	ring_request(ring, fd);
}

static void handle_request(struct io_uring * ring, struct request_t * in, int res) {
	if( res < 0 ) {
		int err = -res;
		switch(err) {
			case EPIPE:
			case ECONNRESET:
				close(in->fd);
				free(in);
				return;
			default:
				std::cerr << "request error: (" << err << ") " << strerror(err) << std::endl;
				exit(EXIT_FAILURE);
		}
	}

	if(res == 0) {
		close(in->fd);
		free(in);
		return;
	}

	const char * it = in->buff;
	if( !strstr( it, "\r\n\r\n" ) ) {
		std::cout << "Incomplete request" << std::endl;
		close(in->fd);
		free(in);
		return;
	}

	it = in->buff;
	const std::string reply = "Hello, World!\n";
	int ret = memcmp(it, "GET ", 4);
	if( ret != 0 ) {
		ring_answer(ring, in->fd, E400);
		goto NEXT;
	}

	it += 4;
	ret = memcmp(it, "/plaintext", 10);
	if( ret != 0 ) {
		ring_answer(ring, in->fd, E404);
		goto NEXT;
	}

	ring_answer(ring, in->fd, reply);

	NEXT:
		ring_request(ring, in->fd);
		return;
}

static void handle_answer(struct io_uring * ring, struct request_t * in, int res) {
	if( res < 0 ) {
		int err = -res;
		switch(err) {
			case EPIPE:
			case ECONNRESET:
				close(in->fd);
				free(in);
				return;
			default:
				std::cerr << "answer error: (" << err << ") " << strerror(err) << std::endl;
				exit(EXIT_FAILURE);
		}
	}

	if( res >= in->length ) {
		free(in);
		return;
	}

	struct io_uring_sqe * sqe = get_sqe(ring);
	io_uring_prep_write(sqe, in->fd, in->buff + res, in->length - res, 0);
	io_uring_sqe_set_data(sqe, in);
	submit(ring);
	// std::cout << "Re-Submitted request: " << in << " (" << (void*)in->buffer << ")"<<std::endl;

	ring_request(ring, in->fd);
}

//=========================================================
extern "C" {
extern int __io_uring_flush_sq(struct io_uring *ring);
}

void * proc_loop(void * arg) {
	size_t count = 0;
	struct options_t & opt = *(struct options_t *)arg;

	struct io_uring * ring = opt.ring;

	char endfd_buf[8];
	ring_end(ring, opt.endfd, endfd_buf, 8);

	ring_accept(ring, opt.acpt.sockfd, opt.acpt.addr, opt.acpt.addrlen, opt.acpt.flags);

	bool done = false;
	while(!done) {
    		struct io_uring_cqe *cqe;
		int ret;
		while(-EAGAIN == (ret = io_uring_wait_cqe_nr(ring, &cqe, 0))) {
			ret = io_uring_submit_and_wait(ring, 1);
			if (ret < 0) {
				fprintf( stderr, "io_uring get error: (%d) %s\n", (int)-ret, strerror(-ret) );
				exit(EXIT_FAILURE);
			}
			opt.result.subs += ret;
			opt.result.cnts++;
		}

		if (ret < 0 && -EAGAIN != ret) {
			fprintf( stderr, "io_uring peek error: (%d) %s\n", (int)-ret, strerror(-ret) );
			exit(EXIT_FAILURE);
		}

		auto req = (struct request_t *)cqe->user_data;
		// std::cout << req << " completed with " << cqe->res << std::endl;

		switch(req->type) {
			case EVENT_END:
				done = true;
				break;
			case EVENT_ACCEPT:
				handle_new_conn(ring, cqe->res);
				free(req);
				ring_accept(ring, opt.acpt.sockfd, opt.acpt.addr, opt.acpt.addrlen, opt.acpt.flags);
				break;
			case EVENT_REQUEST:
				handle_request(ring, req, cqe->res);
				break;
			case EVENT_ANSWER:
				handle_answer(ring, req, cqe->res);
				break;
		}

		io_uring_cqe_seen(ring, cqe);
	}

	return (void*)count;
}

//=========================================================
struct __attribute__((aligned(128))) aligned_ring {
	struct io_uring storage;
};

#include <bit>

#include <pthread.h>
extern "C" {
	#include <signal.h>
	#include <sys/eventfd.h>
	#include <sys/socket.h>
	#include <netinet/in.h>
}

int main(int argc, char * argv[]) {
	signal(SIGPIPE, SIG_IGN);

	unsigned nthreads = 1;
	unsigned port = 8800;
	unsigned entries = 256;
	unsigned backlog = 10;
	bool attach = false;

	//===================
	// Arguments
	int c;
	while ((c = getopt (argc, argv, "t:p:e:b:a")) != -1) {
		switch (c)
		{
		case 't':
			nthreads = atoi(optarg);
			break;
		case 'p':
			port = atoi(optarg);
			break;
		case 'e':
			entries = atoi(optarg);
			break;
		case 'b':
			backlog = atoi(optarg);
			break;
		case 'a':
			attach = true;
			break;
		case '?':
		default:
			std::cerr << "Usage: -t <threads> -p <port> -e <entries> -b <backlog> -a" << std::endl;
			return EXIT_FAILURE;
		}
	}

	if( !std::ispow2(entries) ) {
		unsigned v = entries;
		v--;
		v |= v >> 1;
		v |= v >> 2;
		v |= v >> 4;
		v |= v >> 8;
		v |= v >> 16;
		v++;
		std::cerr << "Warning: num_entries not a power of 2 (" << entries << ") raising to " << v << std::endl;
		entries = v;
	}

	//===================
	// End FD
	int efd = eventfd(0, EFD_SEMAPHORE);
	if (efd < 0) {
		std::cerr << "eventfd error: (" << errno << ") " << strerror(errno) << std::endl;
		exit(EXIT_FAILURE);
	}

	//===================
	// Open Socket
	std::cout << getpid() << " : Listening on port " << port << std::endl;
	int server_fd = socket(AF_INET, SOCK_STREAM, 0);
	if(server_fd < 0) {
		std::cerr << "socket error: (" << errno << ") " << strerror(errno) << std::endl;
		exit(EXIT_FAILURE);
	}

	int ret = 0;
	struct sockaddr_in address;
	int addrlen = sizeof(address);
	memset( (char *)&address, '\0', addrlen );
	address.sin_family = AF_INET;
	address.sin_addr.s_addr = htonl(INADDR_ANY);
	address.sin_port = htons( port );

	int waited = 0;
	while(true) {
		ret = bind( server_fd, (struct sockaddr *)&address, sizeof(address) );
		if(ret < 0) {
			if(errno == EADDRINUSE) {
				if(waited == 0) {
					std::cerr << "Waiting for port" << std::endl;
				} else {
					std::cerr << "\r" << waited;
					std::cerr.flush();
				}
				waited ++;
				usleep( 1000000 );
				continue;
			}
			std::cerr << "bind error: (" << errno << ") " << strerror(errno) << std::endl;
			exit(EXIT_FAILURE);
		}
		break;
	}

	ret = listen( server_fd, backlog );
	if(ret < 0) {
		std::cerr << "listen error: (" << errno << ") " << strerror(errno) << std::endl;
		exit(EXIT_FAILURE);
	}

	//===================
	// Run Server Threads
	std::cout << "Starting " << nthreads << " Threads";
	if(attach) {
		std::cout << " with attached Rings";
	}
	std::cout << std::endl;

	aligned_ring thrd_rings[nthreads];
	pthread_t    thrd_hdls[nthreads];
	options_t    thrd_opts[nthreads];
	for(unsigned i = 0; i < nthreads; i++) {
		if(!attach || i == 0) {
			io_uring_queue_init(entries, &thrd_rings[i].storage, 0);
		}
		else {
			struct io_uring_params p;
			memset(&p, 0, sizeof(p));
			p.flags = IORING_SETUP_ATTACH_WQ;
			p.wq_fd = thrd_rings[0].storage.ring_fd;
			io_uring_queue_init_params(entries, &thrd_rings[i].storage, &p);
		}

		thrd_opts[i].acpt.sockfd  = server_fd;
		thrd_opts[i].acpt.addr    = (struct sockaddr *)&address;
		thrd_opts[i].acpt.addrlen = (socklen_t*)&addrlen;
		thrd_opts[i].acpt.flags   = 0;
		thrd_opts[i].endfd        = efd;
		thrd_opts[i].ring         = &thrd_rings[i].storage;

		int ret = pthread_create(&thrd_hdls[i], nullptr, proc_loop, &thrd_opts[i]);
		if (ret < 0) {
			std::cerr << "pthread create error: (" << errno << ") " << strerror(errno) << std::endl;
			exit(EXIT_FAILURE);
		}
	}

	//===================
	// Server Started
	std::cout << "Server Started" << std::endl;
	{
		char buffer[128];
		int ret;
		do {
			ret = read(STDIN_FILENO, buffer, 128);
			if(ret < 0) {
				std::cerr << "main read error: (" << errno << ") " << strerror(errno) << std::endl;
				exit(EXIT_FAILURE);
			}
			else if(ret > 0) {
				std::cout << "User inputed '";
				std::cout.write(buffer, ret);
				std::cout << "'" << std::endl;
			}
		} while(ret != 0);

		std::cout << "Shutdown received" << std::endl;
	}

	//===================
	(std::cout << "Sending Shutdown to Threads... ").flush();
	ret = eventfd_write(efd, nthreads);
	if (ret < 0) {
		std::cerr << "eventfd close error: (" << errno << ") " << strerror(errno) << std::endl;
		exit(EXIT_FAILURE);
	}
	std::cout << "done" << std::endl;

	//===================
	(std::cout << "Stopping Threads Done... ").flush();
	size_t total = 0;
	size_t count = 0;
	for(unsigned i = 0; i < nthreads; i++) {
		void * retval;
		int ret = pthread_join(thrd_hdls[i], &retval);
		if (ret < 0) {
			std::cerr << "pthread create error: (" << errno << ") " << strerror(errno) << std::endl;
			exit(EXIT_FAILURE);
		}
		// total += (size_t)retval;
		total += thrd_opts[i].result.subs;
		count += thrd_opts[i].result.cnts;

		io_uring_queue_exit(thrd_opts[i].ring);
	}
	std::cout << "done" << std::endl;
	std::cout << "Submit average: " << total << "/" << count << "(" << (((double)total) / count) << ")" << std::endl;

	//===================
	(std::cout << "Closing Socket... ").flush();
	ret = shutdown( server_fd, SHUT_RD );
	if( ret < 0 ) {
		std::cerr << "shutdown socket error: (" << errno << ") " << strerror(errno) << std::endl;
		exit(EXIT_FAILURE);
	}

	ret = close(server_fd);
	if (ret < 0) {
		std::cerr << "close socket error: (" << errno << ") " << strerror(errno) << std::endl;
		exit(EXIT_FAILURE);
	}
	std::cout << "done" << std::endl;
}