/*
  The "validate" phase of translation is used to take a syntax tree and convert it into a standard form that aims to be
  as regular in structure as possible.  Some assumptions can be made regarding the state of the tree after this pass is
  complete, including:

  - No nested structure or union definitions; any in the input are "hoisted" to the level of the containing struct or
    union.

  - All enumeration constants have type EnumInstType.

  - The type "void" never occurs in lists of function parameter or return types; neither do tuple types.  A function
    taking no arguments has no argument types, and tuples are flattened.

  - No context instances exist; they are all replaced by the set of declarations signified by the context, instantiated
    by the particular set of type arguments.

  - Every declaration is assigned a unique id.

  - No typedef declarations or instances exist; the actual type is substituted for each instance.

  - Each type, struct, and union definition is followed by an appropriate assignment operator.

  - Each use of a struct or union is connected to a complete definition of that struct or union, even if that definition
    occurs later in the input.
*/

#include <list>
#include <iterator>
#include "Validate.h"
#include "SynTree/Visitor.h"
#include "SynTree/Mutator.h"
#include "SynTree/Type.h"
#include "SynTree/Statement.h"
#include "Indexer.h"
#include "SynTree/TypeSubstitution.h"
#include "FixFunction.h"
#include "ImplementationType.h"
#include "utility.h"
#include "UniqueName.h"
#include "AddVisit.h"


#define debugPrint( x ) if ( doDebug ) { std::cout << x; }

namespace SymTab {
    class HoistStruct : public Visitor {
      public:
	static void hoistStruct( std::list< Declaration * > &translationUnit );
  
	std::list< Declaration * > &get_declsToAdd() { return declsToAdd; }
  
	virtual void visit( StructDecl *aggregateDecl );
	virtual void visit( UnionDecl *aggregateDecl );

	virtual void visit( CompoundStmt *compoundStmt );
	virtual void visit( IfStmt *ifStmt );
	virtual void visit( WhileStmt *whileStmt );
	virtual void visit( ForStmt *forStmt );
	virtual void visit( SwitchStmt *switchStmt );
	virtual void visit( ChooseStmt *chooseStmt );
	virtual void visit( CaseStmt *caseStmt );
	virtual void visit( CatchStmt *catchStmt );
      private:
	HoistStruct();

	template< typename AggDecl > void handleAggregate( AggDecl *aggregateDecl );

	std::list< Declaration * > declsToAdd;
	bool inStruct;
    };

    class Pass1 : public Visitor {
	typedef Visitor Parent;
	virtual void visit( EnumDecl *aggregateDecl );
	virtual void visit( FunctionType *func );
    };
  
    class Pass2 : public Indexer {
	typedef Indexer Parent;
      public:
	Pass2( bool doDebug, const Indexer *indexer );
      private:
	virtual void visit( StructInstType *structInst );
	virtual void visit( UnionInstType *unionInst );
	virtual void visit( ContextInstType *contextInst );
	virtual void visit( StructDecl *structDecl );
	virtual void visit( UnionDecl *unionDecl );
	virtual void visit( TypeInstType *typeInst );

	const Indexer *indexer;
  
	typedef std::map< std::string, std::list< StructInstType * > > ForwardStructsType;
	typedef std::map< std::string, std::list< UnionInstType * > > ForwardUnionsType;
	ForwardStructsType forwardStructs;
	ForwardUnionsType forwardUnions;
    };

    class Pass3 : public Indexer {
	typedef Indexer Parent;
      public:
	Pass3( const Indexer *indexer );
      private:
	virtual void visit( ObjectDecl *object );
	virtual void visit( FunctionDecl *func );

	const Indexer *indexer;
    };

    class AddStructAssignment : public Visitor {
      public:
	static void addStructAssignment( std::list< Declaration * > &translationUnit );

	std::list< Declaration * > &get_declsToAdd() { return declsToAdd; }
  
	virtual void visit( StructDecl *structDecl );
	virtual void visit( UnionDecl *structDecl );
	virtual void visit( TypeDecl *typeDecl );
	virtual void visit( ContextDecl *ctxDecl );
	virtual void visit( FunctionDecl *functionDecl );

	virtual void visit( FunctionType *ftype );
	virtual void visit( PointerType *ftype );
  
	virtual void visit( CompoundStmt *compoundStmt );
	virtual void visit( IfStmt *ifStmt );
	virtual void visit( WhileStmt *whileStmt );
	virtual void visit( ForStmt *forStmt );
	virtual void visit( SwitchStmt *switchStmt );
	virtual void visit( ChooseStmt *chooseStmt );
	virtual void visit( CaseStmt *caseStmt );
	virtual void visit( CatchStmt *catchStmt );

	AddStructAssignment() : functionNesting( 0 ) {}
      private:
	template< typename StmtClass > void visitStatement( StmtClass *stmt );
  
	std::list< Declaration * > declsToAdd;
	std::set< std::string > structsDone;
	unsigned int functionNesting;			// current level of nested functions
    };

    class EliminateTypedef : public Mutator {
      public:
	static void eliminateTypedef( std::list< Declaration * > &translationUnit );
      private:
	virtual Declaration *mutate( TypedefDecl *typeDecl );
	virtual TypeDecl *mutate( TypeDecl *typeDecl );
	virtual DeclarationWithType *mutate( FunctionDecl *funcDecl );
	virtual ObjectDecl *mutate( ObjectDecl *objDecl );
	virtual CompoundStmt *mutate( CompoundStmt *compoundStmt );
	virtual Type *mutate( TypeInstType *aggregateUseType );
	virtual Expression *mutate( CastExpr *castExpr );
  
	std::map< std::string, TypedefDecl * > typedefNames;
    };

    void validate( std::list< Declaration * > &translationUnit, bool doDebug ) {
	Pass1 pass1;
	Pass2 pass2( doDebug, 0 );
	Pass3 pass3( 0 );
	EliminateTypedef::eliminateTypedef( translationUnit );
	HoistStruct::hoistStruct( translationUnit );
	acceptAll( translationUnit, pass1 );
	acceptAll( translationUnit, pass2 );
	AddStructAssignment::addStructAssignment( translationUnit );
	acceptAll( translationUnit, pass3 );
    }
    
    void validateType( Type *type, const Indexer *indexer ) {
	Pass1 pass1;
	Pass2 pass2( false, indexer );
	Pass3 pass3( indexer );
	type->accept( pass1 );
	type->accept( pass2 );
	type->accept( pass3 );
    }

    template< typename Visitor >
    void acceptAndAdd( std::list< Declaration * > &translationUnit, Visitor &visitor, bool addBefore ) {
	std::list< Declaration * >::iterator i = translationUnit.begin();
	while ( i != translationUnit.end() ) {
	    (*i)->accept( visitor );
	    std::list< Declaration * >::iterator next = i;
	    next++;
	    if ( ! visitor.get_declsToAdd().empty() ) {
		translationUnit.splice( addBefore ? i : next, visitor.get_declsToAdd() );
	    } // if
	    i = next;
	} // while
    }

    void HoistStruct::hoistStruct( std::list< Declaration * > &translationUnit ) {
	HoistStruct hoister;
	acceptAndAdd( translationUnit, hoister, true );
    }

    HoistStruct::HoistStruct() : inStruct( false ) {
    }

    void filter( std::list< Declaration * > &declList, bool (*pred)( Declaration * ), bool doDelete ) {
	std::list< Declaration * >::iterator i = declList.begin();
	while ( i != declList.end() ) {
	    std::list< Declaration * >::iterator next = i;
	    ++next;
	    if ( pred( *i ) ) {
		if ( doDelete ) {
		    delete *i;
		} // if
		declList.erase( i );
	    } // if
	    i = next;
	} // while
    }

    bool isStructOrUnion( Declaration *decl ) {
	return dynamic_cast< StructDecl * >( decl ) || dynamic_cast< UnionDecl * >( decl );
    }

    template< typename AggDecl >
    void HoistStruct::handleAggregate( AggDecl *aggregateDecl ) {
	if ( inStruct ) {
	    // Add elements in stack order corresponding to nesting structure.
	    declsToAdd.push_front( aggregateDecl );
	    Visitor::visit( aggregateDecl );
	} else {
	    inStruct = true;
	    Visitor::visit( aggregateDecl );
	    inStruct = false;
	} // if
	// Always remove the hoisted aggregate from the inner structure.
	filter( aggregateDecl->get_members(), isStructOrUnion, false );
    }

    void HoistStruct::visit( StructDecl *aggregateDecl ) {
	handleAggregate( aggregateDecl );
    }

    void HoistStruct::visit( UnionDecl *aggregateDecl ) {
	handleAggregate( aggregateDecl );
    }

    void HoistStruct::visit( CompoundStmt *compoundStmt ) {
	addVisit( compoundStmt, *this );
    }

    void HoistStruct::visit( IfStmt *ifStmt ) {
	addVisit( ifStmt, *this );
    }

    void HoistStruct::visit( WhileStmt *whileStmt ) {
	addVisit( whileStmt, *this );
    }

    void HoistStruct::visit( ForStmt *forStmt ) {
	addVisit( forStmt, *this );
    }

    void HoistStruct::visit( SwitchStmt *switchStmt ) {
	addVisit( switchStmt, *this );
    }

    void HoistStruct::visit( ChooseStmt *switchStmt ) {
	addVisit( switchStmt, *this );
    }

    void HoistStruct::visit( CaseStmt *caseStmt ) {
	addVisit( caseStmt, *this );
    }

    void HoistStruct::visit( CatchStmt *cathStmt ) {
	addVisit( cathStmt, *this );
    }

    void Pass1::visit( EnumDecl *enumDecl ) {
	// Set the type of each member of the enumeration to be EnumConstant
  
	for ( std::list< Declaration * >::iterator i = enumDecl->get_members().begin(); i != enumDecl->get_members().end(); ++i ) {
	    ObjectDecl *obj = dynamic_cast< ObjectDecl * >( *i );
	    assert( obj );
	    obj->set_type( new EnumInstType( Type::Qualifiers( true, false, false, false, false, false ), enumDecl->get_name() ) );
	} // for
	Parent::visit( enumDecl );
    }

    namespace {
	template< typename DWTIterator >
	void fixFunctionList( DWTIterator begin, DWTIterator end, FunctionType *func ) {
	    // the only case in which "void" is valid is where it is the only one in the list; then
	    // it should be removed entirely
	    // other fix ups are handled by the FixFunction class
	    if ( begin == end ) return;
	    FixFunction fixer;
	    DWTIterator i = begin;
	    *i = (*i )->acceptMutator( fixer );
	    if ( fixer.get_isVoid() ) {
		DWTIterator j = i;
		++i;
		func->get_parameters().erase( j );
		if ( i != end ) { 
		    throw SemanticError( "invalid type void in function type ", func );
		} // if
	    } else {
		++i;
		for ( ; i != end; ++i ) {
		    FixFunction fixer;
		    *i = (*i )->acceptMutator( fixer );
		    if ( fixer.get_isVoid() ) {
			throw SemanticError( "invalid type void in function type ", func );
		    } // if
		} // for
	    } // if
	}
    }

    void Pass1::visit( FunctionType *func ) {
	// Fix up parameters and return types
	fixFunctionList( func->get_parameters().begin(), func->get_parameters().end(), func );
	fixFunctionList( func->get_returnVals().begin(), func->get_returnVals().end(), func );
	Visitor::visit( func );
    }

    Pass2::Pass2( bool doDebug, const Indexer *other_indexer ) : Indexer( doDebug ) {
	if ( other_indexer ) {
	    indexer = other_indexer;
	} else {
	    indexer = this;
	} // if
    }

    void Pass2::visit( StructInstType *structInst ) {
	Parent::visit( structInst );
	StructDecl *st = indexer->lookupStruct( structInst->get_name() );
	// it's not a semantic error if the struct is not found, just an implicit forward declaration
	if ( st ) {
	    assert( ! structInst->get_baseStruct() || structInst->get_baseStruct()->get_members().empty() || ! st->get_members().empty() );
	    structInst->set_baseStruct( st );
	} // if
	if ( ! st || st->get_members().empty() ) {
	    // use of forward declaration
	    forwardStructs[ structInst->get_name() ].push_back( structInst );
	} // if
    }

    void Pass2::visit( UnionInstType *unionInst ) {
	Parent::visit( unionInst );
	UnionDecl *un = indexer->lookupUnion( unionInst->get_name() );
	// it's not a semantic error if the union is not found, just an implicit forward declaration
	if ( un ) {
	    unionInst->set_baseUnion( un );
	} // if
	if ( ! un || un->get_members().empty() ) {
	    // use of forward declaration
	    forwardUnions[ unionInst->get_name() ].push_back( unionInst );
	} // if
    }

    void Pass2::visit( ContextInstType *contextInst ) {
	Parent::visit( contextInst );
	ContextDecl *ctx = indexer->lookupContext( contextInst->get_name() );
	if ( ! ctx ) {
	    throw SemanticError( "use of undeclared context " + contextInst->get_name() );
	} // if
	for ( std::list< TypeDecl * >::const_iterator i = ctx->get_parameters().begin(); i != ctx->get_parameters().end(); ++i ) {
	    for ( std::list< DeclarationWithType * >::const_iterator assert = (*i )->get_assertions().begin(); assert != (*i )->get_assertions().end(); ++assert ) {
		if ( ContextInstType *otherCtx = dynamic_cast< ContextInstType * >(*assert ) ) {
		    cloneAll( otherCtx->get_members(), contextInst->get_members() );
		} else {
		    contextInst->get_members().push_back( (*assert )->clone() );
		} // if
	    } // for
	} // for
	applySubstitution( ctx->get_parameters().begin(), ctx->get_parameters().end(), contextInst->get_parameters().begin(), ctx->get_members().begin(), ctx->get_members().end(), back_inserter( contextInst->get_members() ) );
    }

    void Pass2::visit( StructDecl *structDecl ) {
	if ( ! structDecl->get_members().empty() ) {
	    ForwardStructsType::iterator fwds = forwardStructs.find( structDecl->get_name() );
	    if ( fwds != forwardStructs.end() ) {
		for ( std::list< StructInstType * >::iterator inst = fwds->second.begin(); inst != fwds->second.end(); ++inst ) {
		    (*inst )->set_baseStruct( structDecl );
		} // for
		forwardStructs.erase( fwds );
	    } // if
	} // if
	Indexer::visit( structDecl );
    }

    void Pass2::visit( UnionDecl *unionDecl ) {
	if ( ! unionDecl->get_members().empty() ) {
	    ForwardUnionsType::iterator fwds = forwardUnions.find( unionDecl->get_name() );
	    if ( fwds != forwardUnions.end() ) {
		for ( std::list< UnionInstType * >::iterator inst = fwds->second.begin(); inst != fwds->second.end(); ++inst ) {
		    (*inst )->set_baseUnion( unionDecl );
		} // for
		forwardUnions.erase( fwds );
	    } // if
	} // if
	Indexer::visit( unionDecl );
    }

    void Pass2::visit( TypeInstType *typeInst ) {
	if ( NamedTypeDecl *namedTypeDecl = lookupType( typeInst->get_name() ) ) {
	    if ( TypeDecl *typeDecl = dynamic_cast< TypeDecl * >( namedTypeDecl ) ) {
		typeInst->set_isFtype( typeDecl->get_kind() == TypeDecl::Ftype );
	    } // if
	} // if
    }

    Pass3::Pass3( const Indexer *other_indexer ) :  Indexer( false ) {
	if ( other_indexer ) {
	    indexer = other_indexer;
	} else {
	    indexer = this;
	} // if
    }

    void forallFixer( Type *func ) {
	// Fix up assertions
	for ( std::list< TypeDecl * >::iterator type = func->get_forall().begin(); type != func->get_forall().end(); ++type ) {
	    std::list< DeclarationWithType * > toBeDone, nextRound;
	    toBeDone.splice( toBeDone.end(), (*type )->get_assertions() );
	    while ( ! toBeDone.empty() ) {
		for ( std::list< DeclarationWithType * >::iterator assertion = toBeDone.begin(); assertion != toBeDone.end(); ++assertion ) {
		    if ( ContextInstType *ctx = dynamic_cast< ContextInstType * >( (*assertion )->get_type() ) ) {
			for ( std::list< Declaration * >::const_iterator i = ctx->get_members().begin(); i != ctx->get_members().end(); ++i ) {
			    DeclarationWithType *dwt = dynamic_cast< DeclarationWithType * >( *i );
			    assert( dwt );
			    nextRound.push_back( dwt->clone() );
			}
			delete ctx;
		    } else {
			FixFunction fixer;
			*assertion = (*assertion )->acceptMutator( fixer );
			if ( fixer.get_isVoid() ) {
			    throw SemanticError( "invalid type void in assertion of function ", func );
			}
			(*type )->get_assertions().push_back( *assertion );
		    }
		}
		toBeDone.clear();
		toBeDone.splice( toBeDone.end(), nextRound );
	    }
	}
    }

    void Pass3::visit( ObjectDecl *object ) {
	forallFixer( object->get_type() );
	if ( PointerType *pointer = dynamic_cast< PointerType * >( object->get_type() ) ) {
	    forallFixer( pointer->get_base() );
	} // if
	Parent::visit( object );
	object->fixUniqueId();
    }

    void Pass3::visit( FunctionDecl *func ) {
	forallFixer( func->get_type() );
	Parent::visit( func );
	func->fixUniqueId();
    }

    static const std::list< std::string > noLabels;

    void AddStructAssignment::addStructAssignment( std::list< Declaration * > &translationUnit ) {
	AddStructAssignment visitor;
	acceptAndAdd( translationUnit, visitor, false );
    }

    template< typename OutputIterator >
    void makeScalarAssignment( ObjectDecl *srcParam, ObjectDecl *dstParam, DeclarationWithType *member, OutputIterator out ) {
	ObjectDecl *obj = dynamic_cast<ObjectDecl *>( member );
	// unnamed bit fields are not copied as they cannot be accessed
	if ( obj != NULL && obj->get_name() == "" && obj->get_bitfieldWidth() != NULL ) return;

	UntypedExpr *assignExpr = new UntypedExpr( new NameExpr( "?=?" ) );
  
	UntypedExpr *derefExpr = new UntypedExpr( new NameExpr( "*?" ) );
	derefExpr->get_args().push_back( new VariableExpr( dstParam ) );
  
	// do something special for unnamed members
	Expression *dstselect = new AddressExpr( new MemberExpr( member, derefExpr ) );
	assignExpr->get_args().push_back( dstselect );
  
	Expression *srcselect = new MemberExpr( member, new VariableExpr( srcParam ) );
	assignExpr->get_args().push_back( srcselect );
  
	*out++ = new ExprStmt( noLabels, assignExpr );
    }

    template< typename OutputIterator >
    void makeArrayAssignment( ObjectDecl *srcParam, ObjectDecl *dstParam, DeclarationWithType *member, ArrayType *array, OutputIterator out ) {
	static UniqueName indexName( "_index" );
  
	// for a flexible array member nothing is done -- user must define own assignment
	if ( ! array->get_dimension() ) return;
  
	ObjectDecl *index = new ObjectDecl( indexName.newName(), Declaration::NoStorageClass, LinkageSpec::C, 0, new BasicType( Type::Qualifiers(), BasicType::SignedInt ), 0 );
	*out++ = new DeclStmt( noLabels, index );
  
	UntypedExpr *init = new UntypedExpr( new NameExpr( "?=?" ) );
	init->get_args().push_back( new AddressExpr( new VariableExpr( index ) ) );
	init->get_args().push_back( new NameExpr( "0" ) );
	Statement *initStmt = new ExprStmt( noLabels, init );
  
	UntypedExpr *cond = new UntypedExpr( new NameExpr( "?<?" ) );
	cond->get_args().push_back( new VariableExpr( index ) );
	cond->get_args().push_back( array->get_dimension()->clone() );
  
	UntypedExpr *inc = new UntypedExpr( new NameExpr( "++?" ) );
	inc->get_args().push_back( new AddressExpr( new VariableExpr( index ) ) );
  
	UntypedExpr *assignExpr = new UntypedExpr( new NameExpr( "?=?" ) );
  
	UntypedExpr *derefExpr = new UntypedExpr( new NameExpr( "*?" ) );
	derefExpr->get_args().push_back( new VariableExpr( dstParam ) );
  
	Expression *dstselect = new MemberExpr( member, derefExpr );
	UntypedExpr *dstIndex = new UntypedExpr( new NameExpr( "?+?" ) );
	dstIndex->get_args().push_back( dstselect );
	dstIndex->get_args().push_back( new VariableExpr( index ) );
	assignExpr->get_args().push_back( dstIndex );
  
	Expression *srcselect = new MemberExpr( member, new VariableExpr( srcParam ) );
	UntypedExpr *srcIndex = new UntypedExpr( new NameExpr( "?[?]" ) );
	srcIndex->get_args().push_back( srcselect );
	srcIndex->get_args().push_back( new VariableExpr( index ) );
	assignExpr->get_args().push_back( srcIndex );
  
	*out++ = new ForStmt( noLabels, initStmt, cond, inc, new ExprStmt( noLabels, assignExpr ) );
    }

    Declaration *makeStructAssignment( StructDecl *aggregateDecl, StructInstType *refType, unsigned int functionNesting ) {
	FunctionType *assignType = new FunctionType( Type::Qualifiers(), false );
  
	ObjectDecl *returnVal = new ObjectDecl( "", Declaration::NoStorageClass, LinkageSpec::Cforall, 0, refType->clone(), 0 );
	assignType->get_returnVals().push_back( returnVal );
  
	ObjectDecl *dstParam = new ObjectDecl( "_dst", Declaration::NoStorageClass, LinkageSpec::Cforall, 0, new PointerType( Type::Qualifiers(), refType->clone() ), 0 );
	assignType->get_parameters().push_back( dstParam );
  
	ObjectDecl *srcParam = new ObjectDecl( "_src", Declaration::NoStorageClass, LinkageSpec::Cforall, 0, refType, 0 );
	assignType->get_parameters().push_back( srcParam );

	// Routines at global scope marked "static" to prevent multiple definitions is separate translation units
	// because each unit generates copies of the default routines for each aggregate.
	FunctionDecl *assignDecl = new FunctionDecl( "?=?", functionNesting > 0 ? Declaration::NoStorageClass : Declaration::Static, LinkageSpec::AutoGen, assignType, new CompoundStmt( noLabels ), true );
	assignDecl->fixUniqueId();
  
	for ( std::list< Declaration * >::const_iterator member = aggregateDecl->get_members().begin(); member != aggregateDecl->get_members().end(); ++member ) {
	    if ( DeclarationWithType *dwt = dynamic_cast< DeclarationWithType * >( *member ) ) {
		if ( ArrayType *array = dynamic_cast< ArrayType * >( dwt->get_type() ) ) {
		    makeArrayAssignment( srcParam, dstParam, dwt, array, back_inserter( assignDecl->get_statements()->get_kids() ) );
		} else {
		    makeScalarAssignment( srcParam, dstParam, dwt, back_inserter( assignDecl->get_statements()->get_kids() ) );
		} // if
	    } // if
	} // for
	assignDecl->get_statements()->get_kids().push_back( new ReturnStmt( noLabels, new VariableExpr( srcParam ) ) );
  
	return assignDecl;
    }

    Declaration *makeUnionAssignment( UnionDecl *aggregateDecl, UnionInstType *refType, unsigned int functionNesting ) {
	FunctionType *assignType = new FunctionType( Type::Qualifiers(), false );
  
	ObjectDecl *returnVal = new ObjectDecl( "", Declaration::NoStorageClass, LinkageSpec::Cforall, 0, refType->clone(), 0 );
	assignType->get_returnVals().push_back( returnVal );
  
	ObjectDecl *dstParam = new ObjectDecl( "_dst", Declaration::NoStorageClass, LinkageSpec::Cforall, 0, new PointerType( Type::Qualifiers(), refType->clone() ), 0 );
	assignType->get_parameters().push_back( dstParam );
  
	ObjectDecl *srcParam = new ObjectDecl( "_src", Declaration::NoStorageClass, LinkageSpec::Cforall, 0, refType, 0 );
	assignType->get_parameters().push_back( srcParam );
  
	// Routines at global scope marked "static" to prevent multiple definitions is separate translation units
	// because each unit generates copies of the default routines for each aggregate.
	FunctionDecl *assignDecl = new FunctionDecl( "?=?",  functionNesting > 0 ? Declaration::NoStorageClass : Declaration::Static, LinkageSpec::AutoGen, assignType, new CompoundStmt( noLabels ), true );
	assignDecl->fixUniqueId();
  
	UntypedExpr *copy = new UntypedExpr( new NameExpr( "__builtin_memcpy" ) );
	copy->get_args().push_back( new VariableExpr( dstParam ) );
	copy->get_args().push_back( new AddressExpr( new VariableExpr( srcParam ) ) );
	copy->get_args().push_back( new SizeofExpr( refType->clone() ) );

	assignDecl->get_statements()->get_kids().push_back( new ExprStmt( noLabels, copy ) );
	assignDecl->get_statements()->get_kids().push_back( new ReturnStmt( noLabels, new VariableExpr( srcParam ) ) );
  
	return assignDecl;
    }

    void AddStructAssignment::visit( StructDecl *structDecl ) {
	if ( ! structDecl->get_members().empty() && structsDone.find( structDecl->get_name() ) == structsDone.end() ) {
	    StructInstType *structInst = new StructInstType( Type::Qualifiers(), structDecl->get_name() );
	    structInst->set_baseStruct( structDecl );
	    declsToAdd.push_back( makeStructAssignment( structDecl, structInst, functionNesting ) );
	    structsDone.insert( structDecl->get_name() );
	} // if
    }

    void AddStructAssignment::visit( UnionDecl *unionDecl ) {
	if ( ! unionDecl->get_members().empty() ) {
	    UnionInstType *unionInst = new UnionInstType( Type::Qualifiers(), unionDecl->get_name() );
	    unionInst->set_baseUnion( unionDecl );
	    declsToAdd.push_back( makeUnionAssignment( unionDecl, unionInst, functionNesting ) );
	} // if
    }

    void AddStructAssignment::visit( TypeDecl *typeDecl ) {
	CompoundStmt *stmts = 0;
	TypeInstType *typeInst = new TypeInstType( Type::Qualifiers(), typeDecl->get_name(), false );
	typeInst->set_baseType( typeDecl );
	ObjectDecl *src = new ObjectDecl( "_src", Declaration::NoStorageClass, LinkageSpec::Cforall, 0, typeInst->clone(), 0 );
	ObjectDecl *dst = new ObjectDecl( "_dst", Declaration::NoStorageClass, LinkageSpec::Cforall, 0, new PointerType( Type::Qualifiers(), typeInst->clone() ), 0 );
	if ( typeDecl->get_base() ) {
	    stmts = new CompoundStmt( std::list< Label >() );
	    UntypedExpr *assign = new UntypedExpr( new NameExpr( "?=?" ) );
	    assign->get_args().push_back( new CastExpr( new VariableExpr( dst ), new PointerType( Type::Qualifiers(), typeDecl->get_base()->clone() ) ) );
	    assign->get_args().push_back( new CastExpr( new VariableExpr( src ), typeDecl->get_base()->clone() ) );
	    stmts->get_kids().push_back( new ReturnStmt( std::list< Label >(), assign ) );
	} // if
	FunctionType *type = new FunctionType( Type::Qualifiers(), false );
	type->get_returnVals().push_back( new ObjectDecl( "", Declaration::NoStorageClass, LinkageSpec::Cforall, 0, typeInst, 0 ) );
	type->get_parameters().push_back( dst );
	type->get_parameters().push_back( src );
	FunctionDecl *func = new FunctionDecl( "?=?", Declaration::NoStorageClass, LinkageSpec::AutoGen, type, stmts, false );
	declsToAdd.push_back( func );
    }

    void addDecls( std::list< Declaration * > &declsToAdd, std::list< Statement * > &statements, std::list< Statement * >::iterator i ) {
	if ( ! declsToAdd.empty() ) {
	    for ( std::list< Declaration * >::iterator decl = declsToAdd.begin(); decl != declsToAdd.end(); ++decl ) {
		statements.insert( i, new DeclStmt( noLabels, *decl ) );
	    } // for
	    declsToAdd.clear();
	} // if
    }

    void AddStructAssignment::visit( FunctionType *) {
	// ensure that we don't add assignment ops for types defined as part of the function
    }

    void AddStructAssignment::visit( PointerType *) {
	// ensure that we don't add assignment ops for types defined as part of the pointer
    }

    void AddStructAssignment::visit( ContextDecl *) {
	// ensure that we don't add assignment ops for types defined as part of the context
    }

    template< typename StmtClass >
    inline void AddStructAssignment::visitStatement( StmtClass *stmt ) {
	std::set< std::string > oldStructs = structsDone;
	addVisit( stmt, *this );
	structsDone = oldStructs;
    }

    void AddStructAssignment::visit( FunctionDecl *functionDecl ) {
	maybeAccept( functionDecl->get_functionType(), *this );
	acceptAll( functionDecl->get_oldDecls(), *this );
	functionNesting += 1;
	maybeAccept( functionDecl->get_statements(), *this );
	functionNesting -= 1;
    }

    void AddStructAssignment::visit( CompoundStmt *compoundStmt ) {
	visitStatement( compoundStmt );
    }

    void AddStructAssignment::visit( IfStmt *ifStmt ) {
	visitStatement( ifStmt );
    }

    void AddStructAssignment::visit( WhileStmt *whileStmt ) {
	visitStatement( whileStmt );
    }

    void AddStructAssignment::visit( ForStmt *forStmt ) {
	visitStatement( forStmt );
    }

    void AddStructAssignment::visit( SwitchStmt *switchStmt ) {
	visitStatement( switchStmt );
    }

    void AddStructAssignment::visit( ChooseStmt *switchStmt ) {
	visitStatement( switchStmt );
    }

    void AddStructAssignment::visit( CaseStmt *caseStmt ) {
	visitStatement( caseStmt );
    }

    void AddStructAssignment::visit( CatchStmt *cathStmt ) {
	visitStatement( cathStmt );
    }

    bool isTypedef( Declaration *decl ) {
	return dynamic_cast< TypedefDecl * >( decl );
    }

    void EliminateTypedef::eliminateTypedef( std::list< Declaration * > &translationUnit ) {
	EliminateTypedef eliminator;
	mutateAll( translationUnit, eliminator );
	filter( translationUnit, isTypedef, true );
    }

    Type *EliminateTypedef::mutate( TypeInstType *typeInst ) {
	std::map< std::string, TypedefDecl * >::const_iterator def = typedefNames.find( typeInst->get_name() );
	if ( def != typedefNames.end() ) {
	    Type *ret = def->second->get_base()->clone();
	    ret->get_qualifiers() += typeInst->get_qualifiers();
	    delete typeInst;
	    return ret;
	} // if
	return typeInst;
    }

    Declaration *EliminateTypedef::mutate( TypedefDecl *tyDecl ) {
	Declaration *ret = Mutator::mutate( tyDecl );
	typedefNames[ tyDecl->get_name() ] = tyDecl;
	// When a typedef is a forward declaration:
	//    typedef struct screen SCREEN;
	// the declaration portion must be retained:
	//    struct screen;
	// because the expansion of the typedef is:
	//    void rtn( SCREEN *p ) => void rtn( struct screen *p )
	// hence the type-name "screen" must be defined.
	// Note, qualifiers on the typedef are superfluous for the forward declaration.
	if ( StructInstType *aggDecl = dynamic_cast< StructInstType * >( tyDecl->get_base() ) ) {
	    return new StructDecl( aggDecl->get_name() );
	} else if ( UnionInstType *aggDecl = dynamic_cast< UnionInstType * >( tyDecl->get_base() ) ) {
	    return new UnionDecl( aggDecl->get_name() );
	} else {
	    return ret;
	} // if
    }

    TypeDecl *EliminateTypedef::mutate( TypeDecl *typeDecl ) {
	std::map< std::string, TypedefDecl * >::iterator i = typedefNames.find( typeDecl->get_name() );
	if ( i != typedefNames.end() ) {
	    typedefNames.erase( i ) ;
	} // if
	return typeDecl;
    }

    DeclarationWithType *EliminateTypedef::mutate( FunctionDecl *funcDecl ) {
	std::map< std::string, TypedefDecl * > oldNames = typedefNames;
	DeclarationWithType *ret = Mutator::mutate( funcDecl );
	typedefNames = oldNames;
	return ret;
    }

    ObjectDecl *EliminateTypedef::mutate( ObjectDecl *objDecl ) {
	std::map< std::string, TypedefDecl * > oldNames = typedefNames;
	ObjectDecl *ret = Mutator::mutate( objDecl );
	typedefNames = oldNames;
	return ret;
    }

    Expression *EliminateTypedef::mutate( CastExpr *castExpr ) {
	std::map< std::string, TypedefDecl * > oldNames = typedefNames;
	Expression *ret = Mutator::mutate( castExpr );
	typedefNames = oldNames;
	return ret;
    }

    CompoundStmt *EliminateTypedef::mutate( CompoundStmt *compoundStmt ) {
	std::map< std::string, TypedefDecl * > oldNames = typedefNames;
	CompoundStmt *ret = Mutator::mutate( compoundStmt );
	std::list< Statement * >::iterator i = compoundStmt->get_kids().begin();
	while ( i != compoundStmt->get_kids().end() ) {
	    std::list< Statement * >::iterator next = i;
	    ++next;
	    if ( DeclStmt *declStmt = dynamic_cast< DeclStmt * >( *i ) ) {
		if ( dynamic_cast< TypedefDecl * >( declStmt->get_decl() ) ) {
		    delete *i;
		    compoundStmt->get_kids().erase( i );
		} // if
	    } // if
	    i = next;
	} // while
	typedefNames = oldNames;
	return ret;
    }
} // namespace SymTab
