/*
 * This file is part of the Cforall project
 *
 * $Id: Specialize.cc,v 1.4 2005/08/29 20:14:13 rcbilson Exp $
 *
 */

#include <cassert>

#include "Specialize.h"
#include "PolyMutator.h"

#include "SynTree/Declaration.h"
#include "SynTree/Statement.h"
#include "SynTree/Expression.h"
#include "SynTree/Type.h"
#include "SynTree/TypeSubstitution.h"
#include "SynTree/Mutator.h"
#include "ResolvExpr/FindOpenVars.h"
#include "UniqueName.h"
#include "utility.h"


namespace GenPoly {

const std::list<Label> noLabels;

class Specialize : public PolyMutator
{
public:
  Specialize( std::string paramPrefix = "_p" );
  
  virtual Expression* mutate(ApplicationExpr *applicationExpr);
  virtual Expression* mutate(AddressExpr *castExpr);
  virtual Expression* mutate(CastExpr *castExpr);
  virtual Expression* mutate(LogicalExpr *logicalExpr);
  virtual Expression* mutate(ConditionalExpr *conditionalExpr);
  virtual Expression* mutate(CommaExpr *commaExpr);

private:
  Expression *doSpecialization( Type *formalType, Expression *actual, InferredParams *inferParams = 0 );
  void handleExplicitParams( ApplicationExpr *appExpr );
  
  UniqueName thunkNamer;
  std::string paramPrefix;
};

void
convertSpecializations( std::list< Declaration* >& translationUnit )
{
  Specialize specializer;
  mutateAll( translationUnit, specializer );
}

Specialize::Specialize( std::string paramPrefix )
  : thunkNamer( "_thunk" ), paramPrefix( paramPrefix )
{
}

bool
needsSpecialization( Type *formalType, Type *actualType, TypeSubstitution *env )
{
  if( env ) {
    using namespace ResolvExpr;
    OpenVarSet openVars, closedVars;
    AssertionSet need, have;
    findOpenVars( formalType, openVars, closedVars, need, have, false );
    findOpenVars( actualType, openVars, closedVars, need, have, true );
    for( OpenVarSet::const_iterator openVar = openVars.begin(); openVar != openVars.end(); ++openVar ) {
      Type *boundType = env->lookup( openVar->first );
      if( !boundType ) continue;
      if( TypeInstType *typeInst = dynamic_cast< TypeInstType* >( boundType ) ) {
        if( closedVars.find( typeInst->get_name() ) == closedVars.end() ) {
          return true;
        }
      } else {
        return true;
      }
    }
    return false;
  } else {
    return false;
  }
}

Expression*
Specialize::doSpecialization( Type *formalType, Expression *actual, InferredParams *inferParams )
{
  if( needsSpecialization( formalType, actual->get_results().front(), env ) ) {
    PointerType *ptrType;
    FunctionType *funType;
    if( ( ptrType = dynamic_cast< PointerType* >( formalType ) ) && ( funType = dynamic_cast< FunctionType* >( ptrType->get_base() ) ) ) {
      FunctionType *newType = funType->clone();
      if( env ) {
        TypeSubstitution newEnv( *env );
        // it is important to replace only occurrences of type variables that occur free in the
        // thunk's type
        newEnv.applyFree( newType );
      }
      FunctionDecl *thunkFunc = new FunctionDecl( thunkNamer.newName(), Declaration::NoStorageClass, LinkageSpec::C, newType, new CompoundStmt( std::list< std::string >() ), false );
      thunkFunc->fixUniqueId();
      
      UniqueName paramNamer( paramPrefix );
      ApplicationExpr *appExpr = new ApplicationExpr( actual );
      for( std::list< DeclarationWithType* >::iterator param = thunkFunc->get_functionType()->get_parameters().begin(); param != thunkFunc->get_functionType()->get_parameters().end(); ++param ) {
        (*param)->set_name( paramNamer.newName() );
        appExpr->get_args().push_back( new VariableExpr( *param ) );
      }
      appExpr->set_env( maybeClone( env ) );
      if( inferParams ) {
        appExpr->get_inferParams() = *inferParams;
      }
      
      // handle any specializations that may still be present
      std::string oldParamPrefix = paramPrefix;
      paramPrefix += "p";
      std::list< Statement* > oldStmts;
      oldStmts.splice( oldStmts.end(), stmtsToAdd );
      handleExplicitParams( appExpr );
      paramPrefix = oldParamPrefix;
      thunkFunc->get_statements()->get_kids().splice( thunkFunc->get_statements()->get_kids().end(), stmtsToAdd );
      stmtsToAdd.splice( stmtsToAdd.end(), oldStmts );
      
      Statement *appStmt;
      if( funType->get_returnVals().empty() ) {
        appStmt = new ExprStmt( noLabels, appExpr );
      } else {
        appStmt = new ReturnStmt( noLabels, appExpr );
      }
      thunkFunc->get_statements()->get_kids().push_back( appStmt );
      stmtsToAdd.push_back( new DeclStmt( noLabels, thunkFunc ) );
      return new AddressExpr( new VariableExpr( thunkFunc ) );
    } else {
      return actual;
    }
  } else {
    return actual;
  }
}

void
Specialize::handleExplicitParams( ApplicationExpr *appExpr )
{
  // create thunks for the explicit parameters
  assert( !appExpr->get_function()->get_results().empty() );
  PointerType *pointer = dynamic_cast< PointerType* >( appExpr->get_function()->get_results().front() );
  assert( pointer );
  FunctionType *function = dynamic_cast< FunctionType* >( pointer->get_base() );
  std::list< DeclarationWithType* >::iterator formal;
  std::list< Expression* >::iterator actual;
  for( formal = function->get_parameters().begin(), actual = appExpr->get_args().begin(); formal != function->get_parameters().end() && actual != appExpr->get_args().end(); ++formal, ++actual ) {
    *actual = doSpecialization( (*formal)->get_type(), *actual, &appExpr->get_inferParams() );
  }
}

Expression* 
Specialize::mutate(ApplicationExpr *appExpr)
{
  appExpr->get_function()->acceptMutator( *this );
  mutateAll( appExpr->get_args(), *this );
  
  // create thunks for the inferred parameters
  for( InferredParams::iterator inferParam = appExpr->get_inferParams().begin(); inferParam != appExpr->get_inferParams().end(); ++inferParam ) {
    inferParam->second.expr = doSpecialization( inferParam->second.formalType, inferParam->second.expr, &appExpr->get_inferParams() );
  }
  
  handleExplicitParams(appExpr);
  
  return appExpr;
}

Expression* 
Specialize::mutate(AddressExpr *addrExpr)
{
  addrExpr->get_arg()->acceptMutator( *this );
  addrExpr->set_arg( doSpecialization( addrExpr->get_results().front(), addrExpr->get_arg() ) );
  return addrExpr;
}

Expression* 
Specialize::mutate(CastExpr *castExpr)
{
  castExpr->get_arg()->acceptMutator( *this );
  castExpr->set_arg( doSpecialization( castExpr->get_results().front(), castExpr->get_arg() ) );
  return castExpr;
}

Expression* 
Specialize::mutate(LogicalExpr *logicalExpr)
{
  return logicalExpr;
}

Expression* 
Specialize::mutate(ConditionalExpr *condExpr)
{
  return condExpr;
}

Expression* 
Specialize::mutate(CommaExpr *commaExpr)
{
  return commaExpr;
}

} // namespace GenPoly
