Skip to content

Commit ae6875f

Browse files
authored
[TRTLLM-8976][feat] Move indexer-k-cache to KVCacheManager (#8699)
Signed-off-by: Iman Tabrizian <[email protected]>
1 parent 579e106 commit ae6875f

File tree

8 files changed

+255
-162
lines changed

8 files changed

+255
-162
lines changed

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 55 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -536,10 +536,12 @@ class KVCacheBlockPool
536536

537537
// FP4 KV caches have extra pools that contain second level scales for dequantization.
538538
bool containsBlockScales;
539+
bool containsIndexerKCache;
539540

540541
KVCacheBlockPool(SizeType32 numLayers, SizeType32 kvFactor, SizeType32 numKvHeads, SizeType32 sizePerHead,
541542
SizeType32 tokensPerBlock, runtime::ITensor::SharedPtr primaryPtr = nullptr,
542-
runtime::ITensor::SharedPtr secondaryPtr = nullptr, bool containsBlockScales = false)
543+
runtime::ITensor::SharedPtr secondaryPtr = nullptr, bool containsBlockScales = false,
544+
bool containsIndexerKCache = false)
543545
: numLayers(numLayers)
544546
, kvFactor(kvFactor)
545547
, numKvHeads(numKvHeads)
@@ -549,6 +551,7 @@ class KVCacheBlockPool
549551
, primaryPtr(std::move(primaryPtr))
550552
, secondaryPtr(std::move(secondaryPtr))
551553
, containsBlockScales(containsBlockScales)
554+
, containsIndexerKCache(containsIndexerKCache)
552555
{
553556
}
554557
};
@@ -587,14 +590,17 @@ class WindowBlockManager
587590
bool onboardBlocks, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
588591
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
589592
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager,
590-
std::shared_ptr<kvc::BaseLoopbackAgent> loopbackAgent = nullptr);
593+
std::shared_ptr<kvc::BaseLoopbackAgent> loopbackAgent = nullptr, bool enableIndexerKCache = false,
594+
SizeType32 indexerKCacheQuantBlockSize = 128, SizeType32 indexerKCacheIndexHeadDim = 0);
591595

592596
~WindowBlockManager();
593597

594598
void allocatePools(bool useUvm);
595599

596600
void releasePools();
597601

602+
void createIndexerKCachePools();
603+
598604
void startScheduling();
599605

600606
//! \brief Assign blocks for new sequence. Try to reuse blocks.
@@ -721,13 +727,30 @@ class WindowBlockManager
721727
#endif
722728
}
723729

724-
[[nodiscard]] SizeType32 getNumPools(bool includeBlockScalePools = true) const noexcept
730+
[[nodiscard]] SizeType32 getNumPools(
731+
bool includeBlockScalePools = true, bool includeIndexerKCachePools = true) const noexcept
725732
{
726-
if (includeBlockScalePools)
733+
if (includeBlockScalePools && includeIndexerKCachePools)
727734
{
728735
return mPools.size();
729736
}
730-
return std::count_if(mPools.begin(), mPools.end(), [](auto const& pool) { return !pool.containsBlockScales; });
737+
SizeType32 count = 0;
738+
for (auto const& pool : mPools)
739+
{
740+
if (includeBlockScalePools && pool.containsBlockScales)
741+
{
742+
count++;
743+
}
744+
else if (includeIndexerKCachePools && pool.containsIndexerKCache)
745+
{
746+
count++;
747+
}
748+
if (!pool.containsBlockScales && !pool.containsIndexerKCache)
749+
{
750+
count++;
751+
}
752+
}
753+
return count;
731754
}
732755

733756
[[nodiscard]] KVCacheBlockPool const& getPool(SizeType32 poolIdx) const
@@ -962,6 +985,13 @@ class WindowBlockManager
962985
// It may be invalidated to false when other sequence acquires a block that
963986
// is used by another sequence.
964987
std::map<LlmRequest::RequestIdType, bool> mIsValidStoreForReuseSequence;
988+
989+
// Whether to enable indexer K cache
990+
bool mEnableIndexerKCache;
991+
// Quant block size for indexer K cache
992+
SizeType32 mIndexerKCacheQuantBlockSize;
993+
// Index head dim for indexer K cache
994+
SizeType32 mIndexerKCacheIndexHeadDim;
965995
};
966996

967997
class BlockManager
@@ -981,7 +1011,8 @@ class BlockManager
9811011
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
9821012
bool copyOnPartialReuse = true,
9831013
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = nullptr,
984-
std::optional<kvc::BaseAgentConfig> agentConfig = std::nullopt);
1014+
std::optional<kvc::BaseAgentConfig> agentConfig = std::nullopt, bool enableIndexerKCache = false,
1015+
SizeType32 indexerKCacheQuantBlockSize = 128, SizeType32 indexerKCacheIndexHeadDim = 0);
9851016

9861017
BlockManager(BlockManager const&) = delete;
9871018
BlockManager& operator=(BlockManager const&) = delete;
@@ -1161,10 +1192,11 @@ class BlockManager
11611192
return getPool(poolIdx).blockSize;
11621193
}
11631194

1164-
[[nodiscard]] SizeType32 getNumPools(bool includeBlockScalePools = true) const
1195+
[[nodiscard]] SizeType32 getNumPools(
1196+
bool includeBlockScalePools = true, bool includeIndexerKCachePools = true) const
11651197
{
1166-
return sumWindows(
1167-
[includeBlockScalePools](auto const& manager) { return manager.getNumPools(includeBlockScalePools); });
1198+
return sumWindows([includeBlockScalePools, includeIndexerKCachePools](auto const& manager)
1199+
{ return manager.getNumPools(includeBlockScalePools, includeIndexerKCachePools); });
11681200
}
11691201

11701202
[[nodiscard]] std::map<SizeType32, WindowSizeMetadata> const& getWindowSizesMetadata() const noexcept
@@ -1496,6 +1528,7 @@ class BaseKVCacheManager
14961528

14971529
[[nodiscard]] virtual runtime::ITensor::SharedPtr getUniquePrimaryPool() const = 0;
14981530
[[nodiscard]] virtual runtime::ITensor::SharedPtr getPrimaryPool(SizeType32 layer_idx) const = 0;
1531+
[[nodiscard]] virtual runtime::ITensor::SharedPtr getIndexerKCachePool() const = 0;
14991532
[[nodiscard]] virtual SizeType32 getPoolLayerIdx(SizeType32 layer_idx) const = 0;
15001533

15011534
virtual void refreshBlocks() = 0;
@@ -1588,7 +1621,9 @@ class KVCacheManager : public BaseKVCacheManager
15881621
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
15891622
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
15901623
bool copyOnpartialReuse = true,
1591-
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = nullptr);
1624+
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = nullptr,
1625+
bool enableIndexerKCache = false, SizeType32 indexerKCacheQuantBlockSize = 128,
1626+
SizeType32 indexerKCacheIndexHeadDim = 0);
15921627

15931628
KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
15941629
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
@@ -1599,7 +1634,9 @@ class KVCacheManager : public BaseKVCacheManager
15991634
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
16001635
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
16011636
bool copyOnpartialReuse = true,
1602-
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = nullptr);
1637+
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = nullptr,
1638+
bool enableIndexerKCache = false, SizeType32 indexerKCacheQuantBlockSize = 128,
1639+
SizeType32 indexerKCacheIndexHeadDim = 0);
16031640

16041641
KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
16051642
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
@@ -1610,15 +1647,18 @@ class KVCacheManager : public BaseKVCacheManager
16101647
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
16111648
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
16121649
bool copyOnpartialReuse = true,
1613-
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = nullptr);
1650+
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = nullptr,
1651+
bool enableIndexerKCache = false, SizeType32 indexerKCacheQuantBlockSize = 128,
1652+
SizeType32 indexerKCacheIndexHeadDim = 0);
16141653

16151654
KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
16161655
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
16171656
std::vector<SizeType32> const& maxAttentionWindowVec,
16181657
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
16191658
SizeType32 sinkTokenLength, int64_t stream, SizeType32 maxSequenceLength, bool enableBlockReuse = false,
16201659
bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF, bool enablePartialReuse = true,
1621-
bool copyOnpartialReuse = true);
1660+
bool copyOnpartialReuse = true, bool enableIndexerKCache = false, SizeType32 indexerKCacheQuantBlockSize = 128,
1661+
SizeType32 indexerKCacheIndexHeadDim = 0);
16221662

16231663
~KVCacheManager() override = default;
16241664

@@ -1849,6 +1889,7 @@ class KVCacheManager : public BaseKVCacheManager
18491889

18501890
runtime::ITensor::SharedPtr getUniquePrimaryPool() const override;
18511891
runtime::ITensor::SharedPtr getPrimaryPool(SizeType32 layer_idx) const override;
1892+
runtime::ITensor::SharedPtr getIndexerKCachePool() const override;
18521893

18531894
SizeType32 getPoolLayerIdx(SizeType32 layer_idx) const override
18541895
{
@@ -1910,6 +1951,7 @@ class KVCacheManager : public BaseKVCacheManager
19101951
runtime::ITensor::SharedPtr mBlockPoolPointers;
19111952
runtime::ITensor::SharedPtr mLayerToPoolMapping;
19121953
runtime::ITensor::SharedPtr mBlockScalePoolPointers;
1954+
runtime::ITensor::SharedPtr mIndexerKCachePoolPointers;
19131955
// GPU bytes allocated for KV-cache
19141956
std::size_t mAllocatedBytes{0};
19151957
};

0 commit comments

Comments
 (0)