Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llm-factor: migrate to candle #2755

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

karthik2804
Copy link
Contributor

This PR replaces the dependency on rustformers/llm to huggingface/candle. This allows us to run newer models like Llama 3(.1). This now requires the models to be of the safetensors format.

This PR also removes the concept of well-known models. This ensures a consistent directory structure for all models. The rationale is that, with this change, the only group of models initially supported is the Llama family.

Closes #2735

@radu-matei
Copy link
Member

Given this is a breaking change, I'd suggest adding the 3.0 label.

@karthik2804
Copy link
Contributor Author

@radu-matei I do not believe I can add labels in this repository.

@rylev rylev added the spin-3.0 label Aug 27, 2024
crates/llm-local/Cargo.toml Outdated Show resolved Hide resolved
crates/llm-local/src/bert.rs Show resolved Hide resolved
crates/llm-local/src/llama.rs Outdated Show resolved Hide resolved
@karthik2804 karthik2804 force-pushed the llama3-llm-factor branch 3 times, most recently from 611f2b2 to 5abb8ca Compare September 13, 2024 09:16
@karthik2804 karthik2804 marked this pull request as ready for review September 13, 2024 09:26
@karthik2804
Copy link
Contributor Author

The test failure does not seem to be related?

Comment on lines 7 to 13
let json: serde_json::Value =
serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?;
let weight_map = match json.get("weight_map") {
None => candle::bail!("no weight map in {json_file:?}"),
Some(serde_json::Value::Object(map)) => map,
Some(_) => candle::bail!("weight map in {json_file:?} is not a map"),
};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can replace this with:

#[derive(Deserialize)]
struct SafeTensorsJson {
  weight_map: HashMap<String, String>
}

let json: SafeTensorsJson = serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reverted this change and the other because it was leading to some off error, where the returned vector was a duplicate of the same thing repeated several times which meant the same files were being loaded over and over which led to consuming large amounts of memory.

crates/llm-local/src/utils.rs Outdated Show resolved Hide resolved
crates/llm-local/src/utils.rs Outdated Show resolved Hide resolved
crates/llm-local/src/utils.rs Outdated Show resolved Hide resolved
Comment on lines 15 to 23
for value in weight_map.values() {
if let Some(file) = value.as_str() {
safetensors_files.insert(file.to_string());
}
}
let safetensors_files = safetensors_files
.iter()
.map(|v| model_dir.join(v))
.collect::<Vec<_>>();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for value in weight_map.values() {
if let Some(file) = value.as_str() {
safetensors_files.insert(file.to_string());
}
}
let safetensors_files = safetensors_files
.iter()
.map(|v| model_dir.join(v))
.collect::<Vec<_>>();
safetensors_files.extend(weight_map.values().map(|v| model_dir.join(v))

This assumes no need to call as_str because of the suggested change above.

crates/llm-local/src/llama.rs Outdated Show resolved Hide resolved
crates/llm-local/src/lib.rs Outdated Show resolved Hide resolved
crates/llm-local/src/lib.rs Outdated Show resolved Hide resolved
crates/llm-local/src/lib.rs Outdated Show resolved Hide resolved
}

#[async_trait]
trait CachedInferencingModel: Send + Sync {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we document this trait? What about it makes it Cached? Are implementors required to cache results or does it just happen that the current implementors do?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with keeping the name, but I personally find the name CachedInferencingModel confusing when implementors aren't required to cache anything. InferencingModel seems like a more appropriate name.

@karthik2804 karthik2804 force-pushed the llama3-llm-factor branch 2 times, most recently from 08f1611 to 3352e7e Compare September 16, 2024 13:39
crates/llm-local/src/token_output_stream.rs Outdated Show resolved Hide resolved
crates/llm-local/src/token_output_stream.rs Outdated Show resolved Hide resolved
crates/llm-local/src/token_output_stream.rs Outdated Show resolved Hide resolved
Comment on lines 25 to 28
match self.tokenizer.decode(tokens, true) {
Ok(str) => Ok(str),
Err(err) => anyhow::bail!("cannot decode: {err}"),
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
match self.tokenizer.decode(tokens, true) {
Ok(str) => Ok(str),
Err(err) => anyhow::bail!("cannot decode: {err}"),
}
self.tokenizer.decode(tokens, true).context("failed to decode token stream")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does look like I cannot do this because tokenizer.decode returns a Result<String, Box<dyn Error + Send + Sync>> which does not seem to be suitable to use context on(?)

};
self.tokens.push(token);
let text = self.decode(&self.tokens[self.prev_index..])?;
if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't fully understand what this check is supposed to be doing. Why do we care about the length of the next text vs the previous, and why do we care whether the last character is alphanumeric?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The length check is to see if we have any new tokens. The alphanumeric check is supposed to be to check if we have a valid token to decode. That is what I gather from the python function the docs link to

https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The python code is dealing with unfinished utf-8 byte sequences which is not possible at this point in the Rust code. Rust chars are guaranteed to be valid utf-8. The check for alphanumeric chars is checking that the character is A-Z | a-z | 0-9 which does seem to be what we want.

The Tokenizer::decode function returns Strings so I'm guessing somehow the tokenizer crate is taking care of byte sequences that aren't valid utf-8?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

crates/llm-local/src/token_output_stream.rs Outdated Show resolved Hide resolved
crates/llm-local/src/llama.rs Outdated Show resolved Hide resolved
crates/llm-local/src/llama.rs Outdated Show resolved Hide resolved
crates/llm-local/src/lib.rs Outdated Show resolved Hide resolved
}

#[async_trait]
trait CachedInferencingModel: Send + Sync {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with keeping the name, but I personally find the name CachedInferencingModel confusing when implementors aren't required to cache anything. InferencingModel seems like a more appropriate name.

crates/llm-local/src/lib.rs Outdated Show resolved Hide resolved
};
self.tokens.push(token);
let text = self.decode(&self.tokens[self.prev_index..])?;
if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The python code is dealing with unfinished utf-8 byte sequences which is not possible at this point in the Rust code. Rust chars are guaranteed to be valid utf-8. The check for alphanumeric chars is checking that the character is A-Z | a-z | 0-9 which does seem to be what we want.

The Tokenizer::decode function returns Strings so I'm guessing somehow the tokenizer crate is taking care of byte sequences that aren't valid utf-8?

crates/llm-local/src/token_output_stream.rs Outdated Show resolved Hide resolved
Copy link
Collaborator

@rylev rylev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉

@karthik2804 karthik2804 force-pushed the llama3-llm-factor branch 2 times, most recently from 1a20e61 to 279c58c Compare September 19, 2024 15:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: 🆕 New
Development

Successfully merging this pull request may close these issues.

Migrate from rustformes/llm to Candle
3 participants