|
6 | 6 | import time
|
7 | 7 | import json
|
8 | 8 | import ctypes
|
| 9 | +import typing |
9 | 10 | import fnmatch
|
10 | 11 | import multiprocessing
|
11 | 12 |
|
@@ -249,24 +250,26 @@ def __init__(
|
249 | 250 | self._kv_overrides_array[i].key = k.encode("utf-8")
|
250 | 251 | if isinstance(v, bool):
|
251 | 252 | self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_BOOL
|
252 |
| - self._kv_overrides_array[i].value.bool_value = v |
| 253 | + self._kv_overrides_array[i].value.val_bool = v |
253 | 254 | elif isinstance(v, int):
|
254 | 255 | self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_INT
|
255 |
| - self._kv_overrides_array[i].value.int_value = v |
| 256 | + self._kv_overrides_array[i].value.val_i64 = v |
256 | 257 | elif isinstance(v, float):
|
257 | 258 | self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_FLOAT
|
258 |
| - self._kv_overrides_array[i].value.float_value = v |
| 259 | + self._kv_overrides_array[i].value.val_f64 = v |
259 | 260 | elif isinstance(v, str): # type: ignore
|
260 | 261 | v_bytes = v.encode("utf-8")
|
261 | 262 | if len(v_bytes) > 128: # TODO: Make this a constant
|
262 | 263 | raise ValueError(f"Value for {k} is too long: {v}")
|
263 | 264 | v_bytes = v_bytes.ljust(128, b"\0")
|
264 | 265 | self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_STR
|
265 | 266 | # copy min(v_bytes, 128) to str_value
|
| 267 | + address = typing.cast(int, ctypes.addressof(self._kv_overrides_array[i].value) + llama_cpp.llama_model_kv_override_value.val_str.offset) |
| 268 | + buffer_start = ctypes.cast(address, ctypes.POINTER(ctypes.c_char)) |
266 | 269 | ctypes.memmove(
|
267 |
| - self._kv_overrides_array[i].value.str_value, |
| 270 | + buffer_start, |
268 | 271 | v_bytes,
|
269 |
| - min(len(v_bytes), 128), |
| 272 | + 128, |
270 | 273 | )
|
271 | 274 | else:
|
272 | 275 | raise ValueError(f"Unknown value type for {k}: {v}")
|
|
0 commit comments