//
// Cforall Version 1.0.0 Copyright (C) 2019 University of Waterloo
//
// The contents of this file are covered under the licence agreement in the
// file "LICENCE" distributed with Cforall.
//
// Util.cpp -- General utilities for working with the AST.
//
// Author           : Andrew Beach
// Created On       : Wed Jan 19  9:46:00 2022
// Last Modified By : Andrew Beach
// Last Modified On : Wed May 11 16:16:00 2022
// Update Count     : 3
//

#include "Util.hpp"

#include "Node.hpp"
#include "ParseNode.hpp"
#include "Pass.hpp"
#include "TranslationUnit.hpp"

#include <vector>

namespace ast {

namespace {

/// Check that ast::ptr/strong references do not form a cycle.
struct NoStrongCyclesCore {
	std::vector<const Node *> parents;

	void previsit( const Node * node ) {
		for ( auto & parent : parents ) {
			assert( parent != node );
		}
		parents.push_back( node );
	}

	void postvisit( const Node * node ) {
		assert( !parents.empty() );
		assert( parents.back() == node );
		parents.pop_back();
	}
};

/// Check that every note that can has a set CodeLocation.
void isCodeLocationSet( const ParseNode * node ) {
	assert( node->location.isSet() );
}

void areLabelLocationsSet( const Stmt * stmt ) {
	for ( const Label& label : stmt->labels ) {
		assert( label.location.isSet() );
	}
}

/// Make sure the reference counts are in a valid combination.
void isStable( const Node * node ) {
	assert( node->isStable() );
}

/// Check that a FunctionDecl is synchronized with it's FunctionType.
void functionDeclMatchesType( const FunctionDecl * decl ) {
	// The type is a cache of sorts, if it is missing that is only a
	// problem if isTypeFixed is set.
	if ( decl->isTypeFixed ) {
		assert( decl->type );
	} else if ( !decl->type ) {
		return;
	}

	const FunctionType * type = decl->type;

	// Check that `type->forall` corresponds with `decl->type_params`.
	assert( type->forall.size() == decl->type_params.size() );
	// Check that `type->assertions` corresponds with `decl->assertions`.
	assert( type->assertions.size() == decl->assertions.size() );
	// Check that `type->params` corresponds with `decl->params`.
	assert( type->params.size() == decl->params.size() );
	// Check that `type->returns` corresponds with `decl->returns`.
	assert( type->returns.size() == decl->returns.size() );
}

struct InvariantCore {
	// To save on the number of visits: this is a kind of composed core.
	// None of the passes should make changes so ordering doesn't matter.
	NoStrongCyclesCore no_strong_cycles;

	void previsit( const Node * node ) {
		no_strong_cycles.previsit( node );
		isStable( node );
	}

	void previsit( const ParseNode * node ) {
		previsit( (const Node *)node );
		isCodeLocationSet( node );
	}

	void previsit( const FunctionDecl * node ) {
		previsit( (const ParseNode *)node );
		functionDeclMatchesType( node );
	}

	void previsit( const Stmt * node ) {
		previsit( (const ParseNode *)node );
		areLabelLocationsSet( node );
	}

	void postvisit( const Node * node ) {
		no_strong_cycles.postvisit( node );
	}
};

} // namespace

void checkInvariants( TranslationUnit & transUnit ) {
	ast::Pass<InvariantCore>::run( transUnit );
}

} // namespace ast
