//
// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
//
// The contents of this file are covered under the licence agreement in the
// file "LICENCE" distributed with Cforall.
//
// Stmt.hpp --
//
// Author           : Aaron B. Moss
// Created On       : Wed May  8 13:00:00 2019
// Last Modified By : Andrew Beach
// Last Modified On : Wed May 15 16:01:00 2019
// Update Count     : 2
//

#pragma once

#include <list>
#include <utility>                // for move
#include <vector>

#include "Label.hpp"
#include "Node.hpp"               // for node, ptr
#include "ParseNode.hpp"
#include "Visitor.hpp"
#include "Common/CodeLocation.h"

namespace ast {

class Expr;

/// Base statement node
class Stmt : public ParseNode {
public:
	std::vector<Label> labels;

	Stmt( const CodeLocation& loc, std::vector<Label>&& labels = {} )
	: ParseNode(loc), labels(std::move(labels)) {}

	Stmt(const Stmt& o) : ParseNode(o), labels(o.labels) {}

	const Stmt* accept( Visitor& v ) const override = 0;
private:
	Stmt* clone() const override = 0;
};

/// Compound statement `{ ... }`
class CompoundStmt final : public Stmt {
public:
	std::list<ptr<Stmt>> kids;

	CompoundStmt(const CodeLocation& loc, std::list<ptr<Stmt>>&& ks = {} )
	: Stmt(loc), kids(std::move(ks)) {}

	CompoundStmt( const CompoundStmt& o );
	CompoundStmt( CompoundStmt&& o ) = default;

	void push_back( Stmt* s ) { kids.emplace_back( s ); }
	void push_front( Stmt* s ) { kids.emplace_front( s ); }

	const CompoundStmt* accept( Visitor& v ) const override { return v.visit( this ); }
private:
	CompoundStmt* clone() const override { return new CompoundStmt{ *this }; }
};

/// Empty statment `;`
class NullStmt final : public Stmt {
public:
	NullStmt( const CodeLocation& loc, std::vector<Label>&& labels = {} )
	: Stmt(loc, std::move(labels)) {}

	const NullStmt* accept( Visitor& v ) const override { return v.visit( this ); }
private:
	NullStmt* clone() const override { return new NullStmt{ *this }; }
};

/// Expression wrapped by statement
class ExprStmt final : public Stmt {
public:
	ptr<Expr> expr;

	ExprStmt( const CodeLocation& loc, const Expr* e ) : Stmt(loc), expr(e) {}

	const Stmt * accept( Visitor& v ) const override { return v.visit( this ); }
private:
	ExprStmt * clone() const override { return new ExprStmt{ *this }; }
};

class AsmStmt final : public Stmt {
public:
	bool isVolatile;
	ptr<Expr> instruction;
	std::vector<ptr<Expr>> output, input;
	std::vector<ptr<ConstantExpr>> clobber;
	std::vector<Label> gotoLabels;

	AsmStmt( const CodeLocation& loc, bool isVolatile, const Expr * instruction,
		std::vector<ptr<Expr>>&& output, std::vector<ptr<Expr>>&& input,
		std::vector<ptr<ConstantExpr>>&& clobber, std::vector<Label>&& gotoLabels,
		std::vector<Label>&& labels = {})
	: Stmt(loc, std::move(labels)), isVolatile(isVolatile), instruction(instruction),
	  output(std::move(output)), input(std::move(input)), clobber(std::move(clobber)),
	  gotoLabels(std::move(gotoLabels)) {}

	const Stmt* accept( Visitor& v ) const override { return v.visit( this ); }
private:
	AsmStmt* clone() const override { return new AsmStmt{ *this }; }
};

class DirectiveStmt final : public Stmt {
public:
	std::string directive;

	DirectiveStmt( const CodeLocation& loc, const std::string & directive,
		std::vector<Label>&& labels = {} )
	: Stmt(loc, std::move(labels)), directive(directive) {}

	const Stmt* accept( Visitor& v ) const override { return v.visit( this ); }
private:
	DirectiveStmt* clone() const override { return new DirectiveStmt{ *this }; }
};

class IfStmt final : public Stmt {
public:
	ptr<Expr> cond;
	ptr<Stmt> thenPart;
	ptr<Stmt> elsePart;
	std::vector<ptr<Stmt>> inits;

	IfStmt( const CodeLocation& loc, const Expr* cond, const Stmt* thenPart,
		Stmt * const elsePart, std::vector<ptr<Stmt>>&& inits,
		std::vector<Label>&& labels = {} )
	: Stmt(loc, std::move(labels)), cond(cond), thenPart(thenPart), elsePart(elsePart),
	  inits(std::move(inits)) {}

	const Stmt* accept( Visitor& v ) const override { return v.visit( this ); }
private:
	IfStmt* clone() const override { return new IfStmt{ *this }; }
};

class SwitchStmt final : public Stmt {
public:
	ptr<Expr> cond;
	std::vector<ptr<Stmt>> stmts;

	SwitchStmt( const CodeLocation& loc, const Expr* cond, std::vector<ptr<Stmt>>&& stmts,
		std::vector<Label>&& labels = {} )
	: Stmt(loc, std::move(labels)), cond(cond), stmts(std::move(stmts)) {}

	const Stmt* accept( Visitor& v ) const override { return v.visit( this ); }
private:
	SwitchStmt* clone() const override { return new SwitchStmt{ *this }; }
};

class CaseStmt final : public Stmt {
public:
	ptr<Expr> cond;
	std::vector<ptr<Stmt>> stmts;

    CaseStmt( const CodeLocation& loc, const Expr* cond, std::vector<ptr<Stmt>>&& stmts,
        std::vector<Label>&& labels = {} )
    : Stmt(loc, std::move(labels)), cond(cond), stmts(std::move(stmts)) {}

	bool isDefault() { return !cond; }

	const Stmt* accept( Visitor& v ) const override { return v.visit( this ); }
private:
	CaseStmt* clone() const override { return new CaseStmt{ *this }; }
};

class WhileStmt final : public Stmt {
public:
	ptr<Expr> cond;
	ptr<Stmt> body;
	std::vector<ptr<Stmt>> inits;
	bool isDoWhile;

	WhileStmt( const CodeLocation& loc, const Expr* cond, const Stmt* body,
		std::vector<ptr<Stmt>>&& inits, bool isDoWhile = false, std::vector<Label>&& labels = {} )
	: Stmt(loc, std::move(labels)), cond(cond), body(body), inits(std::move(inits)),
	  isDoWhile(isDoWhile) {}

	const Stmt* accept( Visitor& v ) const override { return v.visit( this ); }
private:
	WhileStmt* clone() const override { return new WhileStmt{ *this }; }
};

class ForStmt final : public Stmt {
public:
	std::vector<ptr<Stmt>> inits;
	ptr<Expr> cond;
	ptr<Expr> increment;
	ptr<Stmt> body;

	ForStmt( const CodeLocation& loc, std::vector<ptr<Stmt>>&& inits, const Expr* cond,
		const Expr* increment, const Stmt* body, std::vector<Label>&& labels = {} )
	: Stmt(loc, std::move(labels)), inits(std::move(inits)), cond(cond), increment(increment),
	  body(body) {}

	const Stmt* accept( Visitor& v ) const override { return v.visit( this ); }
private:
	ForStmt* clone() const override { return new ForStmt{ *this }; }
};

class BranchStmt final : public Stmt {
public:
	enum Kind { Goto, Break, Continue, FallThrough, FallThroughDefault };
	static constexpr size_t kindEnd = 1 + (size_t)FallThroughDefault;

	const Label originalTarget;
	Label target;
	ptr<Expr> computedTarget;
	Kind kind;

	BranchStmt( const CodeLocation& loc, Kind kind, Label target,
		std::vector<Label>&& labels = {} );
	BranchStmt( const CodeLocation& loc, const Expr* computedTarget,
		std::vector<Label>&& labels = {} )
	: Stmt(loc, std::move(labels)), originalTarget(loc), target(loc),
	  computedTarget(computedTarget), kind(Goto) {}

	const char * kindName() { return kindNames[kind]; }

	const Stmt* accept( Visitor& v ) const override { return v.visit( this ); }
private:
	BranchStmt* clone() const override { return new BranchStmt{ *this }; }
	static const char * kindNames[kindEnd];
};

class ReturnStmt final : public Stmt {
public:
	ptr<Expr> expr;

	ReturnStmt( const CodeLocation& loc, const Expr* expr, std::vector<Label>&& labels = {} )
	: Stmt(loc, std::move(labels)), expr(expr) {}

	const Stmt* accept( Visitor& v ) const override { return v.visit( this ); }
private:
	ReturnStmt* clone() const override { return new ReturnStmt{ *this }; }
};

class ThrowStmt final : public Stmt {
public:
	enum Kind { Terminate, Resume };

	ptr<Expr> expr;
	ptr<Expr> target;
	Kind kind;

	ThrowStmt( const CodeLocation& loc, Kind kind, const Expr* expr, const Expr* target,
		std::vector<Label>&& labels = {} )
	: Stmt(loc, std::move(labels)), expr(expr), target(target), kind(kind) {}

	const Stmt* accept( Visitor& v ) const override { return v.visit( this ); }
private:
	ThrowStmt* clone() const override { return new ThrowStmt{ *this }; }
};

class TryStmt final : public Stmt {
public:
	ptr<CompoundStmt> body;
	std::vector<ptr<CatchStmt>> handlers;
	ptr<FinallyStmt> finally;

	TryStmt( const CodeLocation& loc, const CompoundStmt* body,
		std::vector<ptr<CatchStmt>>&& handlers, const FinallyStmt* finally,
		std::vector<Label>&& labels = {} )
	: Stmt(loc, std::move(labels)), body(body), handlers(std::move(handlers)), finally(finally) {}

	const Stmt* accept( Visitor& v ) const override { return v.visit( this ); }
private:
	TryStmt* clone() const override { return new TryStmt{ *this }; }
};

class CatchStmt final : public Stmt {
public:
	enum Kind { Terminate, Resume };

	ptr<Decl> decl;
	ptr<Expr> cond;
	ptr<Stmt> body;
	Kind kind;

	CatchStmt( const CodeLocation& loc, Kind kind, const Decl* decl, const Expr* cond,
		const Stmt* body, std::vector<Label>&& labels = {} )
	: Stmt(loc, std::move(labels)), decl(decl), cond(cond), body(body), kind(kind) {}

	const Stmt* accept( Visitor& v ) const override { return v.visit( this ); }
private:
	CatchStmt* clone() const override { return new CatchStmt{ *this }; }
};

class FinallyStmt final : public Stmt {
public:
	ptr<CompoundStmt> body;

	FinallyStmt( const CodeLocation& loc, const CompoundStmt* body,
		std::vector<Label>&& labels = {} )
	: Stmt(loc, std::move(labels)), body(body) {}

	const Stmt* accept( Visitor& v ) const override { return v.visit( this ); }
private:
	FinallyStmt* clone() const override { return new FinallyStmt{ *this }; }
};

class WaitForStmt final : public Stmt {
public:
	struct Target {
		ptr<Expr> function;
		std::vector<ptr<Expr>> arguments;
	};

	struct Clause {
		Target target;
		ptr<Stmt> stmt;
		ptr<Expr> cond;
	};

	struct Timeout {
		ptr<Expr> time;
		ptr<Stmt> stmt;
		ptr<Expr> cond;
	};

	struct OrElse {
		ptr<Stmt> stmt;
		ptr<Expr> cond;
	};

	std::vector<Clause> clauses;
	Timeout timeout;
	OrElse orElse;

	WaitForStmt( const CodeLocation& loc, std::vector<Label>&& labels = {} )
	: Stmt(loc, std::move(labels)) {}

	const Stmt* accept( Visitor& v ) const override { return v.visit( this ); }
private:
	WaitForStmt* clone() const override { return new WaitForStmt{ *this }; }
};

class WithStmt final : public Stmt {
public:
	std::vector<ptr<Expr>> exprs;
	ptr<Stmt> stmt;

	WithStmt( const CodeLocation& loc, std::vector<ptr<Expr>>&& exprs, const Stmt* stmt,
		std::vector<Label>&& labels = {} )
	: Stmt(loc, std::move(labels)), exprs(std::move(exprs)), stmt(stmt) {}

	const Stmt* accept( Visitor& v ) const override { return v.visit( this ); }
private:
	WithStmt* clone() const override { return new WithStmt{ *this }; }
};

class DeclStmt final : public Stmt {
public:
	ptr<Decl> decl;

	DeclStmt( const CodeLocation& loc, const Decl* decl, std::vector<Label>&& labels = {} )
	: Stmt(loc, std::move(labels)), decl(decl) {}

	const Stmt* accept( Visitor& v ) const override { return v.visit( this ); }
private:
	DeclStmt* clone() const override { return new DeclStmt{ *this }; }
};

class ImplicitCtorDtorStmt final : public Stmt {
public:
	readonly<Stmt> callStmt;

	ImplicitCtorDtorStmt( const CodeLocation& loc, const Stmt* callStmt,
		std::vector<Label>&& labels = {} )
	: Stmt(loc, std::move(labels)), callStmt(callStmt) {}

	const Stmt* accept( Visitor& v ) const override { return v.visit( this ); }
private:
	ImplicitCtorDtorStmt* clone() const override { return new ImplicitCtorDtorStmt{ *this }; }
};

//=================================================================================================
/// This disgusting and giant piece of boiler-plate is here to solve a cyclic dependency
/// remove only if there is a better solution
/// The problem is that ast::ptr< ... > uses increment/decrement which won't work well with
/// forward declarations
inline void increment( const class Stmt * node, Node::ref_type ref ) { node->increment( ref ); }
inline void decrement( const class Stmt * node, Node::ref_type ref ) { node->decrement( ref ); }
inline void increment( const class CompoundStmt * node, Node::ref_type ref ) { node->increment( ref ); }
inline void decrement( const class CompoundStmt * node, Node::ref_type ref ) { node->decrement( ref ); }
inline void increment( const class ExprStmt * node, Node::ref_type ref ) { node->increment( ref ); }
inline void decrement( const class ExprStmt * node, Node::ref_type ref ) { node->decrement( ref ); }
// inline void increment( const class AsmStmt * node, Node::ref_type ref ) { node->increment( ref ); }
// inline void decrement( const class AsmStmt * node, Node::ref_type ref ) { node->decrement( ref ); }
// inline void increment( const class DirectiveStmt * node, Node::ref_type ref ) { node->increment( ref ); }
// inline void decrement( const class DirectiveStmt * node, Node::ref_type ref ) { node->decrement( ref ); }
// inline void increment( const class IfStmt * node, Node::ref_type ref ) { node->increment( ref ); }
// inline void decrement( const class IfStmt * node, Node::ref_type ref ) { node->decrement( ref ); }
// inline void increment( const class WhileStmt * node, Node::ref_type ref ) { node->increment( ref ); }
// inline void decrement( const class WhileStmt * node, Node::ref_type ref ) { node->decrement( ref ); }
// inline void increment( const class ForStmt * node, Node::ref_type ref ) { node->increment( ref ); }
// inline void decrement( const class ForStmt * node, Node::ref_type ref ) { node->decrement( ref ); }
// inline void increment( const class SwitchStmt * node, Node::ref_type ref ) { node->increment( ref ); }
// inline void decrement( const class SwitchStmt * node, Node::ref_type ref ) { node->decrement( ref ); }
// inline void increment( const class CaseStmt * node, Node::ref_type ref ) { node->increment( ref ); }
// inline void decrement( const class CaseStmt * node, Node::ref_type ref ) { node->decrement( ref ); }
// inline void increment( const class BranchStmt * node, Node::ref_type ref ) { node->increment( ref ); }
// inline void decrement( const class BranchStmt * node, Node::ref_type ref ) { node->decrement( ref ); }
// inline void increment( const class ReturnStmt * node, Node::ref_type ref ) { node->increment( ref ); }
// inline void decrement( const class ReturnStmt * node, Node::ref_type ref ) { node->decrement( ref ); }
// inline void increment( const class ThrowStmt * node, Node::ref_type ref ) { node->increment( ref ); }
// inline void decrement( const class ThrowStmt * node, Node::ref_type ref ) { node->decrement( ref ); }
// inline void increment( const class TryStmt * node, Node::ref_type ref ) { node->increment( ref ); }
// inline void decrement( const class TryStmt * node, Node::ref_type ref ) { node->decrement( ref ); }
// inline void increment( const class CatchStmt * node, Node::ref_type ref ) { node->increment( ref ); }
// inline void decrement( const class CatchStmt * node, Node::ref_type ref ) { node->decrement( ref ); }
// inline void increment( const class FinallyStmt * node, Node::ref_type ref ) { node->increment( ref ); }
// inline void decrement( const class FinallyStmt * node, Node::ref_type ref ) { node->decrement( ref ); }
// inline void increment( const class WaitForStmt * node, Node::ref_type ref ) { node->increment( ref ); }
// inline void decrement( const class WaitForStmt * node, Node::ref_type ref ) { node->decrement( ref ); }
// inline void increment( const class WithStmt * node, Node::ref_type ref ) { node->increment( ref ); }
// inline void decrement( const class WithStmt * node, Node::ref_type ref ) { node->decrement( ref ); }
// inline void increment( const class DeclStmt * node, Node::ref_type ref ) { node->increment( ref ); }
// inline void decrement( const class DeclStmt * node, Node::ref_type ref ) { node->decrement( ref ); }
inline void increment( const class NullStmt * node, Node::ref_type ref ) { node->increment( ref ); }
inline void decrement( const class NullStmt * node, Node::ref_type ref ) { node->decrement( ref ); }
// inline void increment( const class ImplicitCtorDtorStmt * node, Node::ref_type ref ) { node->increment( ref ); }
// inline void decrement( const class ImplicitCtorDtorStmt * node, Node::ref_type ref ) { node->decrement( ref ); }

}

// Local Variables: //
// tab-width: 4 //
// mode: c++ //
// compile-command: "make install" //
// End: //
