Skip to content

Commit 6744876

Browse files
Merge branch 'release/1.1' into user/barry/fix_sm100_r1_ci
2 parents 59e3a81 + b326be2 commit 6744876

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ class BlockRange
128128
BaseKVCacheManager const* mManager;
129129
runtime::ITensor::SharedPtr mPool;
130130
SizeType32 mWindowSize;
131-
const LlmRequest::RequestIdType mRequestId;
131+
LlmRequest::RequestIdType const mRequestId;
132132
std::vector<SizeType32> mBlockIds;
133133

134134
static constexpr SizeType32 kFIRST_AND_ONLY_BEAM = 0;
@@ -203,7 +203,18 @@ class BlockIterator
203203
{
204204
if (mIdx < mRange->mBlockIds.size())
205205
{
206-
mCurrent = runtime::ITensor::slice(mRange->mPool, mRange->mBlockIds.at(mIdx), 1);
206+
if (mRange->mManager != nullptr)
207+
{
208+
BlockPtr const& block
209+
= mRange->mManager->getBlockManager().getBlockById(mRange->mBlockIds.at(mIdx), mRange->mWindowSize);
210+
TLLM_CHECK_WITH_INFO(block->isPrimary(), "cache transceiver only supports primary blocks");
211+
auto const blockOffset = block->getMemoryPoolBlockIndex();
212+
mCurrent = runtime::ITensor::slice(mRange->mPool, blockOffset, 1);
213+
}
214+
else
215+
{
216+
mCurrent = runtime::ITensor::slice(mRange->mPool, mRange->mBlockIds.at(mIdx), 1);
217+
}
207218
}
208219
}
209220

cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1952,7 +1952,7 @@ SizeType32 KVCacheManager::getNeededBlocksOneStep(
19521952
return 0;
19531953
}
19541954

1955-
auto const numCurrTokens = mSequences.at(req.mRequestId).getNumTokens();
1955+
auto const numCurrTokens = getSequence(req.mRequestId).getNumTokens();
19561956
auto const generatedTokens = numCurrTokens - req.getPromptLen();
19571957
auto const maxTokensToAddToKVCache = req.mMaxNewTokens - generatedTokens;
19581958
auto const tokensPerStep = req.getNumDraftTokens() + 1;
@@ -2198,7 +2198,13 @@ void KVCacheManager::addSequence(
21982198
void KVCacheManager::storeContextBlocks(LlmRequest const& llmRequest)
21992199
{
22002200
auto const requestId = llmRequest.mRequestId;
2201-
if (mSequences.find(requestId) != mSequences.end())
2201+
bool found = false;
2202+
{
2203+
// protect the mSequences
2204+
std::scoped_lock lock(mSequencesMtx);
2205+
found = mSequences.find(requestId) != mSequences.end();
2206+
}
2207+
if (found)
22022208
{
22032209
auto& sequence = getSequence(requestId);
22042210
if (mEnableBlockReuse && !llmRequest.isDummyRequest())

0 commit comments

Comments
 (0)