#include #include #include #include uint8_t mem[64*1024]; union regset { uint8_t r8[8]; // AL, AH, CL, CH... uint16_t r16[8]; // AX, CX, DX, BX, SP, BP, SI, DI } regset; int carry, zero; uint16_t ip = 0x100; uint16_t flags = 0; uint16_t scratch; uint16_t reg; int off; void* rmptr; #define RM8 *(uint8_t*)rmptr #define RM16 *(uint16_t*)rmptr #define R8(x) regset.r8[fixr8ref(x)] #define R16(x) regset.r16[x] #define AL regset.r8[0] #define CL regset.r8[2] #define DL regset.r8[4] #define AX regset.r16[0] #define BX regset.r16[3] #define SP regset.r16[4] #define BP regset.r16[5] #define SI regset.r16[6] #define DI regset.r16[7] #define ADD 0 #define OR 1 #define ADC 2 #define SBB 3 #define AND 4 #define SUB 5 #define XOR 6 #define CMP 7 #define TEST 8 const char* const regnames[] = { "AX", "CX", "DX", "BX", "SP", "BP", "SI", "DI", 0 }; // Dump registers void dump() { printf("\n"); int i; for (i=0; i<8; i++) { printf("%s=%04X ", regnames[i], regset.r16[i]); } printf("\n"); printf("IP=%04X : %02X %02X\n", ip, mem[ip], mem[ip+1]); } void arith8(int op, int8_t* dst, int src) { // I'll just assume this sign-extends int t = *dst; int p = *dst; switch (op) { case ADD: t += src; break; case OR: t |= src; break; case ADC: t += (src + carry); break; case SBB: t -= (src + carry); break; case AND: t &= src; break; case SUB: t -= src; break; case XOR: t ^= src; break; case CMP: t -= src; break; case TEST: t &= src; break; } zero=!t; // I dont think the following is correct carry=((t ^ p) >> 8) & 1; if (op==CMP || op==TEST) return; *dst = t; } void arith16(int op, int16_t* dst, int src) { // I'll just assume this sign-extends int t = *dst; int p = *dst; switch (op) { case ADD: t += src; break; case OR: t |= src; break; case ADC: t += (src + carry); break; case SBB: t -= (src + carry); break; case AND: t &= src; break; case SUB: t -= src; break; case XOR: t ^= src; break; case CMP: t -= src; break; case TEST: t &= src; break; } zero=!t; // I dont think the following is correct carry=((t ^ p) >> 16) & 1; if (op==CMP || op==TEST) return; *dst = t; } uint16_t get_flags() { return (carry & 1) | ((zero & 1)<< 6); } int fixr8ref(int r) { return ((r << 1) & 0b110) | ((r >> 2) & 0b1); } uint8_t imm8() { uint8_t r = *(uint8_t*)&mem[ip]; ip++; return r; } uint16_t imm16() { uint16_t r = *(uint16_t*)&mem[ip]; ip+=2; return r; } #define BIT8 8 #define BIT16 16 void modrm(int size) { uint8_t mod = (mem[ip] >> 6) & 0b11; // highest 2 bits reg = (mem[ip] >> 3) & 0b111; uint8_t rm = (mem[ip]) & 0b111; ip++; if (mod == 3) { if (size == BIT16) { rmptr = &R16(rm); } else { rmptr = &R8(rm); } return; } if (mod == 0 && rm == 6) { rmptr = mem + (int16_t)imm16(); return; } switch(rm) { case 0: rmptr = mem + BX + SI; break; case 1: rmptr = mem + BX + DI; break; case 2: rmptr = mem + BP + SI; break; case 3: rmptr = mem + BP + DI; break; case 4: rmptr = mem + SI; break; case 5: rmptr = mem + DI; break; case 6: rmptr = mem + BP; break; case 7: rmptr = mem + BX; break; } if (mod == 1) rmptr += (int8_t)imm8(); if (mod == 2) rmptr += (int16_t)imm16(); } int cond(int num) { int r = 0; switch(num >> 1) { case 2: r = !scratch; break; default: fprintf(stderr, "Condition code %d not implemented\n", num); exit(1); break; } // Each odd condition is the negation of the previous if (num & 1) r=1-r; return r; } void push(uint16_t v) { SP -= 2; *(uint16_t*)&mem[SP] = v; } uint16_t pop() { uint16_t v = *(uint16_t*)&mem[SP]; SP += 2; return v; } void handle_syscall(int number) { switch(number) { case 0: exit(0); break; case 2: printf("%c", DL); break; default: fprintf(stderr, "Fatal: Unhandled syscall CL=%Xh\n", number); exit(1); } } void handle_intr(int number) { switch(number) { case 0x20: handle_syscall(0); break; case 0xE0: handle_syscall(CL); break; default: fprintf(stderr, "Fatal: Unhandled interrupt %Xh\n", number); exit(1); } } // Execute a single instruction void step() { uint8_t opcode = mem[ip]; ip++; switch(opcode) { case 0x00: // ARITH r/m8,reg8 case 0x08: case 0x10: case 0x18: case 0x20: case 0x28: case 0x30: case 0x38: modrm(BIT8); arith8(opcode >> 3, &RM8, R8(reg)); break; case 0x01: // ARITH r/m16,reg16 case 0x09: case 0x11: case 0x19: case 0x21: case 0x29: case 0x31: case 0x39: modrm(BIT16); arith16(opcode >> 3, &RM16, R16(reg)); break; case 0x50 ... 0x57: // PUSH reg16 push(R16(opcode-0x50)); break; case 0x58 ... 0x5F: // POP reg16 R16(opcode-0x58)=pop(); break; case 0x70 ... 0x7F: // Jcc reg = ip + (int8_t)imm8(); if (cond(opcode-0x70)) { ip = reg; } break; case 0x83: // ARITH r/m16,imm8 modrm(BIT16); arith16(reg, &RM16, imm8()); break; case 0x84: // TEST r/m8,reg8 modrm(BIT8); scratch = R8(reg) & RM8; break; case 0x88: // MOV r/m8,reg8 modrm(BIT8); RM8 = R8(reg); break; case 0x89: // MOV r/m16,reg16 modrm(BIT16); RM16 = R16(reg); break; case 0x8B: // MOV reg16,r/m16 modrm(BIT16); R16(reg) = RM16; break; case 0x9C: // PUSHF push(get_flags()); break; case 0xA1: // MOV AX,memoffs16 AX = *(uint16_t*)&mem[imm16()]; break; case 0xAC: // LODSB AL = mem[SI]; SI++; break; case 0xB0 ... 0xB7: // MOV reg8,imm8 reg = opcode-0xB0; R8(reg) = imm8(); break; case 0xB8 ... 0xBF: // MOV reg16,imm16 reg = opcode-0xB8; R16(reg) = imm16(); break; case 0xC3: // RET ip = pop(); break; case 0xCC: // INT3 dump(); break; case 0xCD: // INT imm8 handle_intr(imm8()); break; case 0xE8: // CALL rw reg = ip + (int16_t)imm16(); if (reg == 5) { handle_syscall(CL); } else { push(ip); ip = reg; } break; case 0xEB: // JMP rb ip = ip + (int8_t)imm8(); break; case 0xFA: // CLI case 0xFB: // STI break; // no-op, since we dont have interrupts default: ip--; dump(); fprintf(stderr, "Invalid opcode at IP=%04X\n", ip); exit(1); break; } } void copy_cmdline(char* str) { int i, c; uint8_t *len = &mem[0x80]; char* ptr = (char*)&mem[0x81]; c = strlen(str); // Clip at max length if (c>0x7E) { fprintf(stderr, "Command line too long, max is 126 bytes\n"); exit(1); } memcpy(ptr, str, c); ptr[c]=0x0D; *len=c; } int main(int argc, char** argv) { memset(&mem, sizeof(mem), 0); memset(®set, sizeof(regset), 0); mem[0]=0xCD; mem[1]=0x20; push(0); if (argc>2) { copy_cmdline(argv[2]); } else { copy_cmdline(""); } FILE* fd = fopen(argv[1], "r"); fread(mem + ip, 1, sizeof(mem) - ip, fd); while(1) { step(); dump(); } }