#include "avl.h"
#include "avl-private.h"

// AVL tree specific (internal) operations:
// rotateLeft, rotateRight, fix
//
// AVL tree enhanced height operation
//
// calcBalance is a simple computation of height(R) - height(L)

// an AVL tree's height is easy to compute
// just follow path with the larger balance
forall(otype K | Comparable(K), otype V)
int height(tree(K, V) * t){
  int helper(tree(K, V) * t, int ht){
    if (empty(t)){
      return ht;
    } else if (t->balance > 0){
      return helper(t->right, 1+ht);
    } else {
      // can traverse either branch to find the height
      // of an AVL tree whose balance is 0
      return helper(t->left, 1+ht);
    }
  }
  return helper(t, 0);
}

forall(otype K | Comparable(K), otype V)
int calcBalance(tree(K, V) * t){
  int l = height(t->left);
  int r = height(t->right);
  t->balance = r-l;
  return t->balance;
}

// re-establish the link between parent and child
forall(otype K | Comparable(K), otype V)
void relinkToParent(tree(K, V) * t){
  tree(K, V) * parent = t->parent; // FIX ME!!
  if (empty(t->parent)){
    return;
  } else if (parent->key < t->key){
    parent->right = t;
  } else {
    parent->left = t;
  }
}

// rotate left from t
forall(otype K | Comparable(K), otype V)
tree(K, V) * rotateLeft(tree(K, V) * t){
  tree(K, V) * newRoot = t->right;
  t->right = newRoot->left;
  newRoot->left = t;

  // swap parents
  newRoot->parent = t->parent;
  t->parent = newRoot;
  if (t->right != NULL) {
    tree(K, V) * right = t->right; // FIX ME!!
    right->parent = t;
  }
  // re-establish the link between newRoot and its parent
  relinkToParent(newRoot);
  return newRoot;
}

// rotate right from t
forall(otype K | Comparable(K), otype V)
tree(K, V) * rotateRight(tree(K, V) * t){
  tree(K, V) * newRoot = t->left;
  t->left = newRoot->right;
  newRoot->right = t;

  // swap parents
  newRoot->parent = t->parent;
  t->parent = newRoot;
  if (t->left != NULL){
    tree(K, V) * left = t->left; // FIX ME!!
    left->parent = t;
  }
  // re-establish the link between newRoot and its parent
  relinkToParent(newRoot);
  return newRoot;
}

// balances a node that has balance factor -2 or 2
forall(otype K | Comparable(K), otype V)
tree(K, V) * fix(tree(K, V) * t){
  // ensure that t's balance factor is one of
  // the appropriate values
  assert(t->balance == 2 || t->balance == -2);

  if (t->balance == -2){
    tree(K, V) * left = t->left; // FIX ME!!
    if (left->balance == 1){
      t->left = rotateLeft(t->left);
    }
    return rotateRight(t);
  } else if (t->balance == 2){
    tree(K, V) * right = t->right; // FIX ME!!
    if (right->balance == -1){
      t->right = rotateRight(t->right);
    }
    return rotateLeft(t);
  } else {
    // shouldn't ever get here
    assert((int)0);
    return t;
  }
}

// attempt to fix the tree, if necessary
forall(otype K | Comparable(K), otype V)
tree(K, V) * tryFix(tree(K, V) * t){
  int b = calcBalance(t);

  if (b == -2 || b == 2){
    t = fix(t);
  } else {
    assert(b == 0 || b == 1 || b == -1);
  }
  return t;
}

// sets parent field of c to be p
forall(otype K | Comparable(K), otype V)
void setParent(tree(K, V) * c, tree(K, V) * p){
  if (! empty(c)){
    c->parent = p;
  }
}

