//
// Cforall Version 1.0.0 Copyright (C) 2015 University of Waterloo
//
// The contents of this file are covered under the licence agreement in the
// file "LICENCE" distributed with Cforall.
//
// Corun.cpp -- generate code needed by the actor system
//
// Author           : Colby Parsons
// Created On       : Monday October 9 15:16:42 2023
// Last Modified By : Peter A. Buhr
// Last Modified On : Thu Dec 14 17:32:17 2023
// Update Count     : 1
//

#include "AST/Decl.hpp"
#include "AST/Expr.hpp"
#include "AST/Pass.hpp"
#include "AST/Stmt.hpp"
#include "AST/TranslationUnit.hpp"
#include "Common/UniqueName.h"
using namespace ast;
using namespace std;

namespace Concurrency {

struct CorunKeyword : public WithDeclsToAdd<>, public WithStmtsToAdd<> {
    UniqueName CorunFnNamer = "__CFA_corun_lambda_"s;
    UniqueName CoforFnNamer = "__CFA_cofor_lambda_"s;
    // UniqueName CoforFnVarNamer = "__CFA_cofor_lambda_var"s;
    UniqueName RunnerBlockNamer = "__CFA_corun_block_"s;
    
    string coforArgName = "__CFA_cofor_lambda_arg";
    string numProcsName = "__CFA_cofor_num_procs";
    string currProcsName = "__CFA_cofor_curr_procs";
    string thdArrName = "__CFA_cofor_thread_array";
    string loopTempName = "__CFA_cofor_loop_temp";
    

    const StructDecl * runnerBlockDecl = nullptr;
    const StructDecl * coforRunnerDecl = nullptr;

    // Finds runner_block (corun task) and cofor_runner (cofor task) decls
    void previsit( const StructDecl * decl ) {
        if ( !decl->body ) {
            return;
        } else if ( "runner_block" == decl->name ) {
            assert( !runnerBlockDecl );
            runnerBlockDecl = decl;
        } else if ( "cofor_runner" == decl->name ) {
            assert( !coforRunnerDecl );
            coforRunnerDecl = decl;
        }
    }

    // codegen for cofor statements
    Stmt * postvisit( const CoforStmt * stmt ) {
        if ( !runnerBlockDecl || !coforRunnerDecl )
            SemanticError( stmt->location, "To use cofor statements add #include <cofor.hfa>" );

        if ( stmt->inits.size() != 1 )
            SemanticError( stmt->location, "Cofor statements must have a single initializer in the loop control" );

        if ( !stmt->body )
            return nullptr;

        const CodeLocation & loc = stmt->location;
        const string fnName = CoforFnNamer.newName();

        CompoundStmt * body = new CompoundStmt( loc );

        // push back cofor initializer to generated body
        body->push_back( deepCopy( stmt->inits.at(0) ) );

        CompoundStmt * fnBody = new CompoundStmt( loc );

        const DeclStmt * declStmtPtr = dynamic_cast<const DeclStmt *>(stmt->inits.at(0).get());
        if ( ! declStmtPtr )
            SemanticError( stmt->location, "Cofor statement initializer is somehow not a decl statement?" );

        const Decl * declPtr = dynamic_cast<const Decl *>(declStmtPtr->decl.get());
        if ( ! declPtr )
            SemanticError( stmt->location, "Cofor statement initializer is somehow not a decl?" );

        Type * initType = new TypeofType( new NameExpr( loc, declPtr->name ) );

        // Generates:
        // typeof(init) __CFA_cofor_lambda_var = *((typeof(init) *)val);
        fnBody->push_back( new DeclStmt( loc, 
            new ObjectDecl( loc,
                declPtr->name,
                initType,
                new SingleInit( loc,
                    UntypedExpr::createDeref( loc,
                        new CastExpr( loc, 
                            new NameExpr( loc, coforArgName ), 
                            new PointerType( initType ), ExplicitCast
                        )
                    )
                )
            )
        ));

        // push rest of cofor body into loop lambda
        fnBody->push_back( deepCopy( stmt->body ) );

        // Generates:
        // void __CFA_cofor_lambda_() {
        //    typeof(init) __CFA_cofor_lambda_var = *((typeof(init) *)val);
        //    stmt->body;
        // }
        Stmt * coforLambda = new DeclStmt( loc,
            new FunctionDecl( loc,
                fnName,                                             // name
                {
                    new ObjectDecl( loc,
                        coforArgName,
                        new ast::PointerType( new ast::VoidType() )
                    )
                },                                                  // params
                {},                                                 // return
                fnBody   // body
            )
        );
        body->push_back( coforLambda );

        // Generates:
        // unsigned __CFA_cofor_num_procs = get_proc_count();
        body->push_back( new DeclStmt( loc,
                new ObjectDecl( loc,
                    numProcsName,
                    new BasicType( BasicType::Kind::UnsignedInt ),
                    new SingleInit( loc, 
                        new UntypedExpr( loc,
                            new NameExpr( loc, "get_proc_count" ),
                            {}
                        )
                    )
                )
            )
        );

        // Generates:
        // unsigned __CFA_cofor_curr_procs = 0;
        body->push_back( new DeclStmt( loc,
                new ObjectDecl( loc,
                    currProcsName,
                    new BasicType( BasicType::Kind::UnsignedInt ),
                    new SingleInit( loc, ConstantExpr::from_int( loc, 0 ) )
                )
            )
        );

        // Generates:
        // unsigned cofor_runner __CFA_cofor_thread_array[nprocs];
        body->push_back( new DeclStmt( loc,
                new ObjectDecl( loc,
                    thdArrName,
                    new ast::ArrayType(
                        new StructInstType( coforRunnerDecl ),
                        new NameExpr( loc, numProcsName ),
                        ast::FixedLen,
                        ast::DynamicDim
                    )
                )
            )
        );

        // Generates:
        // start_runners( __CFA_cofor_thread_array, __CFA_cofor_num_procs, __CFA_cofor_lambda_ );
        body->push_back( new ExprStmt( loc,
            new UntypedExpr( loc,
                new NameExpr( loc, "start_runners" ),
                {
                    new NameExpr( loc, thdArrName ),
                    new NameExpr( loc, numProcsName ),
                    new NameExpr( loc, fnName )
                }
            )
        ));

        // Generates:
        // typeof(initializer) * __CFA_cofor_loop_temp = malloc();
        CompoundStmt * forLoopBody = new CompoundStmt( loc );
        forLoopBody->push_back( new DeclStmt( loc,
                new ObjectDecl( loc,
                    loopTempName,
                    new PointerType( initType ),
                    new SingleInit( loc, 
                        new UntypedExpr( loc,
                            new NameExpr( loc, "malloc" ),
                            {}
                        )
                    )
                )
            )
        );

        // Generates:
        // *__CFA_cofor_loop_temp = initializer;
        forLoopBody->push_back( new ExprStmt( loc,
            UntypedExpr::createAssign( loc,
                UntypedExpr::createDeref( loc, new NameExpr( loc, loopTempName ) ),
                new NameExpr( loc, declPtr->name )
            )
        ));

        // Generates:
        // send_work( __CFA_cofor_thread_array, __CFA_cofor_num_procs,
        //     __CFA_cofor_curr_procs, __CFA_cofor_loop_temp );
        forLoopBody->push_back( new ExprStmt( loc,
            new UntypedExpr( loc,
                new NameExpr( loc, "send_work" ),
                {
                    new NameExpr( loc, thdArrName ),
                    new NameExpr( loc, numProcsName ),
                    new NameExpr( loc, currProcsName ),
                    new NameExpr( loc, loopTempName )
                }
            )
        ));

        body->push_back( new ForStmt( loc,
            {},
            deepCopy( stmt->cond ),
            deepCopy( stmt->inc ),
            forLoopBody
        ));

        // Generates:
        // end_runners( __CFA_cofor_thread_array, __CFA_cofor_num_procs );
        body->push_back( new ExprStmt( loc,
            new UntypedExpr( loc,
                new NameExpr( loc, "end_runners" ),
                {
                    new NameExpr( loc, thdArrName ),
                    new NameExpr( loc, numProcsName )
                }
            )
        ));

        return body;
    }

    // codegen for corun statements
    Stmt * postvisit( const CorunStmt * stmt ) {
        if ( !runnerBlockDecl || !coforRunnerDecl )
            SemanticError( stmt->location, "To use corun statements add #include <cofor.hfa>" );

        if ( !stmt->stmt )
            return nullptr;

        const CodeLocation & loc = stmt->location;
        const string fnName = CorunFnNamer.newName();
        const string objName = RunnerBlockNamer.newName();

        // Generates:
        // void __CFA_corun_lambda_() { ... stmt->stmt ... }
        Stmt * runnerLambda = new DeclStmt( loc,
            new FunctionDecl( loc,
                fnName,                                             // name
                {},                                                 // params
                {},                                                 // return
                new CompoundStmt( loc, { deepCopy(stmt->stmt) } )   // body
            )
        );

        // Generates:
        // runner_block __CFA_corun_block_;
        Stmt * objDecl = new DeclStmt( loc,
            new ObjectDecl( loc,
                objName,
                new StructInstType( runnerBlockDecl )
            )
        );

        // Generates:
        // __CFA_corun_block_{ __CFA_corun_lambda_ };
        Stmt * threadStart = new ExprStmt( loc,
            new UntypedExpr ( loc,
                new NameExpr( loc, "?{}" ),
                {
                    new NameExpr( loc, objName ),
                    new NameExpr( loc, fnName )
                }
            )
        );

        stmtsToAddBefore.push_back( runnerLambda );
        stmtsToAddBefore.push_back( objDecl );

        return threadStart;
    }
};

void implementCorun( TranslationUnit & translationUnit ) {
    Pass<CorunKeyword>::run( translationUnit );
}

} // namespace Concurrency
