//
// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
//
// The contents of this file are covered under the licence agreement in the
// file "LICENCE" distributed with Cforall.
//
// Expr.hpp --
//
// Author           : Aaron B. Moss
// Created On       : Fri May 10 10:30:00 2019
// Last Modified By : Aaron B. Moss
// Created On       : Fri May 10 10:30:00 2019
// Update Count     : 1
//

#pragma once

#include <cassert>
#include <map>
#include <utility>        // for move
#include <vector>

#include "Fwd.hpp"        // for UniqueId
#include "ParseNode.hpp"
#include "Visitor.hpp"

namespace ast {

/// Contains the ID of a declaration and a type that is derived from that declaration,
/// but subject to decay-to-pointer and type parameter renaming
struct ParamEntry {
	UniqueId decl;
	ptr<Type> actualType;
	ptr<Type> formalType;
	ptr<Expr> expr;

	ParamEntry() : decl( 0 ), actualType( nullptr ), formalType( nullptr ), expr( nullptr ) {}
	ParamEntry( UniqueId id, Type* actual, Type* formal, Expr* e )
	: decl( id ), actualType( actual ), formalType( formal ), expr( e ) {}
};

/// Pre-resolution list of parameters to infer
using ResnSlots = std::vector<UniqueId>;
/// Post-resolution map of inferred parameters
using InferredParams = std::map< UniqueId, ParamEntry >;

/// Base node for expressions
class Expr : public ParseNode {
public:
	/// Saves space (~16 bytes) by combining ResnSlots and InferredParams
	struct InferUnion {
		enum { Empty, Slots, Params } mode;
		union data_t {
			char def;
			ResnSlots resnSlots;
			InferredParams inferParams;

			data_t() : def('\0') {}
			~data_t() {}
		} data;

		/// initializes from other InferUnion
		void init_from( const InferUnion& o ) {
			switch ( o.mode ) {
			case Empty:  return;
			case Slots:  new(&data.resnSlots) ResnSlots{ o.data.resnSlots }; return;
			case Params: new(&data.inferParams) InferredParams{ o.data.inferParams }; return;
			}
		}

		/// initializes from other InferUnion (move semantics)
		void init_from( InferUnion&& o ) {
			switch ( o.mode ) {
			case Empty:  return;
			case Slots:  new(&data.resnSlots) ResnSlots{ std::move(o.data.resnSlots) }; return;
			case Params:
				new(&data.inferParams) InferredParams{ std::move(o.data.inferParams) }; return;
			}
		}

		/// clears variant fields
		void reset() {
			switch( mode ) {
			case Empty:  return;
			case Slots:  data.resnSlots.~ResnSlots(); return;
			case Params: data.inferParams.~InferredParams(); return;
			}
		}

		InferUnion() : mode(Empty), data() {}
		InferUnion( const InferUnion& o ) : mode( o.mode ), data() { init_from( o ); }
		InferUnion( InferUnion&& o ) : mode( o.mode ), data() { init_from( std::move(o) ); }
		InferUnion& operator= ( const InferUnion& ) = delete;
		InferUnion& operator= ( InferUnion&& ) = delete;
		~InferUnion() { reset(); }

		ResnSlots& resnSlots() {
			switch (mode) {
			case Empty: new(&data.resnSlots) ResnSlots{}; mode = Slots; // fallthrough
			case Slots: return data.resnSlots;
			case Params: assert(!"Cannot return to resnSlots from Params");
			}
		}

		InferredParams& inferParams() {
			switch (mode) {
			case Slots: data.resnSlots.~ResnSlots(); // fallthrough
			case Empty: new(&data.inferParams) InferredParams{}; mode = Params; // fallthrough
			case Params: return data.inferParams;
			}
		}
	};

	ptr<Type> result;
	ptr<TypeSubstitution> env;
	InferUnion inferred;
	bool extension = false;

	Expr(const CodeLocation & loc ) : ParseNode( loc ), result(), env(), inferred() {}

	Expr * set_extension( bool ex ) { extension = ex; return this; }

	virtual const Expr * accept( Visitor & v ) const override = 0;
private:
	Expr * clone() const override = 0;
};

/// A type used as an expression (e.g. a type generator parameter)
class TypeExpr final : public Expr {
public:
	ptr<Type> type;

	TypeExpr( const CodeLocation & loc, const Type * t ) : Expr(loc), type(t) {}

	const Expr * accept( Visitor & v ) const override { return v.visit( this ); }
private:
	TypeExpr * clone() const override { return new TypeExpr{ *this }; }
};


//=================================================================================================
/// This disgusting and giant piece of boiler-plate is here to solve a cyclic dependency
/// remove only if there is a better solution
/// The problem is that ast::ptr< ... > uses increment/decrement which won't work well with
/// forward declarations
inline void increment( const class Expr * node, Node::ref_type ref ) { node->increment(ref); }
inline void decrement( const class Expr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class ApplicationExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class ApplicationExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class UntypedExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class UntypedExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class NameExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class NameExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class AddressExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class AddressExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class LabelAddressExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class LabelAddressExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class CastExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class CastExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class KeywordCastExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class KeywordCastExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class VirtualCastExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class VirtualCastExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class MemberExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class MemberExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class UntypedMemberExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class UntypedMemberExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class VariableExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class VariableExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class ConstantExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class ConstantExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class SizeofExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class SizeofExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class AlignofExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class AlignofExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class UntypedOffsetofExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class UntypedOffsetofExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class OffsetofExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class OffsetofExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class OffsetPackExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class OffsetPackExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class AttrExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class AttrExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class LogicalExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class LogicalExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class ConditionalExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class ConditionalExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class CommaExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class CommaExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class TypeExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class TypeExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class AsmExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class AsmExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class ImplicitCopyCtorExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class ImplicitCopyCtorExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class ConstructorExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class ConstructorExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class CompoundLiteralExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class CompoundLiteralExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class UntypedValofExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class UntypedValofExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class RangeExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class RangeExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class UntypedTupleExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class UntypedTupleExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class TupleExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class TupleExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class TupleIndexExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class TupleIndexExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class TupleAssignExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class TupleAssignExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class StmtExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class StmtExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class UniqueExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class UniqueExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class UntypedInitExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class UntypedInitExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class InitExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class InitExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class DeletedExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class DeletedExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class DefaultArgExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class DefaultArgExpr * node, Node::ref_type ref ) { node->decrement(ref); }
// inline void increment( const class GenericExpr * node, Node::ref_type ref ) { node->increment(ref); }
// inline void decrement( const class GenericExpr * node, Node::ref_type ref ) { node->decrement(ref); }
}

// Local Variables: //
// tab-width: 4 //
// mode: c++ //
// compile-command: "make install" //
// End: //
