diff --git a/src/model.rs b/src/model.rs index d59bf0e0..3d674ab3 100644 --- a/src/model.rs +++ b/src/model.rs @@ -144,7 +144,7 @@ impl Llama { fn self_attention( hidden_states: &mut Tensor, // (seq, n_kv_h * n_groups * dqkv) att_scores: &mut Tensor, // (n_kv_h, n_groups, seq, total_seq) - q: &Tensor, // (seq, n_kv_h * n_groups * dqkv) + q: &Tensor, // (seq, n_kv_h * n_groups, dqkv) k: &Tensor, // (total_seq, n_kv_h * dqkv) v: &Tensor, // (total_seq, n_kv_h * dqkv) n_kv_h: usize, diff --git a/src/tensor.rs b/src/tensor.rs index b56d2dd9..9864f746 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -53,7 +53,7 @@ impl Tensor { pub fn slice(&self, start: usize, shape: &Vec) -> Self { let new_length: usize = shape.iter().product(); - assert!(self.offset + start + new_length <= self.length); + assert!(new_length <= self.length && start <= self.length - new_length); Tensor { data: self.data.clone(), shape: shape.clone(), @@ -61,8 +61,6 @@ impl Tensor { length: new_length, } } - - } // Some helper functions for testing and debugging