/*
 * This file is part of the Cforall project
 *
 * $Id: PtrsCastable.cc,v 1.5 2005/08/29 20:14:16 rcbilson Exp $
 *
 */

#include "typeops.h"
#include "SynTree/Type.h"
#include "SynTree/Declaration.h"
#include "SynTree/Visitor.h"
#include "SymTab/Indexer.h"


namespace ResolvExpr {

class PtrsCastable : public Visitor
{
public:
  PtrsCastable( Type *dest, const TypeEnvironment &env, const SymTab::Indexer &indexer );
  
  int get_result() const { return result; }

  virtual void visit(VoidType *voidType);
  virtual void visit(BasicType *basicType);
  virtual void visit(PointerType *pointerType);
  virtual void visit(ArrayType *arrayType);
  virtual void visit(FunctionType *functionType);
  virtual void visit(StructInstType *inst);
  virtual void visit(UnionInstType *inst);
  virtual void visit(EnumInstType *inst);
  virtual void visit(ContextInstType *inst);
  virtual void visit(TypeInstType *inst);
  virtual void visit(TupleType *tupleType);

private:
  Type *dest;
  int result;
  const TypeEnvironment &env;
  const SymTab::Indexer &indexer;
};

int
objectCast( Type *src, const TypeEnvironment &env, const SymTab::Indexer &indexer )
{
  if( dynamic_cast< FunctionType* >( src ) ) {
    return -1;
  } else if( TypeInstType *typeInst = dynamic_cast< TypeInstType* >( src ) ) {
    EqvClass eqvClass;
    if( NamedTypeDecl *ntDecl = indexer.lookupType( typeInst->get_name() ) ) {
      if( TypeDecl *tyDecl = dynamic_cast< TypeDecl* >( ntDecl ) ) {
        if( tyDecl->get_kind() == TypeDecl::Ftype ) {
          return -1;
        }
      }
    } else if( env.lookup( typeInst->get_name(), eqvClass ) ) {
      if( eqvClass.kind == TypeDecl::Ftype ) {
        return -1;
      }
    }
  }
  return 1;
}

int
ptrsCastable( Type *src, Type *dest, const TypeEnvironment &env, const SymTab::Indexer &indexer )
{
  if( TypeInstType *destAsTypeInst = dynamic_cast< TypeInstType* >( dest ) ) {
    EqvClass eqvClass;
    if( env.lookup( destAsTypeInst->get_name(), eqvClass ) ) {
      return ptrsAssignable( src, eqvClass.type, env );
    }
  }
  if( dynamic_cast< VoidType* >( dest ) ) {
    return objectCast( src, env, indexer );
  } else {
    PtrsCastable ptrs( dest, env, indexer );
    src->accept( ptrs );
    return ptrs.get_result();
  }
}

PtrsCastable::PtrsCastable( Type *dest, const TypeEnvironment &env, const SymTab::Indexer &indexer )
  : dest( dest ), result( 0 ), env( env ), indexer( indexer )
{
}

void 
PtrsCastable::visit(VoidType *voidType)
{
  result = objectCast( dest, env, indexer );
}

void 
PtrsCastable::visit(BasicType *basicType)
{
  result = objectCast( dest, env, indexer );
}

void 
PtrsCastable::visit(PointerType *pointerType)
{
  result = objectCast( dest, env, indexer );
}

void 
PtrsCastable::visit(ArrayType *arrayType)
{
  result = objectCast( dest, env, indexer );
}

void 
PtrsCastable::visit(FunctionType *functionType)
{
  result = -1;
}

void 
PtrsCastable::visit(StructInstType *inst)
{
  result = objectCast( dest, env, indexer );
}

void 
PtrsCastable::visit(UnionInstType *inst)
{
  result = objectCast( dest, env, indexer );
}

void 
PtrsCastable::visit(EnumInstType *inst)
{
  if( dynamic_cast< EnumInstType* >( inst ) ) {
    result = 1;
  } else if( BasicType *bt = dynamic_cast< BasicType* >( inst ) ) {
    if( bt->get_kind() == BasicType::SignedInt ) {
      result = 0;
    } else {
      result = 1;
    }
  } else {
    result = objectCast( dest, env, indexer );
  }
}

void 
PtrsCastable::visit(ContextInstType *inst)
{
  // I definitely don't think we should be doing anything here
}

void 
PtrsCastable::visit(TypeInstType *inst)
{
  result = objectCast( inst, env, indexer ) && objectCast( dest, env, indexer ) ? 1 : -1;
}

void 
PtrsCastable::visit(TupleType *tupleType)
{
  result = objectCast( dest, env, indexer );
}

} // namespace ResolvExpr
