const std = @import("std"); const elf = std.elf; const Allocator = std.mem.Allocator; const root = @import("root"); const compile = root.compile; const Register = enum(u5) { // zig fmt: off zero, ra, sp, gp, tp, t0, t1, t2, s0, s1, a0, a1, a2, a3, a4, a5, a6, a7, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, t3, t4, t5, t6, // zig fmt: on const fp = .s0; fn x(number: u5) @This() { return @enumFromInt(number); } }; const Opcode = u7; const Instruction = packed union { r: R, i: I, s: S, b: B, u: U, j: J, const R = packed struct(u32) { opcode: Opcode, rd: Register, funct3: u3, rs1: Register, rs2: Register, funct7: u7, fn init(opcode: Opcode, rd: Register, funct3: u3, rs1: Register, rs2: Register, funct7: u7) Self { return .{ .r = .{ .opcode = opcode, .rd = rd, .funct3 = funct3, .rs1 = rs1, .rs2 = rs2, .funct7 = funct7, } }; } }; /// rd = rs1 + rs2 fn add(rd: Register, rs1: Register, rs2: Register) Self { return R.init(0b0110011, rd, 0, rs1, rs2, 0); } /// rd = rs1 + rs2 fn addw(rd: Register, rs1: Register, rs2: Register) Self { return R.init(0b0111011, rd, 0, rs1, rs2, 0); } /// rd = rs1 - rs2 fn sub(rd: Register, rs1: Register, rs2: Register) Self { return R.init(0b0110011, rd, 0, rs1, rs2, 32); } /// rd = rs1 - rs2 fn subw(rd: Register, rs1: Register, rs2: Register) Self { return R.init(0b0111011, rd, 0, rs1, rs2, 32); } /// Bitwise xor. /// rd = rs1 ^ rs2 fn xor(rd: Register, rs1: Register, rs2: Register) Self { return R.init(0b0110011, rd, 4, rs1, rs2, 0); } /// Bitwise or. /// rd = rs1 | rs2 fn or_(rd: Register, rs1: Register, rs2: Register) Self { return R.init(0b0110011, rd, 6, rs1, rs2, 0); } /// Bitwise and. /// rd = rs1 & rs2 fn and_(rd: Register, rs1: Register, rs2: Register) Self { return R.init(0b0110011, rd, 7, rs1, rs2, 0); } /// Shift left logical. /// rd = rs1 << rs2 fn sll(rd: Register, rs1: Register, rs2: Register) Self { return R.init(0b0110011, rd, 1, rs1, rs2, 0); } /// Shift left logical word. /// rd = rs1 << rs2 fn sllw(rd: Register, rs1: Register, rs2: Register) Self { return R.init(0b0111011, rd, 1, rs1, rs2, 0); } /// Shift right logical. /// rd = rs1 >> rs2 fn srl(rd: Register, rs1: Register, rs2: Register) Self { return R.init(0b0110011, rd, 5, rs1, rs2, 0); } /// Shift right logical word. /// rd = rs1 >> rs2 fn srlw(rd: Register, rs1: Register, rs2: Register) Self { return R.init(0b0111011, rd, 5, rs1, rs2, 0); } /// Shift right arithmetic (preserves sign bit). /// rd = (rs1 >> rs2) | (rs1 & (1 << (bits - 1))) fn sra(rd: Register, rs1: Register, rs2: Register) Self { return R.init(0b0110011, rd, 5, rs1, rs2, 32); } /// Shift right arithmetic (preserves sign bit) word. /// rd = (rs1 >> rs2) | (rs1 & (1 << (bits - 1))) fn sraw(rd: Register, rs1: Register, rs2: Register) Self { return R.init(0b0111011, rd, 5, rs1, rs2, 32); } /// Set less than, signed. /// rd = rs1 s< rs2 ? 1 : 0 fn slt(rd: Register, rs1: Register, rs2: Register) Self { return R.init(0b0110011, rd, 2, rs1, rs2, 0); } /// Set less than, unsigned. /// rd = rs1 u< rs2 ? 1 : 0 fn sltu(rd: Register, rs1: Register, rs2: Register) Self { return R.init(0b0110011, rd, 3, rs1, rs2, 0); } const I = packed struct(u32) { opcode: Opcode, rd: Register, funct3: u3, rs1: Register, imm: u12, fn init(opcode: Opcode, rd: Register, funct3: u3, rs1: Register, imm: i12) Self { return .{ .i = .{ .opcode = opcode, .rd = rd, .funct3 = funct3, .rs1 = rs1, .imm = @bitCast(imm), } }; } }; /// Add immediate. /// rd = rs1 + imm fn addi(rd: Register, rs1: Register, imm: i12) Self { return I.init(0b0010011, rd, 0, rs1, imm); } /// Add immediate word. /// rd = rs1 + imm fn addiw(rd: Register, rs1: Register, imm: i12) Self { return I.init(0b0011011, rd, 0, rs1, imm); } /// Xor immediate. /// rd = rs1 ^ imm fn xori(rd: Register, rs1: Register, imm: i12) Self { return I.init(0b0010011, rd, 4, rs1, imm); } /// Or immediate. /// rd = rs1 | imm fn ori(rd: Register, rs1: Register, imm: i12) Self { return I.init(0b0010011, rd, 6, rs1, imm); } /// And immediate. /// rd = rs1 & imm fn andi(rd: Register, rs1: Register, imm: i12) Self { return I.init(0b0010011, rd, 7, rs1, imm); } /// Shift left logical immediate. /// rd = rs1 << by fn slli(rd: Register, rs1: Register, by: u6) Self { return I.init(0b0010011, rd, 1, rs1, @intCast(by)); } /// Shift left logical word immediate. /// rd = rs1 << by fn slliw(rd: Register, rs1: Register, by: u5) Self { return I.init(0b0011011, rd, 1, rs1, @intCast(by)); } /// Shift right logical immediate. /// rd = rs1 >> by fn srli(rd: Register, rs1: Register, by: u6) Self { return I.init(0b0010011, rd, 5, rs1, @intCast(by)); } /// Shift right logical word immediate. /// rd = rs1 >> by fn srliw(rd: Register, rs1: Register, by: u6) Self { return I.init(0b0011011, rd, 5, rs1, @intCast(by)); } /// Shift right arithmetic immediate (preserves sign bit). /// rd = ((((~0 * (rs1 >> (bits - 1))) << bits) | rs1) >> by) fn srai(rd: Register, rs1: Register, by: u6) Self { return I.init(0b0010011, rd, 5, rs1, 0x20 << 5 | @as(i12, @intCast(by))); } /// Shift right arithmetic word immediate (preserves sign bit). /// rd = ((((~0 * (rs1 >> (bits - 1))) << bits) | rs1) >> by) fn sraiw(rd: Register, rs1: Register, by: u6) Self { return I.init(0b0011011, rd, 5, rs1, 0x20 << 5 | @as(i12, @intCast(by))); } /// Set less than immediate, signed. /// rd = rs1 s< rs2 ? 1 : 0 fn slti(rd: Register, rs1: Register, imm: i12) Self { return I.init(0b0010011, rd, 2, rs1, imm); } /// Set less than sign extended immediate, unsigned comparison. /// rd = rs1 u< rs2 ? 1 : 0 fn sltiu(rd: Register, rs1: Register, imm: i12) Self { return I.init(0b0010011, rd, 3, rs1, imm); } /// Load byte and sign extend. /// rd = @as(*i8, @ptrFromInt(rs1 + imm)).* fn lb(rd: Register, rs1: Register, imm: i12) Self { return I.init(0b0000011, rd, 0, rs1, imm); } /// Load half and sign extend. /// rd = @as(*i16, @ptrFromInt(rs1 + imm)).* fn lh(rd: Register, rs1: Register, imm: i12) Self { return I.init(0b0000011, rd, 1, rs1, imm); } /// Load word and sign extend. /// rd = @as(*i32, @ptrFromInt(rs1 + imm)).* fn lw(rd: Register, rs1: Register, imm: i12) Self { return I.init(0b0000011, rd, 2, rs1, imm); } /// Load double. /// rd = @as(*u64, @ptrFromInt(rs1 + imm)).* fn ld(rd: Register, rs1: Register, imm: i12) Self { return I.init(0b0000011, rd, 3, rs1, imm); } /// Load byte and zero extend. /// rd = @as(*u8, @ptrFromInt(rs1 + imm)).* fn lbu(rd: Register, rs1: Register, imm: i12) Self { return I.init(0b0000011, rd, 4, rs1, imm); } /// Load half and zero extend. /// rd = @as(*u16, @ptrFromInt(rs1 + imm)).* fn lhu(rd: Register, rs1: Register, imm: i12) Self { return I.init(0b0000011, rd, 5, rs1, imm); } /// Load word and zero extend. /// rd = @as(*u32, @ptrFromInt(rs1 + imm)).* fn lwu(rd: Register, rs1: Register, imm: i12) Self { return I.init(0b0000011, rd, 6, rs1, imm); } /// Jump and link register. /// rd = pc + 4; pc = rs1 + imm fn jalr(rd: Register, rs1: Register, imm: i12) Self { return I.init(0b1100111, rd, 0, rs1, imm); } /// Environment call. Issue a syscall on linux fn ecall() Self { return I.init(0b1110011, .zero, 0, .zero, 0); } const S = packed struct(u32) { opcode: Opcode, imm4_0: u5, funct3: u3, rs1: Register, rs2: Register, imm11_5: u7, fn init(opcode: Opcode, funct3: u3, rs1: Register, rs2: Register, imm: i12) Self { const umm: u12 = @bitCast(imm); return .{ .s = .{ .opcode = opcode, .imm4_0 = @truncate(umm), .funct3 = funct3, .rs1 = rs1, .rs2 = rs2, .imm11_5 = umm >> 5, } }; } }; /// Store byte. /// @as(*u8, @ptrFromInt(rs1 + imm)).* = @truncate(rs2) fn sb(rs1: Register, imm: i12, rs2: Register) Self { return S.init(0b0100011, 0, rs1, rs2, imm); } /// Store half. /// @as(*u16, @ptrFromInt(rs1 + imm)).* = @truncate(rs2) fn sh(rs1: Register, imm: i12, rs2: Register) Self { return S.init(0b0100011, 1, rs1, rs2, imm); } /// Store word. /// @as(*u32, @ptrFromInt(rs1 + imm)).* = @truncate(rs2) fn sw(rs1: Register, imm: i12, rs2: Register) Self { return S.init(0b0100011, 2, rs1, rs2, imm); } /// Store double. /// @as(*u64, @ptrFromInt(rs1 + imm)).* = @truncate(rs2) fn sd(rs1: Register, imm: i12, rs2: Register) Self { return S.init(0b0100011, 3, rs1, rs2, imm); } /// The lowest bit of the immediate value, imm, is not stored as imm will always be even. const B = packed struct(u32) { opcode: Opcode, imm11: u1, imm4_1: u4, funct3: u3, rs1: Register, rs2: Register, imm10_5: u6, imm12: u1, fn init(opcode: Opcode, funct3: u3, rs1: Register, rs2: Register, imm: i13) Self { std.debug.assert(imm % 2 == 0); const umm: u13 = @bitCast(imm); return .{ .b = .{ .opcode = opcode, .imm11 = @truncate(umm >> 11), .imm4_1 = @truncate(umm >> 1), .funct3 = funct3, .rs1 = rs1, .rs2 = rs2, .imm10_5 = @truncate(umm >> 5), .imm12 = umm >> 12, } }; } }; /// Branch if equal. /// pc = if (rs1 == rs2) then pc + imm else pc + 4 /// /// Advancing pc by 4 is done by all instructions, but shown here for clarity. /// Note that `imm` must be even since function addresses always are. fn beq(rs1: Register, rs2: Register, imm: i13) Self { return B.init(0b1100011, 0, rs1, rs2, imm); } /// Branch if not equal. /// pc = if (rs1 != rs2) then pc + imm else pc + 4 /// /// Advancing pc by 4 is done by all instructions, but shown here for clarity. /// Note that `imm` must be even since function addresses always are. fn bne(rs1: Register, rs2: Register, imm: i13) Self { return B.init(0b1100011, 1, rs1, rs2, imm); } /// Branch if less than, signed. /// pc = if (rs1 s< rs2) then pc + imm else pc + 4 /// /// Advancing pc by 4 is done by all instructions, but shown here for clarity. /// Note that `imm` must be even since function addresses always are. fn blt(rs1: Register, rs2: Register, imm: i13) Self { return B.init(0b1100011, 4, rs1, rs2, imm); } /// Branch if greater than or equal, signed. /// pc = if (rs1 s>= rs2) then pc + imm else pc + 4 /// /// Advancing pc by 4 is done by all instructions, but shown here for clarity. /// Note that `imm` must be even since function addresses always are. fn bge(rs1: Register, rs2: Register, imm: i13) Self { return B.init(0b1100011, 5, rs1, rs2, imm); } /// Branch if less than, unsigned. /// pc = if (rs1 u< rs2) then pc + imm else pc + 4 /// /// Advancing pc by 4 is done by all instructions, but shown here for clarity. /// Note that `imm` must be even since function addresses always are. fn bltu(rs1: Register, rs2: Register, imm: i13) Self { return B.init(0b1100011, 6, rs1, rs2, imm); } /// Branch if greater than or equal, unsigned. /// pc = if (rs1 u>= rs2) then pc + imm else pc + 4 /// /// Advancing pc by 4 is done by all instructions, but shown here for clarity. /// Note that `imm` must be even since function addresses always are. fn bgeu(rs1: Register, rs2: Register, imm: i13) Self { return B.init(0b1100011, 7, rs1, rs2, imm); } const U = packed struct(u32) { opcode: Opcode, rd: Register, imm12_31: u20, fn init(opcode: Opcode, rd: Register, imm: i20) Self { return .{ .u = .{ .opcode = opcode, .rd = rd, .imm12_31 = @bitCast(imm), } }; } }; /// Load upper immediate. /// rd = imm << 12 fn lui(rd: Register, imm: i20) Self { return U.init(0b0110111, rd, imm); } /// Add upper immediate to pc. /// rd = pc + (imm << 12) fn auipc(rd: Register, imm: i20) Self { return U.init(0b0010111, rd, imm); } const J = packed struct(u32) { opcode: Opcode, rd: Register, imm12_19: u8, imm11: u1, imm1_10: u10, imm20: u1, fn init(opcode: Opcode, rd: Register, imm: i21) Self { std.debug.assert(imm % 2 == 0); const umm: u21 = @bitCast(imm); return .{ .j = .{ .opcode = opcode, .rd = rd, .imm12_19 = @truncate(umm >> 12), .imm11 = @truncate(umm >> 11), .imm1_10 = @truncate(umm >> 1), .imm20 = umm >> 20, } }; } }; /// Jump and link. /// rd = pc + 4; pc = pc + imm fn jal(rd: Register, imm: i21) Self { return J.init(0b1101111, rd, imm); } const Self = @This(); }; const RegisterAllocator = struct { allocated: std.AutoHashMap(compile.VReg, Register), available: std.ArrayList(Register), fn init(allocator: Allocator) !RegisterAllocator { var available: std.ArrayList(Register) = .init(allocator); for ([_]Register{ .t6, .t5, .t4, .t3, .t2, .t1, .t0 }) |reg| { try available.append(reg); } var allocated: std.AutoHashMap(compile.VReg, Register) = .init(allocator); try allocated.ensureTotalCapacity(@intCast(available.items.len)); return .{ .allocated = allocated, .available = available, }; } fn deinit(self: *RegisterAllocator) void { self.allocated.deinit(); self.available.deinit(); } fn get(self: *const RegisterAllocator, vreg: compile.VReg) Register { return self.allocated.get(vreg).?; } fn allocate(self: *RegisterAllocator, vreg: compile.VReg) ?Register { const reg = self.available.pop() orelse return null; self.allocated.putAssumeCapacityNoClobber(vreg, reg); return reg; } fn free(self: *RegisterAllocator, vreg: compile.VReg) void { const ent = self.allocated.fetchRemove(vreg).?; const reg = ent.value; std.debug.assert(std.mem.indexOfScalar(Register, self.available.items, reg) == null); return self.available.appendAssumeCapacity(reg); } }; const Context = struct { register_allocator: RegisterAllocator, instructions: std.ArrayList(Instruction), // Current stuff that changes often, basically here to avoid prop drilling. block: ?*const compile.Block = null, current_instruction_index: ?usize = null, fn deinit(self: *Context) void { self.register_allocator.deinit(); self.instructions.deinit(); } fn emit(self: *Context, inst: Instruction) !void { try self.instructions.append(inst); } fn maybeFreeSources(self: *Context, vregs: compile.Instr.Sources) !void { for (vregs.slice()) |src| { if (self.block.?.vreg_last_use.get(src) == self.current_instruction_index.?) { self.register_allocator.free(src); } } } fn genConstant(self: *Context, constant: compile.Instr.Constant) !void { const reg = self.register_allocator.allocate(constant.dest) orelse return error.OutOfRegisters; if (constant.value <= std.math.maxInt(i12)) { try self.emit(.addi(reg, .zero, @intCast(constant.value))); } else if (constant.value <= std.math.maxInt(i32)) { // If the higest bit in the immediate in addi is set, it will be sign extended. We negate that by adding one more to the immediate for lui. try self.emit(.lui(reg, @intCast((constant.value >> 12) + if (constant.value & (1 << 11) != 0) @as(u64, 1) else 0))); try self.emit(.addi(reg, reg, @bitCast(@as(u12, @truncate(constant.value))))); } else { unreachable; // TODO } } fn genBinOp(self: *Context, bin_op: compile.Instr.BinOp) !void { const lhs = self.register_allocator.get(bin_op.lhs); const rhs = self.register_allocator.get(bin_op.rhs); try self.maybeFreeSources(bin_op.sources()); const reg = self.register_allocator.allocate(bin_op.dest) orelse return error.OutOfRegisters; switch (bin_op.op) { .add => try self.emit(.add(reg, lhs, rhs)), } } fn codegenInstr(self: *Context, instr: compile.Instr) !void { switch (instr.type) { inline else => |ty| { const func = comptime blk: { const typeName = @typeName(@TypeOf(ty)); var it = std.mem.splitBackwardsScalar(u8, typeName, '.'); const base = it.first(); if (!@hasDecl(Context, "gen" ++ base)) { @compileError(std.fmt.comptimePrint( "codegen.Context must have a member named 'gen{s}' " ++ "since compile.Instr.Type has a variant named {s}", .{ base, typeName }, )); } break :blk @field(Context, "gen" ++ base); }; try func(self, ty); }, } } fn codegenBlock(self: *Context, block: compile.Block) !void { self.block = █ defer self.block = null; for (block.instrs, 0..) |instr, i| { self.current_instruction_index = i; try self.codegenInstr(instr); } } }; pub fn create_elf(allocator: Allocator, block: compile.Block) ![]u8 { var ctx: Context = .{ .register_allocator = try .init(allocator), .instructions = .init(allocator) }; defer ctx.deinit(); try ctx.codegenBlock(block); try ctx.instructions.appendSlice(&[_]Instruction{ .addi(.a0, ctx.register_allocator.get(block.instrs[block.instrs.len - 1].dest()), 0), .addi(.a7, .zero, 93), .ecall(), }); var output_buffer: std.ArrayList(u8) = .init(allocator); errdefer output_buffer.deinit(); try output_buffer.appendNTimes(undefined, @sizeOf(elf.Elf64_Ehdr) + @sizeOf(elf.Elf64_Phdr)); const output = output_buffer.writer(); for (ctx.instructions.items) |instr| { try output.writeInt(u32, @bitCast(instr), .little); } const base_addr = 0x10000000; const elf_header: elf.Elf64_Ehdr = .{ .e_ident = elf.MAGIC.* ++ [_]u8{ elf.ELFCLASS64, // EI_CLASS elf.ELFDATA2LSB, // EI_DATA 1, // EI_VERSION @intFromEnum(elf.OSABI.NONE), // EI_OSABI 0, // EI_ABIVERSION } ++ [1]u8{0} ** 7, // EI_PAD .e_type = .EXEC, .e_machine = elf.EM.RISCV, .e_version = 1, .e_entry = base_addr + @sizeOf(elf.Elf64_Ehdr) + @sizeOf(elf.Elf64_Phdr), // we place the code directly after the headers .e_phoff = @sizeOf(elf.Elf64_Ehdr), // we place the program header(s) directly after the elf header .e_shoff = 0, .e_flags = 0, // ? .e_ehsize = @sizeOf(elf.Elf64_Ehdr), .e_phentsize = @sizeOf(elf.Elf64_Phdr), .e_phnum = 1, .e_shentsize = @sizeOf(elf.Elf64_Shdr), .e_shnum = 0, .e_shstrndx = 0, }; const program_header: elf.Elf64_Phdr = .{ .p_type = elf.PT_LOAD, .p_flags = elf.PF_R | elf.PF_X, .p_offset = 0, .p_vaddr = base_addr, .p_paddr = base_addr, .p_filesz = output_buffer.items.len, .p_memsz = output_buffer.items.len, .p_align = 0x1000, }; @memcpy(output_buffer.items[0..@sizeOf(elf.Elf64_Ehdr)], std.mem.asBytes(&elf_header)); @memcpy(output_buffer.items[@sizeOf(elf.Elf64_Ehdr)..][0..@sizeOf(elf.Elf64_Phdr)], std.mem.asBytes(&program_header)); return output_buffer.toOwnedSlice(); }