import sys
import re
import struct
import IPython
import copy


class AssemblerException(Exception):
    pass


class InvalidRegister(AssemblerException):

    def __init__(self, register):
        super().__init__("Invalid register: {}".format(register))


class InvalidOperation(AssemblerException):

    def __init__(self, operation):
        super().__init__("Invalid operation: {}".format(operation))


class ExpectedImmediate(AssemblerException):

    def __init__(self, value):
        super().__init__("Expected immediate, got {}".format(value))


class ExpectedRegister(AssemblerException):

    def __init__(self, value):
        super().__init__("Expected register, got {}".format(value))


class IPOverwrite(AssemblerException):

    def __init__(self, instruction=None):
        if instruction:
            super().__init__("IP can't be overwritten. Instruction: {}".format(instruction))
        else:
            super().__init__("IP can't be overwritten.")


class InvalidValue(AssemblerException):

    def __init__(self, instruction):
        super().__init__("Invalid value while assembling: {}".format(instruction))

rol = lambda val, r_bits, max_bits: \
    (val << r_bits % max_bits) & (2**max_bits - 1) | \
    ((val & (2**max_bits - 1)) >> (max_bits - (r_bits % max_bits)))


class VMAssembler:

    def __init__(self, key):
        self.assembled_code = bytearray()
        self.define_ops(key)

    def parse(self, instruction):
        action = getattr(self, "{}".format(instruction.opcode.method))
        action(instruction)

    def process_code_line(self, line):
        sys.stdout.write("CODE: ")
        components = [x for x in re.split('\W', line) if x]
        instruction = VMInstruction(components[0], components[1:])
        sys.stdout.write(str(instruction) + "\n")
        self.parse(instruction)

    def imm2reg(self, instruction):
        """
        Intel syntax -> REG, IMM
        """
        opcode = instruction.opcode
        reg = instruction.args[0]
        imm = instruction.args[1]
        if reg.name == "ip":
            raise IPOverwrite(instruction)
        if not imm.isimm():
            raise ExpectedImmediate(imm)
        if not reg.isreg():
            raise ExpectedRegister(reg)
        if not opcode.uint8() or not reg.uint8() or not imm.uint16():
            raise InvalidValue(instruction)
        self.assembled_code += opcode.uint8() + reg.uint8() + imm.uint16()
        return

    def reg2reg(self, instruction):
        """
        Intel syntax -> DST_REG, SRC_REG
        """
        opcode = instruction.opcode
        dst_reg = instruction.args[0]
        src_reg = instruction.args[1]
        if dst_reg.name == "ip" or src_reg.name == "ip":
            raise IPOverwrite(instruction)
        if not dst_reg.isreg():
            raise ExpectedRegister(dst_reg)
        if not src_reg.isreg():
            raise ExpectedRegister(src_reg)
        if not opcode.uint8() or not dst_reg.uint8() or not src_reg.uint8():
            raise InvalidValue(instruction)
        byte_with_nibbles = struct.pack("<B", dst_reg.uint8()[0] << 4 ^ (
            src_reg.uint8()[0] & 0b00001111))
        self.assembled_code += opcode.uint8() + byte_with_nibbles
        return

    def reg2imm(self, instruction):
        """
        Intel syntax -> IMM, REG
        """
        opcode = instruction.opcode
        imm = instruction.args[0]
        reg = instruction.args[1]
        if reg.name == "ip":
            raise IPOverwrite(instruction)
        if not imm.isimm():
            raise ExpectedImmediate(imm)
        if not reg.isreg():
            raise ExpectedRegister(reg)
        if not opcode.uint8() or not reg.uint8() or not imm.uint16():
            raise InvalidValue(instruction)
        self.assembled_code += opcode.uint8() + imm.uint16() + reg.uint8()
        return

    def byt2reg(self, instruction):
        """
        Intel syntax -> REG, [BYTE]IMM
        """
        opcode = instruction.opcode
        reg = instruction.args[0]
        imm = instruction.args[1]
        if reg.name == "ip":
            raise IPOverwrite(instruction)
        if not imm.isimm():
            raise ExpectedImmediate(imm)
        if not reg.isreg():
            raise ExpectedRegister(reg)
        if not opcode.uint8() or not reg.uint8() or not imm.uint8():
            raise InvalidValue(instruction)
        self.assembled_code += opcode.uint8() + reg.uint8() + imm.uint8()
        return

    def regonly(self, instruction):
        """
        Instruction with only an argument: a register
        """
        opcode = instruction.opcode
        reg = instruction.args[0]
        if reg.name == "ip":
            raise IPOverwrite(instruction)
        if not reg.isreg():
            raise ExpectedRegister(reg)
        if not opcode.uint8() or not reg.uint8():
            raise InvalidValue(instruction)
        self.assembled_code += opcode.uint8() + reg.uint8()
        return

    def immonly(self, instruction):
        """
        Instruction with only an argument: an immediate
        """
        opcode = instruction.opcode
        imm = instruction.args[0]
        if not imm.isimm():
            raise ExpectedImmediate(imm)
        if not opcode.uint8() or not imm.uint16():
            raise InvalidValue(instruction)
        self.assembled_code += opcode.uint8() + imm.uint16()
        return

    def single(self, instruction):
        """
        Instruction with no arguments
        """
        opcode = instruction.opcode
        self.assembled_code += opcode.uint8()
        return

    def define_ops(self, key):
        key_ba = bytearray(key, 'utf-8')
        olds = copy.deepcopy(ops)
        for b in key_ba:
            for op_com in ops:
                op_com.set_value(rol(b ^ op_com.value, b % 8, 8))
        for i in ops:
            for j in ops:
                j.set_value(rol(j.value, i.value % 8, 8))
        for o, n in zip(olds, ops):
            print("{} : {}->{}".format(o.name, hex(o.value), hex(n.value)))


class VMComponent:
    """
    Represents a register, operation or an immediate the VM recognizes
    """

    def __init__(self, name, value, method=None):
        self.name = name.casefold()
        self.value = value
        self.method = method

    def __repr__(self):
        return "{}".format(self.name)

    def set_name(self, name):
        self.name = name

    def set_value(self, value):
        self.value = value

    def uint8(self):
        numre = re.compile("^[0-9]+$")
        if isinstance(self.value, int):
            return struct.pack("<B", self.value)
        elif self.value.startswith("0x"):
            return struct.pack("<B", int(self.value, 16))
        elif numre.match(self.value):  # only numbers
            return struct.pack("<B", int(self.value))
        return None

    def uint16(self):
        numre = re.compile("^[0-9]+$")
        if isinstance(self.value, int):
            return struct.pack("<H", self.value)
        elif self.value.startswith("0x"):
            return struct.pack("<H", int(self.value, 16))
        elif numre.match(self.value):  # only numbers
            return struct.pack("<H", int(self.value))
        return None

    def isreg(self):
        if self.name not in [x.casefold() for x in reg_names]:
            return False
        return True

    def isop(self):
        if self.name not in [x[0].casefold() for x in op_names]:
            return False
        return True

    def isimm(self):
        if self.name != self.value:
            return False
        return True


class VMInstruction:
    """
    Represents an instruction the VM recognizes.
    e.g: MOVI [R0, 2]
          ^       ^
        opcode  args
    """

    def __init__(self, opcode, instr_list):
        immediate_regexp = re.compile("^(0x*|[0-9]*$)")
        self.opcode = next((x for x in ops if x.name == opcode), None)
        self.args = []
        for el in instr_list:
            if not immediate_regexp.match(el):
                # create a VM component for a register
                reg_comp = next((x for x in regs if x.name == el), None)
                self.args.append(reg_comp)
            else:
                # directly append the immediate
                self.args.append(VMComponent(el, el))

    def __repr__(self):
        return "{} {}".format(self.opcode.name, ", ".join([x.name for x in self.args]))

op_names = [["MOVI", "imm2reg"],
            ["MOVR", "reg2reg"],
            ["LOAD", "imm2reg"],
            ["STOR", "reg2imm"],
            ["ADDI", "imm2reg"],
            ["ADDR", "reg2reg"],
            ["SUBI", "imm2reg"],
            ["SUBR", "reg2reg"],
            ["XORB", "byt2reg"],
            ["XORW", "imm2reg"],
            ["XORR", "reg2reg"],
            ["NOTR", "regonly"],
            ["MULI", "imm2reg"],
            ["MULR", "reg2reg"],
            ["DIVI", "imm2reg"],
            ["DIVR", "reg2reg"],
            ["PUSH", "regonly"],
            ["POOP", "regonly"],
            ["CMPI", "imm2reg"],
            ["CMPR", "reg2reg"],
            ["JMPI", "immonly"],
            ["JMPR", "regonly"],
            ["JPAI", "immonly"],
            ["JPAR", "regonly"],
            ["JPBI", "immonly"],
            ["JPBR", "regonly"],
            ["JPEI", "immonly"],
            ["JPER", "regonly"],
            ["SHIT", "single"],
            ["NOPE", "single"],
            ["GRMN", "single"]]

reg_names = ["R0", "R1", "R2", "R3", "S0", "S1", "S2", "S3", "IP", "BP", "SP"]
section_names = ["DATA:", "CODE:", "STACK:"]
section_flags = {s.casefold(): i + 1 for i, s in enumerate(section_names)}
ops = [VMComponent(le[0], i, le[1]) for i, le in enumerate(op_names)]
regs = [VMComponent(s.casefold(), i) for i, s in enumerate(reg_names)]


def assemble_data(line):
    sys.stdout.write("DATA:\t")
    sys.stdout.write(line.strip(",") + "\n")


def main():
    if len(sys.argv) < 4:
        print("Usage: {} opcodes_key file_to_assemble output".format(
            sys.argv[0]))
        return
    vma = VMAssembler(sys.argv[1])
    with open(sys.argv[2], 'r') as f:
        gen = (line.casefold().strip() for line in f if line != "\n")
        flag = None

        for line in gen:
            if line in section_flags:
                flag = section_flags[line]
                continue
            if flag == section_flags["data:"]:
                vma.process_code_line(line)
            elif flag == section_flags["code:"]:
                vma.process_code_line(line)
        if not flag:
            sys.stderr.write(
                "Nothing was assembled! Did you use the section delimiters?\n")
    with open(sys.argv[3], 'wb') as f:
        f.write(vma.assembled_code)

if __name__ == '__main__':
    main()