//
// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
//
// The contents of this file are covered under the licence agreement in the
// file "LICENCE" distributed with Cforall.
//
// GenInit.cc --
//
// Author           : Rob Schluntz
// Created On       : Mon May 18 07:44:20 2015
// Last Modified By : Rob Schluntz
// Last Modified On : Fri May 13 11:37:48 2016
// Update Count     : 166
//

#include <stack>
#include <list>
#include "GenInit.h"
#include "InitTweak.h"
#include "SynTree/Declaration.h"
#include "SynTree/Type.h"
#include "SynTree/Expression.h"
#include "SynTree/Statement.h"
#include "SynTree/Initializer.h"
#include "SynTree/Mutator.h"
#include "SymTab/Autogen.h"
#include "SymTab/Mangler.h"
#include "GenPoly/PolyMutator.h"
#include "GenPoly/DeclMutator.h"
#include "GenPoly/ScopedSet.h"
#include "ResolvExpr/typeops.h"

namespace InitTweak {
	namespace {
		const std::list<Label> noLabels;
		const std::list<Expression *> noDesignators;
	}

	class ReturnFixer final : public GenPoly::PolyMutator {
	  public:
		/// consistently allocates a temporary variable for the return value
		/// of a function so that anything which the resolver decides can be constructed
		/// into the return type of a function can be returned.
		static void makeReturnTemp( std::list< Declaration * > &translationUnit );

		ReturnFixer();

		using GenPoly::PolyMutator::mutate;
		virtual DeclarationWithType * mutate( FunctionDecl *functionDecl ) override;
		virtual Statement * mutate( ReturnStmt * returnStmt ) override;

	  protected:
		FunctionType * ftype;
		UniqueName tempNamer;
		std::string funcName;
	};

	class CtorDtor final : public GenPoly::PolyMutator {
	  public:
		typedef GenPoly::PolyMutator Parent;
		using Parent::mutate;
		/// create constructor and destructor statements for object declarations.
		/// the actual call statements will be added in after the resolver has run
		/// so that the initializer expression is only removed if a constructor is found
		/// and the same destructor call is inserted in all of the appropriate locations.
		static void generateCtorDtor( std::list< Declaration * > &translationUnit );

		virtual DeclarationWithType * mutate( ObjectDecl * ) override;
		virtual DeclarationWithType * mutate( FunctionDecl *functionDecl ) override;
		// should not traverse into any of these declarations to find objects
		// that need to be constructed or destructed
		virtual Declaration* mutate( StructDecl *aggregateDecl ) override;
		virtual Declaration* mutate( UnionDecl *aggregateDecl ) override { return aggregateDecl; }
		virtual Declaration* mutate( EnumDecl *aggregateDecl ) override { return aggregateDecl; }
		virtual Declaration* mutate( TraitDecl *aggregateDecl ) override { return aggregateDecl; }
		virtual TypeDecl* mutate( TypeDecl *typeDecl ) override { return typeDecl; }
		virtual Declaration* mutate( TypedefDecl *typeDecl ) override { return typeDecl; }

		virtual Type * mutate( FunctionType *funcType ) override { return funcType; }

		virtual CompoundStmt * mutate( CompoundStmt * compoundStmt ) override;

	  private:
		// set of mangled type names for which a constructor or destructor exists in the current scope.
		// these types require a ConstructorInit node to be generated, anything else is a POD type and thus
		// should not have a ConstructorInit generated.

		bool isManaged( ObjectDecl * objDecl ) const ; // determine if object is managed
		bool isManaged( Type * type ) const; // determine if type is managed
		void handleDWT( DeclarationWithType * dwt ); // add type to managed if ctor/dtor
		GenPoly::ScopedSet< std::string > managedTypes;
		bool inFunction = false;
	};

	class HoistArrayDimension final : public GenPoly::DeclMutator {
	  public:
		typedef GenPoly::DeclMutator Parent;

		/// hoist dimension from array types in object declaration so that it uses a single
		/// const variable of type size_t, so that side effecting array dimensions are only
		/// computed once.
		static void hoistArrayDimension( std::list< Declaration * > & translationUnit );

	  private:
		using Parent::mutate;

		virtual DeclarationWithType * mutate( ObjectDecl * objectDecl ) override;
		virtual DeclarationWithType * mutate( FunctionDecl *functionDecl ) override;
		// should not traverse into any of these declarations to find objects
		// that need to be constructed or destructed
		virtual Declaration* mutate( StructDecl *aggregateDecl ) override { return aggregateDecl; }
		virtual Declaration* mutate( UnionDecl *aggregateDecl ) override { return aggregateDecl; }
		virtual Declaration* mutate( EnumDecl *aggregateDecl ) override { return aggregateDecl; }
		virtual Declaration* mutate( TraitDecl *aggregateDecl ) override { return aggregateDecl; }
		virtual TypeDecl* mutate( TypeDecl *typeDecl ) override { return typeDecl; }
		virtual Declaration* mutate( TypedefDecl *typeDecl ) override { return typeDecl; }

		virtual Type* mutate( FunctionType *funcType ) override { return funcType; }

		void hoist( Type * type );

		DeclarationNode::StorageClass storageclass = DeclarationNode::NoStorageClass;
		bool inFunction = false;
	};

	void genInit( std::list< Declaration * > & translationUnit ) {
		ReturnFixer::makeReturnTemp( translationUnit );
		HoistArrayDimension::hoistArrayDimension( translationUnit );
		CtorDtor::generateCtorDtor( translationUnit );
	}

	void ReturnFixer::makeReturnTemp( std::list< Declaration * > & translationUnit ) {
		ReturnFixer fixer;
		mutateAll( translationUnit, fixer );
	}

	ReturnFixer::ReturnFixer() : tempNamer( "_retVal" ) {}

	Statement *ReturnFixer::mutate( ReturnStmt *returnStmt ) {
		std::list< DeclarationWithType * > & returnVals = ftype->get_returnVals();
		assert( returnVals.size() == 0 || returnVals.size() == 1 );
		// hands off if the function returns an lvalue - we don't want to allocate a temporary if a variable's address
		// is being returned
		if ( returnStmt->get_expr() && returnVals.size() == 1 && funcName != "?=?" && ! returnVals.front()->get_type()->get_isLvalue() ) {
			// ensure return value is not destructed by explicitly creating
			// an empty SingleInit node wherein maybeConstruct is false
			ObjectDecl *newObj = new ObjectDecl( tempNamer.newName(), DeclarationNode::NoStorageClass, LinkageSpec::C, 0, returnVals.front()->get_type()->clone(), new ListInit( std::list<Initializer*>(), noDesignators, false ) );
			stmtsToAdd.push_back( new DeclStmt( noLabels, newObj ) );

			// and explicitly create the constructor expression separately
			UntypedExpr *construct = new UntypedExpr( new NameExpr( "?{}" ) );
			construct->get_args().push_back( new AddressExpr( new VariableExpr( newObj ) ) );
			construct->get_args().push_back( returnStmt->get_expr() );
			stmtsToAdd.push_back(new ExprStmt(noLabels, construct));

			returnStmt->set_expr( new VariableExpr( newObj ) );
		} // if
		return returnStmt;
	}

	DeclarationWithType* ReturnFixer::mutate( FunctionDecl *functionDecl ) {
		// xxx - need to handle named return values - this pass may need to happen
		// after resolution? the ordering is tricky because return statements must be
		// constructed - the simplest way to do that (while also handling multiple
		// returns) is to structure the returnVals into a tuple, as done here.
		// however, if the tuple return value is structured before resolution,
		// it's difficult to resolve named return values, since the name is lost
		// in conversion to a tuple. this might be easiest to deal with
		// after reference types are added, as it may then be possible to
		// uniformly move named return values to the parameter list directly
		ValueGuard< FunctionType * > oldFtype( ftype );
		ValueGuard< std::string > oldFuncName( funcName );

		ftype = functionDecl->get_functionType();
		std::list< DeclarationWithType * > & retVals = ftype->get_returnVals();
		if ( retVals.size() > 1 ) {
			TupleType * tupleType = safe_dynamic_cast< TupleType * >( ResolvExpr::extractResultType( ftype ) );
			ObjectDecl * newRet = new ObjectDecl( tempNamer.newName(), DeclarationNode::NoStorageClass, LinkageSpec::C, 0, tupleType, new ListInit( std::list<Initializer*>(), noDesignators, false ) );
			retVals.clear();
			retVals.push_back( newRet );
		}
		funcName = functionDecl->get_name();
		DeclarationWithType * decl = Mutator::mutate( functionDecl );
		return decl;
	}

	// precompute array dimension expression, because constructor generation may duplicate it,
	// which would be incorrect if it is a side-effecting computation.
	void HoistArrayDimension::hoistArrayDimension( std::list< Declaration * > & translationUnit ) {
		HoistArrayDimension hoister;
		hoister.mutateDeclarationList( translationUnit );
	}

	DeclarationWithType * HoistArrayDimension::mutate( ObjectDecl * objectDecl ) {
		storageclass = objectDecl->get_storageClass();
		DeclarationWithType * temp = Parent::mutate( objectDecl );
		hoist( objectDecl->get_type() );
		storageclass = DeclarationNode::NoStorageClass;
		return temp;
	}

	void HoistArrayDimension::hoist( Type * type ) {
		// if in function, generate const size_t var
		static UniqueName dimensionName( "_array_dim" );

		// C doesn't allow variable sized arrays at global scope or for static variables,
		// so don't hoist dimension.
		if ( ! inFunction ) return;
		if ( storageclass == DeclarationNode::Static ) return;

		if ( ArrayType * arrayType = dynamic_cast< ArrayType * >( type ) ) {
			if ( ! arrayType->get_dimension() ) return; // xxx - recursive call to hoist?

			// don't need to hoist dimension if it's a constexpr - only need to if there's potential
			// for side effects.
			if ( isConstExpr( arrayType->get_dimension() ) ) return;

			ObjectDecl * arrayDimension = new ObjectDecl( dimensionName.newName(), storageclass, LinkageSpec::C, 0, SymTab::SizeType->clone(), new SingleInit( arrayType->get_dimension() ) );
			arrayDimension->get_type()->set_isConst( true );

			arrayType->set_dimension( new VariableExpr( arrayDimension ) );
			addDeclaration( arrayDimension );

			hoist( arrayType->get_base() );
			return;
		}
	}

	DeclarationWithType * HoistArrayDimension::mutate( FunctionDecl *functionDecl ) {
		ValueGuard< bool > oldInFunc( inFunction );
		inFunction = true;
		DeclarationWithType * decl = Parent::mutate( functionDecl );
		return decl;
	}

	void CtorDtor::generateCtorDtor( std::list< Declaration * > & translationUnit ) {
		CtorDtor ctordtor;
		mutateAll( translationUnit, ctordtor );
	}

	bool CtorDtor::isManaged( Type * type ) const {
		if ( TupleType * tupleType = dynamic_cast< TupleType * > ( type ) ) {
			// tuple is also managed if any of its components are managed
			if ( std::any_of( tupleType->get_types().begin(), tupleType->get_types().end(), [&](Type * type) { return isManaged( type ); }) ) {
				return true;
			}
		}
		return managedTypes.find( SymTab::Mangler::mangle( type ) ) != managedTypes.end();
	}

	bool CtorDtor::isManaged( ObjectDecl * objDecl ) const {
		Type * type = objDecl->get_type();
		while ( ArrayType * at = dynamic_cast< ArrayType * >( type ) ) {
			type = at->get_base();
		}
		return isManaged( type );
	}

	void CtorDtor::handleDWT( DeclarationWithType * dwt ) {
		// if this function is a user-defined constructor or destructor, mark down the type as "managed"
		if ( ! LinkageSpec::isOverridable( dwt->get_linkage() ) && isCtorDtor( dwt->get_name() ) ) {
			std::list< DeclarationWithType * > & params = GenPoly::getFunctionType( dwt->get_type() )->get_parameters();
			assert( ! params.empty() );
			PointerType * type = safe_dynamic_cast< PointerType * >( params.front()->get_type() );
			managedTypes.insert( SymTab::Mangler::mangle( type->get_base() ) );
		}
	}

	ConstructorInit * genCtorInit( ObjectDecl * objDecl ) {
		// call into genImplicitCall from Autogen.h to generate calls to ctor/dtor
		// for each constructable object
		std::list< Statement * > ctor;
		std::list< Statement * > dtor;

		InitExpander srcParam( objDecl->get_init() );
		InitExpander nullParam( (Initializer *)NULL );
		SymTab::genImplicitCall( srcParam, new VariableExpr( objDecl ), "?{}", back_inserter( ctor ), objDecl );
		SymTab::genImplicitCall( nullParam, new VariableExpr( objDecl ), "^?{}", front_inserter( dtor ), objDecl, false );

		// Currently genImplicitCall produces a single Statement - a CompoundStmt
		// which  wraps everything that needs to happen. As such, it's technically
		// possible to use a Statement ** in the above calls, but this is inherently
		// unsafe, so instead we take the slightly less efficient route, but will be
		// immediately informed if somehow the above assumption is broken. In this case,
		// we could always wrap the list of statements at this point with a CompoundStmt,
		// but it seems reasonable at the moment for this to be done by genImplicitCall
		// itself. It is possible that genImplicitCall produces no statements (e.g. if
		// an array type does not have a dimension). In this case, it's fine to ignore
		// the object for the purposes of construction.
		assert( ctor.size() == dtor.size() && ctor.size() <= 1 );
		if ( ctor.size() == 1 ) {
			// need to remember init expression, in case no ctors exist
			// if ctor does exist, want to use ctor expression instead of init
			// push this decision to the resolver
			assert( dynamic_cast< ImplicitCtorDtorStmt * > ( ctor.front() ) && dynamic_cast< ImplicitCtorDtorStmt * > ( dtor.front() ) );
			return new ConstructorInit( ctor.front(), dtor.front(), objDecl->get_init() );
		}
		return nullptr;
	}

	DeclarationWithType * CtorDtor::mutate( ObjectDecl * objDecl ) {
		handleDWT( objDecl );
		// hands off if @=, extern, builtin, etc.
		// if global but initializer is not constexpr, always try to construct, since this is not legal C
		if ( ( tryConstruct( objDecl ) && isManaged( objDecl ) ) || (! inFunction && ! isConstExpr( objDecl->get_init() ) ) ) {
			// constructed objects cannot be designated
			if ( isDesignated( objDecl->get_init() ) ) throw SemanticError( "Cannot include designations in the initializer for a managed Object. If this is really what you want, then initialize with @=.", objDecl );
			// constructed objects should not have initializers nested too deeply
			if ( ! checkInitDepth( objDecl ) ) throw SemanticError( "Managed object's initializer is too deep ", objDecl );

			objDecl->set_init( genCtorInit( objDecl ) );
		}
		return Parent::mutate( objDecl );
	}

	DeclarationWithType * CtorDtor::mutate( FunctionDecl *functionDecl ) {
		ValueGuard< bool > oldInFunc = inFunction;
		inFunction = true;

		handleDWT( functionDecl );

		managedTypes.beginScope();
		// go through assertions and recursively add seen ctor/dtors
		for ( auto & tyDecl : functionDecl->get_functionType()->get_forall() ) {
			for ( DeclarationWithType *& assertion : tyDecl->get_assertions() ) {
				assertion = assertion->acceptMutator( *this );
			}
		}
		// parameters should not be constructed and destructed, so don't mutate FunctionType
		mutateAll( functionDecl->get_oldDecls(), *this );
		functionDecl->set_statements( maybeMutate( functionDecl->get_statements(), *this ) );

		managedTypes.endScope();
		return functionDecl;
	}

	Declaration* CtorDtor::mutate( StructDecl *aggregateDecl ) {
		// don't construct members, but need to take note if there is a managed member,
		// because that means that this type is also managed
		for ( Declaration * member : aggregateDecl->get_members() ) {
			if ( ObjectDecl * field = dynamic_cast< ObjectDecl * >( member ) ) {
				if ( isManaged( field ) ) {
					managedTypes.insert( SymTab::Mangler::mangle( aggregateDecl ) );
					break;
				}
			}
		}
		return aggregateDecl;
	}

	CompoundStmt * CtorDtor::mutate( CompoundStmt * compoundStmt ) {
		managedTypes.beginScope();
		CompoundStmt * stmt = Parent::mutate( compoundStmt );
		managedTypes.endScope();
		return stmt;
	}

} // namespace InitTweak

// Local Variables: //
// tab-width: 4 //
// mode: c++ //
// compile-command: "make install" //
// End: //
