Skip to content

Commit 2b59a08

Browse files
committed
remove unused llm_w,llm_h
1 parent a875132 commit 2b59a08

File tree

1 file changed

+1
-4
lines changed

1 file changed

+1
-4
lines changed

tpu_inference/models/jax/qwen2_vl.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -518,9 +518,6 @@ def __call__(self, x: jax.Array, grid_thw: tuple[tuple[int, int,
518518
for i in range(num_grids):
519519
t, h, w = grid_thw[i]
520520

521-
llm_h = h // self.spatial_merge_size
522-
llm_w = w // self.spatial_merge_size
523-
524521
rotary_pos_emb_thw = self.rotary_pos_emb_thw(t, h, w)
525522
cu_seqlens_thw = jnp.full(t, h * w, dtype=jnp.int32)
526523

@@ -904,4 +901,4 @@ def load_weights(self, rng_key: jax.Array) -> None:
904901
for i, path in enumerate(sorted(unloaded_params)[:10]):
905902
logger.error(f" {i+1}. {path}")
906903

907-
raise ValueError(f"Not all parameters were loaded. Found {len(unloaded_params)} unloaded parameters.")
904+
raise ValueError(f"Not all parameters were loaded. Found {len(unloaded_params)} unloaded parameters.")

0 commit comments

Comments
 (0)