#include "avl.h"
#include "avl-private.h"

// from stdlib
forall(otype T)
void swap(T *, T *);

// swaps the data within two tree nodes
forall(otype K | Comparable(K), otype V)
void node_swap(tree(K, V) * t, tree(K, V) * t2){
  swap(&t->key, &t2->key);
  swap(&t->value, &t2->value);
}

// go left as deep as possible from within the right subtree
forall(otype K | Comparable(K), otype V)
tree(K, V) * find_successor(tree(K, V) * t){
  tree(K, V) * find_successor_helper(tree(K, V) * t){
    // go left as deep as possible, return the last node
    if (empty(t->left)){
      return t;
    } else {
      return find_successor_helper(t->left);
    }
  }
  return find_successor_helper(t->right);
}

// cleanup - don't want to deep delete, so set children to NULL first.
forall(otype K | Comparable(K), otype V)
void deleteSingleNode(tree(K, V) * t) {
  t->left = NULL;
  t->right = NULL;
  deleteSingleNode(t);
}

// does the actual remove operation once we've found the node in question
forall(otype K | Comparable(K), otype V)
tree(K, V) * remove_node(tree(K, V) * t){
  // is the node a leaf?
  if (empty(t->left) && empty(t->right)){
    // yes, just delete this node
    delete(t);
    return NULL;
  } else if (empty(t->left)){
    // if the left is empty, there is only one child -> move right up
    node_swap(t, t->right);
    tree(K, V) * tmp = t->right;

    // relink tree
    t->left = tmp->left;
    t->right = tmp->right;

    setParent(t->left, t);
    setParent(t->right, t);
    deleteSingleNode(tmp);
    return t;
  } else if (empty(t->right)){
    // if the right is empty, there is only one child -> move left up
    node_swap(t, t->left);
    tree(K, V) * tmp = t->left;

    // relink tree
    t->left = tmp->left;
    t->right = tmp->right;

    setParent(t->left, t);
    setParent(t->right, t);
    deleteSingleNode(tmp);
    return t;
  } else {
    // swap with the successor
    tree(K, V) * s = find_successor(t);
    tree(K, V) * parent = s->parent;

    if (parent->left == s){
      parent->left = s->right;
    } else {
      assert(parent->right == s);
      parent->right = s->right;
    }
    setParent(s->right, parent);
    node_swap(t, s);
    deleteSingleNode(s);
    return t;
  }
}

// finds the node that needs to be removed
forall(otype K | Comparable(K), otype V)
tree(K, V) * remove_helper(tree(K, V) * t, K key, int * worked){
  if (empty(t)){
    // did not work because key was not found
    // set the status variable and return
    *worked = 1;
  } else if (t->key == key) {
    t = remove_node(t);
  } else if (t->key < key){
    t->right = remove_helper(t->right, key, worked);
  } else {
    // t->key > key
    t->left = remove_helper(t->left, key, worked);
  }
  // try to fix after deleting
  if (! empty(t)) {
    t = tryFix(t);
  }
  return t;
}

forall(otype K | Comparable(K), otype V)
int remove(tree(K, V) ** t, K key){
  int worked = 0;
  tree(K, V) * newTree = remove_helper(*t, key, &worked);
  *t = newTree;
  return worked;
}
