﻿
#define RPMALLOC_EXPORT __declspec(dllimport)
#include "llvm/Support/rpnew.h"

#pragma comment(lib, "LLVM-21.lib")

#include "llvm/IR/PassManager.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Passes/PassPlugin.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/ValueMapper.h"
#include "llvm/Transforms/Utils/CallPromotionUtils.h" 
#include "llvm/Support/raw_ostream.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/BlockFrequencyInfo.h"
#include "llvm/Analysis/ProfileSummaryInfo.h"
#include "llvm/Transforms/Utils/Instrumentation.h"
#include "llvm/Transforms/IPO/HotColdSplitting.h"

// 最適化パスのヘッダ
#include "llvm/Transforms/Scalar/SROA.h"
#include "llvm/Transforms/Scalar/EarlyCSE.h"
#include "llvm/Transforms/InstCombine/InstCombine.h"
#include "llvm/Transforms/Scalar/SimplifyCFG.h"
#include "llvm/Transforms/Scalar/ADCE.h"
#include "llvm/Transforms/Scalar/SCCP.h"
#include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h"
#include "llvm/Transforms/IPO/GlobalDCE.h"
#include "llvm/Transforms/IPO/DeadArgumentElimination.h"
#include "llvm/ProfileData/InstrProf.h"
#include "llvm/Transforms/IPO/MergeFunctions.h"

#include <iterator>
#include <limits>

using namespace llvm;

// ログを出力したい場合は1。通常ビルド時は0。
#define TE_ENABLE_LOGGING 0

namespace {

// ポインタキャストを剥がして直接の呼び出し先を取得
static Function* getDirectCallee(CallBase* CB) {
  if (!CB)
    return nullptr;

  if (Function* F = CB->getCalledFunction())
    return F;

  Value* Called = CB->getCalledOperand();
  if (!Called)
    return nullptr;

  Value* Stripped = Called->stripPointerCasts();
  return dyn_cast<Function>(Stripped);
}

// ============================================================================
// TTI を使った関数サイズ計測
// ============================================================================
static bool computeFunctionTotalSize(const Function& F,
                                     const TargetTransformInfo& TTI,
                                     int64_t& OutSize) {
  int64_t Sum = 0;
  const TargetTransformInfo::TargetCostKind CostKind =
    TargetTransformInfo::TCK_CodeSize;

  for (const Instruction& I : instructions(F)) {
    InstructionCost IC = TTI.getInstructionCost(&I, CostKind);

    if (!IC.isValid())
      return false;

    int64_t Val = IC.getValue();
    if (Val < 0) Val = 0;

    if (Sum > INT64_MAX - Val) return false;  // Overflow check
    Sum += Val;
  }
  OutSize = Sum;
  return true;
}

// ============================================================================
// ヘルパー: インライン展開を実行し、直後にクリーンアップ最適化
// ============================================================================
static bool applyInlineAndOptimize(Function& F, CallBase* CB,
                                   FunctionAnalysisManager& FAM,
                                   bool IsSimulation) {
  if (!CB->getParent()) {
    return false;
  }

  InlineFunctionInfo IFI;
  InlineResult IR = InlineFunction(*CB, IFI, /*MergeAttributes=*/true);

  if (!IR.isSuccess()) {
    return false;
  }

  // 解析情報の無効化
  FAM.invalidate(F, PreservedAnalyses::none());

  if (IsSimulation) {
    // クリーンアップ最適化
    FunctionPassManager FPM;
    for (const Instruction& I : instructions(F)) {
      if (isa<AllocaInst>(&I)) {
        FPM.addPass(SROAPass(SROAOptions::ModifyCFG));
        break;
      }
    }
    FPM.addPass(EarlyCSEPass());
    FPM.addPass(SimplifyCFGPass());
    FPM.addPass(InstCombinePass());
    FPM.addPass(ADCEPass());
    FPM.run(F, FAM);
  }

  return true;
}

// ============================================================================
// クローン -> インライン -> 最適化 -> 計測
// ============================================================================
static int64_t measureSizeAfterInliningAndOptimization(Function& OriginalF,
                                                       CallBase* CB,
                                                       bool& Success) {
  Success = false;
  if (!CB) return INT64_MAX;

  // 関数のクローン作成
  ValueToValueMapTy VMap;
  Function* ClonedF = CloneFunction(&OriginalF, VMap);
  if (!ClonedF) return INT64_MAX;

  // クローンされた関数内での CallBase を特定
  CallBase* ClonedCB = dyn_cast_or_null<CallBase>(VMap[CB]);
  if (!ClonedCB) {
    ClonedF->eraseFromParent();
    return INT64_MAX;
  }

  // ローカル最適化の実行準備
  LoopAnalysisManager LAM;
  FunctionAnalysisManager ClonedFAM;
  CGSCCAnalysisManager CGAM;
  ModuleAnalysisManager MAM;

  PassBuilder PB;
  PB.registerModuleAnalyses(MAM);
  PB.registerCGSCCAnalyses(CGAM);
  PB.registerFunctionAnalyses(ClonedFAM);
  PB.registerLoopAnalyses(LAM);
  PB.crossRegisterProxies(LAM, ClonedFAM, CGAM, MAM);

  // インライン展開と最適化
  if (!applyInlineAndOptimize(*ClonedF, ClonedCB, ClonedFAM, true)) {
    ClonedF->eraseFromParent();
    return INT64_MAX;
  }

  // サイズ計測
  const TargetTransformInfo& ClonedTTI =
      ClonedFAM.getResult<TargetIRAnalysis>(*ClonedF);
  int64_t NewSize = 0;
  if (!computeFunctionTotalSize(*ClonedF, ClonedTTI, NewSize)) {
    NewSize = static_cast<int64_t>(std::distance(instructions(*ClonedF).begin(),
                                                 instructions(*ClonedF).end()));
  }

  ClonedF->eraseFromParent();
  Success = true;
  return NewSize;
}

// ============================================================================
// メインパス
// ============================================================================
struct TeInlinerPass : public PassInfoMixin<TeInlinerPass> {
  // ユーザー定義の閾値設定
  const uint64_t ExtendedCalleeSizeLimit = 500;     // Calleeの最大サイズ
  const uint64_t MaxCallerSizeForHot = 3000;        // Hot: キャッシュ溢れを防ぐため厳しく
  const uint64_t MaxCallerSizeForWarm = 10000;      // Warm: 従来の制限値を維持
  const uint64_t MaxCallerSizeForCold = UINT64_MAX; // Cold: ゴミ捨て場として大幅緩和
  const int MaxInlineIterations = 200;              // 1関数あたりの最大インライン回数
  const int64_t MaxColdCalleeSizeInHotCaller = 25;  // Hot CallerにCold Callerからのインライン対象を見つけた場合、Cold Callerのサイズがこれより大きい場合は処理を打ち切る

  struct CandidateInfo {
    WeakTrackingVH Handle;
  };

  PreservedAnalyses run(Function& F, FunctionAnalysisManager& FAM) {
    if (F.isDeclaration())
      return PreservedAnalyses::all();

    // 解析情報の取得
    auto& CallerTTI = FAM.getResult<TargetIRAnalysis>(F);

    bool Changed = false;
    bool MadeChange;
    int IterationCount = 0;

    do {
      MadeChange = false;
      IterationCount++;

      // 安全装置: 無限ループ防止
      if (IterationCount > MaxInlineIterations) {
        break;
      }

      uint64_t CallerInstCount = F.getInstructionCount(); // Callerの命令数

      ProfileSummaryInfo* PSI = nullptr;
      BlockFrequencyInfo* BFI = nullptr;

      auto& MAMProxy = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F);
      PSI = MAMProxy.getCachedResult<ProfileSummaryAnalysis>(*F.getParent());
      BFI = &FAM.getResult<BlockFrequencyAnalysis>(F);

      // Caller の熱さ判定
      bool IsCallerHotFunc = false;
      bool IsCallerColdFunc = false;

      if (PSI && PSI->hasProfileSummary()) {
        if (BFI) {
          IsCallerHotFunc = PSI->isFunctionHotInCallGraph(&F, *BFI);
          IsCallerColdFunc = PSI->isFunctionColdInCallGraph(&F, *BFI);
        } else {
          IsCallerHotFunc = PSI->isFunctionEntryHot(&F);
          IsCallerColdFunc = F.hasFnAttribute(Attribute::Cold) || PSI->isFunctionEntryCold(&F);
        }
      } else {
        IsCallerColdFunc = F.hasFnAttribute(Attribute::Cold);
      }
      bool IsCallerWarmFunc = !IsCallerHotFunc && !IsCallerColdFunc;

      // Callerのサイズ制限
      uint64_t CurrentCallerSizeLimit = MaxCallerSizeForWarm;
      if (IsCallerHotFunc) {
        CurrentCallerSizeLimit = MaxCallerSizeForHot;
      } else if (IsCallerColdFunc) {
        CurrentCallerSizeLimit = MaxCallerSizeForCold;
      }

      // Callerが巨大すぎる場合は処理を打ち切る
      if (CallerInstCount > CurrentCallerSizeLimit) {
        break;
      }
      
      // 現在の関数サイズ(TTIコスト)を計測
      int64_t CurrentSize = 0;
      if (!computeFunctionTotalSize(F, CallerTTI, CurrentSize)) {
        break;
      }

      CallBase* BestCandidate = nullptr;
      int64_t BestScore = std::numeric_limits<int64_t>::min();
      SmallVector<CandidateInfo, 64> NormalCandidates;
      SmallVector<CandidateInfo, 16> PriorityCandidates;

      for (Instruction& I : instructions(F)) {
        if (CallBase* CB = dyn_cast<CallBase>(&I)) {
          // 親ブロックがない（死んでいる）命令はスキップ
          if (!CB->getParent()) continue;

          Function* Callee = getDirectCallee(CB);

          // 基本的な除外条件
          if (!Callee || Callee->isDeclaration() || Callee->isIntrinsic() ||
              Callee == &F)
            continue;

          // 【重要】Windows x64 LLVM21 バグ回避
          // 可変長引数関数はインライン展開すると不正なコードが生成される場合があるため除外
          if (Callee->isVarArg()) continue;

          // 属性による除外
          if (Callee->hasFnAttribute(Attribute::AlwaysInline) ||
              Callee->hasFnAttribute(Attribute::NoInline))
            continue;

          // CallerとCallee のターゲット機能に互換性があるか
          if (!CallerTTI.areInlineCompatible(&F, Callee)) continue;

          // Solitary判定
          bool IsSolitary = false;
          if (!Callee->hasAvailableExternallyLinkage() &&
              Callee->hasLocalLinkage() && Callee->hasOneUse()) {
            IsSolitary = true;
          }

          // Solitaryでなければスキップ
          if (!IsSolitary) {
            continue;
          }

          const uint64_t CalleeSize = Callee->getInstructionCount();

          // 動的サイズ制限 & I-Cache保護
          const uint64_t LimitForHot = 100;
          const uint64_t LimitForWarm = 225;
          const uint64_t LimitForCold = 400;  // 実行回数が少ない
          const uint64_t LimitForZero = ExtendedCalleeSizeLimit;  // 実行回数ゼロ
          const uint64_t LimitForHotPair = 250;      // Hot -> Hot
          const uint64_t LimitForWarmPair = 350;     // Warm -> Warm
          uint64_t CurrentSizeLimit = LimitForWarm;  // デフォルト

          bool IsCallSiteCold = false;
          bool IsCalleeCold = false;

          if (PSI && PSI->hasProfileSummary()) {
            bool IsCalleeHot = PSI->isFunctionEntryHot(Callee);
            if (IsCalleeHot) {
              IsCalleeCold = false;
            } else {
              IsCalleeCold = Callee->hasFnAttribute(Attribute::Cold) ||
                             PSI->isFunctionEntryCold(Callee);
            }
            bool IsCalleeWarm = !IsCalleeHot && !IsCalleeCold;

            if (BFI) {
              BasicBlock* BB = CB->getParent();

              // 1. 実行回数ゼロ (Dead Code) -> 最も緩い制限 (ゴミ捨て場)
              // isColdBlockよりも強い条件として、明示的にカウント0をチェック
              auto OptCount = BFI->getBlockProfileCount(BB);
              if (OptCount.has_value() && OptCount.value() == 0) {
                CurrentSizeLimit = IsCallerHotFunc ? LimitForCold : LimitForZero;
                IsCallSiteCold = true;
              }
              // 2. Hot Block -> 厳しい制限 (I-Cache保護)
              else if (PSI->isHotBlock(BB, BFI)) {
                CurrentSizeLimit = LimitForHot;
                IsCallSiteCold = false;
              }
              // 3. Cold Block -> 緩い制限
              else if (PSI->isColdBlock(BB, BFI)) {
                CurrentSizeLimit = IsCallerHotFunc ? LimitForWarm : LimitForCold;
                IsCallSiteCold = true;
              }
              // 4. Warm Block (その他) -> 標準制限
              else {
                CurrentSizeLimit = LimitForWarm;
                IsCallSiteCold = false;
              }
            } else {
              // BFIなしフォールバック
              if (IsCallerHotFunc)
                CurrentSizeLimit = LimitForHot;
              else if (IsCallerColdFunc)
                CurrentSizeLimit = LimitForZero;
              IsCallSiteCold = (CurrentSizeLimit >= LimitForCold);
            }

            // 優先インライン展開 (Priority)
            // 条件に合致すれば、シミュレーションなしで即決(Priority)リストへ
            // ここでcontinueすることで、後の厳しいチェックをスキップ

            // 1. Cold同士
            if (IsCallerColdFunc && IsCalleeCold) {
              PriorityCandidates.push_back({WeakTrackingVH(CB)});
              continue;
            }
            // 2. Hot同士
            if (IsCallerHotFunc && IsCalleeHot &&
                !IsCallSiteCold &&
                CalleeSize <= LimitForHotPair) {
              PriorityCandidates.push_back({WeakTrackingVH(CB)});
              continue;
            }
            // 3. Warm同士
            if (IsCallerWarmFunc && IsCalleeWarm &&
                CalleeSize <= LimitForWarmPair) {
              PriorityCandidates.push_back({WeakTrackingVH(CB)});
              continue;
            }
          } else {
            // PSIがない場合のフォールバック
            IsCalleeCold = Callee->hasFnAttribute(Attribute::Cold);
          }

          // 1. 動的リミットによる足切り
          if (CalleeSize > CurrentSizeLimit) {
            if (TE_ENABLE_LOGGING) {
              errs() << "[TeInliner] Skip " << Callee->getName()
                     << ": Too large for context (" << CalleeSize << " > "
                     << CurrentSizeLimit << ")\n";
            }
            continue;
          }

          // 2. I-Cache汚染防止 (Hot Callerへの巨大Cold混入防止)
          if (IsCallerHotFunc) {
            if (IsCalleeCold && !IsCallSiteCold) {
              if (CalleeSize > MaxColdCalleeSizeInHotCaller) {
                if (TE_ENABLE_LOGGING) {
                  errs() << "[TeInliner] Skip: Large Cold (" << CalleeSize
                         << ") into Hot Caller\n";
                }
                continue;
              }
            }
          }

          NormalCandidates.push_back({WeakTrackingVH(CB)});
        }
      }

      // 削減効果が高そうな順にソート
      std::stable_sort(
          NormalCandidates.begin(), NormalCandidates.end(),
          [](const CandidateInfo& A, const CandidateInfo& B) {
            if (!A.Handle || !B.Handle) return false;

            CallBase* CBA = dyn_cast_or_null<CallBase>(A.Handle);
            CallBase* CBB = dyn_cast_or_null<CallBase>(B.Handle);
            if (!CBA || !CBB) return false;

            // ヘルパー: 定数引数の数を数える
            auto CountConstArgs = [](CallBase* CB) -> int {
              return std::count_if(CB->arg_begin(), CB->arg_end(),
                                   [](Value* V) { return isa<Constant>(V); });
            };

            // 基準1: 定数引数の数が多い方を優先 (定数畳み込み期待)
            int ConstsA = CountConstArgs(CBA);
            int ConstsB = CountConstArgs(CBB);
            if (ConstsA != ConstsB) {
              return ConstsA > ConstsB;  // 多いほうが先
            }

            // 基準2: 引数の総数が多い方を優先 (呼び出しオーバーヘッド削減期待)
            // 引数が多いほど、Call命令前の mov/push 命令が多く削減できる
            if (CBA->arg_size() != CBB->arg_size()) {
              return CBA->arg_size() > CBB->arg_size();  // 多いほうが先
            }

            // 基準3: Calleeのサイズが小さい方を優先 (処理コスト安)
            return getDirectCallee(CBA)->getInstructionCount() <
                   getDirectCallee(CBB)->getInstructionCount();
          });


      // 優先的にインライン展開
      for (const auto& Info : PriorityCandidates) {
        if (!Info.Handle)
          continue;

        CallBase* CB = dyn_cast_or_null<CallBase>(Info.Handle);
        if (!CB)
          continue;

        Function* Callee = getDirectCallee(CB);
        std::string CalleeName = Callee->getName().str();
        std::string CallerName = F.getName().str();

        // シミュレーションなしで強制的にインライン展開
        if (applyInlineAndOptimize(F, CB, FAM, false)) {
          if (TE_ENABLE_LOGGING) {
            errs() << formatv("[TeInliner] Priority Inlined {0} into {1}\n",
                              CalleeName, CallerName);
          }

          // ゾンビ関数の即時破壊
          // ネストした関数の連鎖的なインライン化（A->B->C）を促進
          if (!Callee->isDeclaration()) {
            Callee->deleteBody();
          }

          MadeChange = true;
          Changed = true;
        }
      }


      // 通常候補の処理 (シミュレーションあり)
      // 試行回数の動的制限 (Callerが大きくなるほど試行回数を減らす)
      uint64_t CallerSize = F.getInstructionCount();
      size_t MaxCandidatesToTry = 50000 / (CallerSize + 1000);
      MaxCandidatesToTry = std::clamp<size_t>(MaxCandidatesToTry, 5, 40);
      if (NormalCandidates.size() > MaxCandidatesToTry) {
        NormalCandidates.resize(MaxCandidatesToTry);
      }

      for (const auto& Info : NormalCandidates) {
        if (!Info.Handle)
          continue;

        CallBase* CB = dyn_cast_or_null<CallBase>(Info.Handle);
        if (!CB)
          continue;

        const int64_t AcceptanceThreshold = -static_cast<int64_t>(ExtendedCalleeSizeLimit);

        // シミュレーション実行
        bool Success = false;
        int64_t EstimatedNewSize = measureSizeAfterInliningAndOptimization(F, CB, Success);
        if (!Success)
          continue;

        int64_t SizeSaving = CurrentSize - EstimatedNewSize;
        if (TE_ENABLE_LOGGING) {
          errs() << "[TeInliner] Eval " << getDirectCallee(CB)->getName()
                 << ": Saving=" << SizeSaving
                 << " Threshold=" << AcceptanceThreshold
                 << (SizeSaving >= AcceptanceThreshold ? " [ACCEPT]" : " [REJECT]")
                 << "\n";
        }

        // 閾値チェック
        if (SizeSaving >= AcceptanceThreshold) {
          if (SizeSaving > BestScore) {
            BestScore = SizeSaving;
            BestCandidate = CB;
          }
        }
      }

      // ベストな候補が見つかれば、本番のインライン展開を実行
      if (BestCandidate) {
        Function* Callee = getDirectCallee(BestCandidate);
        std::string CalleeName = Callee->getName().str();
        std::string CallerName = F.getName().str();

        if (applyInlineAndOptimize(F, BestCandidate, FAM, false)) {
          if (TE_ENABLE_LOGGING) {
            errs() << formatv(
                "[TeInliner] Inlined {0} into {1}. Est.Saving: {2}\n",
                CalleeName, CallerName, BestScore);
          }

          // ゾンビ関数の即時破壊
          // ネストした関数の連鎖的なインライン化（A->B->C）を促進
          if (!Callee->isDeclaration()) {
            Callee->deleteBody();
          }

          MadeChange = true;
          Changed = true;
        }
      }

    } while (MadeChange);

    return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
  }
  static bool isRequired() { return true; }
};

// ============================================================================
// TeColdOptPass (Coldコードのサイズを縮小)
// ============================================================================
struct TeColdOptPass : public PassInfoMixin<TeColdOptPass> {
  PreservedAnalyses run(Module& M, ModuleAnalysisManager& MAM) {
    bool Changed = false;

    auto& FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();

    for (Function& F : M) {
      if (F.isDeclaration()) continue;

      // 条件1: PGOデータがあり、実行回数が0
      if (F.hasProfileData()) {
        auto EntryCount = F.getEntryCount();
        if (EntryCount.has_value() && EntryCount->getCount() == 0) {
          if (F.hasFnAttribute(Attribute::MinSize)) {
            continue;
          }

          // 条件2: 関数内にループが存在しない(直線的なコード)
          // ループがある場合(計算処理など)は、ベクトル化阻害のリスクがあるため除外
          // ループがない場合のみ-Oz化
          auto& LI = FAM.getResult<LoopAnalysis>(F);
          if (LI.empty()) {
            F.addFnAttr(Attribute::MinSize);
            F.addFnAttr(Attribute::OptimizeForSize);
            Changed = true;
          }
        }
      }
    }
    return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
  }
};

}  // namespace

// Plugin registration
extern "C" LLVM_ATTRIBUTE_WEAK ::llvm::PassPluginLibraryInfo
llvmGetPassPluginInfo() {
  return {
    LLVM_PLUGIN_API_VERSION, "TeInliner", LLVM_VERSION_STRING,
    [](PassBuilder& PB) {
      auto IsLTOPreLink = [](ThinOrFullLTOPhase Phase) {
        return Phase == ThinOrFullLTOPhase::ThinLTOPreLink ||
               Phase == ThinOrFullLTOPhase::FullLTOPreLink;
      };

      // 1. Main Inliner
      PB.registerOptimizerEarlyEPCallback([&](ModulePassManager& MPM,
                                              OptimizationLevel Level,
                                              ThinOrFullLTOPhase Phase) {
        if (IsLTOPreLink(Phase)) return;

        // インライン展開
        MPM.addPass(createModuleToFunctionPassAdaptor(TeInlinerPass()));

        // インライン展開で生じたallocaをSROAでレジスタ化
        // 後続のLLVパイプライン(LoopVectorize等)が最適化を行いやすい状態にする
        FunctionPassManager LocalCleanup;
        LocalCleanup.addPass(SROAPass(SROAOptions::ModifyCFG));
        LocalCleanup.addPass(EarlyCSEPass(true));
        MPM.addPass(createModuleToFunctionPassAdaptor(std::move(LocalCleanup)));
      });

      // 2. Final Cleanup
      PB.registerOptimizerLastEPCallback([&](ModulePassManager& MPM,
                                             OptimizationLevel,
                                             ThinOrFullLTOPhase Phase) {
        if (IsLTOPreLink(Phase)) return;

        // Hot/Cold Splitting
        MPM.addPass(HotColdSplittingPass());

        MPM.addPass(MergeFunctionsPass());

        // サイズの最適化
        MPM.addPass(TeColdOptPass());

        FunctionPassManager CleanupFPM;
        CleanupFPM.addPass(SCCPPass());                       // 定数伝播
        CleanupFPM.addPass(SROAPass(SROAOptions::ModifyCFG)); // メモリ最適化
        CleanupFPM.addPass(AggressiveInstCombinePass());      // 強力な結合
        CleanupFPM.addPass(InstCombinePass());
        CleanupFPM.addPass(SimplifyCFGPass());
        CleanupFPM.addPass(ADCEPass());                       // ゴミ掃除
        MPM.addPass(createModuleToFunctionPassAdaptor(std::move(CleanupFPM)));

        // モジュール単位の残骸処理
        MPM.addPass(DeadArgumentEliminationPass());
      });
    }
  };
}