/* -*- C++ -*- */
#include <cassert>
#include <cctype>
#include <algorithm>

#include "ParseNode.h"
#include "SynTree/Type.h"
#include "SynTree/Constant.h"
#include "SynTree/Expression.h"
#include "SynTree/Declaration.h"
#include "UnimplementedError.h"
#include "parseutility.h"
#include "utility.h"

using namespace std;

ExpressionNode::ExpressionNode() : ParseNode(), argName( 0 ) {
}

ExpressionNode::ExpressionNode(string *name_) : ParseNode( *name_ ), argName( 0 ) {
  delete name_;
}

ExpressionNode::ExpressionNode( const ExpressionNode &other )
  : ParseNode( other.name )
{
  if( other.argName ) {
    argName = other.argName->clone();
  } else {
    argName = 0;
  }
}

ExpressionNode * ExpressionNode::set_asArgName( std::string *aName ) {
  argName = new VarRefNode(aName);
  return this;
}

ExpressionNode * ExpressionNode::set_asArgName( ExpressionNode *aDesignator ) {
  argName = aDesignator;
  return this;
}

void ExpressionNode::printDesignation( std::ostream &os, int indent ) const {
  if( argName ) {
    os << string(' ', indent) << "(designated by:  ";
    argName->printOneLine(os, indent );
    os << ")" << std::endl;
  }
}

NullExprNode::NullExprNode()
{
}

NullExprNode *
NullExprNode::clone() const
{
  return new NullExprNode();
}

void
NullExprNode::print(std::ostream & os, int indent) const
{
  printDesignation(os);
  os << "null expression";
}

void
NullExprNode::printOneLine(std::ostream & os, int indent) const
{
  printDesignation(os);
  os << "null";
}

Expression *
NullExprNode::build() const
{
  return 0;
}

CommaExprNode *ExpressionNode::add_to_list(ExpressionNode *exp){
  return new CommaExprNode(this, exp );
}

//  enum ConstantNode::Type =  { Integer, Float, Character, String, Range }

ConstantNode::ConstantNode(void) :
  ExpressionNode(), sign(true), longs(0), size(0)
{}

ConstantNode::ConstantNode(string *name_) :
  ExpressionNode(name_), sign(true), longs(0), size(0)
{}

ConstantNode::ConstantNode(Type t, string *inVal) :
  type(t), sign(true), longs(0), size(0)
{
  if( inVal ) {
    value = *inVal;
    delete inVal;
  } else {
    value = "";
  }

  classify(value);
}

ConstantNode::ConstantNode( const ConstantNode &other )
  : ExpressionNode( other ), type( other.type ), value( other.value ), sign( other.sign ), base( other.base ), longs( other.longs ), size( other.size )
{
}

// for some reason, std::tolower doesn't work as an argument to std::transform in g++ 3.1
inline char
tolower_hack( char c )
{
  return std::tolower( c );
}

void ConstantNode::classify(std::string &str){
  switch(type){
    case Integer:
    case Float:
      {
	std::string sfx("");
	char c;
	int i = str.length() - 1;

	while( i >= 0 && !isxdigit(c = str.at(i--)) )
	  sfx += c;

	value = str.substr( 0, i + 2 );

	// get rid of underscores
	value.erase(remove(value.begin(), value.end(), '_'), value.end());

	std::transform(sfx.begin(), sfx.end(), sfx.begin(), tolower_hack);

	if( sfx.find("ll") != string::npos ){
	  longs = 2;
	} else if (sfx.find("l") != string::npos ){
	  longs = 1;
	}

	assert((longs >= 0) && (longs <= 2));

	if( sfx.find("u") != string::npos )
	  sign = false;

	break;
      }
    case Character:
      {
	// remove underscores from hex and oct escapes
	if(str.substr(1,2) == "\\x")
	  value.erase(remove(value.begin(), value.end(), '_'), value.end());

	break;
      }
  default:
    // shouldn't be here
    ;
  }
}

ConstantNode::Type ConstantNode::get_type(void) const {
  return type;
}

ConstantNode*
ConstantNode::append( std::string *newValue )
{
  if( newValue ) {
    if (type == String){
      std::string temp = *newValue;
      value.resize( value.size() - 1 );
      value += newValue->substr(1, newValue->size());
    } else
      value += *newValue;

    delete newValue;
  }
  return this;
}

void ConstantNode::printOneLine(std::ostream &os, int indent ) const
{
  os << string(indent, ' ');
  printDesignation(os);

  switch( type ) {
    /* integers */
  case Integer:
      os << value ;
      break;
  case Float:
    os << value ;
    break;

  case Character:
    os << "'" << value << "'";
    break;

  case String:
    os << '"' << value << '"';
    break;
  }

  os << ' ';
}

void ConstantNode::print(std::ostream &os, int indent ) const
{
  printOneLine( os, indent );
  os << endl;
}

Expression *ConstantNode::build() const {
  ::Type::Qualifiers q;
  BasicType *bt;

  switch(get_type()){
  case Integer:
    /* Cfr. standard 6.4.4.1 */
    //bt.set_kind(BasicType::SignedInt);
    bt = new BasicType(q, BasicType::SignedInt);
    break;

  case Float:
    bt = new BasicType(q, BasicType::Float);
    break;

  case Character:
    bt = new BasicType(q, BasicType::Char);
    break;

  case String:
    // string should probably be a primitive type
    ArrayType *at;
    std::string value = get_value();
    at = new ArrayType(q, new BasicType(q, BasicType::Char),
				new ConstantExpr( Constant( new BasicType(q, BasicType::SignedInt),
									      toString( value.size() - 1 ) ) ),  // account for '\0'
				false, false );

    return new ConstantExpr( Constant(at, value), maybeBuild< Expression >( get_argName() ) );
  }

  return new ConstantExpr(  Constant(bt, get_value()),  maybeBuild< Expression >( get_argName() ) );
}


VarRefNode::VarRefNode() : isLabel(false) {}

VarRefNode::VarRefNode(string *name_, bool labelp) :
  ExpressionNode(name_), isLabel(labelp) {}

VarRefNode::VarRefNode( const VarRefNode &other )
  : ExpressionNode( other ), isLabel( other.isLabel )
{
}

Expression *VarRefNode::build() const {
  return new NameExpr( get_name(), maybeBuild< Expression >( get_argName() ) );
}

void VarRefNode::printOneLine(std::ostream &os, int indent ) const {
  printDesignation(os);
  os << get_name() << ' ';
}

void VarRefNode::print(std::ostream &os, int indent ) const {
  printDesignation(os);
  os << '\r' << string(indent, ' ') << "Referencing: ";

  os << "Variable: " << get_name();

  os << endl;
}


OperatorNode::OperatorNode(Type t):type(t) {}

OperatorNode::OperatorNode( const OperatorNode &other )
  : ExpressionNode( other ), type( other.type )
{
}

OperatorNode::~OperatorNode() {}

OperatorNode::Type OperatorNode::get_type(void) const{
  return type;
}

void OperatorNode::printOneLine( std::ostream &os, int indent ) const
{
  printDesignation(os);
  os << OpName[ type ] << ' ';
}

void OperatorNode::print( std::ostream &os, int indent ) const{
  printDesignation(os);
  os << '\r' << string(indent, ' ') << "Operator: " << OpName[type] << endl;

  return;
}

std::string OperatorNode::get_typename(void) const{
  return string(OpName[ type ]);
}

const char *OperatorNode::OpName[] =
  { "TupleC",  "Comma", "TupleFieldSel",// "TuplePFieldSel", //n-adic
    // triadic
    "Cond",   "NCond",
    // diadic
    "SizeOf",      "AlignOf", "Attr", "CompLit", "Plus",    "Minus",   "Mul",     "Div",     "Mod",      "Or",
      "And",       "BitOr",   "BitAnd",  "Xor",     "Cast",    "LShift",  "RShift",  "LThan",   "GThan",
      "LEThan",    "GEThan", "Eq",      "Neq",     "Assign",  "MulAssn", "DivAssn", "ModAssn", "PlusAssn",
      "MinusAssn", "LSAssn", "RSAssn",  "AndAssn", "ERAssn",  "OrAssn",  "Index",   "FieldSel","PFieldSel",
      "Range",
    // monadic
    "UnPlus", "UnMinus", "AddressOf", "PointTo", "Neg", "BitNeg", "Incr", "IncrPost", "Decr", "DecrPost", "LabelAddress"
  };

CompositeExprNode::CompositeExprNode(void) : ExpressionNode(), function( 0 ), arguments( 0 ) {
}

CompositeExprNode::CompositeExprNode(string *name_) : ExpressionNode(name_), function( 0 ), arguments( 0 )
{
}

CompositeExprNode::CompositeExprNode(ExpressionNode *f, ExpressionNode *args):
  function(f), arguments(args) {
}

CompositeExprNode::CompositeExprNode(ExpressionNode *f, ExpressionNode *arg1, ExpressionNode *arg2):
  function(f), arguments(arg1) {
  arguments->set_link(arg2);
}

CompositeExprNode::CompositeExprNode( const CompositeExprNode &other )
  : ExpressionNode( other ), function( maybeClone( other.function ) )
{
  ParseNode *cur = other.arguments;
  while( cur ) {
    if( arguments ) {
      arguments->set_link( cur->clone() );
    } else {
      arguments = (ExpressionNode*)cur->clone();
    }
    cur = cur->get_link();
  }
}

CompositeExprNode::~CompositeExprNode()
{
  delete function;
  delete arguments;
}

// the names that users use to define operator functions
static const char *opFuncName[] =
  { "",  "", "",
    "",   "",
    // diadic
    "",   "", "", "", "?+?",    "?-?",   "?*?",     "?/?",     "?%?",     "",       "",
      "?|?",  "?&?",  "?^?",     "",    "?<<?",  "?>>?",  "?<?",   "?>?",    "?<=?",
      "?>=?", "?==?",      "?!=?",     "?=?",  "?*=?", "?/=?", "?%=?", "?+=?", "?-=?",
      "?<<=?", "?>>=?",  "?&=?", "?^=?",  "?|=?",  "?[?]",   "","","Range",
    // monadic
    "+?", "-?", "", "*?", "!?", "~?", "++?", "?++", "--?", "?--", "LabAddress"
  };

#include "utility.h"
Expression *CompositeExprNode::build() const {
  OperatorNode *op;
  std::list<Expression *> args;

  buildList(get_args(), args);

  if (!( op = dynamic_cast<OperatorNode *>(function)) ){
    // a function as opposed to an operator
    return new UntypedExpr(function->build(), args, maybeBuild< Expression >( get_argName() ));

  } else {

    switch(op->get_type()){
    case OperatorNode::Incr:
    case OperatorNode::Decr:
    case OperatorNode::IncrPost:
    case OperatorNode::DecrPost:
    case OperatorNode::Assign:
    case OperatorNode::MulAssn:
    case OperatorNode::DivAssn:
    case OperatorNode::ModAssn:
    case OperatorNode::PlusAssn:
    case OperatorNode::MinusAssn:
    case OperatorNode::LSAssn:
    case OperatorNode::RSAssn:
    case OperatorNode::AndAssn:
    case OperatorNode::ERAssn:
    case OperatorNode::OrAssn:
      // the rewrite rules for these expressions specify that the first argument has its address taken
      assert( !args.empty() );
      args.front() = new AddressExpr( args.front() );
      break;

    default:
      /* do nothing */
      ;
    }

    switch(op->get_type()){

    case OperatorNode::Incr:
    case OperatorNode::Decr:
    case OperatorNode::IncrPost:
    case OperatorNode::DecrPost:
    case OperatorNode::Assign:
    case OperatorNode::MulAssn:
    case OperatorNode::DivAssn:
    case OperatorNode::ModAssn:
    case OperatorNode::PlusAssn:
    case OperatorNode::MinusAssn:
    case OperatorNode::LSAssn:
    case OperatorNode::RSAssn:
    case OperatorNode::AndAssn:
    case OperatorNode::ERAssn:
    case OperatorNode::OrAssn:
    case OperatorNode::Plus:
    case OperatorNode::Minus:
    case OperatorNode::Mul:
    case OperatorNode::Div:
    case OperatorNode::Mod:
    case OperatorNode::BitOr:
    case OperatorNode::BitAnd:
    case OperatorNode::Xor:
    case OperatorNode::LShift:
    case OperatorNode::RShift:
    case OperatorNode::LThan:
    case OperatorNode::GThan:
    case OperatorNode::LEThan:
    case OperatorNode::GEThan:
    case OperatorNode::Eq:
    case OperatorNode::Neq:
    case OperatorNode::Index:
    case OperatorNode::Range:
    case OperatorNode::UnPlus:
    case OperatorNode::UnMinus:
    case OperatorNode::PointTo:
    case OperatorNode::Neg:
    case OperatorNode::BitNeg:
    case OperatorNode::LabelAddress:
      return new UntypedExpr( new NameExpr( opFuncName[ op->get_type() ] ), args );

    case OperatorNode::AddressOf:
      assert( args.size() == 1 );
      assert( args.front() );

      return new AddressExpr( args.front() );

    case OperatorNode::Cast:
      {
	TypeValueNode * arg = dynamic_cast<TypeValueNode *>(get_args());
	assert( arg );

        DeclarationNode *decl_node = arg->get_decl();
        ExpressionNode *expr_node = dynamic_cast<ExpressionNode *>(arg->get_link());

        Type *targetType = decl_node->buildType();
        if( dynamic_cast< VoidType* >( targetType ) ) {
          delete targetType;
          return new CastExpr( expr_node->build(), maybeBuild< Expression >( get_argName() ) );
        } else {
          return new CastExpr(expr_node->build(),targetType, maybeBuild< Expression >( get_argName() ) );
        }
      }

    case OperatorNode::FieldSel:
      {
	assert( args.size() == 2 );

	NameExpr *member = dynamic_cast<NameExpr *>(args.back());
	// TupleExpr *memberTup = dynamic_cast<TupleExpr *>(args.back());

	if ( member != 0 )
	  {
	    UntypedMemberExpr *ret = new UntypedMemberExpr(member->get_name(), args.front());
	    delete member;
	    return ret;
	  }
	/* else if ( memberTup != 0 )
	  {
	    UntypedMemberExpr *ret = new UntypedMemberExpr(memberTup->get_name(), args.front());
	    delete member;
	    return ret;
	    } */
	else
	  assert( false );
      }

    case OperatorNode::PFieldSel:
      {
	assert( args.size() == 2 );

	NameExpr *member = dynamic_cast<NameExpr *>(args.back());  // modify for Tuples   xxx
	assert( member != 0 );

	UntypedExpr *deref = new UntypedExpr( new NameExpr( "*?" ) );
	deref->get_args().push_back( args.front() );

	UntypedMemberExpr *ret = new UntypedMemberExpr(member->get_name(), deref);
	delete member;
	return ret;
      }

    case OperatorNode::AlignOf:
    case OperatorNode::SizeOf:
      {
/// 	bool isSizeOf = (op->get_type() == OperatorNode::SizeOf);

	if( TypeValueNode * arg = dynamic_cast<TypeValueNode *>(get_args()) ) {
          return new SizeofExpr(arg->get_decl()->buildType());
        } else {
	  return new SizeofExpr(args.front());
        }
      }
    
    case OperatorNode::Attr:
      {
        VarRefNode *var = dynamic_cast<VarRefNode *>(get_args());
        assert( var );
        if( !get_args()->get_link() ) {
          return new AttrExpr(var->build(), (Expression*)0);
	} else if( TypeValueNode * arg = dynamic_cast<TypeValueNode *>(get_args()->get_link()) ) {
          return new AttrExpr(var->build(), arg->get_decl()->buildType());
        } else {
	  return new AttrExpr(var->build(), args.back());
        }
      }
    

    case OperatorNode::CompLit:
      throw UnimplementedError( "C99 compound literals" );

      // the short-circuited operators
    case OperatorNode::Or:
    case OperatorNode::And:
      assert(args.size() == 2);
      return new LogicalExpr( notZeroExpr( args.front() ), notZeroExpr( args.back() ), (op->get_type() == OperatorNode::And) );

    case OperatorNode::Cond:
      {
        assert(args.size() == 3);
        std::list< Expression* >::const_iterator i = args.begin();
        Expression *arg1 = notZeroExpr( *i++ );
        Expression *arg2 = *i++;
        Expression *arg3 = *i++;
        return new ConditionalExpr( arg1, arg2, arg3 );
      }

    case OperatorNode::NCond:
      throw UnimplementedError( "GNU 2-argument conditional expression" );

    case OperatorNode::Comma:
      {
        assert(args.size() == 2);
        std::list< Expression* >::const_iterator i = args.begin();
        Expression *ret = *i++;
        while( i != args.end() ) {
          ret = new CommaExpr( ret, *i++ );
        }
        return ret;
      }

      // Tuples
    case OperatorNode::TupleC:
      {
        TupleExpr *ret = new TupleExpr();
        std::copy( args.begin(), args.end(), back_inserter( ret->get_exprs() ) );
        return ret;
      }

    default:
      // shouldn't happen
      return 0;
    }
  }
}

void CompositeExprNode::printOneLine(std::ostream &os, int indent) const
{
  printDesignation(os);
  os << "( ";
  function->printOneLine( os, indent );
  for( ExpressionNode *cur = arguments; cur != 0; cur = dynamic_cast< ExpressionNode* >( cur->get_link() ) ) {
    cur->printOneLine( os, indent );
  }
  os << ") ";
}

void CompositeExprNode::print(std::ostream &os, int indent) const
{
  printDesignation(os);
  os << '\r' << string(indent, ' ') << "Application of: " << endl;
  function->print( os, indent + ParseNode::indent_by );

  os << '\r' << string(indent, ' ') ;
  if( arguments ) {
    os << "... on arguments: " << endl;
    arguments->printList(os, indent + ParseNode::indent_by);
  } else
    os << "... on no arguments: " << endl;
}

void CompositeExprNode::set_function(ExpressionNode *f){
  function = f;
}

void CompositeExprNode::set_args(ExpressionNode *args){
  arguments = args;
}

ExpressionNode *CompositeExprNode::get_function(void) const {
  return function;
}

ExpressionNode *CompositeExprNode::get_args(void) const {
  return arguments;
}

void CompositeExprNode::add_arg(ExpressionNode *arg){
  if(arguments)
    arguments->set_link(arg);
  else
    set_args(arg);
}

CommaExprNode::CommaExprNode(): CompositeExprNode(new OperatorNode(OperatorNode::Comma)) {}

CommaExprNode::CommaExprNode(ExpressionNode *exp)
  : CompositeExprNode( new OperatorNode(OperatorNode::Comma), exp )
 {
 }

CommaExprNode::CommaExprNode(ExpressionNode *exp1, ExpressionNode *exp2)
  : CompositeExprNode(new OperatorNode(OperatorNode::Comma), exp1, exp2)
{
}

CommaExprNode *CommaExprNode::add_to_list(ExpressionNode *exp){
  add_arg(exp);

  return this;
}

CommaExprNode::CommaExprNode( const CommaExprNode &other )
  : CompositeExprNode( other )
{
}

ValofExprNode::ValofExprNode(StatementNode *s): body(s) {}

ValofExprNode::ValofExprNode( const ValofExprNode &other )
  : ExpressionNode( other ), body( maybeClone( body ) )
{
}

ValofExprNode::~ValofExprNode() {
  delete body;
}

void ValofExprNode::print( std::ostream &os, int indent ) const {
  printDesignation(os);
  os << string(indent, ' ') << "Valof Expression:" << std::endl;
  get_body()->print(os, indent + 4);
}

void ValofExprNode::printOneLine( std::ostream &, int indent ) const
{
  assert( false );
}

Expression *ValofExprNode::build() const {
  return new UntypedValofExpr ( get_body()->build(), maybeBuild< Expression >( get_argName() ) );
}

ForCtlExprNode::ForCtlExprNode(ParseNode *init_, ExpressionNode *cond, ExpressionNode *incr)
  throw (SemanticError)
  : condition(cond), change(incr)
{
  if(init_ == 0)
    init = 0;
  else {
    DeclarationNode *decl;
    ExpressionNode *exp;

    if((decl = dynamic_cast<DeclarationNode *>(init_)) != 0)
      init = new StatementNode(decl);
    else if((exp = dynamic_cast<ExpressionNode *>(init_)) != 0)
      init = new StatementNode(StatementNode::Exp, exp);
    else
      throw SemanticError("Error in for control expression");
  }
}

ForCtlExprNode::ForCtlExprNode( const ForCtlExprNode &other )
  : ExpressionNode( other ), init( maybeClone( other.init ) ), condition( maybeClone( other.condition ) ), change( maybeClone( other.change ) )
{
}

ForCtlExprNode::~ForCtlExprNode(){
  delete init;
  delete condition;
  delete change;
}

Expression *ForCtlExprNode::build() const {
  // this shouldn't be used!
  assert( false );
  return 0;
}

void ForCtlExprNode::print( std::ostream &os, int indent ) const{
  os << string(indent,' ') << "For Control Expression -- : " << endl;

  os << "\r" << string(indent + 2,' ') << "initialization: ";
  if(init != 0)
    init->print(os, indent + 4);

  os << "\n\r" << string(indent + 2,' ') << "condition: ";
  if(condition != 0)
    condition->print(os, indent + 4);
  os << "\n\r" << string(indent + 2,' ') << "increment: ";
  if(change != 0)
    change->print(os, indent + 4);
}

void
ForCtlExprNode::printOneLine( std::ostream &, int indent ) const
{
  assert( false );
}

TypeValueNode::TypeValueNode(DeclarationNode *decl)
  : decl( decl )
{
}

TypeValueNode::TypeValueNode( const TypeValueNode &other )
  : ExpressionNode( other ), decl( maybeClone( other.decl ) )
{
}

Expression *
TypeValueNode::build() const
{
  return new TypeExpr( decl->buildType() );
}

void
TypeValueNode::print(std::ostream &os, int indent) const
{
  os << std::string( indent, ' ' ) << "Type:";
  get_decl()->print(os, indent + 2);
}

void
TypeValueNode::printOneLine(std::ostream &os, int indent) const
{
  os << "Type:";
  get_decl()->print(os, indent + 2);
}

ExpressionNode *flattenCommas( ExpressionNode *list )
{
  if( CompositeExprNode *composite = dynamic_cast< CompositeExprNode * >( list ) )
    {
      OperatorNode *op;
           if ( (op = dynamic_cast< OperatorNode * >( composite->get_function() )) && (op->get_type() == OperatorNode::Comma) )
	     {
	         if ( ExpressionNode *next = dynamic_cast< ExpressionNode * >( list->get_link() ) )
		   composite->add_arg( next );
		 return flattenCommas( composite->get_args() );
	     }
    }

  if ( ExpressionNode *next = dynamic_cast< ExpressionNode * >( list->get_link() ) )
    list->set_next( flattenCommas( next ) );

  return list;
}

ExpressionNode *tupleContents( ExpressionNode *tuple )
{
  if( CompositeExprNode *composite = dynamic_cast< CompositeExprNode * >( tuple ) ) {
    OperatorNode *op = 0;
    if ( (op = dynamic_cast< OperatorNode * >( composite->get_function() )) && (op->get_type() == OperatorNode::TupleC) )
      return composite->get_args();
  }
  return tuple;
}
