// programs that sends a file many times as fast as it can
// compares sendfile to splice

#define _GNU_SOURCE

#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include <errno.h>
#include <locale.h>
#include <time.h>
#include <unistd.h>

#include <sys/ioctl.h>
#include <sys/sendfile.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <fcntl.h>

#include <netinet/in.h>
#include <arpa/inet.h>
#include <netdb.h>

#include <liburing.h>

enum {
	USAGE_ERROR = 1,
	HOST_ERROR,
	PIPE_ERROR,
	FSTAT_ERROR,
	SOCKET_ERROR,
	CONNECT_ERROR,
	SENDFILE_ERROR,
	SPLICEIN_ERROR,
	SPLICEOUT_ERROR,
	URINGWAIT_ERROR
};

enum { buffer_len = 10240 };
char buffer[buffer_len];

enum { TIMEGRAN = 1000000000LL, TIMES = 100000 };

int pipefd[2];
struct io_uring ring;

char * buf;

struct stats {
	size_t calls;
	size_t bytes;
	struct {
		struct {
			size_t cnt;
			size_t bytes;
		} r, w;
	} shorts;
};
static void my_sendfile(int out, int in, size_t size, struct stats *);
static void my_splice  (int out, int in, size_t size, struct stats *);
static void my_iouring (int out, int in, size_t size, struct stats *);
static void my_ringlink(int out, int in, size_t size, struct stats *);
static void my_readwrit(int out, int in, size_t size, struct stats *);
typedef void (*sender_t)(int out, int in, size_t size, struct stats *);

static void run(sender_t sender, struct addrinfo * addr, int infd, size_t size);

int main(int argc, char * argv[]) {
	setlocale(LC_ALL, "");
	const char * file_path;
	struct addrinfo * addr;
	int file_fd;
	int ret;
	switch(argc) {
	case 3:
		{
			// Open the file
			const char * const path = argv[2];
			ret = open(path, 0, O_RDONLY);
			if(ret < 0) {
				fprintf( stderr, "cannot open file '%s': %s\n\n", path, strerror(errno) );
				goto USAGE;
			}

			file_path = path;
			file_fd = ret;


			// connect to the address
			char * state = 0;
			char * str = argv[1];
			const char * const host = strtok_r(str, ":", &state);
			if(NULL == host) {
				fprintf( stderr, "Invalid host:port specification, no host.\n\n" );
				goto USAGE;
			}

			const char * const port = strtok_r(NULL, ":", &state);
			if(NULL == port) {
				fprintf( stderr, "Invalid host:port specification, no port.\n\n" );
				goto USAGE;
			}

			printf("looking up '%s:%s'\n", host, port);

			struct addrinfo hints = {};
			struct addrinfo * pResultList = NULL;

			hints.ai_family = AF_INET;
			hints.ai_socktype = SOCK_STREAM;
			hints.ai_flags = AI_NUMERICSERV;

			ret = getaddrinfo(host, port, &hints, &pResultList);

			switch(ret) {
			case 0:
				addr = pResultList;
				goto DONE;

			case EAI_ADDRFAMILY:
				fprintf( stderr, "The specified network host does not have any network addresses in the requested address family.\n\n" );
				break;

			case EAI_AGAIN:
				fprintf( stderr, "The name server returned a temporary failure indication. Try again later.\n\n" );
				exit( HOST_ERROR );

			case EAI_BADFLAGS:
				fprintf( stderr, "hints.ai_flags  contains invalid flags; or, hints.ai_flags included AI_CANONNAME and name was NULL.\n\n" );
				exit( HOST_ERROR );

			case EAI_FAIL:
				fprintf( stderr, "The name server returned a permanent failure indication.\n\n" );
				break;

			case EAI_FAMILY:
				fprintf( stderr, "The requested address family is not supported.\n\n" );
				exit( HOST_ERROR );

			case EAI_MEMORY:
				fprintf( stderr, "Out of memory.\n\n" );
				exit( HOST_ERROR );

			case EAI_NODATA:
				fprintf( stderr, "The specified network host exists, but does not have any network addresses defined.\n\n" );
				break;

			case EAI_NONAME:
				fprintf( stderr, "The unkonwn host or invalid port.\n\n" );
				break;

			case EAI_SERVICE:
				fprintf( stderr, "The requested service is not available for the requested socket type.\n\n" );
				break;

			case EAI_SOCKTYPE:
				fprintf( stderr, "The requested  socket  type  is  not  supported.\n\n" );
				exit( HOST_ERROR );

			case EAI_SYSTEM:
				// Other system error, check errno for details.
			default:
				fprintf( stderr, "Unnown hostname error: (%d) %s\n\n", (int)errno, strerror(errno) );
				exit( HOST_ERROR );
			}
			if(pResultList) freeaddrinfo(pResultList);
			goto USAGE;
		}
	USAGE:
	default:
		fprintf( stderr, "USAGE: %s host:port file\n", argv[0] );
		exit( USAGE_ERROR );
	}

	DONE:

	io_uring_queue_init(16, &ring, 0);

	size_t file_size = 0;
	{
		struct stat buf;
   		ret = fstat(file_fd, &buf);
		if(0 != ret) {
			fprintf( stderr, "fstat error: (%d) %s\n\n", (int)errno, strerror(errno) );
			exit( FSTAT_ERROR );
		}
		file_size = buf.st_size;
	}

	{
		char addr_str[INET_ADDRSTRLEN];
		struct sockaddr_in * address = (struct sockaddr_in *) addr->ai_addr;
		inet_ntop( AF_INET, &address->sin_addr, addr_str, INET_ADDRSTRLEN );
		printf("sending '%s' (%zu bytes) to '%s:%i'\n", file_path, file_size, addr_str, ntohs(address->sin_port));
	}

	ret = pipe(pipefd);
	if( ret < 0 ) {
		fprintf( stderr, "pipe error: (%d) %s\n\n", (int)errno, strerror(errno) );
		exit( PIPE_ERROR );
	}

	buf = malloc(file_size);

	printf("--- read + write ---\n");
	run(my_readwrit, addr, file_fd, file_size);
	printf("--- splice ---\n");
	run(my_splice  , addr, file_fd, file_size);
	printf("--- sendfile ---\n");
	run(my_sendfile, addr, file_fd, file_size);
	printf("--- io_uring ---\n");
	run(my_iouring, addr, file_fd, file_size);
	printf("--- io_uring + link ---\n");
	run(my_ringlink, addr, file_fd, file_size);

	close(pipefd[0]);
	close(pipefd[1]);
	close(file_fd);
	return 0;
}

static void run(sender_t sender, struct addrinfo * addr, int infd, size_t size) {

	int sock = socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol);
      if(sock < 0) {
		fprintf( stderr, "socket error: (%d) %s\n\n", (int)errno, strerror(errno) );
		exit( SOCKET_ERROR );
      }

      int ret = connect(sock, addr->ai_addr, addr->ai_addrlen);
      if(ret < 0) {
            fprintf( stderr, "connect error: (%d) %s\n\n", (int)errno, strerror(errno) );
		exit( CONNECT_ERROR );
      }

	struct stats st;
	st.calls = 0;
	st.bytes = 0;
	st.shorts.r.cnt = 0;
	st.shorts.r.bytes = 0;
	st.shorts.w.cnt = 0;
	st.shorts.w.bytes = 0;

	struct timespec after, before;

	clock_gettime(CLOCK_MONOTONIC, &before);

	for(long long int i = 0; i < TIMES; i++) {
		sender( sock, infd, size, &st );
	}

	clock_gettime(CLOCK_MONOTONIC, &after);

	close(sock);

	uint64_t tb = ((int64_t)before.tv_sec * TIMEGRAN) + before.tv_nsec;
	uint64_t ta = ((int64_t)after.tv_sec * TIMEGRAN) + after.tv_nsec;
	double secs = ((double)ta - tb) / TIMEGRAN;

	printf("Sent %'zu bytes in %'zu files, %f seconds\n", st.bytes, st.calls, secs);
	printf(" - %'3.3f bytes per second\n", (((double)st.bytes) / secs));
	printf(" - %'f seconds per file\n", secs / st.calls);
	printf(" - %'3.3f bytes per calls\n", (((double)st.bytes) / st.calls));
	if(st.shorts.r.cnt ){
		printf(" - %'zu short reads\n", st.shorts.r.cnt);
		printf(" - %'3.3f bytes per short read\n", (((double)st.shorts.r.bytes) / st.shorts.r.cnt));
	} else printf("No short reads\n");
	if(st.shorts.w.cnt ){
		printf(" - %'zu short reads\n", st.shorts.w.cnt);
		printf(" - %'3.3f bytes per short read\n", (((double)st.shorts.w.bytes) / st.shorts.w.cnt));
	} else printf("No short writes\n");
}

static void my_sendfile(int out, int in, size_t size, struct stats * st) {
	off_t off = 0;
	for(;;) {

		ssize_t ret = sendfile(out, in, &off, size);
		if(ret < 0) {
			fprintf( stderr, "connect error: (%d) %s\n\n", (int)errno, strerror(errno) );
			exit( SENDFILE_ERROR );
		}

		st->calls++;
		st->bytes += ret;
		off += ret;
		size -= ret;
		if( size == 0 ) return;
		st->shorts.r.cnt++;
		st->shorts.r.bytes += ret;
	}
}

static void my_splice  (int out, int in, size_t size, struct stats * st) {
	unsigned flags = 0; //SPLICE_F_MOVE; // | SPLICE_F_MORE;
	off_t offset = 0;
	size_t writes = 0;
	for(;;) {
		ssize_t reti = 0;
		reti = splice(in, &offset, pipefd[1], NULL, size, flags);
		if( reti < 0 ) {
			fprintf( stderr, "splice in error: (%d) %s\n\n", (int)errno, strerror(errno) );
			exit( SPLICEIN_ERROR );
		}

		size -= reti;
		size_t in_pipe = reti;
		for(;;) {
			ssize_t reto = 0;
			reto = splice(pipefd[0], NULL, out, NULL, in_pipe, flags);
			if( reto < 0 ) {
				fprintf( stderr, "splice out error: (%d) %s\n\n", (int)errno, strerror(errno) );
				exit( SPLICEOUT_ERROR );
			}
			in_pipe -= reto;
			writes += reto;
			if(0 == in_pipe) break;
			st->shorts.w.cnt++;
			st->shorts.w.bytes += reto;
		}
		if(0 == size) break;
		st->shorts.r.cnt++;
		st->shorts.r.bytes += reti;
	}
	st->calls++;
	st->bytes += writes;
}

static ssize_t naive_splice(int fd_in, loff_t *off_in, int fd_out, loff_t *off_out, size_t len, unsigned int flags) {
	struct io_uring_sqe * sqe = io_uring_get_sqe(&ring);

	io_uring_prep_splice(sqe, fd_in, NULL != off_in ? *off_in: -1, fd_out, NULL != off_out ? *off_out: -1, len, flags);

	io_uring_submit(&ring);

	struct io_uring_cqe * cqe = NULL;
	/* wait for the sqe to complete */
	int ret = io_uring_wait_cqe_nr(&ring, &cqe, 1);

	/* read and process cqe event */
	switch(ret) {
	case 0:
		{
			ssize_t val = cqe->res;
			if( cqe->res < 0 ) {
				printf("Completion Error : %s\n", strerror( -cqe->res ));
				return EXIT_FAILURE;
			}
			io_uring_cqe_seen(&ring, cqe);
			return val;
		}
	default:
		fprintf( stderr, "io_uring_wait error: (%d) %s\n\n", (int)-ret, strerror(-ret) );
		exit( URINGWAIT_ERROR );
	}
}

static void my_iouring (int out, int in, size_t size, struct stats * st) {
	unsigned flags = 0; //SPLICE_F_MOVE; // | SPLICE_F_MORE;
	off_t offset = 0;
	size_t writes = 0;
	for(;;) {
		ssize_t reti = 0;
		reti = naive_splice(in, &offset, pipefd[1], NULL, size, flags);
		if( reti < 0 ) {
			fprintf( stderr, "splice in error: (%d) %s\n\n", (int)errno, strerror(errno) );
			exit( SPLICEIN_ERROR );
		}

		size -= reti;
		size_t in_pipe = reti;
		for(;;) {
			ssize_t reto = 0;
			reto = naive_splice(pipefd[0], NULL, out, NULL, in_pipe, flags);
			if( reto < 0 ) {
				fprintf( stderr, "splice out error: (%d) %s\n\n", (int)errno, strerror(errno) );
				exit( SPLICEOUT_ERROR );
			}
			in_pipe -= reto;
			writes += reto;
			if(0 == in_pipe) break;
			st->shorts.w.cnt++;
			st->shorts.w.bytes += reto;
		}
		if(0 == size) break;
		st->shorts.r.cnt++;
		st->shorts.r.bytes += reti;
	}
	st->calls++;
	st->bytes += writes;
}

static void my_ringlink(int out, int in, size_t size, struct stats * st) {
	enum { SPLICE_IN, SPLICE_OUT };

	size_t in_pipe = size;
	off_t offset = 0;
	bool has_in = false;
	bool has_out = false;
	while(true) {
		if(!has_in && size > 0) {
			struct io_uring_sqe * sqe = io_uring_get_sqe(&ring);
			io_uring_prep_splice(sqe, in, offset, pipefd[1], -1, size, 0);
			sqe->user_data = SPLICE_IN;
			sqe->flags = IOSQE_IO_LINK;
			has_in = true;
		}
		if(!has_out) {
			struct io_uring_sqe * sqe = io_uring_get_sqe(&ring);
			io_uring_prep_splice(sqe, pipefd[0], -1, out, -1, in_pipe, 0);
			sqe->user_data = SPLICE_OUT;
			has_out = true;
		}

		int ret = io_uring_submit_and_wait(&ring, 1);
		if(ret < 0) {
			fprintf( stderr, "io_uring_submit error: (%d) %s\n\n", (int)-ret, strerror(-ret) );
			exit( URINGWAIT_ERROR );
		}

		/* poll the cq and count how much polling we did */
		while(true) {
			struct io_uring_cqe * cqe = NULL;
			/* wait for the sqe to complete */
			int ret = io_uring_wait_cqe_nr(&ring, &cqe, 0);

			/* read and process cqe event */
			switch(ret) {
			case 0:
				if( cqe->res < 0 ) {
					printf("Completion Error : %s\n", strerror( -cqe->res ));
					exit( URINGWAIT_ERROR );
				}

				ssize_t write = cqe->res;
				int which = cqe->user_data;
				io_uring_cqe_seen(&ring, cqe);
				switch( which ) {
				case SPLICE_IN:
					has_in = false;
					size -= write;
					offset += write;
					if(0 == size) break;
					st->shorts.r.cnt++;
					st->shorts.r.bytes += write;
					break;
				case SPLICE_OUT:
					has_out = false;
					in_pipe -= write;
					st->bytes += write;
					if(0 == in_pipe) break;
					st->shorts.w.cnt++;
					st->shorts.w.bytes += write;
					break;
				default:
					printf("Completion Error : unknown user data\n");
					exit( URINGWAIT_ERROR );
				}
				continue;
			case -EAGAIN:
				goto OUTER;
			default:
				fprintf( stderr, "io_uring_get_cqe error: (%d) %s\n\n", (int)-ret, strerror(-ret) );
				exit( URINGWAIT_ERROR );
			}
		}
		OUTER:
		if(0 == in_pipe) break;
	}
	st->calls++;
}

static void my_readwrit(int out, int in, size_t size, struct stats * st) {
	off_t offset = 0;
	size_t writes = 0;
	for(;;) {
		ssize_t reti = pread(in, buf, size, offset);
		if( reti < 0 ) {
			printf("Read in Error : (%d) %s\n\n", (int)errno, strerror(errno) );
			exit( 1 );
		}

		offset += reti;
		size -= reti;

		size_t in_buf = reti;
		for(;;) {
			ssize_t reto = write(out, buf, in_buf);
			if( reto < 0 ) {
					printf("Write out Error : (%d) %s\n\n", (int)errno, strerror(errno) );
					exit( 1 );
				}

			in_buf -= reto;
			writes += reto;
			if(0 == in_buf) break;
			st->shorts.w.cnt++;
			st->shorts.w.bytes += reto;
		}
		if(0 == size) break;
		st->shorts.r.cnt++;
		st->shorts.r.bytes += reti;
	}
	st->calls++;
	st->bytes += writes;
}