#include "avl.h"
#include "avl-private.h"

forall(otype K | Comparable(K), otype V)
V * find(tree(K, V) * t, K key){
  if (empty(t)){
    return NULL;
  }

  if (t->key == key){
    return &t->value;
  } else if (t->key < key){
    return find(t->right, key);
  } else {
    // t->key > key
    return find(t->left, key);
  }
}

forall(otype K | Comparable(K), otype V)
int empty(tree(K, V) * t){
  return t == NULL;
}

// returns the root of the tree
forall(otype K | Comparable(K), otype V)
int insert(tree(K, V) ** t, K key, V value) {
  // handles a non-empty tree
  // problem if the following signature is used: tries to use an adapter to call helper, but shouldn't
  // be necessary - seems to be a problem with helper's type variables not being renamed
  // tree(K, V) * helper(tree(K, V) * t, K key, V value){
  tree(K, V) * helper(tree(K, V) * t){
    if (t->key == key){
      // ran into the same key - uh-oh
      return NULL;
    } else if (t->key < key){
      if (t->right == NULL){
        t->right = create(key, value);
        tree(K, V) * right = t->right; // FIX ME!
        right->parent = t;             // !!!!!!!
        return t->right;
      } else {
        return helper(t->right);
      }
    } else {
      if (t->left == NULL){
        t->left = create(key, value);
        tree(K, V) * left = t->left;   // FIX ME!
        left->parent = t;              // !!!!!!!
        return t->left;
      } else {
        return helper(t->left);
      }
    }
  }

  if (empty(*t)){
    // be nice and return a new tree
    *t = create(key, value);
    return 0;
  }
  tree(K, V) * newTree = helper(*t);
  if (newTree == NULL){
    // insert error handling code, only possibility
    // currently is that the key already exists
    return 99;
  }
  // move up the tree, updating balance factors
  // if the balance factor is -1, 0, or 1 keep going
  // if the balance factor is -2 or 2, call fix
  while (newTree->parent != NULL){ // loop until newTree == NULL?
    newTree = tryFix(newTree);
    tree(K, V) * parent = newTree->parent;  // FIX ME!!
    assert(parent->left == newTree ||
         parent->right == newTree);
    newTree = newTree->parent;
  }
  insert(t, key, value);

  // do it one more time - this is the root
  newTree = tryFix(newTree);
  *t = newTree;
  return 0;
}
