From 01408fb4ad243a8b2c97b18825fc40c46becc07e Mon Sep 17 00:00:00 2001 From: Alex Wertheim Date: Mon, 8 May 2023 20:57:22 +0000 Subject: [PATCH 01/26] Added argument for prompt generation of fixed token length --- example_xla.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/example_xla.py b/example_xla.py index dd111e58b..425f223d5 100644 --- a/example_xla.py +++ b/example_xla.py @@ -87,6 +87,7 @@ def main( dim: int = 4096, n_layers: int = 32, n_heads: int = 32, + prompt_len: int = 6, ): rank, world_size = setup_model_parallel() if rank > 0: @@ -96,7 +97,9 @@ def main( ckpt_dir, tokenizer_path, rank, world_size, max_seq_len, max_batch_size, dim, n_layers, n_heads ) - prompts = [ + prompts = [generator.tokenizer.decode(range(1, prompt_len))] + print(prompts) + # prompts = [ # For these prompts, the expected answer is the natural continuation of the prompt "I believe the meaning of life is", # "Simply put, the theory of relativity states that ", @@ -122,7 +125,7 @@ def main( #plush girafe => girafe peluche # #cheese =>""", - ] + # ] for _ in range(2): with torch.no_grad(): results = generator.generate( @@ -145,8 +148,9 @@ def _fn( dim: int = 4096, n_layers: int = 32, n_heads: int = 32, + prompt_len: int = 6, ): - main(tokenizer_path, temperature, top_p, max_seq_len, max_batch_size, ckpt_dir, dim, n_layers, n_heads) + main(tokenizer_path, temperature, top_p, max_seq_len, max_batch_size, ckpt_dir, dim, n_layers, n_heads, prompt_len) def mp_main( mp: bool, @@ -159,11 +163,12 @@ def mp_main( dim: int = 4096, n_layers: int = 32, n_heads: int = 32, + prompt_len: int = 6, ): if mp: - xmp.spawn(_fn, args=(tokenizer_path, temperature, top_p, max_seq_len, max_batch_size, ckpt_dir, dim, n_layers, n_heads)) + xmp.spawn(_fn, args=(tokenizer_path, temperature, top_p, max_seq_len, max_batch_size, ckpt_dir, dim, n_layers, n_heads, prompt_len)) else: - main(tokenizer_path, temperature, top_p, max_seq_len, max_batch_size, ckpt_dir, dim, n_layers, n_heads) + main(tokenizer_path, temperature, top_p, max_seq_len, max_batch_size, ckpt_dir, dim, n_layers, n_heads, prompt_len) if __name__ == "__main__": From e2591a7c4d22c38c0bcfab3380e7074c05ce2449 Mon Sep 17 00:00:00 2001 From: Alex Wertheim Date: Mon, 8 May 2023 21:06:58 +0000 Subject: [PATCH 02/26] Commented out old prompt --- example_xla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example_xla.py b/example_xla.py index 425f223d5..d61a0a851 100644 --- a/example_xla.py +++ b/example_xla.py @@ -101,7 +101,7 @@ def main( print(prompts) # prompts = [ # For these prompts, the expected answer is the natural continuation of the prompt - "I believe the meaning of life is", + # "I believe the meaning of life is", # "Simply put, the theory of relativity states that ", # "Building a website can be done in 10 simple steps:\n", # Few shot prompts: https://huggingface.co/blog/few-shot-learning-gpt-neo-and-inference-api From ef0452e6b04104db3a1e388e479464d42fa92548 Mon Sep 17 00:00:00 2001 From: Alex Wertheim Date: Mon, 8 May 2023 21:12:22 +0000 Subject: [PATCH 03/26] Cast Tuple to List when creating prompt string --- example_xla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example_xla.py b/example_xla.py index d61a0a851..ded195ede 100644 --- a/example_xla.py +++ b/example_xla.py @@ -97,7 +97,7 @@ def main( ckpt_dir, tokenizer_path, rank, world_size, max_seq_len, max_batch_size, dim, n_layers, n_heads ) - prompts = [generator.tokenizer.decode(range(1, prompt_len))] + prompts = [generator.tokenizer.decode(List(range(1, prompt_len)))] print(prompts) # prompts = [ # For these prompts, the expected answer is the natural continuation of the prompt From b721b839d5317834b77464c8275602dde89bafd2 Mon Sep 17 00:00:00 2001 From: Alex Wertheim Date: Mon, 8 May 2023 21:16:25 +0000 Subject: [PATCH 04/26] Fix list cast, correct size --- example_xla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example_xla.py b/example_xla.py index ded195ede..431631d6d 100644 --- a/example_xla.py +++ b/example_xla.py @@ -97,7 +97,7 @@ def main( ckpt_dir, tokenizer_path, rank, world_size, max_seq_len, max_batch_size, dim, n_layers, n_heads ) - prompts = [generator.tokenizer.decode(List(range(1, prompt_len)))] + prompts = [generator.tokenizer.decode(list(range(prompt_len)))] print(prompts) # prompts = [ # For these prompts, the expected answer is the natural continuation of the prompt From 8b27dfda1856e78f9f55edec441fc05bfc34c3bf Mon Sep 17 00:00:00 2001 From: Alex Wertheim Date: Wed, 10 May 2023 21:58:12 +0000 Subject: [PATCH 05/26] Make `max_gen_len` an exposed parameter - Made `max_gen_len` an exposed parameter - Set the default for `max_seq_len` to 2048 from 512 - Change `total_len` to be set to be max of `max_seq_len` and `max_gen_len+max_prompt_size` --- example_xla.py | 17 ++++++++++------- llama/generation.py | 4 ++-- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/example_xla.py b/example_xla.py index 431631d6d..695523335 100644 --- a/example_xla.py +++ b/example_xla.py @@ -81,13 +81,14 @@ def main( tokenizer_path: str, temperature: float = 0.8, top_p: float = 0.95, - max_seq_len: int = 512, + max_seq_len: int = 2048, max_batch_size: int = 32, ckpt_dir: str = '', dim: int = 4096, n_layers: int = 32, n_heads: int = 32, prompt_len: int = 6, + max_gen_len: int = 256, ): rank, world_size = setup_model_parallel() if rank > 0: @@ -129,7 +130,7 @@ def main( for _ in range(2): with torch.no_grad(): results = generator.generate( - prompts, max_gen_len=256, temperature=temperature, top_p=top_p + prompts, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p ) for result in results: @@ -142,33 +143,35 @@ def _fn( tokenizer_path: str, temperature: float = 0.8, top_p: float = 0.95, - max_seq_len: int = 512, + max_seq_len: int = 2048, max_batch_size: int = 32, ckpt_dir: str = '', dim: int = 4096, n_layers: int = 32, n_heads: int = 32, prompt_len: int = 6, + max_gen_len: int = 256, ): - main(tokenizer_path, temperature, top_p, max_seq_len, max_batch_size, ckpt_dir, dim, n_layers, n_heads, prompt_len) + main(tokenizer_path, temperature, top_p, max_seq_len, max_batch_size, ckpt_dir, dim, n_layers, n_heads, prompt_len, max_gen_len) def mp_main( mp: bool, tokenizer_path: str, temperature: float = 0.8, top_p: float = 0.95, - max_seq_len: int = 512, + max_seq_len: int = 2048, max_batch_size: int = 32, ckpt_dir: str = '', dim: int = 4096, n_layers: int = 32, n_heads: int = 32, prompt_len: int = 6, + max_gen_len: int = 256, ): if mp: - xmp.spawn(_fn, args=(tokenizer_path, temperature, top_p, max_seq_len, max_batch_size, ckpt_dir, dim, n_layers, n_heads, prompt_len)) + xmp.spawn(_fn, args=(tokenizer_path, temperature, top_p, max_seq_len, max_batch_size, ckpt_dir, dim, n_layers, n_heads, prompt_len, max_gen_len)) else: - main(tokenizer_path, temperature, top_p, max_seq_len, max_batch_size, ckpt_dir, dim, n_layers, n_heads, prompt_len) + main(tokenizer_path, temperature, top_p, max_seq_len, max_batch_size, ckpt_dir, dim, n_layers, n_heads, prompt_len, max_gen_len) if __name__ == "__main__": diff --git a/llama/generation.py b/llama/generation.py index b0ea81e2d..9fcfd1b9b 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -58,7 +58,7 @@ def generate( prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts] - total_len = params.max_seq_len + total_len = min(params.max_seq_len, max_gen_len + max_prompt_size) tokens = torch.full((params.max_batch_size, total_len), self.tokenizer.pad_id).long() for k, t in enumerate(prompt_tokens): @@ -84,7 +84,7 @@ def generate( ) xm.mark_step() self.model.cache_kvs = cache_kvs - print(f"Decoded in {time.time() - decoding_start_time:.5f} seconds") + print(f"Decoded {total_len-1} tokens in {time.time() - decoding_start_time:.5f} seconds") decoded = [] for i, t in enumerate(tokens.tolist()): From 10d9c0bbcaacfdb45fcc68847689332093548349 Mon Sep 17 00:00:00 2001 From: Alex Wertheim Date: Wed, 10 May 2023 22:11:26 +0000 Subject: [PATCH 06/26] Reintroduced max_prompt_size --- llama/generation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/llama/generation.py b/llama/generation.py index 9fcfd1b9b..81f33e7ba 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -58,6 +58,8 @@ def generate( prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts] + max_prompt_size = max([len(t) for t in prompt_tokens]) + total_len = min(params.max_seq_len, max_gen_len + max_prompt_size) tokens = torch.full((params.max_batch_size, total_len), self.tokenizer.pad_id).long() From 984fb6847db64b7906a91921245948f006375c35 Mon Sep 17 00:00:00 2001 From: Alex Wertheim Date: Wed, 10 May 2023 23:41:31 +0000 Subject: [PATCH 07/26] Modified how prompts is generated - To avoid decoding->encoding errors, `prompts` is now set to be `prompt_len` many copies of the fixed 8th token ("the") - Removed a print debugging statement --- example_xla.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/example_xla.py b/example_xla.py index 695523335..f35b49afe 100644 --- a/example_xla.py +++ b/example_xla.py @@ -98,8 +98,7 @@ def main( ckpt_dir, tokenizer_path, rank, world_size, max_seq_len, max_batch_size, dim, n_layers, n_heads ) - prompts = [generator.tokenizer.decode(list(range(prompt_len)))] - print(prompts) + prompts = [generator.tokenizer.decode([8]*prompt_len)] # prompts = [ # For these prompts, the expected answer is the natural continuation of the prompt # "I believe the meaning of life is", From 12a2c53a091d0bdd33d9e7b800b852fcc8d8c3a5 Mon Sep 17 00:00:00 2001 From: Liyang Lu Date: Thu, 11 May 2023 21:02:25 +0000 Subject: [PATCH 08/26] bucketize_prompt_len --- example_xla.py | 2 +- llama/generation.py | 23 +++++++++++++++++++++-- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/example_xla.py b/example_xla.py index f35b49afe..4cbb4f5d4 100644 --- a/example_xla.py +++ b/example_xla.py @@ -98,7 +98,7 @@ def main( ckpt_dir, tokenizer_path, rank, world_size, max_seq_len, max_batch_size, dim, n_layers, n_heads ) - prompts = [generator.tokenizer.decode([8]*prompt_len)] + prompts = [generator.tokenizer.decode([8]*prompt_len) for _ in range(max_batch_size)] # prompts = [ # For these prompts, the expected answer is the natural continuation of the prompt # "I believe the meaning of life is", diff --git a/llama/generation.py b/llama/generation.py index 81f33e7ba..05be9913d 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -43,6 +43,20 @@ def _generate_one_token(self, tokens, input_tokens, input_text_mask, cur_pos_ten return tokens, input_tokens, cur_pos_tensor, input_pos_tensor, output_pos_tensor, cache_kvs + def _create_start_pos_buckets(self, max_seq_len): + buckets = [max_seq_len] + while buckets[-1] > 256: + buckets.append(buckets[-1] // 2) + buckets.append(1) + buckets.reverse() + + return buckets + + def _select_start_pos_bucket(self, prompt_size, buckets): + for i in range(1, len(buckets)): + if prompt_size < buckets[i]: + return buckets[i - 1] + def generate( self, prompts: List[str], @@ -58,18 +72,23 @@ def generate( prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts] + min_prompt_size = min([len(t) for t in prompt_tokens]) max_prompt_size = max([len(t) for t in prompt_tokens]) + assert min_prompt_size >= 1 and max_prompt_size < params.max_seq_len total_len = min(params.max_seq_len, max_gen_len + max_prompt_size) - tokens = torch.full((params.max_batch_size, total_len), self.tokenizer.pad_id).long() + tokens = torch.full((params.max_batch_size, params.max_seq_len), self.tokenizer.pad_id).long() for k, t in enumerate(prompt_tokens): tokens[k, : len(t)] = torch.tensor(t).long() device = xm.xla_device() tokens = tokens.to(device) input_text_mask = tokens != self.tokenizer.pad_id - start_pos = 1 + # start_pos = 1 + start_pos_buckets = self._create_start_pos_buckets(params.max_seq_len) + start_pos = self._select_start_pos_bucket(min_prompt_size, start_pos_buckets) + print(f"start_pos = {start_pos}") cur_pos_tensor = torch.tensor(start_pos).to(device) input_pos_tensor = torch.arange(0, start_pos).to(device) output_pos_tensor = cur_pos_tensor - 1 From 6a7c6f10eb5d8f79d9ef9e933302778ee406b164 Mon Sep 17 00:00:00 2001 From: Liyang Lu Date: Thu, 11 May 2023 22:43:16 +0000 Subject: [PATCH 09/26] update --- llama/generation.py | 65 ++++++++++++++++++++++++++++++--------------- 1 file changed, 44 insertions(+), 21 deletions(-) diff --git a/llama/generation.py b/llama/generation.py index 05be9913d..08b3350bb 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -43,19 +43,19 @@ def _generate_one_token(self, tokens, input_tokens, input_text_mask, cur_pos_ten return tokens, input_tokens, cur_pos_tensor, input_pos_tensor, output_pos_tensor, cache_kvs - def _create_start_pos_buckets(self, max_seq_len): - buckets = [max_seq_len] - while buckets[-1] > 256: - buckets.append(buckets[-1] // 2) - buckets.append(1) - buckets.reverse() - - return buckets - - def _select_start_pos_bucket(self, prompt_size, buckets): - for i in range(1, len(buckets)): - if prompt_size < buckets[i]: - return buckets[i - 1] +# def _create_start_pos_buckets(self, max_seq_len): +# buckets = [max_seq_len] +# while buckets[-1] > 64: +# buckets.append(buckets[-1] // 2) +# buckets.append(1) +# buckets.reverse() +# +# return buckets +# +# def _select_start_pos_bucket(self, prompt_size, buckets): +# for i in range(1, len(buckets)): +# if prompt_size < buckets[i]: +# return buckets[i - 1] def generate( self, @@ -86,18 +86,40 @@ def generate( input_text_mask = tokens != self.tokenizer.pad_id # start_pos = 1 - start_pos_buckets = self._create_start_pos_buckets(params.max_seq_len) - start_pos = self._select_start_pos_bucket(min_prompt_size, start_pos_buckets) - print(f"start_pos = {start_pos}") - cur_pos_tensor = torch.tensor(start_pos).to(device) - input_pos_tensor = torch.arange(0, start_pos).to(device) - output_pos_tensor = cur_pos_tensor - 1 - input_tokens = tokens.index_select(1, input_pos_tensor) + # start_pos_buckets = self._create_start_pos_buckets(params.max_seq_len) + # start_pos = self._select_start_pos_bucket(min_prompt_size, start_pos_buckets) + # print(f"start_pos = {start_pos}") + # cur_pos_tensor = torch.tensor(start_pos).to(device) + # input_pos_tensor = torch.arange(0, start_pos).to(device) + # output_pos_tensor = cur_pos_tensor - 1 + # input_tokens = tokens.index_select(1, input_pos_tensor) cache_kvs = self.model.cache_kvs xm.mark_step(wait=True) decoding_start_time = time.time() - for _ in range(start_pos, total_len): + prev_pos = 0 + while prev_pos < min_prompt_size: + section_len = 1 + while prev_pos + section_len * 2 <= min_prompt_size: + section_len *= 2 + cur_pos = prev_pos + section_len + cur_pos_tensor = torch.tensor(cur_pos).to(device) + input_pos_tensor = torch.arange(prev_pos, cur_pos).to(device) + output_pos_tensor = cur_pos_tensor - 1 + input_tokens = tokens.index_select(1, input_pos_tensor) + xm.mark_step(wait=True) + + tokens, input_tokens, cur_pos_tensor, input_pos_tensor, output_pos_tensor, cache_kvs \ + = self._generate_one_token_fn( + tokens, input_tokens, input_text_mask, cur_pos_tensor, + input_pos_tensor, output_pos_tensor, cache_kvs, temperature, top_p + ) + xm.mark_step() + + prev_pos = cur_pos + + assert cur_pos_tensor.item() == prev_pos + 1 + for _ in range(prev_pos + 1, total_len): tokens, input_tokens, cur_pos_tensor, input_pos_tensor, output_pos_tensor, cache_kvs \ = self._generate_one_token_fn( tokens, input_tokens, input_text_mask, cur_pos_tensor, @@ -105,6 +127,7 @@ def generate( ) xm.mark_step() self.model.cache_kvs = cache_kvs + print(f"Processed prompts with {min_prompt_size} to {max_prompt_size} tokens, and generated {total_len - max_prompt_size} tokens") print(f"Decoded {total_len-1} tokens in {time.time() - decoding_start_time:.5f} seconds") decoded = [] From f58733ff0612f5903bcbc97cca59ba94107fbaf5 Mon Sep 17 00:00:00 2001 From: Liyang Lu Date: Thu, 11 May 2023 23:12:14 +0000 Subject: [PATCH 10/26] update --- llama/generation.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/llama/generation.py b/llama/generation.py index 08b3350bb..733965364 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -94,20 +94,22 @@ def generate( # output_pos_tensor = cur_pos_tensor - 1 # input_tokens = tokens.index_select(1, input_pos_tensor) cache_kvs = self.model.cache_kvs - xm.mark_step(wait=True) + xm.mark_step() decoding_start_time = time.time() + min_section_len = 16 prev_pos = 0 - while prev_pos < min_prompt_size: + while prev_pos + min_section_len - 1 < min_prompt_size: section_len = 1 while prev_pos + section_len * 2 <= min_prompt_size: section_len *= 2 cur_pos = prev_pos + section_len + print(f"Processing prompt pos [{prev_pos}, {cur_pos}), section length {section_len}") cur_pos_tensor = torch.tensor(cur_pos).to(device) input_pos_tensor = torch.arange(prev_pos, cur_pos).to(device) output_pos_tensor = cur_pos_tensor - 1 input_tokens = tokens.index_select(1, input_pos_tensor) - xm.mark_step(wait=True) + xm.mark_step() tokens, input_tokens, cur_pos_tensor, input_pos_tensor, output_pos_tensor, cache_kvs \ = self._generate_one_token_fn( From c54802b33d4d79b5f19d6b018bce131e02b07f45 Mon Sep 17 00:00:00 2001 From: Liyang Lu Date: Fri, 12 May 2023 00:00:39 +0000 Subject: [PATCH 11/26] update --- example_xla.py | 4 ++-- llama/generation.py | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/example_xla.py b/example_xla.py index 4cbb4f5d4..e7f43253e 100644 --- a/example_xla.py +++ b/example_xla.py @@ -98,7 +98,6 @@ def main( ckpt_dir, tokenizer_path, rank, world_size, max_seq_len, max_batch_size, dim, n_layers, n_heads ) - prompts = [generator.tokenizer.decode([8]*prompt_len) for _ in range(max_batch_size)] # prompts = [ # For these prompts, the expected answer is the natural continuation of the prompt # "I believe the meaning of life is", @@ -126,8 +125,9 @@ def main( # #cheese =>""", # ] - for _ in range(2): + for i in range(2): with torch.no_grad(): + prompts = [generator.tokenizer.decode([8] * (prompt_len // (i + 1))) for _ in range(max_batch_size)] results = generator.generate( prompts, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p ) diff --git a/llama/generation.py b/llama/generation.py index 733965364..ff1748f04 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -97,9 +97,8 @@ def generate( xm.mark_step() decoding_start_time = time.time() - min_section_len = 16 prev_pos = 0 - while prev_pos + min_section_len - 1 < min_prompt_size: + while prev_pos < min_prompt_size: section_len = 1 while prev_pos + section_len * 2 <= min_prompt_size: section_len *= 2 @@ -130,7 +129,7 @@ def generate( xm.mark_step() self.model.cache_kvs = cache_kvs print(f"Processed prompts with {min_prompt_size} to {max_prompt_size} tokens, and generated {total_len - max_prompt_size} tokens") - print(f"Decoded {total_len-1} tokens in {time.time() - decoding_start_time:.5f} seconds") + print(f"Totally decoded {total_len - 1} tokens in {time.time() - decoding_start_time:.5f} seconds") decoded = [] for i, t in enumerate(tokens.tolist()): From ee3a349d38bb0a60b8fae94e5617256b3a72a62e Mon Sep 17 00:00:00 2001 From: Liyang Lu Date: Fri, 12 May 2023 00:05:13 +0000 Subject: [PATCH 12/26] tmp test --- example_xla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example_xla.py b/example_xla.py index e7f43253e..0e163dcd6 100644 --- a/example_xla.py +++ b/example_xla.py @@ -127,7 +127,7 @@ def main( # ] for i in range(2): with torch.no_grad(): - prompts = [generator.tokenizer.decode([8] * (prompt_len // (i + 1))) for _ in range(max_batch_size)] + prompts = [generator.tokenizer.decode([8] * (14 - i * 10)) for _ in range(max_batch_size)] results = generator.generate( prompts, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p ) From 7b287367fbdeec2cb1e3016f7e1e540066631ebf Mon Sep 17 00:00:00 2001 From: Liyang Lu Date: Fri, 12 May 2023 00:20:30 +0000 Subject: [PATCH 13/26] tmp test --- example_xla.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/example_xla.py b/example_xla.py index 0e163dcd6..4cbb4f5d4 100644 --- a/example_xla.py +++ b/example_xla.py @@ -98,6 +98,7 @@ def main( ckpt_dir, tokenizer_path, rank, world_size, max_seq_len, max_batch_size, dim, n_layers, n_heads ) + prompts = [generator.tokenizer.decode([8]*prompt_len) for _ in range(max_batch_size)] # prompts = [ # For these prompts, the expected answer is the natural continuation of the prompt # "I believe the meaning of life is", @@ -125,9 +126,8 @@ def main( # #cheese =>""", # ] - for i in range(2): + for _ in range(2): with torch.no_grad(): - prompts = [generator.tokenizer.decode([8] * (14 - i * 10)) for _ in range(max_batch_size)] results = generator.generate( prompts, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p ) From 4029650b6c3c0d344385891840979b042498fe14 Mon Sep 17 00:00:00 2001 From: Liyang Lu Date: Fri, 12 May 2023 00:28:29 +0000 Subject: [PATCH 14/26] clean up --- llama/generation.py | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/llama/generation.py b/llama/generation.py index ff1748f04..c660342fc 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -43,20 +43,6 @@ def _generate_one_token(self, tokens, input_tokens, input_text_mask, cur_pos_ten return tokens, input_tokens, cur_pos_tensor, input_pos_tensor, output_pos_tensor, cache_kvs -# def _create_start_pos_buckets(self, max_seq_len): -# buckets = [max_seq_len] -# while buckets[-1] > 64: -# buckets.append(buckets[-1] // 2) -# buckets.append(1) -# buckets.reverse() -# -# return buckets -# -# def _select_start_pos_bucket(self, prompt_size, buckets): -# for i in range(1, len(buckets)): -# if prompt_size < buckets[i]: -# return buckets[i - 1] - def generate( self, prompts: List[str], @@ -85,14 +71,6 @@ def generate( tokens = tokens.to(device) input_text_mask = tokens != self.tokenizer.pad_id - # start_pos = 1 - # start_pos_buckets = self._create_start_pos_buckets(params.max_seq_len) - # start_pos = self._select_start_pos_bucket(min_prompt_size, start_pos_buckets) - # print(f"start_pos = {start_pos}") - # cur_pos_tensor = torch.tensor(start_pos).to(device) - # input_pos_tensor = torch.arange(0, start_pos).to(device) - # output_pos_tensor = cur_pos_tensor - 1 - # input_tokens = tokens.index_select(1, input_pos_tensor) cache_kvs = self.model.cache_kvs xm.mark_step() From 74e120ca7a6026a8f770a258e18effd41d08261d Mon Sep 17 00:00:00 2001 From: Liyang Lu Date: Fri, 12 May 2023 04:57:38 +0000 Subject: [PATCH 15/26] adjust scale factor --- llama/generation.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/llama/generation.py b/llama/generation.py index c660342fc..77c87ba17 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -76,10 +76,11 @@ def generate( decoding_start_time = time.time() prev_pos = 0 + scale_factor = 8 while prev_pos < min_prompt_size: section_len = 1 - while prev_pos + section_len * 2 <= min_prompt_size: - section_len *= 2 + while prev_pos + section_len * scale_factor <= min_prompt_size: + section_len *= scale_factor cur_pos = prev_pos + section_len print(f"Processing prompt pos [{prev_pos}, {cur_pos}), section length {section_len}") cur_pos_tensor = torch.tensor(cur_pos).to(device) From 94f19e9feff5aba5c2fcb3e9c0e885b7f111e41e Mon Sep 17 00:00:00 2001 From: Liyang Lu Date: Fri, 19 May 2023 16:57:25 +0000 Subject: [PATCH 16/26] turn temperature and top_p into tensors --- llama/generation.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/llama/generation.py b/llama/generation.py index 77c87ba17..a826e1350 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -20,11 +20,12 @@ def __init__(self, model: Transformer, tokenizer: Tokenizer): backend="torchxla_trace_once", fullgraph=True) def _generate_one_token(self, tokens, input_tokens, input_text_mask, cur_pos_tensor, - input_pos_tensor, output_pos_tensor, cache_kvs, temperature, top_p): + input_pos_tensor, output_pos_tensor, cache_kvs, temperature, + temperature_tensor, top_p_tensor): logits, cache_kvs = self.model(input_tokens, input_pos_tensor, output_pos_tensor, cache_kvs) if temperature > 0: - probs = torch.softmax(logits / temperature, dim=-1) - next_token = sample_top_p(probs, top_p) + probs = torch.softmax(logits / temperature_tensor, dim=-1) + next_token = sample_top_p(probs, top_p_tensor) else: next_token = torch.argmax(logits, dim=-1) next_token = next_token.reshape(-1) @@ -71,6 +72,9 @@ def generate( tokens = tokens.to(device) input_text_mask = tokens != self.tokenizer.pad_id + temperature_tensor = torch.tensor(temperature).to(device) + top_p_tensor = torch.tensor(top_p).to(device) + cache_kvs = self.model.cache_kvs xm.mark_step() @@ -92,7 +96,8 @@ def generate( tokens, input_tokens, cur_pos_tensor, input_pos_tensor, output_pos_tensor, cache_kvs \ = self._generate_one_token_fn( tokens, input_tokens, input_text_mask, cur_pos_tensor, - input_pos_tensor, output_pos_tensor, cache_kvs, temperature, top_p + input_pos_tensor, output_pos_tensor, cache_kvs, temperature, + temperature_tensor, top_p_tensor ) xm.mark_step() @@ -103,7 +108,8 @@ def generate( tokens, input_tokens, cur_pos_tensor, input_pos_tensor, output_pos_tensor, cache_kvs \ = self._generate_one_token_fn( tokens, input_tokens, input_text_mask, cur_pos_tensor, - input_pos_tensor, output_pos_tensor, cache_kvs, temperature, top_p + input_pos_tensor, output_pos_tensor, cache_kvs, temperature, + temperature_tensor, top_p_tensor ) xm.mark_step() self.model.cache_kvs = cache_kvs @@ -133,7 +139,7 @@ def generate( def sample_top_p(probs, p): probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) probs_sum = torch.cumsum(probs_sort, dim=-1) - mask = probs_sum - probs_sort > p + mask = (probs_sum - probs_sort) > p probs_sort = torch.where(mask, 0.0, probs_sort) probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) next_token = torch.multinomial(probs_sort, num_samples=1) From d0cc999e18573299bd187394538f444e445c586e Mon Sep 17 00:00:00 2001 From: Liyang Lu Date: Fri, 19 May 2023 17:03:02 +0000 Subject: [PATCH 17/26] tmp test --- example_xla.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/example_xla.py b/example_xla.py index 4cbb4f5d4..dc3af063a 100644 --- a/example_xla.py +++ b/example_xla.py @@ -98,7 +98,6 @@ def main( ckpt_dir, tokenizer_path, rank, world_size, max_seq_len, max_batch_size, dim, n_layers, n_heads ) - prompts = [generator.tokenizer.decode([8]*prompt_len) for _ in range(max_batch_size)] # prompts = [ # For these prompts, the expected answer is the natural continuation of the prompt # "I believe the meaning of life is", @@ -126,10 +125,18 @@ def main( # #cheese =>""", # ] - for _ in range(2): + + pairs = [] + for l in [1500, 1600, 100, 500, 800]: + for t in [0.1, 0.5, 0.8]: + for p in [0.8, 0.8]: + pairs.append([l, t, p]) + + for l, t, p in pairs: + prompts = [generator.tokenizer.decode([8]*l) for _ in range(max_batch_size)] with torch.no_grad(): results = generator.generate( - prompts, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p + prompts, max_gen_len=max_gen_len, temperature=t, top_p=p ) for result in results: From 644c88b32a9da70309e0114c0337f742b46a6e51 Mon Sep 17 00:00:00 2001 From: Liyang Lu Date: Fri, 19 May 2023 17:42:13 +0000 Subject: [PATCH 18/26] tmp update --- example_xla.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/example_xla.py b/example_xla.py index dc3af063a..23c8065f1 100644 --- a/example_xla.py +++ b/example_xla.py @@ -127,9 +127,9 @@ def main( # ] pairs = [] - for l in [1500, 1600, 100, 500, 800]: - for t in [0.1, 0.5, 0.8]: - for p in [0.8, 0.8]: + for l in [1500]: + for t in [0.1, 0.5]: + for p in [0.8, 0.9]: pairs.append([l, t, p]) for l, t, p in pairs: From a50045cb5115cfabde1af3cbe920e2d090d25cdd Mon Sep 17 00:00:00 2001 From: Liyang Lu Date: Fri, 19 May 2023 18:15:24 +0000 Subject: [PATCH 19/26] update --- example_xla.py | 1 + llama/generation.py | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/example_xla.py b/example_xla.py index 23c8065f1..0ce3b0e15 100644 --- a/example_xla.py +++ b/example_xla.py @@ -133,6 +133,7 @@ def main( pairs.append([l, t, p]) for l, t, p in pairs: + print(l, t, p) prompts = [generator.tokenizer.decode([8]*l) for _ in range(max_batch_size)] with torch.no_grad(): results = generator.generate( diff --git a/llama/generation.py b/llama/generation.py index a826e1350..26ebe5223 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -20,10 +20,10 @@ def __init__(self, model: Transformer, tokenizer: Tokenizer): backend="torchxla_trace_once", fullgraph=True) def _generate_one_token(self, tokens, input_tokens, input_text_mask, cur_pos_tensor, - input_pos_tensor, output_pos_tensor, cache_kvs, temperature, + input_pos_tensor, output_pos_tensor, cache_kvs, temperature_tensor, top_p_tensor): logits, cache_kvs = self.model(input_tokens, input_pos_tensor, output_pos_tensor, cache_kvs) - if temperature > 0: + if temperature_tensor.item() > 0: probs = torch.softmax(logits / temperature_tensor, dim=-1) next_token = sample_top_p(probs, top_p_tensor) else: @@ -96,7 +96,7 @@ def generate( tokens, input_tokens, cur_pos_tensor, input_pos_tensor, output_pos_tensor, cache_kvs \ = self._generate_one_token_fn( tokens, input_tokens, input_text_mask, cur_pos_tensor, - input_pos_tensor, output_pos_tensor, cache_kvs, temperature, + input_pos_tensor, output_pos_tensor, cache_kvs, temperature_tensor, top_p_tensor ) xm.mark_step() @@ -108,7 +108,7 @@ def generate( tokens, input_tokens, cur_pos_tensor, input_pos_tensor, output_pos_tensor, cache_kvs \ = self._generate_one_token_fn( tokens, input_tokens, input_text_mask, cur_pos_tensor, - input_pos_tensor, output_pos_tensor, cache_kvs, temperature, + input_pos_tensor, output_pos_tensor, cache_kvs, temperature_tensor, top_p_tensor ) xm.mark_step() From a185dda95bdeaaeec7ff9ec3aba57c28f729b83a Mon Sep 17 00:00:00 2001 From: Liyang Lu Date: Fri, 19 May 2023 18:38:44 +0000 Subject: [PATCH 20/26] tmp experiment --- llama/generation.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/llama/generation.py b/llama/generation.py index 26ebe5223..b8eebfb09 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -23,11 +23,11 @@ def _generate_one_token(self, tokens, input_tokens, input_text_mask, cur_pos_ten input_pos_tensor, output_pos_tensor, cache_kvs, temperature_tensor, top_p_tensor): logits, cache_kvs = self.model(input_tokens, input_pos_tensor, output_pos_tensor, cache_kvs) - if temperature_tensor.item() > 0: - probs = torch.softmax(logits / temperature_tensor, dim=-1) - next_token = sample_top_p(probs, top_p_tensor) - else: - next_token = torch.argmax(logits, dim=-1) + #if temperature_tensor.item() > 0: + probs = torch.softmax(logits / temperature_tensor, dim=-1) + next_token = sample_top_p(probs, top_p_tensor) + #else: + # next_token = torch.argmax(logits, dim=-1) next_token = next_token.reshape(-1) # only replace token if prompt has already been generated input_text_mask_tmp = input_text_mask.index_select(1, cur_pos_tensor).squeeze(dim=1) From 799ed7de81073c0a1335229d0c6958ca2d157dc0 Mon Sep 17 00:00:00 2001 From: Liyang Lu Date: Fri, 19 May 2023 19:03:26 +0000 Subject: [PATCH 21/26] update --- llama/generation.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/llama/generation.py b/llama/generation.py index b8eebfb09..6f9da850a 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -21,13 +21,13 @@ def __init__(self, model: Transformer, tokenizer: Tokenizer): def _generate_one_token(self, tokens, input_tokens, input_text_mask, cur_pos_tensor, input_pos_tensor, output_pos_tensor, cache_kvs, - temperature_tensor, top_p_tensor): + temperature_tensor, top_p_tensor, with_temp): logits, cache_kvs = self.model(input_tokens, input_pos_tensor, output_pos_tensor, cache_kvs) - #if temperature_tensor.item() > 0: - probs = torch.softmax(logits / temperature_tensor, dim=-1) - next_token = sample_top_p(probs, top_p_tensor) - #else: - # next_token = torch.argmax(logits, dim=-1) + if with_temp: + probs = torch.softmax(logits / temperature_tensor, dim=-1) + next_token = sample_top_p(probs, top_p_tensor) + else: + next_token = torch.argmax(logits, dim=-1) next_token = next_token.reshape(-1) # only replace token if prompt has already been generated input_text_mask_tmp = input_text_mask.index_select(1, cur_pos_tensor).squeeze(dim=1) @@ -97,7 +97,7 @@ def generate( = self._generate_one_token_fn( tokens, input_tokens, input_text_mask, cur_pos_tensor, input_pos_tensor, output_pos_tensor, cache_kvs, - temperature_tensor, top_p_tensor + temperature_tensor, top_p_tensor, temperature > 0 ) xm.mark_step() @@ -109,7 +109,7 @@ def generate( = self._generate_one_token_fn( tokens, input_tokens, input_text_mask, cur_pos_tensor, input_pos_tensor, output_pos_tensor, cache_kvs, - temperature_tensor, top_p_tensor + temperature_tensor, top_p_tensor, temperature > 0 ) xm.mark_step() self.model.cache_kvs = cache_kvs From 36f17d879ee378804db788301833013ff2d2caf5 Mon Sep 17 00:00:00 2001 From: Liyang Lu Date: Fri, 19 May 2023 19:27:29 +0000 Subject: [PATCH 22/26] tmp test --- example_xla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example_xla.py b/example_xla.py index 0ce3b0e15..3f7e1d150 100644 --- a/example_xla.py +++ b/example_xla.py @@ -128,7 +128,7 @@ def main( pairs = [] for l in [1500]: - for t in [0.1, 0.5]: + for t in [0.1, 0.5, 0]: for p in [0.8, 0.9]: pairs.append([l, t, p]) From 7220c02ded9dbd6fb23cbb3bc918a205912c58cf Mon Sep 17 00:00:00 2001 From: Liyang Lu Date: Fri, 19 May 2023 19:30:54 +0000 Subject: [PATCH 23/26] update --- llama/generation.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/llama/generation.py b/llama/generation.py index 6f9da850a..ddbc4c2b6 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -74,6 +74,7 @@ def generate( temperature_tensor = torch.tensor(temperature).to(device) top_p_tensor = torch.tensor(top_p).to(device) + with_temp = temperature > 0 cache_kvs = self.model.cache_kvs xm.mark_step() @@ -97,7 +98,7 @@ def generate( = self._generate_one_token_fn( tokens, input_tokens, input_text_mask, cur_pos_tensor, input_pos_tensor, output_pos_tensor, cache_kvs, - temperature_tensor, top_p_tensor, temperature > 0 + temperature_tensor, top_p_tensor, with_temp ) xm.mark_step() @@ -109,7 +110,7 @@ def generate( = self._generate_one_token_fn( tokens, input_tokens, input_text_mask, cur_pos_tensor, input_pos_tensor, output_pos_tensor, cache_kvs, - temperature_tensor, top_p_tensor, temperature > 0 + temperature_tensor, top_p_tensor, with_temp ) xm.mark_step() self.model.cache_kvs = cache_kvs From ddb7a5eb63f1d286aa0edaee3f0cfd157c7a0cf6 Mon Sep 17 00:00:00 2001 From: Liyang Lu Date: Fri, 19 May 2023 19:41:30 +0000 Subject: [PATCH 24/26] recover tmp changes --- example_xla.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/example_xla.py b/example_xla.py index 3f7e1d150..4cbb4f5d4 100644 --- a/example_xla.py +++ b/example_xla.py @@ -98,6 +98,7 @@ def main( ckpt_dir, tokenizer_path, rank, world_size, max_seq_len, max_batch_size, dim, n_layers, n_heads ) + prompts = [generator.tokenizer.decode([8]*prompt_len) for _ in range(max_batch_size)] # prompts = [ # For these prompts, the expected answer is the natural continuation of the prompt # "I believe the meaning of life is", @@ -125,19 +126,10 @@ def main( # #cheese =>""", # ] - - pairs = [] - for l in [1500]: - for t in [0.1, 0.5, 0]: - for p in [0.8, 0.9]: - pairs.append([l, t, p]) - - for l, t, p in pairs: - print(l, t, p) - prompts = [generator.tokenizer.decode([8]*l) for _ in range(max_batch_size)] + for _ in range(2): with torch.no_grad(): results = generator.generate( - prompts, max_gen_len=max_gen_len, temperature=t, top_p=p + prompts, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p ) for result in results: From 8ab9f48e2cbafa85c9dadc47f774a0e77a88e585 Mon Sep 17 00:00:00 2001 From: Liyang Lu Date: Fri, 19 May 2023 21:14:30 +0000 Subject: [PATCH 25/26] add comment --- llama/generation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/llama/generation.py b/llama/generation.py index ddbc4c2b6..8a0ce6ea1 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -72,6 +72,8 @@ def generate( tokens = tokens.to(device) input_text_mask = tokens != self.tokenizer.pad_id + # Passing tensors instead of floats into self._generate_one_token_fn, + # so that different values would not trigger compilations of new graphs temperature_tensor = torch.tensor(temperature).to(device) top_p_tensor = torch.tensor(top_p).to(device) with_temp = temperature > 0 From 516351f59005ff69b7cf03b4c79d30c6b42dc447 Mon Sep 17 00:00:00 2001 From: Liyang Lu Date: Fri, 19 May 2023 21:31:23 +0000 Subject: [PATCH 26/26] minor update --- llama/generation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama/generation.py b/llama/generation.py index 8a0ce6ea1..925f82223 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -74,8 +74,8 @@ def generate( # Passing tensors instead of floats into self._generate_one_token_fn, # so that different values would not trigger compilations of new graphs - temperature_tensor = torch.tensor(temperature).to(device) - top_p_tensor = torch.tensor(top_p).to(device) + temperature_tensor = torch.tensor(float(temperature)).to(device) + top_p_tensor = torch.tensor(float(top_p)).to(device) with_temp = temperature > 0 cache_kvs = self.model.cache_kvs