#include <fstream>
#include <string>
#include <iostream>
#include <sstream>
#include <vector>
#include <map>
#include "soft_cpu.h"

using namespace std;

class ASM {
   string infile;
   string outfile;
public:
   ASM(const string & infile, const string & outfile) {
     this->infile = infile;
     this->outfile = outfile;
   }
 
   string to_upper(const string & str) {
      string ret;
      for(size_t i = 0; i < str.size(); i++) {
         ret += toupper(str[i]);
      }
      return ret;
   }

   string read_line(fstream & fin) {
      string line;
      char ch;
      ch = fin.get();
      while(fin) {
         if(ch == '\n') 
            break;
         line += ch;
         ch = fin.get();
      }
      return to_upper(line);
   }

   void run() {
     fstream af(infile.c_str(), ios::in);
     if(af) {
        fstream of(outfile.c_str(), ios::out | ios::ate | ios::trunc | ios::binary);
        if(of) {
           string opcode;
           string arg;
           string line;
           int lineno = 1;
           line = read_line(af);
           data_word_type acode;
           data_word_type code;
           data_word_type count = 0,n = 0;
           vector< data_word_type > the_code;
           map< string, data_word_type > label_map;
           map< data_word_type, string > jmp_map;
           bool error = false;
           while(af) {
              stringstream str;
              str << line;
              str >> opcode;
              n++;
              if(str) {
                 str >> arg;
                 n++;
              }
              if(opcode.find(":") != string::npos) {
                 opcode = opcode.substr(0,opcode.size()-1);
                 label_map[opcode] = count * sizeof(data_word_type);
                 line = read_line(af);
                 n = 0;
                 continue;
              }
              else if(opcode == "PUSH_R2") {
                 code = _PUSH_R2;
              }
              else if(opcode == "POP_R2") {
                 code = _POP_R2;
              }
              else if(opcode == "ADD") {
                 code = _ADD;
              }
              else if(opcode == "SUB") {
                 code = _SUB;
              }
              else if(opcode == "MUL") {
                 code = _MUL;
              }
              else if(opcode == "DIV") {
                 code = _DIV;
              }
              else if(opcode == "MOD") {
                 code = _MOD;
              }
              else if(opcode == "LT") {
                 code = _LT;
              }
              else if(opcode == "GT") {
                 code = _GT;
              }
              else if(opcode == "EQ") {
                 code = _EQ;
              }
              else if(opcode == "NEQ") {
                 code = _NEQ;
              }
              else if(opcode == "LTEQ") {
                 code = _LTEQ;
              }
              else if(opcode == "GTEQ") {
                 code = _GTEQ;
              }
              else if(opcode == "JMP") {
                 code = _JMP;
                 jmp_map[count+1] = arg;
                 acode = 0;
              }
              else if(opcode == "JZ") {
                 code = _JZ;
                 jmp_map[count+1] = arg;
                 acode = 0;
              }
              else if(opcode == "JE") {
                 code = _JE;
                 jmp_map[count+1] = arg;
                 acode = 0;
              }
              else if(opcode == "LOAD_R1") {
                 code = _LOAD1;
                 acode = atoi(arg.c_str());
              }
              else if(opcode == "LOAD_R2") {
                 code = _LOAD2;
                 acode = atoi(arg.c_str());
              }
              else if(opcode == "CMP") {
                 code = _CMP;
                 acode = atoi(arg.c_str());
              }
              else if(opcode == "PRINT") {
                 code = _PRINT;
              }
              else if(opcode == "HALT") {
                 code = _HALT;
              }
              else {
                 cerr << "error: bad instruction..." << endl;
                 cerr << "on line: " << lineno << endl;
                 error = true;
                 break;
              }
              count++;
              the_code.push_back(code);
              if(arg.size()) {
                 the_code.push_back(acode);
                 count++;
              }
              opcode = "";
              arg = "";
              line = read_line(af);
              lineno++;
              n = 0;
           }

           if(error == false) {
              code = the_code.size();
              of.write((char*)&code, sizeof(data_word_type));
              of.flush();
              for(map< data_word_type, string >::iterator ptr = jmp_map.begin(); ptr != jmp_map.end(); ptr++) {
                 the_code[ptr->first] = label_map[ptr->second];
              }
              for(size_t i = 0; i < the_code.size(); i++) {
                 code = the_code[i]; 
                 of.write((char*)&code, sizeof(data_word_type));
                 of.flush();
              }
              of.close();
           }
        }
        af.close();
     }
   }
};

int
main(int argc, char ** argv)
{
   if(argc == 3) {
      ASM assembler(argv[1], argv[2]);
      assembler.run();
   }
   else {
      cout << "usage: asm <infile> <outfile>" << endl;
   }
   return 0;
}
