#include "InitTweak.h"
#include "SynTree/Visitor.h"
#include "SynTree/Statement.h"
#include "SynTree/Initializer.h"
#include "SynTree/Expression.h"
#include "GenPoly/GenPoly.h"

namespace InitTweak {
	namespace {
		class HasDesignations : public Visitor {
		public:
			bool hasDesignations = false;
			template<typename Init>
			void handleInit( Init * init ) {
				if ( ! init->get_designators().empty() ) hasDesignations = true;
				else Visitor::visit( init );
			}
			virtual void visit( SingleInit * singleInit ) { handleInit( singleInit); }
			virtual void visit( ListInit * listInit ) { handleInit( listInit); }
		};

		class InitExpander_OLD : public Visitor {
			public:
			virtual void visit( SingleInit * singleInit );
			virtual void visit( ListInit * listInit );
			std::list< Expression * > argList;
		};

		void InitExpander_OLD::visit( SingleInit * singleInit ) {
			argList.push_back( singleInit->get_value()->clone() );
		}

		void InitExpander_OLD::visit( ListInit * listInit ) {
			// xxx - for now, assume no nested list inits
			std::list<Initializer*>::iterator it = listInit->begin_initializers();
			for ( ; it != listInit->end_initializers(); ++it ) {
				(*it)->accept( *this );
			}
		}
	}

	std::list< Expression * > makeInitList( Initializer * init ) {
		InitExpander_OLD expander;
		maybeAccept( init, expander );
		return expander.argList;
	}

	bool isDesignated( Initializer * init ) {
		HasDesignations finder;
		maybeAccept( init, finder );
		return finder.hasDesignations;
	}

	class InitExpander::ExpanderImpl {
	public:
		virtual std::list< Expression * > next( std::list< Expression * > & indices ) = 0;
	};

	class InitImpl : public InitExpander::ExpanderImpl {
	public:
		InitImpl( Initializer * init ) {
			if ( init ) inits.push_back( init );
		}

		virtual std::list< Expression * > next( std::list< Expression * > & indices ) {
			// this is wrong, but just a placeholder for now
			return ! inits.empty() ? makeInitList( inits.front() ) : std::list< Expression * >();
		}
	private:
		std::list< Initializer * > inits;
	};

	class ExprImpl : public InitExpander::ExpanderImpl {
	public:
		ExprImpl( Expression * expr ) : arg( expr ) {}

		virtual std::list< Expression * > next( std::list< Expression * > & indices ) {
			std::list< Expression * > ret;
			Expression * expr = maybeClone( arg );
			if ( expr ) {
				for ( std::list< Expression * >::reverse_iterator it = indices.rbegin(); it != indices.rend(); ++it ) {
					// go through indices and layer on subscript exprs ?[?]
					++it;
					UntypedExpr * subscriptExpr = new UntypedExpr( new NameExpr( "?[?]") );
					subscriptExpr->get_args().push_back( expr );
					subscriptExpr->get_args().push_back( (*it)->clone() );
					expr = subscriptExpr;
				}
				ret.push_back( expr );
			}
			return ret;
		}
	private:
		Expression * arg;
	};

	InitExpander::InitExpander( Initializer * init ) : expander( new InitImpl( init ) ) {}

	InitExpander::InitExpander( Expression * expr ) : expander( new ExprImpl( expr ) ) {}

	std::list< Expression * > InitExpander::operator*() {
		return cur;
	}

	InitExpander & InitExpander::operator++() {
		cur = expander->next( indices );
		return *this;
	}

	// use array indices list to build switch statement
	void InitExpander::addArrayIndex( Expression * index, Expression * dimension ) {
		indices.push_back( index );
		indices.push_back( dimension );
	}

	template< typename OutIterator >
	void build( UntypedExpr * callExpr, InitExpander::IndexList::iterator idx, InitExpander::IndexList::iterator end, OutIterator out ) {
		if ( idx == end ) return;
		Expression * index = *idx++;
		assert( idx != end );
		Expression * dimension = *idx++;

		// if ( idx == end ) {
		// 	// loop through list of expressions belonging to the current initializer
		// 	UntypedExpr * cond = new UntypedExpr( new NameExpr( "?<?") );
		// 	cond->get_args().push_back( index->clone() );
		// 	cond->get_args().push_back( dimension->clone() );

		// 	UntypedExpr * call = callExpr->clone();
		// 	std::list< Expression * > args = *++expander; // xxx - need a way to indentify the end of an init list
		// 	call->get_args().splice( args );

		// 	*out++ = new IfStmt( noLabels, cond, new ExprStmt( call ), NULL );

		// 	UntypedExpr * increment = new UntypedExpr( new NameExpr( "++?" ) );
		// 	increment->get_args().push_back( index->clone() );
		// 	*out++ = new ExprStmt( increment );
		// } else {
		// 	std::list< Statement * > branches;
		// 	for (...) { // loop over conditions?
		// 		std::list< Statement * > stmts;
		// 		build( idx, end, back_inserter( stmts ) );
		// 		CaseStmt * caseStmt = new CaseStmt( noLabels, condition, stmts );
		// 		branches.push_back( caseStmt );
		// 	}
		// 	*out++ = new SwitchStmt( noLabels, index->clone(), branches );
		// }
	}

	// generate switch statement, consuming all of expander's elements
	Statement * InitExpander::buildListInit( UntypedExpr * dst ) {
		std::list< Statement * > results;
		build( dst, indices.begin(), indices.end(), back_inserter( results ) );
		assert( results.size() <= 1 );
		return ! results.empty() ? results.front() : NULL;
	}

	bool tryConstruct( ObjectDecl * objDecl ) {
		return ! LinkageSpec::isBuiltin( objDecl->get_linkage() ) &&
			(objDecl->get_init() == NULL ||
				( objDecl->get_init() != NULL && objDecl->get_init()->get_maybeConstructed() )) &&
			! isDesignated( objDecl->get_init() )
			&& objDecl->get_storageClass() != DeclarationNode::Extern;
	}

	Expression * getCtorDtorCall( Statement * stmt ) {
		if ( stmt == NULL ) return NULL;
		if ( ExprStmt * exprStmt = dynamic_cast< ExprStmt * >( stmt ) ) {
			return exprStmt->get_expr();
		} else if ( CompoundStmt * compoundStmt = dynamic_cast< CompoundStmt * >( stmt ) ) {
			// could also be a compound statement with a loop, in the case of an array
			if( compoundStmt->get_kids().size() == 2 ) {
				// loop variable and loop
				ForStmt * forStmt = dynamic_cast< ForStmt * >( compoundStmt->get_kids().back() );
				assert( forStmt && forStmt->get_body() );
				return getCtorDtorCall( forStmt->get_body() );
			} else if ( compoundStmt->get_kids().size() == 1 ) {
				// should be the call statement, but in any case there's only one option
				return getCtorDtorCall( compoundStmt->get_kids().front() );
			} else {
				assert( false && "too many statements in compoundStmt for getCtorDtorCall" );
			}
		} if ( ImplicitCtorDtorStmt * impCtorDtorStmt = dynamic_cast< ImplicitCtorDtorStmt * > ( stmt ) ) {
			return getCtorDtorCall( impCtorDtorStmt->get_callStmt() );
		} else {
			// should never get here
			assert( false && "encountered unknown call statement" );
		}
	}
	namespace {
		VariableExpr * getCalledFunction( ApplicationExpr * appExpr ) {
			assert( appExpr );
			return dynamic_cast< VariableExpr * >( appExpr->get_function() );
		}
	}

	ApplicationExpr * isIntrinsicCallExpr( Expression * expr ) {
		ApplicationExpr * appExpr = dynamic_cast< ApplicationExpr * >( expr );
		if ( ! appExpr ) return NULL;
		VariableExpr * function = getCalledFunction( appExpr );
		assert( function );
		// check for Intrinsic only - don't want to remove all overridable ctor/dtors because autogenerated ctor/dtor
		// will call all member dtors, and some members may have a user defined dtor.
		return function->get_var()->get_linkage() == LinkageSpec::Intrinsic ? appExpr : NULL;
	}

	bool isInstrinsicSingleArgCallStmt( Statement * stmt ) {
		Expression * callExpr = getCtorDtorCall( stmt );
		if ( ! callExpr ) return false;
		if ( ApplicationExpr * appExpr = isIntrinsicCallExpr( callExpr ) ) {
			assert( ! appExpr->get_function()->get_results().empty() );
			FunctionType *funcType = GenPoly::getFunctionType( appExpr->get_function()->get_results().front() );
			assert( funcType );
			return funcType->get_parameters().size() == 1;
		}
		return false;
	}

	namespace {
		template<typename CallExpr>
		Expression *& callArg( CallExpr * callExpr, unsigned int pos ) {
			if ( pos >= callExpr->get_args().size() ) assert( false && "asking for argument that doesn't exist. Return NULL/throw exception?" );
			for ( Expression *& arg : callExpr->get_args() ) {
				if ( pos == 0 ) return arg;
				pos--;
			}
			assert( false );
		}
	}

	Expression *& getCallArg( Expression * callExpr, unsigned int pos ) {
		if ( ApplicationExpr * appExpr = dynamic_cast< ApplicationExpr * >( callExpr ) ) {
			return callArg( appExpr, pos );
		} else if ( UntypedExpr * untypedExpr = dynamic_cast< UntypedExpr * >( callExpr ) ) {
			return callArg( untypedExpr, pos );
		} else {
			assert( false && "Unexpected expression type passed to getCallArg" );
		}
	}

	namespace {
		std::string funcName( Expression * func ) {
			if ( NameExpr * nameExpr = dynamic_cast< NameExpr * >( func ) ) {
				return nameExpr->get_name();
			} else if ( VariableExpr * varExpr = dynamic_cast< VariableExpr * >( func ) ) {
				return varExpr->get_var()->get_name();
			}	else if ( CastExpr * castExpr = dynamic_cast< CastExpr * >( func ) ) {
				return funcName( castExpr->get_arg() );
			} else {
				assert( false && "Unexpected expression type being called as a function in call expression" );
			}
		}
	}

	std::string getFunctionName( Expression * expr ) {
		if ( ApplicationExpr * appExpr = dynamic_cast< ApplicationExpr * >( expr ) ) {
			return funcName( appExpr->get_function() );
		} else if ( UntypedExpr * untypedExpr = dynamic_cast< UntypedExpr * > ( expr ) ) {
			return funcName( untypedExpr->get_function() );
		} else {
			std::cerr << expr << std::endl;
			assert( false && "Unexpected expression type passed to getFunctionName" );
		}
	}

	Type * getPointerBase( Type * type ) {
		if ( PointerType * ptrType = dynamic_cast< PointerType * >( type ) ) {
			return ptrType->get_base();
		} else if ( ArrayType * arrayType = dynamic_cast< ArrayType * >( type ) ) {
			return arrayType->get_base();
		} else {
			return NULL;
		}
	}

	Type * isPointerType( Type * type ) {
		if ( getPointerBase( type ) ) return type;
		else return NULL;
	}

	class ConstExprChecker : public Visitor {
	public:
		ConstExprChecker() : isConstExpr( true ) {}

		virtual void visit( ApplicationExpr *applicationExpr ) { isConstExpr = false; }
		virtual void visit( UntypedExpr *untypedExpr ) { isConstExpr = false; }
		virtual void visit( NameExpr *nameExpr ) { isConstExpr = false; }
		virtual void visit( CastExpr *castExpr ) { isConstExpr = false; }
		virtual void visit( LabelAddressExpr *labAddressExpr ) { isConstExpr = false; }
		virtual void visit( UntypedMemberExpr *memberExpr ) { isConstExpr = false; }
		virtual void visit( MemberExpr *memberExpr ) { isConstExpr = false; }
		virtual void visit( VariableExpr *variableExpr ) { isConstExpr = false; }
		virtual void visit( ConstantExpr *constantExpr ) { /* bottom out */ }
		// these might be okay?
		// virtual void visit( SizeofExpr *sizeofExpr );
		// virtual void visit( AlignofExpr *alignofExpr );
		// virtual void visit( UntypedOffsetofExpr *offsetofExpr );
		// virtual void visit( OffsetofExpr *offsetofExpr );
		// virtual void visit( OffsetPackExpr *offsetPackExpr );
		// virtual void visit( AttrExpr *attrExpr );
		// virtual void visit( CommaExpr *commaExpr );
		// virtual void visit( LogicalExpr *logicalExpr );
		// virtual void visit( ConditionalExpr *conditionalExpr );
		virtual void visit( TupleExpr *tupleExpr ) { isConstExpr = false; }
		virtual void visit( SolvedTupleExpr *tupleExpr ) { isConstExpr = false; }
		virtual void visit( TypeExpr *typeExpr ) { isConstExpr = false; }
		virtual void visit( AsmExpr *asmExpr ) { isConstExpr = false; }
		virtual void visit( UntypedValofExpr *valofExpr ) { isConstExpr = false; }
		virtual void visit( CompoundLiteralExpr *compLitExpr ) { isConstExpr = false; }

		bool isConstExpr;
	};

	bool isConstExpr( Expression * expr ) {
		if ( expr ) {
			ConstExprChecker checker;
			expr->accept( checker );
			return checker.isConstExpr;
		}
		return true;
	}

	bool isConstExpr( Initializer * init ) {
		if ( init ) {
			ConstExprChecker checker;
			init->accept( checker );
			return checker.isConstExpr;
		} // if
		// for all intents and purposes, no initializer means const expr
		return true;
	}

}
