#include <algorithm>
#include <iostream>
#include <cassert>
#include <list>

#include "SynTree/Type.h"
#include "SynTree/Declaration.h"
#include "SynTree/Statement.h"
#include "SynTree/Expression.h"
#include "SynTree/Initializer.h"

#include "utility.h"
#include "UnimplementedError.h"

#include "CodeGenerator2.h"
#include "OperatorTable.h"
#include "GenType.h"

using namespace std;

namespace CodeGen {
    int CodeGenerator2::tabsize = 4;

    CodeGenerator2::CodeGenerator2( std::ostream &os ) : cur_indent( 0 ), insideFunction( false ), before( os ), after() { }

    CodeGenerator2::CodeGenerator2( std::ostream &os, std::string init, int indent, bool infunp )
	: cur_indent( indent ), insideFunction( infunp ), before( os )
    {
	//before << std::string( init );
    }

    CodeGenerator2::CodeGenerator2( std::ostream &os, char *init, int indent, bool infunp )
	: cur_indent( indent ), insideFunction( infunp ), before( os )
    {
	//before << std::string( init );
    }

    string mangleName( DeclarationWithType *decl ) {
	if ( decl->get_mangleName() != "" ) {
	    return decl->get_mangleName();
	} else {
	    return decl->get_name();
	} // if
    }
  
    //*** Declarations
    void CodeGenerator2::visit( FunctionDecl *functionDecl ) {
	handleStorageClass( functionDecl );
	before << genType( functionDecl->get_functionType(), mangleName( functionDecl ) );

	// how to get this to the Functype?
	std::list< Declaration * > olds = functionDecl->get_oldDecls();
	if ( ! olds.empty() ) {
	    before << " /* function has old declaration */";
	} // if

	// acceptAll( functionDecl->get_oldDecls(), *this );
	if ( functionDecl->get_statements() ) {
	    functionDecl->get_statements()->accept(*this );
	} // if
    }

    void CodeGenerator2::visit( ObjectDecl *objectDecl ) {
	handleStorageClass( objectDecl );
	before << genType( objectDecl->get_type(), mangleName( objectDecl ) );
    
	if ( objectDecl->get_init() ) {
	    before << " = ";
	    objectDecl->get_init()->accept( *this );
	} // if
	if ( objectDecl->get_bitfieldWidth() ) {
	    before << ":";
	    objectDecl->get_bitfieldWidth()->accept( *this );
	} // if
    }

    void CodeGenerator2::handleAggregate( AggregateDecl *aggDecl ) {
	if ( aggDecl->get_name() != "" )
	    before << aggDecl->get_name();
    
	std::list< Declaration * > &memb = aggDecl->get_members();

	if ( ! memb.empty() ) {
	    before << endl << string( cur_indent, ' ' ) << "{" << endl;

	    cur_indent += CodeGenerator2::tabsize; 
	    for ( std::list< Declaration* >::iterator i = memb.begin(); i != memb.end();  i++) {
		before << string( cur_indent, ' ' ); 
		(*i)->accept(*this );
		before << ";" << endl;
	    }

	    cur_indent -= CodeGenerator2::tabsize; 

	    before << string( cur_indent, ' ' ) << "}";
	} // if
    }

    void CodeGenerator2::visit( StructDecl *structDecl ) {
	before << "struct ";
	handleAggregate( structDecl );
    }

    void CodeGenerator2::visit( UnionDecl *aggregateDecl ) {
	before << "union ";
	handleAggregate( aggregateDecl );
    }
  
    void CodeGenerator2::visit( EnumDecl *aggDecl ) {
	before << "enum ";

	if ( aggDecl->get_name() != "" )
	    before << aggDecl->get_name();
    
	std::list< Declaration* > &memb = aggDecl->get_members();

	if ( ! memb.empty() ) {
	    before << endl << "{" << endl;

	    cur_indent += CodeGenerator2::tabsize; 
	    for ( std::list< Declaration* >::iterator i = memb.begin(); i != memb.end();  i++) {
		ObjectDecl *obj = dynamic_cast< ObjectDecl* >( *i );
		assert( obj );
		before << string( cur_indent, ' ' ) << mangleName( obj ); 
		if ( obj->get_init() ) {
		    before << " = ";
		    obj->get_init()->accept(*this );
		} // if
		before << "," << endl;
	    }

	    cur_indent -= CodeGenerator2::tabsize; 

	    before << "}" << endl;
	} // if
    }
  
    void CodeGenerator2::visit( ContextDecl *aggregateDecl ) {}
  
    void CodeGenerator2::visit( TypedefDecl *typeDecl ) {
	before << "typedef ";
	before << genType( typeDecl->get_base(), typeDecl->get_name() );
    }
  
    void CodeGenerator2::visit( TypeDecl *typeDecl ) {
	// really, we should mutate this into something that isn't a TypeDecl but that requires large-scale changes,
	// still to be done
	before << "extern unsigned long " << typeDecl->get_name();
	if ( typeDecl->get_base() ) {
	    before << " = sizeof( " << genType( typeDecl->get_base(), "" ) << " )";
	} // if
    }

    void CodeGenerator2::visit( SingleInit *init ) {
	init->get_value()->accept( *this );
    }

    void CodeGenerator2::visit( ListInit *init ) {
	before << "{ ";
	genCommaList( init->begin_initializers(), init->end_initializers() );
	before << " }";
    }

    void CodeGenerator2::visit( Constant *constant ) { 
	before << constant->get_value() ;
    }

    //*** Expressions
    void CodeGenerator2::visit( ApplicationExpr *applicationExpr ) {
	if ( VariableExpr *varExpr = dynamic_cast< VariableExpr* >( applicationExpr->get_function() ) ) {
	    OperatorInfo opInfo;
	    if ( varExpr->get_var()->get_linkage() == LinkageSpec::Intrinsic && operatorLookup( varExpr->get_var()->get_name(), opInfo ) ) {
		std::list< Expression* >::iterator arg = applicationExpr->get_args().begin();
		switch ( opInfo.type ) {
		  case OT_PREFIXASSIGN:
		  case OT_POSTFIXASSIGN:
		  case OT_INFIXASSIGN:
		    {
			assert( arg != applicationExpr->get_args().end() );
			if ( AddressExpr *addrExpr = dynamic_cast< AddressExpr * >( *arg ) ) {
            
			    *arg = addrExpr->get_arg();
			} else {
			    UntypedExpr *newExpr = new UntypedExpr( new NameExpr( "*?" ) );
			    newExpr->get_args().push_back( *arg );
			    *arg = newExpr;
			} // if
			break;
		    }
          
		  default:
		    // do nothing
		    ;
		}
        
		switch ( opInfo.type ) {
		  case OT_INDEX:
		    assert( applicationExpr->get_args().size() == 2 );
		    (*arg++)->accept( *this );
		    before << "[";
		    (*arg)->accept( *this );
		    before << "]";
		    break;
          
		  case OT_CALL:
		    // there are no intrinsic definitions of the function call operator
		    assert( false );
		    break;
          
		  case OT_PREFIX:
		  case OT_PREFIXASSIGN:
		    assert( applicationExpr->get_args().size() == 1 );
		    before << "(";
		    before << opInfo.symbol;
		    (*arg)->accept( *this );
		    before << ")";
		    break;
          
		  case OT_POSTFIX:
		  case OT_POSTFIXASSIGN:
		    assert( applicationExpr->get_args().size() == 1 );
		    (*arg)->accept( *this );
		    before << opInfo.symbol;
		    break;

		  case OT_INFIX:
		  case OT_INFIXASSIGN:
		    assert( applicationExpr->get_args().size() == 2 );
		    before << "(";
		    (*arg++)->accept( *this );
		    before << opInfo.symbol;
		    (*arg)->accept( *this );
		    before << ")";
		    break;
          
		  case OT_CONSTANT:
		    // there are no intrinsic definitions of 0 or 1 as functions
		    assert( false );
		}
	    } else {
		varExpr->accept( *this );
		before << "(";
		genCommaList( applicationExpr->get_args().begin(), applicationExpr->get_args().end() );
		before << ")";
	    } // if
	} else {
	    applicationExpr->get_function()->accept( *this );
	    before << "(";
	    genCommaList( applicationExpr->get_args().begin(), applicationExpr->get_args().end() );
	    before << ")";
	} // if
    }
  
    void CodeGenerator2::visit( UntypedExpr *untypedExpr ) {
	if ( NameExpr *nameExpr = dynamic_cast< NameExpr* >( untypedExpr->get_function() ) ) {
	    OperatorInfo opInfo;
	    if ( operatorLookup( nameExpr->get_name(), opInfo ) ) {
		std::list< Expression* >::iterator arg = untypedExpr->get_args().begin();
		switch ( opInfo.type ) {
		  case OT_INDEX:
		    assert( untypedExpr->get_args().size() == 2 );
		    (*arg++)->accept( *this );
		    before << "[";
		    (*arg)->accept( *this );
		    before << "]";
		    break;
          
		  case OT_CALL:
		    assert( false );
		    break;
          
		  case OT_PREFIX:
		  case OT_PREFIXASSIGN:
		    assert( untypedExpr->get_args().size() == 1 );
		    before << "(";
		    before << opInfo.symbol;
		    (*arg)->accept( *this );
		    before << ")";
		    break;
          
		  case OT_POSTFIX:
		  case OT_POSTFIXASSIGN:
		    assert( untypedExpr->get_args().size() == 1 );
		    (*arg)->accept( *this );
		    before << opInfo.symbol;
		    break;
  
		  case OT_INFIX:
		  case OT_INFIXASSIGN:
		    assert( untypedExpr->get_args().size() == 2 );
		    before << "(";
		    (*arg++)->accept( *this );
		    before << opInfo.symbol;
		    (*arg)->accept( *this );
		    before << ")";
		    break;
          
		  case OT_CONSTANT:
		    // there are no intrinsic definitions of 0 or 1 as functions
		    assert( false );
		}
	    } else {
		nameExpr->accept( *this );
		before << "(";
		genCommaList( untypedExpr->get_args().begin(), untypedExpr->get_args().end() );
		before << ")";
	    } // if
	} else {
	    untypedExpr->get_function()->accept( *this );
	    before << "(";
	    genCommaList( untypedExpr->get_args().begin(), untypedExpr->get_args().end() );
	    before << ")";
	} // if
    }
  
    void CodeGenerator2::visit( NameExpr *nameExpr ) {
	OperatorInfo opInfo;
	if ( operatorLookup( nameExpr->get_name(), opInfo ) ) {
	    assert( opInfo.type == OT_CONSTANT );
	    before << opInfo.symbol;
	} else {
	    before << nameExpr->get_name();
	} // if
    }
  
    void CodeGenerator2::visit( AddressExpr *addressExpr ) {
	before << "(&";
	// this hack makes sure that we don't convert "constant_zero" to "0" if we're taking its address
	if ( VariableExpr *variableExpr = dynamic_cast< VariableExpr* >( addressExpr->get_arg() ) ) {
	    before << mangleName( variableExpr->get_var() );
	} else {
	    addressExpr->get_arg()->accept( *this );
	} // if
	before << ")";
    }

    void CodeGenerator2::visit( CastExpr *castExpr ) {
	before << "((";
	if ( castExpr->get_results().empty() ) {
	    before << "void" ;
	} else {
	    before << genType( castExpr->get_results().front(), "" );
	} // if
	before << ")";
	castExpr->get_arg()->accept( *this );
	before << ")";
    }
  
    void CodeGenerator2::visit( UntypedMemberExpr *memberExpr ) {
	assert( false );
    }
  
    void CodeGenerator2::visit( MemberExpr *memberExpr ) {
	memberExpr->get_aggregate()->accept( *this );
	before << "." << mangleName( memberExpr->get_member() );
    }
  
    void CodeGenerator2::visit( VariableExpr *variableExpr ) {
	OperatorInfo opInfo;
	if ( variableExpr->get_var()->get_linkage() == LinkageSpec::Intrinsic && operatorLookup( variableExpr->get_var()->get_name(), opInfo ) && opInfo.type == OT_CONSTANT ) {
	    before << opInfo.symbol;
	} else {
	    before << mangleName( variableExpr->get_var() );
	} // if
    }
  
    void CodeGenerator2::visit( ConstantExpr *constantExpr ) {
	assert( constantExpr->get_constant() );
	constantExpr->get_constant()->accept( *this );
    }
  
    void CodeGenerator2::visit( SizeofExpr *sizeofExpr ) {
	before << "sizeof(";
	if ( sizeofExpr->get_isType() ) {
	    before << genType( sizeofExpr->get_type(), "" );
	} else {
	    sizeofExpr->get_expr()->accept( *this );
	} // if
	before << ")";
    }
  
    void CodeGenerator2::visit( LogicalExpr *logicalExpr ) {
	before << "(";
	logicalExpr->get_arg1()->accept( *this );
	if ( logicalExpr->get_isAnd() ) {
	    before << " && ";
	} else {
	    before << " || ";
	} // if
	logicalExpr->get_arg2()->accept( *this );
	before << ")";
    }
  
    void CodeGenerator2::visit( ConditionalExpr *conditionalExpr ) {
	before << "(";
	conditionalExpr->get_arg1()->accept( *this );
	before << " ? ";
	conditionalExpr->get_arg2()->accept( *this );
	before << " : ";
	conditionalExpr->get_arg3()->accept( *this );
	before << ")";
    }
  
    void CodeGenerator2::visit( CommaExpr *commaExpr ) {
	before << "(";
	commaExpr->get_arg1()->accept( *this );
	before << " , ";
	commaExpr->get_arg2()->accept( *this );
	before << ")";
    }
  
    void CodeGenerator2::visit( TupleExpr *tupleExpr ) {}
  
    void CodeGenerator2::visit( TypeExpr *typeExpr ) {}
  
  
    //*** Statements
    void CodeGenerator2::visit( CompoundStmt *compoundStmt ) {
	std::list<Statement*> ks = compoundStmt->get_kids();

	before << endl << string( cur_indent, ' ' ) << "{" << endl;

	cur_indent += CodeGenerator2::tabsize; 

	for ( std::list<Statement *>::iterator i = ks.begin(); i != ks.end();  i++) {
	    before << string( cur_indent, ' ' ) << printLabels( (*i)->get_labels() )  ;
	    (*i)->accept(*this );
	    shift_left();
	    before << endl;
	}
	cur_indent -= CodeGenerator2::tabsize; 

	before << string( cur_indent, ' ' ) << "}" << endl;
    }

    void CodeGenerator2::visit( ExprStmt *exprStmt ) {
	if ( exprStmt != 0 ) {
	    exprStmt->get_expr()->accept( *this );
	    shift_left();
	    before << ";" ;
	} // if
    }

    void CodeGenerator2::visit( IfStmt *ifStmt ) {
	before << "if (";
	ifStmt->get_condition()->accept(*this );
	after += ")\n";
	shift_left(); 

	cur_indent += CodeGenerator2::tabsize;
	before << string( cur_indent, ' ' );
	ifStmt->get_thenPart()->accept(*this );
	cur_indent -= CodeGenerator2::tabsize; 
	shift_left(); before << endl;

	if ( ifStmt->get_elsePart() != 0) {
	    before << string( cur_indent, ' ' ) << " else " << endl ;

	    cur_indent += CodeGenerator2::tabsize; 
	    ifStmt->get_elsePart()->accept(*this );
	    cur_indent -= CodeGenerator2::tabsize; 
	} // if
    }

    void CodeGenerator2::visit( SwitchStmt *switchStmt ) {
	//before << /* "\r" << */ string( cur_indent, ' ' ) << CodeGenerator2::printLabels( switchStmt->get_labels() ) 
	before << "switch (" ;
	switchStmt->get_condition()->accept(*this );
	after += ")\n";
	shift_left();

	before << string( cur_indent, ' ' ) << "{" << std::endl;
	cur_indent += CodeGenerator2::tabsize;

	std::list< Statement * > stmts = switchStmt->get_branches();
	bool lastBreak = false; 

	// horrible, horrible hack
	if ( dynamic_cast<BranchStmt *>( stmts.back() ) != 0 ) {
	    lastBreak = true;
	    stmts.pop_back();
	} // if
	acceptAll( stmts, *this );
	if ( lastBreak ) {
	    Statement *st = switchStmt->get_branches().back();
	    before << CodeGenerator2::printLabels( st->get_labels());
	    st->accept( *this );
	} // if
      
	cur_indent -= CodeGenerator2::tabsize; 

	before << /* "\r" << */ string( cur_indent, ' ' ) << "}" << endl ;
    }

    void CodeGenerator2::visit( CaseStmt *caseStmt ) {
	before << string( cur_indent, ' ' );
	if ( caseStmt->isDefault()) 
	    before << "default "  ;
	else {
	    before << "case "  ;
	    caseStmt->get_condition()->accept(*this );
	} // if
	after += ":\n";
	shift_left();

	std::list<Statement *> sts = caseStmt->get_statements();

	cur_indent += CodeGenerator2::tabsize;
	for ( std::list<Statement *>::iterator i = sts.begin(); i != sts.end();  i++) {
	    before << /* "\r" << */ string( cur_indent, ' ' ) << printLabels( (*i)->get_labels() )  ;
	    (*i)->accept(*this );
	    shift_left();
	    before << ";" << endl;
	}
	cur_indent -= CodeGenerator2::tabsize;
    }

    void CodeGenerator2::visit( BranchStmt *branchStmt ) {
	switch ( branchStmt->get_type()) {
	  case BranchStmt::Goto:
	    if ( ! branchStmt->get_target().empty() )
		before << "goto " << branchStmt->get_target();
	    else { 
		if ( branchStmt->get_computedTarget() != 0 ) {
		    before << "goto *";
		    branchStmt->get_computedTarget()->accept( *this );
		} // if
	    } // if
	    break;
	  case BranchStmt::Break:
	    before << "break";
	    break;
	  case BranchStmt::Continue:
	    before << "continue";
	    break;
	}
	before << ";";
    }


    void CodeGenerator2::visit( ReturnStmt *returnStmt ) {
	before << "return ";

	// xxx -- check for null expression;
	if ( returnStmt->get_expr() ) {
	    returnStmt->get_expr()->accept( *this );
	} // if
	after += ";";
    }

    void CodeGenerator2::visit( WhileStmt *whileStmt ) {
	if ( whileStmt->get_isDoWhile() )
	    before << "do" ;
	else {
	    before << "while(" ;
	    whileStmt->get_condition()->accept(*this );
	    after += ")";
	} // if
	after += "{\n";
	shift_left();

	whileStmt->get_body()->accept( *this );

	before << /* "\r" << */ string( cur_indent, ' ' ) << "}" ;

	if ( whileStmt->get_isDoWhile() ) {
	    before << " while(" ;
	    whileStmt->get_condition()->accept(*this );
	    after += ");";
	} // if

	after += "\n";
    }

    void CodeGenerator2::visit( ForStmt *forStmt ) {
	before << "for (";

	if ( forStmt->get_initialization() != 0 )
	    forStmt->get_initialization()->accept( *this );
	else
	    before << ";";
	shift_left();

	if ( forStmt->get_condition() != 0 )
	    forStmt->get_condition()->accept( *this );
	shift_left(); before << ";";

	if ( forStmt->get_increment() != 0 )
	    forStmt->get_increment()->accept( *this );
	shift_left(); before << ")" << endl;

	if ( forStmt->get_body() != 0 ) {
	    cur_indent += CodeGenerator2::tabsize; 
	    before << string( cur_indent, ' ' ) << CodeGenerator2::printLabels( forStmt->get_body()->get_labels() );
	    forStmt->get_body()->accept( *this );
	    cur_indent -= CodeGenerator2::tabsize; 
	} // if
    }

    void CodeGenerator2::visit( NullStmt *nullStmt ) {
	//before << /* "\r" << */ string( cur_indent, ' ' ) << CodeGenerator2::printLabels( nullStmt->get_labels() );
	before << "/* null statement */ ;";
    }

    void CodeGenerator2::visit( DeclStmt *declStmt ) {
	declStmt->get_decl()->accept( *this );
    
	if ( doSemicolon( declStmt->get_decl() ) ) {
	    after += ";";
	} // if
	shift_left();
    }

    std::string CodeGenerator2::printLabels( std::list< Label > &l ) {
	std::string str( "" );
	l.unique();

	for ( std::list< Label >::iterator i = l.begin(); i != l.end(); i++ )
	    str += *i + ": ";

	return str;
    }

    void CodeGenerator2::shift_left() {
	before << after;
	after = "";
    }

    void CodeGenerator2::handleStorageClass( Declaration *decl ) {
	switch ( decl->get_storageClass() ) {
	  case Declaration::NoStorageClass:
	    break;
	  case Declaration::Auto:
	    break;
	  case Declaration::Static:
	    before << "static ";
	    break;
	  case Declaration::Extern:
	    before << "extern ";
	    break;
	  case Declaration::Register:
	    before << "register ";
	    break;
	  case Declaration::Fortran:
	    before << "fortran ";
	    break;
	}
    }
} // namespace CodeGen
