//
// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
//
// The contents of this file are covered under the licence agreement in the
// file "LICENCE" distributed with Cforall.
//
// Lvalue.cc --
//
// Author           : Richard C. Bilson
// Created On       : Mon May 18 07:44:20 2015
// Last Modified By : Peter A. Buhr
// Last Modified On : Tue Dec 15 15:33:13 2015
// Update Count     : 3
//

#include <cassert>

#include "Lvalue.h"

#include "GenPoly.h"

#include "SynTree/Declaration.h"
#include "SynTree/Type.h"
#include "SynTree/Expression.h"
#include "SynTree/Statement.h"
#include "SynTree/Visitor.h"
#include "SynTree/Mutator.h"
#include "SymTab/Indexer.h"
#include "ResolvExpr/Resolver.h"
#include "ResolvExpr/typeops.h"

#include "Common/UniqueName.h"
#include "Common/utility.h"

namespace GenPoly {
	namespace {
		const std::list<Label> noLabels;

		/// Replace uses of lvalue returns with appropriate pointers
		class Pass1 : public Mutator {
		  public:
			Pass1();

			virtual Expression *mutate( ApplicationExpr *appExpr );
			virtual Statement *mutate( ReturnStmt *appExpr );
			virtual DeclarationWithType *mutate( FunctionDecl *funDecl );
		  private:
			DeclarationWithType* retval;
		};

		/// Replace declarations of lvalue returns with appropriate pointers
		class Pass2 : public Visitor {
		  public:
			virtual void visit( FunctionType *funType );
		  private:
		};

		/// GCC-like Generalized Lvalues (which have since been removed from GCC)
		/// https://gcc.gnu.org/onlinedocs/gcc-3.4.6/gcc/Lvalues.html#Lvalues
		/// Replaces &(a,b) with (a, &b), &(a ? b : c) with (a ? &b : &c)
		class GeneralizedLvalue : public Mutator {
			typedef Mutator Parent;

			virtual Expression * mutate( AddressExpr * addressExpr );
		};
	} // namespace

	void convertLvalue( std::list< Declaration* >& translationUnit ) {
		Pass1 p1;
		Pass2 p2;
		GeneralizedLvalue genLval;
		mutateAll( translationUnit, p1 );
		acceptAll( translationUnit, p2 );
		mutateAll( translationUnit, genLval );
	}

	namespace {
		Type* isLvalueRet( FunctionType *function ) {
			if ( function->get_returnVals().empty() ) return 0;
			Type *ty = function->get_returnVals().front()->get_type();
			return ty->get_isLvalue() ? ty : 0;
		}

		bool isIntrinsicApp( ApplicationExpr *appExpr ) {
			if ( VariableExpr *varExpr = dynamic_cast< VariableExpr* >( appExpr->get_function() ) ) {
				return varExpr->get_var()->get_linkage() == LinkageSpec::Intrinsic;
			} else {
				return false;
			} // if
		}

		Pass1::Pass1() {
		}

		DeclarationWithType * Pass1::mutate( FunctionDecl *funcDecl ) {
			if ( funcDecl->get_statements() ) {
				DeclarationWithType* oldRetval = retval;
				retval = 0;
				if ( ! LinkageSpec::isBuiltin( funcDecl->get_linkage() ) && isLvalueRet( funcDecl->get_functionType() ) ) {
					retval = funcDecl->get_functionType()->get_returnVals().front();
				}
				// fix expressions and return statements in this function
				funcDecl->set_statements( funcDecl->get_statements()->acceptMutator( *this ) );
				retval = oldRetval;
			} // if
			return funcDecl;
		}

		Expression * Pass1::mutate( ApplicationExpr *appExpr ) {
			appExpr->get_function()->acceptMutator( *this );
			mutateAll( appExpr->get_args(), *this );

			PointerType *pointer = safe_dynamic_cast< PointerType* >( appExpr->get_function()->get_result() );
			FunctionType *function = safe_dynamic_cast< FunctionType* >( pointer->get_base() );

			Type *funType = isLvalueRet( function );
			if ( funType && ! isIntrinsicApp( appExpr ) ) {
				Expression *expr = appExpr;
				Type *appType = appExpr->get_result();
				if ( isPolyType( funType ) && ! isPolyType( appType ) ) {
					// make sure cast for polymorphic type is inside dereference
					expr = new CastExpr( appExpr, new PointerType( Type::Qualifiers(), appType->clone() ) );
				}
				UntypedExpr *deref = new UntypedExpr( new NameExpr( "*?" ) );
				deref->set_result( appType->clone() );
				appExpr->set_result( new PointerType( Type::Qualifiers(), appType ) );
				deref->get_args().push_back( expr );
				return deref;
			} else {
				return appExpr;
			} // if
		}

		Statement * Pass1::mutate(ReturnStmt *retStmt) {
			if ( retval && retStmt->get_expr() ) {
				if ( retStmt->get_expr()->get_result()->get_isLvalue() ) {
					// ***** Code Removal ***** because casts may be stripped already

					// strip casts because not allowed to take address of cast
					// while ( CastExpr *castExpr = dynamic_cast< CastExpr* >( retStmt->get_expr() ) ) {
					// 	retStmt->set_expr( castExpr->get_arg() );
					// 	retStmt->get_expr()->set_env( castExpr->get_env() );
					// 	castExpr->set_env( 0 );
					// 	castExpr->set_arg( 0 );
					// 	delete castExpr;
					// } // while
					retStmt->set_expr( new AddressExpr( retStmt->get_expr()->acceptMutator( *this ) ) );
				} else {
					throw SemanticError( "Attempt to return non-lvalue from an lvalue-qualified function" );
				} // if
			} // if
			return retStmt;
		}

		void Pass2::visit( FunctionType *funType ) {
			std::string typeName;
			if ( isLvalueRet( funType ) ) {
				DeclarationWithType *retParm = funType->get_returnVals().front();

				// make a new parameter that is a pointer to the type of the old return value
				retParm->set_type( new PointerType( Type::Qualifiers(), retParm->get_type() ) );
			} // if

			Visitor::visit( funType );
		}

		Expression * GeneralizedLvalue::mutate( AddressExpr * addrExpr ) {
			addrExpr = safe_dynamic_cast< AddressExpr * >( Parent::mutate( addrExpr ) );
			if ( CommaExpr * commaExpr = dynamic_cast< CommaExpr * >( addrExpr->get_arg() ) ) {
				Expression * arg1 = commaExpr->get_arg1()->clone();
				Expression * arg2 = commaExpr->get_arg2()->clone();
				delete addrExpr;
				return new CommaExpr( arg1, new AddressExpr( arg2 ) );
			} else if ( ConditionalExpr * condExpr = dynamic_cast< ConditionalExpr * >( addrExpr->get_arg() ) ) {
				Expression * arg1 = condExpr->get_arg1()->clone();
				Expression * arg2 = condExpr->get_arg2()->clone();
				Expression * arg3 = condExpr->get_arg3()->clone();
				delete addrExpr;
				return new ConditionalExpr( arg1, new AddressExpr( arg2 ), new AddressExpr( arg3 ) );
			}
			return addrExpr;
		}
	} // namespace
} // namespace GenPoly

// Local Variables: //
// tab-width: 4 //
// mode: c++ //
// compile-command: "make install" //
// End: //
