//
// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
//
// The contents of this file are covered under the licence agreement in the
// file "LICENCE" distributed with Cforall.
//
// Lvalue.cc --
//
// Author           : Richard C. Bilson
// Created On       : Mon May 18 07:44:20 2015
// Last Modified By : Peter A. Buhr
// Last Modified On : Fri Mar 17 09:11:18 2017
// Update Count     : 5
//

#include <cassert>                       // for safe_dynamic_cast
#include <string>                        // for string

#include "Common/SemanticError.h"        // for SemanticError
#include "GenPoly.h"                     // for isPolyType
#include "Lvalue.h"
#include "Parser/LinkageSpec.h"          // for Spec, isBuiltin, Intrinsic
#include "ResolvExpr/TypeEnvironment.h"  // for AssertionSet, OpenVarSet
#include "ResolvExpr/Unify.h"            // for unify
#include "SymTab/Indexer.h"              // for Indexer
#include "SynTree/Declaration.h"         // for Declaration, FunctionDecl
#include "SynTree/Expression.h"          // for Expression, ConditionalExpr
#include "SynTree/Mutator.h"             // for mutateAll, Mutator
#include "SynTree/Statement.h"           // for ReturnStmt, Statement (ptr o...
#include "SynTree/Type.h"                // for PointerType, Type, FunctionType
#include "SynTree/Visitor.h"             // for Visitor, acceptAll

namespace GenPoly {
	namespace {
		/// Replace uses of lvalue returns with appropriate pointers
		class Pass1 : public Mutator {
		  public:
			Pass1();

			virtual Expression *mutate( ApplicationExpr *appExpr );
			virtual Statement *mutate( ReturnStmt *appExpr );
			virtual DeclarationWithType *mutate( FunctionDecl *funDecl );
		  private:
			DeclarationWithType* retval;
		};

		/// Replace declarations of lvalue returns with appropriate pointers
		class Pass2 : public Visitor {
		  public:
			virtual void visit( FunctionType *funType );
		  private:
		};

		/// GCC-like Generalized Lvalues (which have since been removed from GCC)
		/// https://gcc.gnu.org/onlinedocs/gcc-3.4.6/gcc/Lvalues.html#Lvalues
		/// Replaces &(a,b) with (a, &b), &(a ? b : c) with (a ? &b : &c)
		class GeneralizedLvalue : public Mutator {
			typedef Mutator Parent;

			virtual Expression * mutate( MemberExpr * memExpr );
			virtual Expression * mutate( AddressExpr * addressExpr );

			template<typename Func>
			Expression * applyTransformation( Expression * expr, Expression * arg, Func mkExpr );
		};
	} // namespace

	void convertLvalue( std::list< Declaration* >& translationUnit ) {
		Pass1 p1;
		Pass2 p2;
		GeneralizedLvalue genLval;
		mutateAll( translationUnit, p1 );
		acceptAll( translationUnit, p2 );
		mutateAll( translationUnit, genLval );
	}

	Expression * generalizedLvalue( Expression * expr ) {
		GeneralizedLvalue genLval;
		return expr->acceptMutator( genLval );
	}

	namespace {
		Type* isLvalueRet( FunctionType *function ) {
			if ( function->get_returnVals().empty() ) return 0;
			Type *ty = function->get_returnVals().front()->get_type();
			return ty->get_lvalue() ? ty : 0;
		}

		bool isIntrinsicApp( ApplicationExpr *appExpr ) {
			if ( VariableExpr *varExpr = dynamic_cast< VariableExpr* >( appExpr->get_function() ) ) {
				return varExpr->get_var()->get_linkage() == LinkageSpec::Intrinsic;
			} else {
				return false;
			} // if
		}

		Pass1::Pass1() {
		}

		DeclarationWithType * Pass1::mutate( FunctionDecl *funcDecl ) {
			if ( funcDecl->get_statements() ) {
				DeclarationWithType* oldRetval = retval;
				retval = 0;
				if ( ! LinkageSpec::isBuiltin( funcDecl->get_linkage() ) && isLvalueRet( funcDecl->get_functionType() ) ) {
					retval = funcDecl->get_functionType()->get_returnVals().front();
				}
				// fix expressions and return statements in this function
				funcDecl->set_statements( funcDecl->get_statements()->acceptMutator( *this ) );
				retval = oldRetval;
			} // if
			return funcDecl;
		}

		Expression * Pass1::mutate( ApplicationExpr *appExpr ) {
			appExpr->get_function()->acceptMutator( *this );
			mutateAll( appExpr->get_args(), *this );

			PointerType *pointer = safe_dynamic_cast< PointerType* >( appExpr->get_function()->get_result() );
			FunctionType *function = safe_dynamic_cast< FunctionType* >( pointer->get_base() );

			Type *funType = isLvalueRet( function );
			if ( funType && ! isIntrinsicApp( appExpr ) ) {
				Expression *expr = appExpr;
				Type *appType = appExpr->get_result();
				if ( isPolyType( funType ) && ! isPolyType( appType ) ) {
					// make sure cast for polymorphic type is inside dereference
					expr = new CastExpr( appExpr, new PointerType( Type::Qualifiers(), appType->clone() ) );
				}
				UntypedExpr *deref = new UntypedExpr( new NameExpr( "*?" ) );
				deref->set_result( appType->clone() );
				appExpr->set_result( new PointerType( Type::Qualifiers(), appType ) );
				deref->get_args().push_back( expr );
				return deref;
			} else {
				return appExpr;
			} // if
		}

		Statement * Pass1::mutate(ReturnStmt *retStmt) {
			if ( retval && retStmt->get_expr() ) {
				if ( retStmt->get_expr()->get_result()->get_lvalue() ) {
					// ***** Code Removal ***** because casts may be stripped already

					// strip casts because not allowed to take address of cast
					// while ( CastExpr *castExpr = dynamic_cast< CastExpr* >( retStmt->get_expr() ) ) {
					// 	retStmt->set_expr( castExpr->get_arg() );
					// 	retStmt->get_expr()->set_env( castExpr->get_env() );
					// 	castExpr->set_env( 0 );
					// 	castExpr->set_arg( 0 );
					// 	delete castExpr;
					// } // while
					retStmt->set_expr( new AddressExpr( retStmt->get_expr()->acceptMutator( *this ) ) );
				} else {
					throw SemanticError( "Attempt to return non-lvalue from an lvalue-qualified function" );
				} // if
			} // if
			return retStmt;
		}

		void Pass2::visit( FunctionType *funType ) {
			std::string typeName;
			if ( isLvalueRet( funType ) ) {
				DeclarationWithType *retParm = funType->get_returnVals().front();

				// make a new parameter that is a pointer to the type of the old return value
				retParm->set_type( new PointerType( Type::Qualifiers(), retParm->get_type() ) );
			} // if

			Visitor::visit( funType );
		}

		template<typename Func>
		Expression * GeneralizedLvalue::applyTransformation( Expression * expr, Expression * arg, Func mkExpr ) {
			if ( CommaExpr * commaExpr = dynamic_cast< CommaExpr * >( arg ) ) {
				Expression * arg1 = commaExpr->get_arg1()->clone();
				Expression * arg2 = commaExpr->get_arg2()->clone();
				Expression * ret = new CommaExpr( arg1, mkExpr( arg2 ) );
				ret->set_env( expr->get_env() );
				expr->set_env( nullptr );
				delete expr;
				return ret->acceptMutator( *this );
			} else if ( ConditionalExpr * condExpr = dynamic_cast< ConditionalExpr * >( arg ) ) {
				Expression * arg1 = condExpr->get_arg1()->clone();
				Expression * arg2 = condExpr->get_arg2()->clone();
				Expression * arg3 = condExpr->get_arg3()->clone();
				ConditionalExpr * ret = new ConditionalExpr( arg1, mkExpr( arg2 ), mkExpr( arg3 ) );
				ret->set_env( expr->get_env() );
				expr->set_env( nullptr );
				delete expr;

				// conditional expr type may not be either of the argument types, need to unify
				using namespace ResolvExpr;
				Type* commonType = nullptr;
				TypeEnvironment newEnv;
				AssertionSet needAssertions, haveAssertions;
				OpenVarSet openVars;
				unify( ret->get_arg2()->get_result(), ret->get_arg3()->get_result(), newEnv, needAssertions, haveAssertions, openVars, SymTab::Indexer(), commonType );
				ret->set_result( commonType ? commonType : ret->get_arg2()->get_result()->clone() );
				return ret->acceptMutator( *this );
			}
			return expr;
		}

		Expression * GeneralizedLvalue::mutate( MemberExpr * memExpr ) {
			Parent::mutate( memExpr );
			return applyTransformation( memExpr, memExpr->get_aggregate(), [=]( Expression * aggr ) { return new MemberExpr( memExpr->get_member(), aggr ); } );
		}

		Expression * GeneralizedLvalue::mutate( AddressExpr * addrExpr ) {
			addrExpr = safe_dynamic_cast< AddressExpr * >( Parent::mutate( addrExpr ) );
			return applyTransformation( addrExpr, addrExpr->get_arg(), []( Expression * arg ) { return new AddressExpr( arg ); } );
		}
	} // namespace
} // namespace GenPoly

// Local Variables: //
// tab-width: 4 //
// mode: c++ //
// compile-command: "make install" //
// End: //
