import sys
import re
import pandas as pd
from collections import defaultdict

# Matched top to bottom, bailing on first match
# More general rules are not double-counted if they occur later
# Such cases are commented "overlap"; don't move those too far up
CATEGORY_RULES = {
    "text-import": [
        "_X19eagerCopyCtorHelperFv_S10string_resPKcm__1;_X12_constructorFv_S10string_resPKcm__1;__memmove_ssse3",
        "_X19eagerCopyCtorHelperFv_S10string_resPKcm__1;_X12_constructorFv_S10string_resPKcm__1;__memcpy_ssse3",
        "helper;__memcpy_ssse3",
#        "strlen"
    ],
    "gc": [
        "_X19eagerCopyCtorHelperFv_S10string_resPKcm__1;_X12_constructorFv_S10string_resPKcm__1;_X7garbageFv_S9VbyteHeapi__1"
    ],
    "malloc-free": [
        "operator new;_X8doMallocFPv_mj__1",
        "operator new;malloc",
        "_X6doFreeFv_Pv__1",
        "free"
    ],
    "ctor-dtor": [
        "std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >::_M_create",
        "operator new", # overlap stl malloc-free
        "operator delete",
        "_X12_constructorFv_S10string_resPKcm__1" # overlap cfa text import
    ]
}

DEFAULT_CATEGORY = "other"

def classify_stack(stack):
    for category, patterns in CATEGORY_RULES.items():
        for pattern in patterns:
            if pattern in stack:
                return category
    if re.search(r"_X6helperFv_i__1$", stack):
        return "harness-leaf"
    if re.search(r"helper$", stack):
        return "harness-leaf"
    return DEFAULT_CATEGORY

# def parse_sut_and_size(filename):
#     # Extract SUT after "perfexp-" and before the next hyphen
#     sut_match = re.search(r"perfexp-([a-zA-Z0-9]+)", filename)
#     # Extract SIZE from "corpus-A-B-C.txt", capturing B
#     size_match = re.search(r"corpus-\d+-(\d+)-\d+\.txt", filename)
    
#     if not sut_match or not size_match:
#         print("Error: Could not parse sut or size from filename.")
#         sys.exit(1)
    
#     return sut_match.group(1), size_match.group(1)

def read_and_aggregate(input_file):
    category_map = defaultdict(lambda: defaultdict(int))  # category -> lineno -> sample_count
    total_samples = 0

    with open(input_file) as f:
        for lineno, line in enumerate(f, 1):
            line = line.strip()
            if not line:
                continue
            *stack_parts, count_str = line.split()
            count = int(count_str)
            stack = ' '.join(stack_parts)
            category = classify_stack(stack)
            category_map[category][lineno] += count
            total_samples += count

    return category_map, total_samples

def flatten(category_map, total_samples): #, sut, size):
    rows = []
    for category, source_map in category_map.items():
        samples_in_category = sum(source_map.values())
        sources = "|".join(f"{lineno}:{count}" for lineno, count in source_map.items())
        fraction = samples_in_category / total_samples if total_samples else 0.0
        rows.append({
#            "sut": sut,
#            "size": size,
            "category": category,
            "samples_in_category": samples_in_category,
            "total_samples": total_samples,
            "fraction": fraction,
            "sources": sources
        })
    return pd.DataFrame(rows)

def main():
    if len(sys.argv) != 2:
        print("Usage: python3 process-allocn-attrib.py <input_file>")
        sys.exit(1)

    input_file = sys.argv[1]
    # sut, size = parse_sut_and_size(input_file)
    category_map, total_samples = read_and_aggregate(input_file)
    df = flatten(category_map, total_samples) #, sut, size)

    # Print the result to stdout in tab-separated format
    df.to_csv(sys.stdout, sep="\t", index=False, header=False)

if __name__ == "__main__":
    main()
