#ifndef __SOLVE_H
#define __SOLVE_H

/*

SOLVE EQUATIONS IN TERMS OF A VARIABLE...

Author: Matthew W. Coan
Date: Wed Dec  2 21:41:07 EST 2015

*/

#include <iostream>
#include <stack>
#include <vector>
#include <string>
#include <sstream>
#include <cstdlib>

namespace math_tools {

using namespace std;

struct bad_token { };
struct math_error { };

class Token {
public:
   enum token_type {
      ID,
      NUMBER,
      OPERATOR,
      KEYWORD,
      NIL
   };

private:
   string value;
   token_type type;

public:
   Token(const string & t, 
         const token_type & tt) 
   : value(t)
   {
      type = tt;
   }

   Token(const Token & t) {
      value = t.value;
      type = t.type;
   }

   Token() {
      type = NIL;
   }

   ~Token() {
   }

   const string & get_value() const { return value; }

   const token_type get_type() const { return type; }

   void set_value(const string & val) {
      value = val;
   }

   const Token & operator=(const Token & t) {
      type = t.type;
      value = t.value;
      return *this;
   }
};

class Node {
public:
   string value;
   Node * left;
   Node * right;

   Node(const string & val) : value(val) {
      left = 0;
      right = 0;
   }

   Node() {
      left = 0;
      right = 0;
   }

   Node(const Node & n) {
      value = n.value;
      left = n.left;
      right = n.right;
   }

   ~Node() {
   }

   Node & operator=(const Node & n) {
      value = n.value;
      left = n.left;
      right = n.right;
      return *this;
   }
};

class Solver {
public:
   typedef vector< Token > token_vector_type;
   typedef stack< Node* > node_stack_type;

private:
   token_vector_type token_vec;
   string current_var;
   string var;
   string last_value;
   string result;
   node_stack_type node_stk;
   Node * tree;
   Node * node;
   bool change;
   bool expand;

   Node * pop() {
      Node * ret = 0;
      if(node_stk.size()) { 
         ret = node_stk.top();
         node_stk.pop();
      }
      return ret;
   }

   void push(Node * n) {
      node_stk.push(n);
   }

   void delete_tree(Node * n) {
      if(n) {
         delete_tree(n->left);
         delete_tree(n->right);
         delete n;
      }
   }

public:
   Solver() { 
      change = true;
      tree = 0;
      push(new Node(""));
      expand = false;
   }

   ~Solver() {
      delete_tree(tree);
   }

   bool is_digit(const char ch) 
   {
      bool ret = false;
      if((ch >= '0' && ch <= '9') || ch == '.') {
         ret = true; 
      }
      return ret;
   }

   bool is_alpha(const char ch) 
   {
      bool ret = false;
      if((ch >= 'A' && ch <= 'Z') || (ch >= 'a' && ch <= 'z')) {
         ret = true;
      }
      return ret;
   }

   bool is_operator(const string & str) 
   {
      bool ret = false;
      if(str == "+")
         ret = true;
      else if(str == ",")
         ret = true;
      else if(str == "-")
         ret = true;
      else if(str == "*")
         ret = true;
      else if(str == "/")
         ret = true;
      else if(str == "=") 
         ret = true; 
      else if(str == "^") 
         ret = true; 
      else if(str == "%") 
         ret = true; 
      else if(str == "!") 
         ret = true; 
      else if(str == "&&") 
         ret = true; 
      else if(str == "||") 
         ret = true; 
      else if(str == "=")
         ret = true;
      else if(str == "==")
         ret = true;
      else if(str == "!=")
         ret = true;
      else if(str == "<") 
         ret = true;
      else if(str == ">")
         ret = true;
      else if(str == "<=")
         ret = true;
      else if(str == ">=")
         ret = true;
      else if(str == "(") 
         ret = true; 
      else if(str == ")") 
         ret = true; 
      return ret;
   }

bool is_space(char ch) 
{
   bool ret = false;
   if(ch == '\r' || ch == '\n' || ch == '\t' || ch == ' ') {
      ret = true;
   }
   return ret;
}

string read_token(stringstream & str)
{
   string ret;
   char ch;
   if(str)
      ;
   else return ret;
   ch = str.peek();
   if(is_alpha(ch)) {
      while(is_alpha(ch) && str) {
         str.get();
         ret += ch;
         ch = str.peek();
      }
   }
   else if(is_digit(ch)) {
      while(is_digit(ch) && str) {
         str.get();
         ret += ch;
         ch = str.peek();
      }
   }
   else if(is_operator(string() + ch)) {
      ret += ch;
      ch = str.get();
   }
   else {
      ch = str.get();
      ret += ch;
   }
   return ret;
}

bool lex(const string & expression)
{
   bool ret = true;
   stringstream str;
   str << expression;
   string temp;
   char ch;
   token_vec.clear();
   while((temp = read_token(str)) != "" && str) {
      if(is_digit(temp[0])) {
         token_vec.push_back(Token(temp, Token::NUMBER));
      } 
      else if(is_alpha(temp[0])) {
         token_vec.push_back(Token(temp, Token::ID));
      }
      else if(is_operator(temp)) {
         token_vec.push_back(Token(temp, Token::OPERATOR));
      }
   }
   return ret;
}

void debug_dump()
{
   for(size_t i = 0; i < token_vec.size(); i++) {
      cout << "token_vec[" << i << "] == [" << token_vec[i].get_value() << "]" << endl;
      switch(token_vec[i].get_type()) {
      case Token::ID:
         cout << "type=ID" << endl;
      break;

      case Token::NUMBER:
         cout << "type=NUMBER" << endl;
      break;

      case Token::OPERATOR:
         cout << "type=OPERATOR" << endl;
      break;

      case Token::KEYWORD:
         cout << "type=KEYWORD" << endl;
      break;

      case Token::NIL:
         cout << "type=NIL" << endl;
      break;
      }
   }
}

string solve_eq(const string & var)
{
   string result = to_string();
   return result;
}

bool is_function(const string & fun) {
   bool ret = false;
   if(fun == "pow" 
      || fun == "sqrt") {
      ret = true;
   }
   return ret;
}

string to_string2() {
   string ret;
   for(int i = 0; i < token_vec.size(); i++) {
      if(ret.size()) ret += " ";
      ret += token_vec[i].get_value();
   }
   return ret;
}

string to_string(Node * n) {
   string result;
   if(n) {
      if(n->value == "()") {
         result = " ( " + to_string(n->left) + " ) ";
      }
      else if(n->value == ",") {
         result = to_string(n->left);
         result += " , ";
         result += to_string(n->right);
      }
      else if(is_function(n->value)) {
         result = n->value + " ( " + to_string(n->right) + " ) ";
      }
      else {
         result = to_string(n->left);
         result += " " + n->value;
         result += to_string(n->right);
      }
   }
   return result;
}

string to_string() {
   return to_string(tree); 
}

bool match(string text, int offset)
{
   bool ret = false;
   if(offset < token_vec.size()) {
      if(token_vec[offset].get_value() == text) {
         ret = true;
      }
   }
   return ret;
}

bool match(Token::token_type code, int offset)
{
    bool ret = false;
    if(offset < token_vec.size()) {
       if(token_vec[offset].get_type() == code) {
          ret = true;
       }
    }
    return ret;
}

bool expression(int & offset) {
   return term(offset);
}

bool primary(int & offset) 
{
   bool ret = false;
   if(match("(", offset)) {
      offset++;
      if(expression(offset)) {
         if(match(")", offset)) {
            node = new Node("()");
            node->left = pop();
            push(node);
            offset++;
            ret = true;
         }
      }
   }
   else if(match(Token::NUMBER, offset)) {
      last_value = token_vec[offset].get_value();
      Node * n = new Node(last_value);
      Node * temp = pop();
      if(temp)
         temp->right = n;
      push(n);
      //cout << "PUSH " << last_value << endl;
      offset++;
      ret = true;
   }
   else if(match(Token::ID, offset)) {
      last_value = token_vec[offset].get_value();
      push(new Node(last_value));
      offset++;
      ret = true;
      if(match("=", offset)) {
         current_var = last_value;
         offset++;
         Node * n1 = new Node(current_var);
         Node * n2 = new Node("=");
         n2->left = n1;
         if(expression(offset)) {
            Node * n3 = pop();
            n2->right = n3;
            push(n2);
            ret = true;
         }
         //cout << "APOP " << current_var << endl;
      }
      else if(match("(", offset)) {
         string fun_name = last_value;
         offset++;
         Node * fun = pop();
         Node * p = fun;
         while(expression(offset)) {
            Node * expr = pop();
            p->right = expr;
            p = expr;
            if(match(",", offset)) {
               expr = new Node(",");
               p->right = expr;
               p = expr;
               offset++;
            }
            else {
               break;
            }
         }
         if(match(")", offset)) {
            offset++;
            ret = true;
         }
         //cout << "CALL " << fun_name << endl;
         push(fun);
      }
      else { 
         //cout << "APUSH " << last_value << endl;
         push(new Node(last_value));
      }
   }
   return ret;
}

bool factor(int & offset) 
{
   bool ret = false;
   if(primary(offset)) {
      if(match("*", offset)) {
         Node * temp = pop();
         offset++;
         if(factor(offset)) {
            node = new Node("*");
            node->right = pop();
            node->left = temp;
            push(node);
            //cout << "MUL" << endl;
            ret = true;
         }
      }
      else if(match("/", offset)) {
         Node * top = pop();
         offset++;

         if(!expand) {
            if(factor(offset)) {
               Node* div = new Node("/");
               div->left = top;
               div->right = pop();
               push(div);
               ret = true;
            }
         }
         else {
            if(factor(offset)) {
               if(top->value == "1.0" || top->value == "1") {
                  node = new Node("/");
                  node->left = top;
                  node->right = pop();
                  push(node);
               }
               else {
                  Node * mul = new Node("*");
                  Node * div = new Node("/");
                  Node * peren = new Node("()");

                  div->left = new Node("1.0");
                  div->right = pop();

                  mul->left = top;
                  mul->right =  peren;

                  peren->left = div;
                  peren->right = 0;

                  push(mul);
               }
               //cout << "DIV" << endl;
               ret = true;
            }
         }
      }
      else {
         ret = true;
      }
   }
   return ret;
}

bool term(int & offset) 
{
   bool ret = false;
   while(factor(offset)) {
      if(match("+", offset)) {
         Node * temp = pop();
         offset++;
         if(term(offset)) {
            node = new Node("+");
            node->right = pop();
            node->left = temp;
            push(node);
            //cout << "ADD" << endl;
            ret = true;
         }
      }
      else if(match("-", offset)) {
         Node * temp = pop();
         offset++;
         if(term(offset)) {
            node = new Node("-");
            node->right = pop();
            node->left = temp;
            push(node);
            //cout << "SUB" << endl;
            ret = true;
         }
      }
      else {
         ret = true;
      }
   }
   return ret;
}

bool parse() 
{
   int offset = 0;
   bool ret = false;
   if(expression(offset)) {
      ret = true;
   }
   return ret;
}

template< class T > 
string to_string(const T & value)
{
   string ret;
   stringstream str;
   str << value;
   str >> ret;
   return trim(ret);
}

public:

string solve(const string & equation, 
             const string & var) 
{
   string ret;
   token_vector_type token_vec;
   //cout << "solve(\"" << equation << "\",\"" << var << "\")" << endl;
   if(lex(equation)) {
      //debug_dump();
      expand = true;
      if(parse()) {
         tree = pop();
         ret = solve_eq(var);
      }
      else {
         cerr << "bad parse...\n";
      }
   }
   return ret;
}

void remove(int off)
{
   token_vector_type temp;
   for(int i = 0; i < token_vec.size(); i++) {
      if(i == off) {

      }
      else {
         temp.push_back(token_vec[i]);
      }
   }
   token_vec = temp;
}

void insert(int off, const Token & tok) 
{
   token_vector_type temp;
   for(int i = 0; i < token_vec.size(); i++) {
      if(i == off) {
         temp.push_back(tok);
      }
      temp.push_back(token_vec[i]);
   }
   token_vec = temp;
}

bool is_number(const string & str) 
{
   bool ret = false;
   if(str.size()) {
      if(is_digit(str[0])) {
         ret = true;
      }
   }
   return ret;
}

bool is_var(const string & str)
{
   bool ret = false;
   if(str.size()) {
      if(is_alpha(str[0])) {
         ret = true;
      }
   }
   return ret;
}

bool factor(Node * & n) 
{
   bool ret = true;
   string var;
   if(n) {
      factor(n->left);
      if(n->value == "sqrt") {
         if(n->right) {
            if(n->right->value == "pow") {
               if(n->right->right) {
                  if(is_var(n->right->right->value)) {
                     var = n->right->right->value;
                     if(n->right->right->right) {
                        if(n->right->right->right->value == ",") {
                           if(n->right->right->right->right) {
                              if(n->right->right->right->right->value == "2.0" || n->right->right->right->right->value == "2") {
                                 n->value = var;
                                 delete_tree(n->left);
                                 delete_tree(n->right);
                                 n->left = 0;
                                 n->right = 0;
                                 change = true;
                                 throw "nothing";
                              }
                           }
                        }
                     }
                  }
               }
            }
         }
      }
      else if(n->value == "+") {
         if(is_number(n->left->value) && is_number(n->right->value)) {
            n->value = to_string(atof(n->left->value.c_str()) + atof(n->right->value.c_str()));
            delete_tree(n->left);
            delete_tree(n->right);
            n->left = 0;
            n->right = 0;
            change = true;
            throw "nothing";
         }
      }
      else if(n->value == "-") {
         if(is_number(n->left->value) && is_number(n->right->value)) {
            n->value = to_string(atof(n->left->value.c_str()) - atof(n->right->value.c_str()));
            delete_tree(n->left);
            delete_tree(n->right);
            n->left = 0;
            n->right = 0;
            change = true;
            throw "nothing";
         }
         else if(n->left->value == n->right->value) {
            n->value = "0.0";
            delete_tree(n->left);
            delete_tree(n->right);
            n->left = 0;
            n->right = 0;
            change = true;
            throw "nothing";
         }
      }
      else if(n->value == "*" && n->left != 0 && n->right != 0) {
         if(is_number(n->left->value) && is_number(n->right->value)) {
            n->value = to_string(atof(n->left->value.c_str()) * atof(n->right->value.c_str()));
            delete_tree(n->left);
            delete_tree(n->right);
            n->left = 0;
            n->right = 0;
            change = true;
            throw "nothing";
         }
         else if(n->left->value == "pow" && n->right->value == "pow") {
            string id_1 = n->left->right->value;
            string id_2 = n->right->right->value;
            if(id_1 == id_2) {
               string exp1 = n->left->right->right->right->value;
               string exp2 = n->right->right->right->right->value;
               double d1 = atof(exp1.c_str()), d2 = atof(exp2.c_str());
               string exp = to_string(d1 + d2);
               n = n->right;
               n->right->right->right->value = exp;
               change = true;
               throw "nothing";
            }
         }
         else if(n->left->value == n->right->value) {
            string id = n->left->value;
            delete_tree(n->left);
            delete_tree(n->right);
            n->left = 0;
            n->right = 0;
            n->value = "pow";
            n->right = new Node(id);
            n->right->right = new Node(",");
            n->right->right->right = new Node("2.0");
            change = true;
            throw "nothing"; 
         }
         else if(is_var(n->left->value) && n->right->value == "pow") {
            string var = n->left->value;
            if(n->right->right->value == var) {
               n = n->right;
               double t2;
               stringstream s;
               s << n->right->right->right->value;
               s >> t2;
               double temp = t2 + 1.0;
               string str = to_string(temp);
               n->right->right->right->value = str;
               change = true;
               throw "nothing"; 
            }
         }
         else if(n->left->value == "pow" && is_var(n->right->value)) {
            string var = n->right->value;
            if(n->left->right->value == var) {
               n = n->left;
               double t2;
               stringstream s;
               s << n->right->right->right->value;
               s >> t2;
               double temp = t2 + 1.0;
               string str = to_string(temp);
               n->right->right->right->value = str;
               change = true;
               throw "nothing";
            }
         }
      }
      else if(n->value == "/" && n->left && n->right) {
         if(is_number(n->left->value) && is_number(n->right->value)) {
            n->value = to_string(atof(n->left->value.c_str()) / atof(n->right->value.c_str()));
            delete_tree(n->left);
            delete_tree(n->right);
            n->left = 0;
            n->right = 0;
            change = true;
            throw "nothing";
         }
         else if(n->left->value == n->right->value) {
            n->value = "1.0";
            delete_tree(n->left);
            delete_tree(n->right);
            n->left = 0;
            n->right = 0;
            change = true;
            throw "nothing";
         }
         else if(n->right->value == "()") {
            if(n->right->left->value == "/") {
               if(n->right->left->left->value == "1.0" || n->right->left->left->value == "1") {
                  n->value = "()";
                  n->left = n->right->left->right;
                  n->right = 0;
                  change = true;
                  throw "nothing";
               }
            }
         }
         else if(n->left->value == "1.0" || n->left->value == "1") {
            if(n->right->value == "/") {
               if(n->right->left->value == "1.0" || n->right->left->value == "1") {
                  n->value = "()";
                  Node * temp = n->left;
                  Node * temp2 = n->right;
                  n->left = n->right->right;
                  temp2->right = 0;
                  delete_tree(temp);
                  delete_tree(temp2);
                  n->right = 0;
                  change = true;
                  throw "nothing";
               }
            }
         }
      }
      else if(n->value == "()") {
         if(n->left && n->right == 0) {
            if(is_number(n->left->value) || is_var(n->left->value) || n->left->value == "()") {
               Node * temp = n->left;
               n->left = 0;
               n->value = temp->value;
               delete_tree(temp);
               change = true;
               throw "nothing";
            }
            else if(is_number(n->left->left->value) || is_var(n->left->value) || n->left->value == "()") {
               Node * temp = n->left;
               n->left = 0;
               n->value = temp->value;
               delete_tree(temp);
               change = true;
               throw "nothing";
            }
            //else {
               //factor(n->left);
            //}
         }
      }
      else {
      }
      factor(n->right);
   }
   return ret;
}

bool factor() {
   bool ret = true;
   tree = pop();
   while(change) {
      change = false;
      try {
         factor(tree);
      }
      catch(const char * msg) {

      }
   }
   return ret;
}

string trim(const string & value)
{
   stringstream str;
   str << value;
   string temp;
   string buffer;
   while(str >> temp) {
      if(buffer.size()) buffer += " ";
      buffer += temp;
   }
   return buffer;
}

string factor(const string & equation) 
{
   string ret;
   cout << "factor(\"" << trim(equation) << "\")" << endl;
   if(lex(equation)) {
      expand = false;
      if(parse()) {
         //Node * x = pop();
         if(factor()) {
            ret = trim(to_string());
            cout << "FACTOR: " << ret << endl;
         }
      }
      else {
         cerr << "bad factor...\n";
      }
   }
   return ret;
}

};

inline string factor(const string & equation)
{
   string ret;
   Solver solver;
   ret = solver.factor(equation);
   return ret;
}

inline string solve(const string & equation, const string & var)
{
   string ret;
   Solver solver;
   ret = factor(solver.solve(factor(equation), var));
   return ret;
}

}

#endif /* __SOLVE_H */
