From 8b2323108f484c259d863e68a23f9766e658c07d Mon Sep 17 00:00:00 2001 From: Mathias Magnusson Date: Tue, 22 Jul 2025 22:38:01 +0200 Subject: begin implementing procedure calls the register allocator does not consider the fact that called procedures probably clobber t-registers. also, the way i refer to the built in functions is cursed. it barely works now and won't when you can define procedures --- src/codegen.zig | 250 +++++++++++++++++++++++++++++++++----------------------- 1 file changed, 146 insertions(+), 104 deletions(-) (limited to 'src') diff --git a/src/codegen.zig b/src/codegen.zig index 701a4e8..a1c2171 100644 --- a/src/codegen.zig +++ b/src/codegen.zig @@ -561,20 +561,11 @@ const Relocation = struct { }; const Context = struct { - register_allocator: RegisterAllocator(compile.VReg), - lvar_allocator: RegisterAllocator(compile.LVar), instructions: std.ArrayList(Instruction), relocations: std.ArrayList(Relocation), block_starts: std.ArrayList(usize), - - // Current stuff that changes often, basically here to avoid prop drilling. - block: ?*const compile.BasicBlock = null, - current_instruction_index: ?usize = null, - - fn deinit(self: *Context) void { - self.register_allocator.deinit(); - self.instructions.deinit(); - } + print_block: compile.BlockRef, + read_int_block: compile.BlockRef, fn addRelocation(self: *Context, target: compile.BlockRef) !void { try self.relocations.append(.{ @@ -587,12 +578,34 @@ const Context = struct { try self.instructions.append(inst); } + fn deinit(self: *Context) void { + self.instructions.deinit(); + } +}; + +const ProcedureContext = struct { + register_allocator: RegisterAllocator(compile.VReg), + lvar_allocator: RegisterAllocator(compile.LVar), + ctx: *Context, + + // Current stuff that changes often, basically here to avoid prop drilling. + block: ?*const compile.BasicBlock = null, + current_instruction_index: ?usize = null, + + fn deinit(self: *Self) void { + self.register_allocator.deinit(); + } + + fn emit(self: *Self, inst: Instruction) !void { + try self.ctx.emit(inst); + } + /// Frees all virtual registers who's last use is the current `compile.Instr` or earlier. This /// must be called after the current compile.Instr's sources have been retrieved, since this /// will deallocate them, and after allocating auxiliary registers since otherwise they may /// collide with the sources. Should be called before allocating results to allow for more /// register re-use. - fn freeUnusedVRegs(self: *Context) !void { + fn freeUnusedVRegs(self: *Self) !void { var it = self.register_allocator.allocated.keyIterator(); while (it.next()) |vreg| { if (self.block.?.vreg_last_use.get(vreg.*)) |last_use| { @@ -603,7 +616,7 @@ const Context = struct { } } - fn genConstantInner(self: *Context, reg: Register, value: u64) !void { + fn genConstantInner(self: *Self, reg: Register, value: u64) !void { if (value <= std.math.maxInt(i12)) { // If the highest bit is set, we will get sign extension from this, but it will be // cleared by the next addiw. @@ -628,13 +641,13 @@ const Context = struct { } } - fn genConstant(self: *Context, constant: compile.Instr.Constant) !void { + fn genConstant(self: *Self, constant: compile.Instr.Constant) !void { try self.freeUnusedVRegs(); const reg = try self.register_allocator.allocate(constant.dest); try self.genConstantInner(reg, constant.value); } - fn genBinOp(self: *Context, bin_op: compile.Instr.BinOp) !void { + fn genBinOp(self: *Self, bin_op: compile.Instr.BinOp) !void { const lhs = self.register_allocator.get(bin_op.lhs); const rhs = self.register_allocator.get(bin_op.rhs); try self.freeUnusedVRegs(); @@ -653,124 +666,68 @@ const Context = struct { } } - fn genProcCall(self: *Context, call: compile.Instr.ProcCall) !void { + fn genProcCall(self: *Self, call: compile.Instr.ProcCall) !void { switch (call.proc) { .print => { const arg = self.register_allocator.get(call.arg); try self.freeUnusedVRegs(); - const num = try self.register_allocator.allocateAux(); - defer self.register_allocator.freeAux(num); - - if (arg != num) try self.emit(.addi(num, arg, 0)); - - const quot = try self.register_allocator.allocateAux(); - defer self.register_allocator.freeAux(quot); - const digit = try self.register_allocator.allocateAux(); - defer self.register_allocator.freeAux(digit); - const count = try self.register_allocator.allocate(call.dest); - - try self.emit(.addi(digit, .zero, '\n')); - try self.emit(.addi(.sp, .sp, -1)); - try self.emit(.sb(.sp, 0, digit)); - try self.emit(.addi(count, .zero, 1)); - try self.emit(.addi(quot, .zero, 10)); - - try self.emit(.addi(digit, num, 0)); - try self.emit(.divu(num, digit, quot)); - try self.emit(.remu(digit, digit, quot)); - try self.emit(.addi(digit, digit, '0')); - try self.emit(.addi(.sp, .sp, -1)); - try self.emit(.sb(.sp, 0, digit)); - try self.emit(.addi(count, count, 1)); - try self.emit(.bne(num, .zero, -4 * 7)); - - try self.emit(.addi(.a7, .zero, @intFromEnum(std.os.linux.syscalls.RiscV64.write))); - try self.emit(.addi(.a0, .zero, 1)); // fd = stdout - try self.emit(.addi(.a1, .sp, 0)); // buf = sp - try self.emit(.addi(.a2, count, 0)); // count = count - try self.emit(.ecall()); // syscall(no, fd, buf, count) - try self.emit(.add(.sp, .sp, count)); - - try self.emit(.addi(count, count, -1)); + const result = try self.register_allocator.allocate(call.dest); + + try self.emit(.addi(.a0, arg, 0)); + try self.ctx.addRelocation(self.ctx.print_block); + try self.emit(.jal(.ra, 0)); + try self.emit(.addi(result, .a0, 0)); }, .read_int => { try self.freeUnusedVRegs(); const result = try self.register_allocator.allocate(call.dest); - const ptr = try self.register_allocator.allocateAux(); - defer self.register_allocator.freeAux(ptr); - const char = try self.register_allocator.allocateAux(); - defer self.register_allocator.freeAux(char); - const newline = try self.register_allocator.allocateAux(); - defer self.register_allocator.freeAux(newline); - - try self.emit(.addi(.sp, .sp, -21)); - try self.emit(.addi(newline, .zero, '\n')); - try self.emit(.addi(result, .zero, 0)); - - try self.emit(.addi(.a7, .zero, @intFromEnum(std.os.linux.syscalls.RiscV64.read))); - try self.emit(.addi(.a0, .zero, 0)); // fd = stdin - try self.emit(.addi(.a1, .sp, 0)); // buf = sp - try self.emit(.addi(.a2, .zero, 21)); // count = count - try self.emit(.ecall()); // syscall(no, fd, buf, count) - - try self.emit(.addi(ptr, .sp, 0)); - try self.emit(.add(.a0, .a0, .sp)); // a0 = end - - // loop start - try self.emit(.bgeu(ptr, .a0, 4 * 8)); // done - try self.emit(.lb(char, ptr, 0)); - try self.emit(.beq(char, newline, 4 * 6)); // done - try self.emit(.mul(result, result, newline)); // '\n' happens to also be 10 - try self.emit(.addi(char, char, -'0')); - // assert 0 <= char <= 9 - try self.emit(.add(result, result, char)); - - try self.emit(.addi(ptr, ptr, 1)); - try self.emit(.jal(.zero, -4 * 7)); // -> loop start - // done - - try self.emit(.addi(.sp, .sp, 21)); + + try self.ctx.addRelocation(self.ctx.read_int_block); + try self.emit(.jal(.ra, 0)); + try self.emit(.addi(result, .a0, 0)); }, } } - fn genBranch(self: *Context, branch: compile.Instr.Branch) !void { + fn genBranch(self: *Self, branch: compile.Instr.Branch) !void { const cond = self.register_allocator.get(branch.cond); try self.freeUnusedVRegs(); - try self.addRelocation(branch.false); + try self.ctx.addRelocation(branch.false); try self.emit(.beq(cond, .zero, 0)); - try self.addRelocation(branch.true); + try self.ctx.addRelocation(branch.true); try self.emit(.jal(.zero, 0)); } - fn genJump(self: *Context, jump: compile.Instr.Jump) !void { + fn genJump(self: *Self, jump: compile.Instr.Jump) !void { try self.freeUnusedVRegs(); - try self.addRelocation(jump.to); + try self.ctx.addRelocation(jump.to); try self.emit(.jal(.zero, 0)); } - fn genExit(self: *Context, _: compile.Instr.Exit) !void { + fn genExit(self: *Self, _: compile.Instr.Exit) !void { try self.freeUnusedVRegs(); try self.emit(.addi(.a0, .zero, 0)); - try self.emit(.addi(.a7, .zero, 93)); + try self.emit(.addi(.a7, .zero, @intFromEnum(std.os.linux.syscalls.RiscV64.exit))); try self.emit(.ecall()); + // This will never be run, but makes binary ninja understand that this exits the function. + try self.emit(.jalr(.zero, .ra, 0)); } - fn genAssignLocal(self: *Context, assign_local: compile.Instr.AssignLocal) !void { + fn genAssignLocal(self: *Self, assign_local: compile.Instr.AssignLocal) !void { const src = self.register_allocator.get(assign_local.val); try self.freeUnusedVRegs(); const reg = try self.lvar_allocator.getOrAllocate(assign_local.local); try self.emit(.addi(reg, src, 0)); } - fn genGetLocal(self: *Context, get_local: compile.Instr.GetLocal) !void { + fn genGetLocal(self: *Self, get_local: compile.Instr.GetLocal) !void { try self.freeUnusedVRegs(); const src = self.lvar_allocator.get(get_local.local); @@ -778,28 +735,28 @@ const Context = struct { try self.emit(.addi(reg, src, 0)); } - fn codegenInstr(self: *Context, instr: compile.Instr) !void { + fn codegenInstr(self: *Self, instr: compile.Instr) !void { switch (instr.type) { inline else => |ty| { const func = comptime blk: { const typeName = @typeName(@TypeOf(ty)); var it = std.mem.splitBackwardsScalar(u8, typeName, '.'); const base = it.first(); - if (!@hasDecl(Context, "gen" ++ base)) { + if (!@hasDecl(Self, "gen" ++ base)) { @compileError(std.fmt.comptimePrint( "codegen.Context must have a member named 'gen{s}' " ++ "since compile.Instr.Type has a variant named {s}", .{ base, typeName }, )); } - break :blk @field(Context, "gen" ++ base); + break :blk @field(Self, "gen" ++ base); }; try func(self, ty); }, } } - fn codegenBlock(self: *Context, block: compile.BasicBlock) !void { + fn codegenBlock(self: *Self, block: compile.BasicBlock) !void { self.block = █ defer self.block = null; for (block.instrs.items, 0..) |instr, i| { @@ -808,25 +765,112 @@ const Context = struct { } } - fn codegenProc(self: *Context, proc: compile.Procedure) !void { + fn codegenProc(self: *Self, proc: compile.Procedure) !void { for (proc.blocks) |block| { - try self.block_starts.append(self.instructions.items.len); + try self.ctx.block_starts.append(self.ctx.instructions.items.len); try self.codegenBlock(block); } } + + const Self = ProcedureContext; }; +fn codegenPrint(self: *Context) !void { + const num = .a0; + + const quot = .t0; + const digit = .t1; + const count = .t2; + + try self.emit(.addi(digit, .zero, '\n')); + try self.emit(.addi(.sp, .sp, -1)); + try self.emit(.sb(.sp, 0, digit)); + try self.emit(.addi(count, .zero, 1)); + try self.emit(.addi(quot, .zero, 10)); + + try self.emit(.addi(digit, num, 0)); + try self.emit(.divu(num, digit, quot)); + try self.emit(.remu(digit, digit, quot)); + try self.emit(.addi(digit, digit, '0')); + try self.emit(.addi(.sp, .sp, -1)); + try self.emit(.sb(.sp, 0, digit)); + try self.emit(.addi(count, count, 1)); + try self.emit(.bne(num, .zero, -4 * 7)); + + try self.emit(.addi(.a7, .zero, @intFromEnum(std.os.linux.syscalls.RiscV64.write))); + try self.emit(.addi(.a0, .zero, 1)); // fd = stdout + try self.emit(.addi(.a1, .sp, 0)); // buf = sp + try self.emit(.addi(.a2, count, 0)); // count = count + try self.emit(.ecall()); // syscall(no, fd, buf, count) + try self.emit(.add(.sp, .sp, count)); + + try self.emit(.addi(.a0, count, -1)); + try self.emit(.jalr(.zero, .ra, 0)); +} + +fn codegenReadInt(self: *Context) !void { + const result = .t0; + const ptr = .t1; + const char = .t2; + const newline = .t3; + + try self.emit(.addi(.sp, .sp, -21)); + try self.emit(.addi(newline, .zero, '\n')); + try self.emit(.addi(result, .zero, 0)); + + try self.emit(.addi(.a7, .zero, @intFromEnum(std.os.linux.syscalls.RiscV64.read))); + try self.emit(.addi(.a0, .zero, 0)); // fd = stdin + try self.emit(.addi(.a1, .sp, 0)); // buf = sp + try self.emit(.addi(.a2, .zero, 21)); // count = count + try self.emit(.ecall()); // syscall(no, fd, buf, count) + + try self.emit(.addi(ptr, .sp, 0)); + try self.emit(.add(.a0, .a0, .sp)); // a0 = end + + // loop start + try self.emit(.bgeu(ptr, .a0, 4 * 8)); // done + try self.emit(.lb(char, ptr, 0)); + try self.emit(.beq(char, newline, 4 * 6)); // done + try self.emit(.mul(result, result, newline)); // '\n' happens to also be 10 + try self.emit(.addi(char, char, -'0')); + // assert 0 <= char <= 9 + try self.emit(.add(result, result, char)); + + try self.emit(.addi(ptr, ptr, 1)); + try self.emit(.jal(.zero, -4 * 7)); // -> loop start + // done + + try self.emit(.addi(.sp, .sp, 21)); + try self.emit(.addi(.a0, result, 0)); + try self.emit(.jalr(.zero, .ra, 0)); +} + pub fn create_elf(allocator: Allocator, proc: compile.Procedure) ![]u8 { var ctx: Context = .{ - .register_allocator = try .init(allocator, &.{ .t6, .t5, .t4, .t3, .t2, .t1, .t0 }), - .lvar_allocator = try .init(allocator, &.{ .s11, .s10, .s9, .s8, .s7, .s6, .s5, .s4, .s3, .s2, .s1, .s0 }), .instructions = .init(allocator), .relocations = .init(allocator), .block_starts = .init(allocator), + .print_block = @enumFromInt(proc.blocks.len), + .read_int_block = @enumFromInt(proc.blocks.len + 1), }; defer ctx.deinit(); - try ctx.codegenProc(proc); + { + var proc_ctx: ProcedureContext = .{ + .register_allocator = try .init(allocator, &.{ .t6, .t5, .t4, .t3, .t2, .t1, .t0 }), + .lvar_allocator = try .init(allocator, &.{ .s11, .s10, .s9, .s8, .s7, .s6, .s5, .s4, .s3, .s2, .s1, .s0 }), + .ctx = &ctx, + }; + defer proc_ctx.deinit(); + + try proc_ctx.codegenProc(proc); + std.debug.assert(proc_ctx.register_allocator.allocated.count() == 0); + } + + try ctx.block_starts.append(ctx.instructions.items.len); + try codegenPrint(&ctx); + try ctx.block_starts.append(ctx.instructions.items.len); + try codegenReadInt(&ctx); // TODO: make this less sheiße for (ctx.relocations.items) |relocation| { @@ -848,8 +892,6 @@ pub fn create_elf(allocator: Allocator, proc: compile.Procedure) ![]u8 { } } - std.debug.assert(ctx.register_allocator.allocated.count() == 0); - var output_buffer: std.ArrayList(u8) = .init(allocator); errdefer output_buffer.deinit(); try output_buffer.appendNTimes(undefined, @sizeOf(elf.Elf64_Ehdr) + @sizeOf(elf.Elf64_Phdr)); -- cgit v1.2.3