/*
 * This file is part of the Cforall project
 *
 * $Id: FindFunction.cc,v 1.5 2005/08/29 20:14:13 rcbilson Exp $
 *
 */

#include "FindFunction.h"
#include "SynTree/Type.h"
#include "SynTree/Declaration.h"
#include "SynTree/Visitor.h"

namespace GenPoly {

class FindFunction : public Mutator
{
public:
  FindFunction( std::list< FunctionType* > &functions, const TyVarMap &tyVars, bool replaceMode, FindFunctionPredicate predicate );
  
  virtual Type *mutate( FunctionType *functionType );
  virtual Type *mutate( PointerType *pointerType );
  
private:
  void handleForall( const std::list< TypeDecl* > &forall );

  std::list< FunctionType* > &functions;
  TyVarMap tyVars;
  bool replaceMode;
  FindFunctionPredicate predicate;
};

void
findFunction( Type *type, std::list< FunctionType* > &functions, const TyVarMap &tyVars, FindFunctionPredicate predicate )
{
  FindFunction finder( functions, tyVars, false, predicate );
  type->acceptMutator( finder );
}

void
findAndReplaceFunction( Type *&type, std::list< FunctionType* > &functions, const TyVarMap &tyVars, FindFunctionPredicate predicate )
{
  FindFunction finder( functions, tyVars, true, predicate );
  type = type->acceptMutator( finder );
}

FindFunction::FindFunction( std::list< FunctionType* > &functions, const TyVarMap &tyVars, bool replaceMode, FindFunctionPredicate predicate )
  : functions( functions ), tyVars( tyVars ), replaceMode( replaceMode ), predicate( predicate )
{
}

void
FindFunction::handleForall( const std::list< TypeDecl* > &forall )
{
  for( std::list< TypeDecl* >::const_iterator i = forall.begin(); i != forall.end(); ++i ) {
    TyVarMap::iterator var = tyVars.find( (*i)->get_name() );
    if( var != tyVars.end() ) {
      tyVars.erase( var );
    }
  }
}

Type* 
FindFunction::mutate( FunctionType *functionType )
{
  TyVarMap oldTyVars = tyVars;
  handleForall( functionType->get_forall() );
  mutateAll( functionType->get_returnVals(), *this );
  Type *ret = functionType;
  if( predicate( functionType, tyVars ) ) {
    functions.push_back( functionType );
    if( replaceMode ) {
      ret = new FunctionType( Type::Qualifiers(), true );
    }
  }
  tyVars = oldTyVars;
  return ret;
}

Type *
FindFunction::mutate( PointerType *pointerType )
{
  TyVarMap oldTyVars = tyVars;
  handleForall( pointerType->get_forall() );
  Type *ret = Mutator::mutate( pointerType );
  tyVars = oldTyVars;
  return ret;
}

} // namespace GenPoly
