You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have 8 RTX-4090 cards, and running grok-1 fails.
Reproducing steps:
Clone the grok-1
Install the requirements:Install the requirements: pip install -r requirements.txt
Download the Hugging Face weights: git clone https://huggingface.co/xai-org/grok-1
The local mesh config change in run.py file local_mesh_config=(1, 1)
Run the run.py file python3 run.py
The error below: Traceback (most recent call last): File "/app/grok-1/run.py", line 87, in <module> main() File "/app/grok-1/run.py", line 82, in main print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01)) File "/app/grok-1/runners.py", line 597, in sample_from_model next(server) File "/app/grok-1/runners.py", line 481, in run rngs, last_output, memory, settings = self.prefill_memory( File "/usr/local/lib/python3.10/dist-packages/haiku/_src/multi_transform.py", line 314, in apply_fn return f.apply(params, None, *args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/haiku/_src/transform.py", line 183, in apply_fn out, state = f.apply(params, None, *args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/haiku/_src/transform.py", line 456, in apply_fn out = f(*args, **kwargs) File "/app/grok-1/runners.py", line 352, in hk_prefill_memory settings = jax.tree_map( File "/app/grok-1/runners.py", line 353, in <lambda> lambda o, v: jax.lax.dynamic_update_index_in_dim(o, v, i, axis=0), TypeError: dynamic_update_slice update shape must be smaller than operand shape, got update shape (1,) for operand shape (0,).
The text was updated successfully, but these errors were encountered:
I have 8 RTX-4090 cards, and running grok-1 fails.
Reproducing steps:
pip install -r requirements.txt
git clone https://huggingface.co/xai-org/grok-1
local_mesh_config=(1, 1)
python3 run.py
The error below:
Traceback (most recent call last): File "/app/grok-1/run.py", line 87, in <module> main() File "/app/grok-1/run.py", line 82, in main print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01)) File "/app/grok-1/runners.py", line 597, in sample_from_model next(server) File "/app/grok-1/runners.py", line 481, in run rngs, last_output, memory, settings = self.prefill_memory( File "/usr/local/lib/python3.10/dist-packages/haiku/_src/multi_transform.py", line 314, in apply_fn return f.apply(params, None, *args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/haiku/_src/transform.py", line 183, in apply_fn out, state = f.apply(params, None, *args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/haiku/_src/transform.py", line 456, in apply_fn out = f(*args, **kwargs) File "/app/grok-1/runners.py", line 352, in hk_prefill_memory settings = jax.tree_map( File "/app/grok-1/runners.py", line 353, in <lambda> lambda o, v: jax.lax.dynamic_update_index_in_dim(o, v, i, axis=0), TypeError: dynamic_update_slice update shape must be smaller than operand shape, got update shape (1,) for operand shape (0,).
The text was updated successfully, but these errors were encountered: