@@ -34,9 +34,9 @@ SamplingConfig::SamplingConfig(SizeType32 beamWidth, OptSize32 const& topK, OptF
3434 OptFloat const & topPMin, std::optional<TokenIdType> const & topPResetIds, OptFloat const & topPDecay,
3535 std::optional<RandomSeedType> const & seed, OptFloat const & temperature, OptSize32 const & minTokens,
3636 OptFloat const & beamSearchDiversityRate, OptFloat const & repetitionPenalty, OptFloat const & presencePenalty,
37- OptFloat const & frequencyPenalty, OptFloat const & lengthPenalty, OptSize32 const & earlyStopping ,
38- OptSize32 const & noRepeatNgramSize , OptSize32 const & numReturnSequences, OptFloat const & minP ,
39- OptVec<SizeType32> const & beamWidthArray)
37+ OptFloat const & frequencyPenalty, OptSize32 const & promptIgnoreLength, OptFloat const & lengthPenalty ,
38+ OptSize32 const & earlyStopping , OptSize32 const & noRepeatNgramSize, OptSize32 const & numReturnSequences ,
39+ OptFloat const & minP, OptVec<SizeType32> const & beamWidthArray)
4040 : mBeamWidth (checkBeamWidth(beamWidth))
4141 , mTopK (checkTopK(topK))
4242 , mTopP (checkTopP(topP))
@@ -50,6 +50,7 @@ SamplingConfig::SamplingConfig(SizeType32 beamWidth, OptSize32 const& topK, OptF
5050 , mRepetitionPenalty (checkRepetitionPenalty(repetitionPenalty))
5151 , mPresencePenalty (presencePenalty)
5252 , mFrequencyPenalty (frequencyPenalty)
53+ , mPromptIgnoreLength (checkPromptIgnoreLength(promptIgnoreLength))
5354 , mLengthPenalty (checkLengthPenalty(lengthPenalty))
5455 , mEarlyStopping (checkEarlyStopping(earlyStopping))
5556 , mNoRepeatNgramSize (checkNoRepeatNgramSize(noRepeatNgramSize))
@@ -67,9 +68,10 @@ bool SamplingConfig::operator==(SamplingConfig const& other) const
6768 && mTemperature == other.mTemperature && mMinTokens == other.mMinTokens
6869 && mBeamSearchDiversityRate == other.mBeamSearchDiversityRate && mRepetitionPenalty == other.mRepetitionPenalty
6970 && mPresencePenalty == other.mPresencePenalty && mFrequencyPenalty == other.mFrequencyPenalty
70- && mLengthPenalty == other.mLengthPenalty && mEarlyStopping == other.mEarlyStopping
71- && mNoRepeatNgramSize == other.mNoRepeatNgramSize && mNumReturnSequences == other.mNumReturnSequences
72- && mMinP == other.mMinP && mBeamWidthArray == other.mBeamWidthArray ;
71+ && mPromptIgnoreLength == other.mPromptIgnoreLength && mLengthPenalty == other.mLengthPenalty
72+ && mEarlyStopping == other.mEarlyStopping && mNoRepeatNgramSize == other.mNoRepeatNgramSize
73+ && mNumReturnSequences == other.mNumReturnSequences && mMinP == other.mMinP
74+ && mBeamWidthArray == other.mBeamWidthArray ;
7375}
7476
7577// Getters
@@ -143,6 +145,11 @@ OptFloat SamplingConfig::getFrequencyPenalty() const
143145 return mFrequencyPenalty ;
144146}
145147
148+ OptSize32 SamplingConfig::getPromptIgnoreLength () const
149+ {
150+ return mPromptIgnoreLength ;
151+ }
152+
146153OptFloat SamplingConfig::getLengthPenalty () const
147154{
148155 return mLengthPenalty ;
@@ -240,6 +247,11 @@ void SamplingConfig::setFrequencyPenalty(OptFloat const& frequencyPenalty)
240247 mFrequencyPenalty = frequencyPenalty;
241248}
242249
250+ void SamplingConfig::setPromptIgnoreLength (OptSize32 const & promptIgnoreLength)
251+ {
252+ mPromptIgnoreLength = checkPromptIgnoreLength (promptIgnoreLength);
253+ }
254+
243255void SamplingConfig::setLengthPenalty (OptFloat const & lengthPenalty)
244256{
245257 mLengthPenalty = lengthPenalty; // TODO: re-enable `checkLengthPenalty` later
@@ -362,6 +374,15 @@ OptFloat const& SamplingConfig::checkRepetitionPenalty(OptFloat const& repetitio
362374 return repetitionpenalty;
363375}
364376
377+ OptSize32 const & SamplingConfig::checkPromptIgnoreLength (OptSize32 const & promptIgnoreLength)
378+ {
379+ if (promptIgnoreLength.has_value ())
380+ {
381+ TLLM_CHECK (promptIgnoreLength.value () >= 0 );
382+ }
383+ return promptIgnoreLength;
384+ }
385+
365386OptFloat const & SamplingConfig::checkLengthPenalty (OptFloat const & lengthPenalty)
366387{
367388 if (lengthPenalty.has_value ())
0 commit comments