Skip to content

Commit 9d495d9

Browse files
roydsouzaAntiGravity
authored andcommitted
Fix #3: Resolve multimodal BOA/EOA tokens from config.json instead of hardcoding
1 parent f1dddb8 commit 9d495d9

1 file changed

Lines changed: 25 additions & 12 deletions

File tree

Sources/SwiftLM/Server.swift

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3013,15 +3013,15 @@ public final class ALMModelFactory: ModelFactory, @unchecked Sendable {
30133013
) async throws -> ModelContext {
30143014
let context = try await LLMModelFactory.shared._load(configuration: configuration, tokenizerLoader: tokenizerLoader)
30153015

3016-
let numAudioEmbeddings = OmniModelFactory.extractNumAudioEmbeddings(configuration: configuration)
3016+
let tokens = OmniModelFactory.extractMultimodalTokens(configuration: configuration)
30173017
let messageGenerator = DefaultMessageGenerator()
30183018
let processor = ALMUserInputProcessor(
30193019
tokenizer: context.tokenizer,
30203020
configuration: context.configuration,
30213021
messageGenerator: messageGenerator,
3022-
boaToken: 255010,
3023-
eoaToken: 255011,
3024-
numAudioEmbeddings: numAudioEmbeddings
3022+
boaToken: tokens.boa,
3023+
eoaToken: tokens.eoa,
3024+
numAudioEmbeddings: tokens.numAudio
30253025
)
30263026

30273027
return .init(
@@ -3081,10 +3081,12 @@ public final class OmniModelFactory: ModelFactory, @unchecked Sendable {
30813081
tokenizerLoader: any TokenizerLoader
30823082
) async throws -> ModelContext {
30833083
let vlmContext = try await VLMModelFactory.shared._load(configuration: configuration, tokenizerLoader: tokenizerLoader)
3084-
let numAudioEmbeddings = OmniModelFactory.extractNumAudioEmbeddings(configuration: configuration)
3084+
let tokens = OmniModelFactory.extractMultimodalTokens(configuration: configuration)
30853085
let omniProcessor = OmniUserInputProcessor(
30863086
vlmProcessor: vlmContext.processor,
3087-
numAudioEmbeddings: numAudioEmbeddings
3087+
boaToken: tokens.boa,
3088+
eoaToken: tokens.eoa,
3089+
numAudioEmbeddings: tokens.numAudio
30883090
)
30893091

30903092
return .init(
@@ -3095,19 +3097,30 @@ public final class OmniModelFactory: ModelFactory, @unchecked Sendable {
30953097
)
30963098
}
30973099

3098-
public static func extractNumAudioEmbeddings(configuration: ResolvedModelConfiguration) -> Int {
3100+
public static func extractMultimodalTokens(configuration: ResolvedModelConfiguration) -> (numAudio: Int, boa: Int, eoa: Int) {
30993101
let configurationURL = configuration.modelDirectory.appending(component: "config.json")
3102+
var numAudio = 128
3103+
var boa = 255010
3104+
var eoa = 255011
3105+
31003106
if let data = try? Data(contentsOf: configurationURL),
31013107
let dict = try? JSONSerialization.jsonObject(with: data) as? [String: Any] {
31023108

3109+
// Extract num_audio_embeddings
31033110
if let subsampling = dict["subsampling_conv_channels"] as? [Int] {
3104-
return subsampling.first ?? 128
3105-
}
3106-
if let audioConfig = dict["audio_config"] as? [String: Any],
3111+
numAudio = subsampling.first ?? 128
3112+
} else if let audioConfig = dict["audio_config"] as? [String: Any],
31073113
let embeddings = audioConfig["num_audio_embeddings"] as? Int {
3108-
return embeddings
3114+
numAudio = embeddings
31093115
}
3116+
3117+
// Extract BOA/EOA tokens
3118+
if let b = dict["boa_token_id"] as? Int { boa = b }
3119+
else if let b = (dict["audio_config"] as? [String: Any])?["boa_token_id"] as? Int { boa = b }
3120+
3121+
if let e = dict["eoa_token_id"] as? Int { eoa = e }
3122+
else if let e = (dict["audio_config"] as? [String: Any])?["eoa_token_id"] as? Int { eoa = e }
31103123
}
3111-
return 128
3124+
return (numAudio, boa, eoa)
31123125
}
31133126
}

0 commit comments

Comments
 (0)