//
// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
//
// The contents of this file are covered under the licence agreement in the
// file "LICENCE" distributed with Cforall.
//
// FindFunction.cc --
//
// Author           : Richard C. Bilson
// Created On       : Mon May 18 07:44:20 2015
// Last Modified By : Rob Schluntz
// Last Modified On : Fri Feb 05 12:22:20 2016
// Update Count     : 6
//

#include "FindFunction.h"
#include "SynTree/Type.h"
#include "SynTree/Declaration.h"
#include "SynTree/Visitor.h"

#include "ScrubTyVars.h"

namespace GenPoly {
	class FindFunction : public Mutator {
	  public:
		FindFunction( std::list< FunctionType* > &functions, const TyVarMap &tyVars, bool replaceMode, FindFunctionPredicate predicate );

		virtual Type *mutate( FunctionType *functionType );
		virtual Type *mutate( PointerType *pointerType );
	  private:
		void handleForall( const std::list< TypeDecl* > &forall );

		std::list< FunctionType* > &functions;
		TyVarMap tyVars;
		bool replaceMode;
		FindFunctionPredicate predicate;
	};

	void findFunction( Type *type, std::list< FunctionType* > &functions, const TyVarMap &tyVars, FindFunctionPredicate predicate ) {
		FindFunction finder( functions, tyVars, false, predicate );
		type->acceptMutator( finder );
	}

	void findAndReplaceFunction( Type *&type, std::list< FunctionType* > &functions, const TyVarMap &tyVars, FindFunctionPredicate predicate ) {
		FindFunction finder( functions, tyVars, true, predicate );
		type = type->acceptMutator( finder );
	}

	FindFunction::FindFunction( std::list< FunctionType* > &functions, const TyVarMap &tyVars, bool replaceMode, FindFunctionPredicate predicate )
		: functions( functions ), tyVars( tyVars ), replaceMode( replaceMode ), predicate( predicate ) {
	}

	void FindFunction::handleForall( const std::list< TypeDecl* > &forall ) {
		for ( std::list< TypeDecl* >::const_iterator i = forall.begin(); i != forall.end(); ++i ) {
			TyVarMap::iterator var = tyVars.find( (*i)->get_name() );
			if ( var != tyVars.end() ) {
				tyVars.erase( var->first );
			} // if
		} // for
	}

	Type * FindFunction::mutate( FunctionType *functionType ) {
		tyVars.beginScope();
		handleForall( functionType->get_forall() );
		mutateAll( functionType->get_returnVals(), *this );
		Type *ret = functionType;
		if ( predicate( functionType, tyVars ) ) {
			functions.push_back( functionType );
			if ( replaceMode ) {
				// replace type parameters in function type with void*
				ret = ScrubTyVars::scrub( functionType->clone(), tyVars );
			} // if
		} // if
		tyVars.endScope();
		return ret;
	}

	Type * FindFunction::mutate( PointerType *pointerType ) {
		tyVars.beginScope();
		handleForall( pointerType->get_forall() );
		Type *ret = Mutator::mutate( pointerType );
		tyVars.endScope();
		return ret;
	}
} // namespace GenPoly

// Local Variables: //
// tab-width: 4 //
// mode: c++ //
// compile-command: "make install" //
// End: //
