// first attempt at a simple x64 backend targeting 100 lines of code (far from it!)
// link the output executable via 'cl.exe out.obj /link /entry:main ucrt.lib legacy_stdio_definitions.lib'

// NOTE: spills can't take more than a page (4096 bytes), would require stack walk call
// NOTE: at most 32 bit constants
// NOTE: args/params only in regs (at most 4)
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#define COUNT_OF(array) (sizeof (array)/sizeof (array)[0])

enum opcode {assign = -5, character, newline, number, identifier, opcode_nop, opcode_label, opcode_param, opcode_const,
	opcode_add, opcode_sub, opcode_mul, opcode_div, opcode_copy, opcode_load, opcode_store,
	opcode_ret, opcode_brz, opcode_jmp, opcode_call, opcode_cqo, opcode_spill, opcode_reload, opcode_clobber, opcode_use};
char const *opcodes[] = {"", "label", "param", "const", "add", "sub", "mul", "div", "copy", "load", "store", "ret", "brz", "jmp", "call", "cqo", "spill", "reload", "clobber", "use"};
enum reg {reg, reg_op0, rax, rcx, rdx, rbx, rsp, rbp, rsi, rdi, r8, r9, r10, r11, r12, r13, r14, r15, reg_count};
struct instruction { enum opcode opcode; int res, op0, op1, op2; enum reg reg;} instructions[1024];
int instruction_count = 1, next_tmp = 1;

 enum opcode scan(char *cur, char **end, int *count) {
	while (*cur == ' ' || *cur == '\t' || *cur == '\r') cur++; // skip whitespace
	if (cur[0] == '/' && cur[1] == '/') while (*cur != '\n') cur++;
	if (*cur == '\n') {*count = 1, *end = cur + 1; return newline; }
	char *start = cur;
	if (*cur >= 'a' && *cur <= 'z') { // keyword/identifier
		while (*cur >= 'a' && *cur <= 'z' || *cur >= '0' && *cur <= '9' || *cur == '_') cur++;
		for (int i = 0; i < COUNT_OF(opcodes); i++) {
			if (cur - start == strlen(opcodes[i]) && memcmp(start, opcodes[i], cur - start) == 0) {
				*count = (int)(cur - start); *end = cur; return i;
			}
		}
		*count = (int)(cur - start), *end = cur; return identifier;
	} else if (*cur >= '0' && *cur <= '9') {
		while (*cur >= '0' && *cur <= '9') cur++;
		*count = (int)(cur - start); *end = cur; return number;
	} else if (cur[0] == ':' && cur[1] == '=') {
		cur += 2; *count = (int)(cur - start); *end = cur; return assign;
	} else if (cur[0] == '\'' && (cur[1] && cur[2] == '\'' || cur[1] == '\\' && cur[2] && cur[3] == '\'')) {
		cur += 3 + (cur[1] == '\\'); *count = (int)(cur - start); *end = cur; return character;
	}
	return opcode_nop;
}

char identifiers[256][9] = {0}; int identifier_count = 1;
int lookup(char *identifier, int count) {
	for (int i = 1; i < COUNT_OF(identifiers); i++) {
		if (count == strlen(identifiers[i]) && memcmp(identifier, identifiers[i], count) == 0) return i;
	}
	return 0;
}
int insert(char *identifier, int count) {
	assert(count <= 8);
	int id = lookup(identifier, count);
	if (id) return id;
	snprintf(identifiers[identifier_count], sizeof identifiers[identifier_count], "%.*s", count, identifier);
	return identifier_count++;
}
int not_newline(char **cur) {
	char *end; int c; if (scan(*cur, &end, &c) == newline) { *cur = end; return 0;}
	return 1;
}
void parse(char *cur) {
	int count = 0; char *end = NULL;
	enum opcode opcode = opcode_nop;
	while (*cur != '\0') { // parse instruction
		enum opcode opcode = scan(cur, &end, &count); cur = end;
		int res = 0;
		if (opcode == newline) { cur = end; continue; }
		if (opcode == identifier) {
			res = insert(&end[-count], count);
			scan(cur, &end, &count); cur = end; // assign
			opcode = scan(cur, &end, &count); cur = end;
		}
		if (opcode == opcode_const) {
			int n = 0;
			if (scan(cur, &end, &count) == number) { cur = end; n = atoi(&end[-count]); }
			if (scan(cur, &end, &count) == character) { cur = end;
				if (count == 4) n = ((char[]){['\\'] = '\\', ['n'] = '\n', ['r'] = '\r', ['t'] = '\r'})[cur[-2]];
				else n = cur[-2];
			}
			instructions[instruction_count++] = (struct instruction){opcode, res, n};
		} else {
			int op0 = 0; if (scan(cur, &end, &count) == identifier) { cur = end; op0 = insert(&end[-count], count); }
			int op1 = 0; if (op0 && not_newline(&cur) && scan(cur, &end, &count) == identifier) { cur = end; op1 = insert(&end[-count], count); }
			int op2 = 0; if (op1 && not_newline(&cur) && scan(cur, &end, &count) == identifier) { cur = end; op2 = insert(&end[-count], count); }
			instructions[instruction_count++] = (struct instruction){opcode, res, op0, op1, op2};
		}
	}
	next_tmp = instruction_count;
}

void select(void) {
	struct instruction out[256]; int count = 0, param = 0;
	int id = instruction_count;
	for (int i = 0; i < instruction_count; i++) {
		struct instruction ins = instructions[i];
		switch (ins.opcode) {
		case opcode_param:
			out[count++] = (struct instruction){opcode_param, ins.res,
				.reg = (enum reg[]){[0] = rcx, [1] = rdx, [2] = r8, [3] = r9}[param++]};
			assert(param <= 4); break; // TODO: stack params
		case opcode_add:
		case opcode_sub:
		case opcode_mul:
			out[count++] = (struct instruction){opcode_copy, ins.res, ins.op0};
			out[count++] = ins; out[count - 1].op0 = ins.res; out[count - 1].reg = reg_op0;
			break;
		case opcode_div: {
			int t0 = next_tmp++, t1 = next_tmp++, t2 = next_tmp++;
			out[count++] = (struct instruction){opcode_copy, t0, ins.op0, .reg = rax};
			out[count++] = (struct instruction){opcode_cqo, t1, t0, .reg = rdx};
			out[count++] = (struct instruction){opcode_div, t2, t0, ins.op1, t1, .reg = rax};
			out[count++] = (struct instruction){opcode_copy, ins.res, t2};
			break;
		}
		case opcode_call: {
			int t0 = ins.op1, t1 = ins.op2;
			if (ins.op1) {t0 = next_tmp++; out[count++] = (struct instruction){opcode_copy, t0, ins.op1, .reg = rcx};}
			if (ins.op2) {t1 = next_tmp++; out[count++] = (struct instruction){opcode_copy, t1, ins.op2, .reg = rdx};}
			int res = next_tmp++;
			out[count++] = (struct instruction){opcode_call, res, ins.op0, t0, t1, .reg = rax};
			enum reg volatile_[] = {rcx, rdx, r8, r9, r10, r11};
			for (int i = 0; i < COUNT_OF(volatile_); i++) {
				out[count++] = (struct instruction){opcode_clobber, next_tmp++, .op0 = res, .reg = volatile_[i]};
				out[count++] = (struct instruction){opcode_use, .op0 = next_tmp - 1};
			}
			if (ins.res) out[count++] = (struct instruction){opcode_copy, ins.res, res};
			break;
		}
		case opcode_ret: if (ins.op0) {out[count++] = (struct instruction){opcode_copy, next_tmp++, ins.op0, .reg = rax}; ins.op0 = next_tmp - 1; } out[count++] = ins; break;
		default: out[count++] = ins;
		}
	}
	memcpy(instructions, out, sizeof instructions[0]*count); instruction_count = count;
}

// trivial live range finding for linear scan (last use to first def of the id + loops);
//	does not break multiple defs into (potentially) multiple live ranges (aka GVN)
struct lr {int from, to;} lrs[256];
void build_live_ranges(void) {
	char live_in[256][256] = {0}; char tmp[256];
	int positions[256] = {0}; int changed = 1;
	for (int i = 1; i < instruction_count; i++) { positions[instructions[i].res] = i; }
	while (changed) {
		changed = 0;
		for (int i = instruction_count - 1; i > 0; i--) {
			struct instruction ins = instructions[i];
			memcpy(tmp, live_in[i], sizeof live_in[i]);
			memcpy(live_in[i], live_in[i + 1], sizeof live_in[i + 1]);
			if (ins.opcode == opcode_label) continue;
			if (ins.res) live_in[i][ins.res] = 0;
			if (ins.opcode == opcode_const) continue;
			if (ins.opcode == opcode_jmp || ins.opcode == opcode_brz) {
				for (int j = 0; j < COUNT_OF(live_in[i]); j++) live_in[i][j] |= live_in[positions[ins.op0]][j];
			} else if (ins.opcode != opcode_call) live_in[i][ins.op0] = ins.op0 > 0;
			live_in[i][ins.op1] = ins.op1 > 0;
			live_in[i][ins.op2] = ins.op2 > 0;
			changed = memcmp(tmp, live_in[i], sizeof live_in[i]);
		}
	}
	for (int i = 1; i < COUNT_OF(lrs); i++) {lrs[i] = (struct lr){instruction_count, 0};}
	for (int i = 1; i < instruction_count; i++) {
		for (int j = 0; j < COUNT_OF(live_in[i]); j++) {
			if (!live_in[i][j]) continue;
			lrs[j].from = (lrs[j].from > i) ? i : lrs[j].from;
			lrs[j].to = (lrs[j].to < i) ? i : lrs[j].to;
			if (instructions[lrs[j].from - 1].res == j) lrs[j].from--; // adjust live-in
		}
	}
}

int live_ranges_cmp(void const *lr0, void const *lr1) {
	if (((struct lr const *)lr0)->to) {
		return ((struct lr const *)lr0)->from - ((struct lr const *)lr1)->from;
	}
	return 1;
}

int slot_count = 4; // home parameters
int used_regs[reg_count];
void allocate_registers(void) {
	while (1) {
		void print_ir(void); print_ir();
		build_live_ranges();
		void print_live_ranges(void); print_live_ranges();
		int slots[256] = {0}; int start_slot_count = slot_count;
		qsort(&lrs[1], COUNT_OF(lrs) - 1, sizeof lrs[0], live_ranges_cmp);
		print_live_ranges();
		enum reg regs[1024] = {0}; int active[reg_count] = {0};
		for (int i = 1; i < COUNT_OF(lrs); i++) {
			if (lrs[i].to == 0) break;
			for (int j = rax; j < reg_count; j++) { // expire old intervals
				if (active[j] && lrs[active[j]].to <= lrs[i].from) {
					active[j] = 0;
				}
			}
			int r = instructions[lrs[i].from].reg;
			if (r == reg_op0) r = regs[instructions[lrs[i].from].op0];
			if (r) { // precolored
				if (active[r]) slots[instructions[lrs[active[r]].from].res] = slot_count++;
				active[r] = i; regs[instructions[lrs[active[r]].from].res] = r;
			} else {
				for (int j = reg_count - 1; j >= rax; j--) {
					if (active[j] == 0 && j != rsp) { r = j; break; }
				}
				if (!r) { // spill furthest
					int furthest = 1;
					for (int j = rax; j < reg_count; j++) {
						if (lrs[active[j]].to > lrs[furthest].to) furthest = j;
					}
					slots[instructions[lrs[furthest].from].res] = slot_count++;
					r = furthest;
				}
				active[r] = i; regs[instructions[lrs[i].from].res] = r;
			}
		}
		if (slot_count == start_slot_count) {
			for (int i = 1; i < instruction_count; i++) {
				used_regs[instructions[i].reg = regs[instructions[i].res]] = 1;
			}
			break;
		}
		// insert spill/reload instructions
		struct instruction out[256] = {0}; int out_count = 1;
		int id = instruction_count;
		for (int i = 1; i < instruction_count; i++) {
			struct instruction ins = instructions[i];
			int op0 = ins.op0, op1 = ins.op1, op2 = ins.op2;
			if (op0 > 0 && slots[op0]) {
				out[out_count++] = (struct instruction){opcode_reload, id, -slots[op0]}; op0 = id++;
			}
			if (op1 > 0 && op0 != op1 && slots[op1]) {
				out[out_count++] = (struct instruction){opcode_reload, id, -slots[op1]}; op1 = id++;
			}
			if (op2 > 0 && op0 != op1 && op1 != op2 && slots[op2]) {
				out[out_count++] = (struct instruction){opcode_reload, id, -slots[op2]}; op2 = id++;
			}
			if (ins.res && slots[ins.res]) {
				out[out_count++] = (struct instruction){ins.opcode, id, op0, op1, op2, .reg = ins.reg};
				out[out_count++] = (struct instruction){opcode_spill, 0, -slots[ins.res], id++};
			} else {
				out[out_count++] = (struct instruction){ins.opcode, ins.res, op0, op1, op2, .reg = ins.reg};
			}
		}
		memcpy(instructions, out, sizeof instructions); instruction_count = out_count;
	}
}

char rex(int r, int b) { return 0x40 | 0x08 | ((r >= r8) ? 0x04 : 0x0) | ((b >= r8) ? 0x01 : 0x0); }
char modrm(int r, int b) { return 0xC0 | ((r - rax) & 0x07) << 3 | (b - rax) & 0x07; }
void write_coff(void) {
	struct sym { char name[9]; int offset, ext; } syms[256] = {0}; int sym_count = 0;
	syms[sym_count++] = (struct sym){{'m', 'a', 'i', 'n'}};
	struct reloc { int sym, offset;} relocs[256] = {0}; int reloc_count = 0;
	char coff[2048] = {0}; int offset = 0;
	int regs[256] = {0}; int offsets[256] = {0};
	struct patch { int offset, target; } patches[1024]; int patch_count = 0;
	enum reg nonvolatile[] = {r12, r13, r14, r15, rdi, rsi, rbx, rbp};
	int stack_size = slot_count*8;
	for (int i = 0; i < COUNT_OF(nonvolatile); i++) {
		if (used_regs[nonvolatile[i]]) {
			coff[offset++] = rex(rax, nonvolatile[i]);
			coff[offset++] = 0x50 | ((nonvolatile[i] - rax) & 0x07);
			stack_size += 8;
		}
	}
	if (stack_size%16 == 0) slot_count++; // align to ABI 16 byte requirement
	{ // stack alignment + allocation
		coff[offset++] = rex(rax + 5, rsp); coff[offset++] = 0x81;
		coff[offset++] = modrm(rax + 5, rsp);
		int size = slot_count*8;
		memcpy(&coff[offset], &size, sizeof size); offset += 4;
	}
	for (int i = 0; i < instruction_count; i++) {
		struct instruction ins = instructions[i]; regs[ins.res] = ins.reg;
		enum reg r = ins.reg, r0 = regs[ins.op0], r1 = regs[ins.op1];
		switch (ins.opcode) {
		case opcode_label: offsets[ins.res] = offset; break;
		case opcode_const: coff[offset++] = rex(rax, r); coff[offset++] = 0xC7; coff[offset++] = modrm(rax, r);
			coff[offset++] = ins.op0 & 0xFF; coff[offset++] = ins.op0 >> 8 & 0xFF; coff[offset++] = ins.op0 >> 16 & 0xFF; coff[offset++] = ins.op0 >> 24 & 0xFF; break;
		case opcode_add: coff[offset++] = rex(r, r1); coff[offset++] = 0x03; coff[offset++] = modrm(r, r1); break;
		case opcode_sub: coff[offset++] = rex(r, r1); coff[offset++] = 0x2B; coff[offset++] = modrm(r, r1); break;
		case opcode_mul: coff[offset++] = rex(r, r1); coff[offset++] = 0x0F; coff[offset++] = 0xAF; coff[offset++] = modrm(r, r1); break;
		case opcode_div: coff[offset++] = rex(r, r1); coff[offset++] = 0xF7; coff[offset++] = modrm(rax + 7, r1); break;
		case opcode_copy: if (r != r0) { coff[offset++] = rex(r, r0); coff[offset++] = 0x8B; coff[offset++] = modrm(r, r0); } break;
		case opcode_cqo: coff[offset++] = rex(r, rax); coff[offset++] = 0x99; break;
		case opcode_ret: {
			coff[offset++] = rex(rax, rsp); coff[offset++] = 0x81;
			coff[offset++] = modrm(rax, rsp);
			int size = slot_count*8;
			memcpy(&coff[offset], &size, sizeof size); offset += 4;
			for (int i = COUNT_OF(nonvolatile) - 1; i >= 0; --i) {
				if (used_regs[nonvolatile[i]]) {
					coff[offset++] = rex(rax, nonvolatile[i]);
					coff[offset++] = 0x58 | ((nonvolatile[i] - rax) & 0x07);
				}
			}
			coff[offset++] = 0xC3; break;
		}
		case opcode_brz:
		case opcode_jmp:
			if (ins.opcode == opcode_brz) { // test + jz
				coff[offset++] = rex(r1, r1); coff[offset++] = 0x85; coff[offset++] = modrm(r1, r1); // test
				coff[offset++] = 0x0F; coff[offset++] = 0x84; // jz
			} else coff[offset++] = 0xE9; // jmp
			patches[patch_count++] = (struct patch){offset, ins.op0}; offset += 4;
			break;
		case opcode_call:
			coff[offset++] = 0xE8; 
			relocs[reloc_count++] = (struct reloc){sym_count, offset}; offset += 4;
			snprintf(syms[sym_count].name, sizeof syms[sym_count].name, "%s", identifiers[ins.op0]); syms[sym_count].ext = 1; sym_count++;
			break;
		case opcode_reload: r1 = r;
		case opcode_spill: {
			coff[offset++] = rex(r1, rsp); coff[offset++] = (ins.opcode == opcode_reload) ? 0x8B: 0x89;
			coff[offset++] = 0x80 | ((r1 - rax) & 0x07) << 3 | 0x04; // modrm
			coff[offset++] = 0x04 << 3 | 0x04; // sib
			int displacement = ins.op0*8; memcpy(&coff[offset], &displacement, sizeof displacement); offset += 4;
			break;
		}
		case opcode_load: r1 = r;
		case opcode_store: // we use 8 byte displacement to avoid special cases (rsp, rbp, r12, r13)
			coff[offset++] = rex(r1, r0); coff[offset++] = (ins.opcode == opcode_load) ? 0x8B : 0x89;
			coff[offset++] = 0x40 | ((r1 - rax) & 0x07) << 3 | 0x04; // modrm
			coff[offset++] = 0x04 << 3 | (r0 - rax) & 0x07; // sib
			coff[offset++] = 0x00;
			break;
		}
	}
	for (int i = 0; i < patch_count; i++) { // patch jumps
		int displacement = offsets[patches[i].target] - (patches[i].offset + 4);
		memcpy(&coff[patches[i].offset], &displacement, sizeof (int));
	}
	int reloc_offset = 20 + 40 + offset; int sym_offset = reloc_offset + 10*reloc_count;
	FILE *fp = fopen("out.obj", "wb");
	fwrite((char[20]){[0] = 0x64, 0x86, 0x01, [8] = sym_offset & 0xFF, sym_offset >> 8 & 0xFF, sym_offset >> 16 & 0xFF, sym_offset >> 24 & 0xFF, [12] = sym_count & 0xFF, sym_count >> 8 & 0xFF, sym_count >> 16 & 0xFF, sym_count >> 24 & 0xFF}, 1, 20, fp); // file header
	fwrite((char[40]){[0] = '.', 't', 'e', 'x', 't', [16] = offset & 0xFF, offset >> 8 & 0xFF, offset >> 16 & 0xFF, offset >> 24 & 0xFF,
		[20] = 60, [24] = reloc_offset & 0xFF, reloc_offset >> 8 & 0xFF, reloc_offset >> 16 & 0xFF, reloc_offset >> 24 & 0xFF,
		[32] = reloc_count & 0xFF, reloc_count >> 8 & 0xFF, reloc_count >> 16 & 0xFF, reloc_count >> 24 & 0xFF,
		[36] = 0x20, 0x00, 0x50, 0x60}, 1, 40, fp); // .text section header
	fwrite(coff, sizeof coff[0], offset, fp);
	for (int i = 0; i < reloc_count; i++) { // relocations
		fwrite(&relocs[i].offset, 1, 4, fp);
		fwrite(&relocs[i].sym, 1, 4, fp);
		fwrite((char[]){0x04, 0x00}, 1, 2, fp);
	}
	for (int i = 0; i < sym_count; i++) { // symbol table
		struct sym s = syms[i]; assert(strlen(s.name) <= 8);
		fwrite(s.name, 1, 8, fp); fwrite(&s.offset, 4, 1, fp);
		fwrite((char[]){s.ext ? 0x00 : 0x01, 0x00, 0x20/*function*/, 0x00, 0x02, 0x00}, 1, 6, fp);
	}
	fwrite((char[4]){4}, 1, 4, fp); // no string table
	fclose(fp);
}

int main(int argc, char **argv) {
	char *program = calloc(1000000, sizeof program[0]); assert(program);
	FILE *fp = fopen(argv[1], "rb"); fread(program, sizeof program[0], 1000000, fp); fclose(fp);
	parse(program);
	void print_ir(void); print_ir();
	select(); print_ir();
	build_live_ranges(); void print_live_ranges(void); print_live_ranges();
	allocate_registers(); print_ir();
	write_coff();
}

void print_ir(void) {
	char const *regs[] = {"_", "op0", "rax", "rcx", "rdx", "rbx", "rsp", "rbp", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15"};
	for (int i = 1; i < instruction_count; i++) {
		struct instruction ins = instructions[i];
		printf("%d:\t", i);
		if (ins.opcode != opcode_label) printf("\t");
		if (ins.res) printf("%d", ins.res);
		if (ins.reg) printf(":%s = ", regs[ins.reg]); else if (ins.res) printf(" := ");
		printf("%s", opcodes[ins.opcode]);
		if (ins.opcode == opcode_const) printf(" %d", ins.op0); else if (ins.op0) printf(" %d", ins.op0);
		if (ins.op1) printf(" %d", ins.op1);
		if (ins.op2) printf(" %d", ins.op2);
		printf("\n");
	}
	printf("\n");
}
void print_live_ranges(void) {
	for (int i = 0; i < COUNT_OF(lrs); i++) {
		if (lrs[i].to) printf("%d: [%d, %d]\n", i, lrs[i].from, lrs[i].to);
	}
	printf("\n");
}