/*
 * This file is part of the Cforall project
 *
 * $Id: PolyMutator.cc,v 1.7 2005/08/29 20:14:13 rcbilson Exp $
 *
 */

#include "PolyMutator.h"
#include "SynTree/Declaration.h"
#include "SynTree/Type.h"
#include "SynTree/Expression.h"
#include "SynTree/Statement.h"
#include "SynTree/Mutator.h"


namespace GenPoly {

namespace {
const std::list<Label> noLabels;
}

PolyMutator::PolyMutator()
  : env( 0 )
{
}

void
PolyMutator::mutateStatementList( std::list< Statement* > &statements )
{
  for( std::list< Statement* >::iterator i = statements.begin(); i != statements.end(); ++i ) {
    if( !stmtsToAddAfter.empty() ) {
      statements.splice( i, stmtsToAddAfter );
    }
    *i = (*i)->acceptMutator( *this );
    if( !stmtsToAdd.empty() ) {
      statements.splice( i, stmtsToAdd );
    }
  }
  if( !stmtsToAddAfter.empty() ) {
    statements.splice( statements.end(), stmtsToAddAfter );
  }
}

Statement*
PolyMutator::mutateStatement( Statement *stmt )
{
  Statement *newStmt = maybeMutate( stmt, *this );
  if( !stmtsToAdd.empty() || !stmtsToAddAfter.empty() ) {
    CompoundStmt *compound = new CompoundStmt( noLabels );
    compound->get_kids().splice( compound->get_kids().end(), stmtsToAdd );
    compound->get_kids().push_back( newStmt );
    compound->get_kids().splice( compound->get_kids().end(), stmtsToAddAfter );
    doEndScope();
    return compound;
  } else {
    return newStmt;
  }
}

Expression*
PolyMutator::mutateExpression( Expression *expr )
{
  if( expr ) {
    if( expr->get_env() ) {
      env = expr->get_env();
    }
    return expr->acceptMutator( *this );
  } else {
    return expr;
  }
}

CompoundStmt*
PolyMutator::mutate(CompoundStmt *compoundStmt)
{
  mutateStatementList( compoundStmt->get_kids() );
  doEndScope();
  return compoundStmt;
}

Statement*
PolyMutator::mutate(IfStmt *ifStmt)
{
  ifStmt->set_thenPart(  mutateStatement( ifStmt->get_thenPart() ) );
  ifStmt->set_elsePart(  mutateStatement( ifStmt->get_elsePart() ) );
  ifStmt->set_condition(  mutateExpression( ifStmt->get_condition() ) );
  return ifStmt;
}

Statement*
PolyMutator::mutate(WhileStmt *whileStmt)
{
  whileStmt->set_body(  mutateStatement( whileStmt->get_body() ) );
  whileStmt->set_condition(  mutateExpression( whileStmt->get_condition() ) );
  return whileStmt;
}

Statement*
PolyMutator::mutate(ForStmt *forStmt)
{
  forStmt->set_body(  mutateStatement( forStmt->get_body() ) );
  forStmt->set_initialization(  maybeMutate( forStmt->get_initialization(), *this ) );
  forStmt->set_condition(  mutateExpression( forStmt->get_condition() ) );
  forStmt->set_increment(  mutateExpression( forStmt->get_increment() ) );
  return forStmt;
}

Statement*
PolyMutator::mutate(SwitchStmt *switchStmt)
{
  mutateStatementList( switchStmt->get_branches() );
  switchStmt->set_condition( mutateExpression( switchStmt->get_condition() ) );
  return switchStmt;
}

Statement*
PolyMutator::mutate(ChooseStmt *switchStmt)
{
  mutateStatementList( switchStmt->get_branches() );
  switchStmt->set_condition( mutateExpression( switchStmt->get_condition() ) );
  return switchStmt;
}

Statement*
PolyMutator::mutate(CaseStmt *caseStmt)
{
  mutateStatementList( caseStmt->get_statements() );
  caseStmt->set_condition(  mutateExpression( caseStmt->get_condition() ) );

  return caseStmt;
}

Statement*
PolyMutator::mutate(TryStmt *tryStmt)
{
  tryStmt->set_block(  maybeMutate( tryStmt->get_block(), *this ) );
  mutateAll( tryStmt->get_catchers(), *this );
  
  return tryStmt;
}

Statement*
PolyMutator::mutate(CatchStmt *cathStmt)
{
  cathStmt->set_body(  mutateStatement( cathStmt->get_body() ) );
  cathStmt->set_decl(  maybeMutate( cathStmt->get_decl(), *this ) );
  return cathStmt;
}

Statement* 
PolyMutator::mutate(ReturnStmt *retStmt)
{
  retStmt->set_expr( mutateExpression( retStmt->get_expr() ) );
  return retStmt;
}

Statement* 
PolyMutator::mutate(ExprStmt *exprStmt)
{
  exprStmt->set_expr( mutateExpression( exprStmt->get_expr() ) );
  return exprStmt;
}


Expression* 
PolyMutator::mutate(UntypedExpr *untypedExpr)
{
  for( std::list< Expression* >::iterator i = untypedExpr->get_args().begin(); i != untypedExpr->get_args().end(); ++i ) {
    *i = mutateExpression( *i );
  }
  return untypedExpr;
}
 
/* static class method */
void 
PolyMutator::makeTyVarMap( Type *type, TyVarMap &tyVarMap )
{
  for( std::list< TypeDecl* >::const_iterator tyVar = type->get_forall().begin(); tyVar != type->get_forall().end(); ++tyVar ) {
    assert( *tyVar );
    tyVarMap[ (*tyVar)->get_name() ] = (*tyVar)->get_kind();
  }
  if( PointerType *pointer = dynamic_cast< PointerType* >( type ) ) {
    makeTyVarMap( pointer->get_base(), tyVarMap );
  }
}

/* static class method */
} // namespace GenPoly
