Skip to content

Commit 5522bbc

Browse files
authored
Add fn 'get_with_hints_dtype' in VarBuilder (huggingface#1877) (huggingface#1897)
* quantized models(awq/squeezellm/...) have multiple data type tensors, use 'get_with_hints_dtype' to load tensors with given dtype
1 parent 888c09a commit 5522bbc

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

candle-nn/src/var_builder.rs

+15-4
Original file line numberDiff line numberDiff line change
@@ -178,16 +178,27 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
178178
name: &str,
179179
hints: B::Hints,
180180
) -> Result<Tensor> {
181-
let path = self.path(name);
182-
self.data
183-
.backend
184-
.get(s.into(), &path, hints, self.data.dtype, &self.data.device)
181+
self.get_with_hints_dtype(s, name, hints, self.data.dtype)
185182
}
186183

187184
/// Retrieve the tensor associated with the given name at the current path.
188185
pub fn get<S: Into<Shape>>(&self, s: S, name: &str) -> Result<Tensor> {
189186
self.get_with_hints(s, name, Default::default())
190187
}
188+
189+
/// Retrieve the tensor associated with the given name & dtype at the current path.
190+
pub fn get_with_hints_dtype<S: Into<Shape>>(
191+
&self,
192+
s: S,
193+
name: &str,
194+
hints: B::Hints,
195+
dtype: DType,
196+
) -> Result<Tensor> {
197+
let path = self.path(name);
198+
self.data
199+
.backend
200+
.get(s.into(), &path, hints, dtype, &self.data.device)
201+
}
191202
}
192203

193204
struct Zeros;

0 commit comments

Comments
 (0)