Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
wangxiongts authored Nov 26, 2024
1 parent a420ba0 commit 9cb0c35
Show file tree
Hide file tree
Showing 31 changed files with 5,544 additions and 3 deletions.
577 changes: 577 additions & 0 deletions License.txt

Large diffs are not rendered by default.

84 changes: 81 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,19 +1,31 @@
# Freeze-Omni: A Smart and Low Latency Speech-to-speech Dialogue Model with Frozen LLM



<p align="center">
<img src="./assets/logo.png" width="70%" height="70%">
</p>

<font size=7><div align='center' > [[🍎 Project Page](https://freeze-omni.github.io/)] [[📖 arXiv Paper](https://arxiv.org/abs/2411.00774)]</div></font>
<font size=7><div align='center' > [[🍎 Project(Demo) Page](https://freeze-omni.github.io)] [[📖 arXiv Paper](https://arxiv.org/abs/2411.00774)] [[🤗 Hugging Face](https://huggingface.co/VITA-MLLM/Freeze-Omni)]</div></font>


---

## 🔥 News
* **`2024.11.26`** 🌟 The inference code, server demo and model weights **have been released**. Long wait!
* **`2024.11.4`** 🌟 We are very proud to launch Freeze-Omni, a speech-to-speech dialogue model
with both low-latency and high intelligence! We have submitted the open-source code, yet it is under review internally. We are moving the process forward as quickly as possible, stay tuned!

## Contents <!-- omit in toc -->


- [Freeze-Omni Overview](#-freeze-omni-overview)
- [Experimental Results](#-experimental-results)
- [Inference](#-inference)
- [Requirements and Installation](#requirements-and-installation)
- [Quick Start](#quick-start)
- [Real-Time Interactive Demo](#real-time-interactive-demo)


## 👀 Freeze-Omni Overview
Freeze-Omni is a speech-to-speech dialogue model, exhibiting the characteristic of being "**smart**" as it is constructed upon a "**frozen**" text-modality LLM. This enables it to keep the original intelligence of the LLM backbone, without being affected by the forgetting problem induced by the fine-tuning process for integration of the speech modality. Specifically, Freeze-Omni contains a speech encoder that supports streaming speech input and a speech decoder that generates streaming output speech. **Three key strategies** are adopted to implement the speech-to-speech dialogue system:
Expand All @@ -22,7 +34,7 @@ Freeze-Omni is a speech-to-speech dialogue model, exhibiting the characteristic

- **AR-base Speech Output**. Freeze-Omni has an AR speech decoder based on a single codebook, which can achieve low-latency speech output in streaming. A prefix tuning method is used so that training on only a small amount of Q&A data can achieve the ability to produce high-quality speech synthesis.

- **Chunk-level State Prediction**. Freeze-Omni adds a classification layer after the last layer of the backbone LLM to predict different states. These states will determine whether or not the LLM interrupts the user to achieve a duplex dialogue for the user and the bot.
- **Chunk-level State Prediction**. Freeze-Omni adds a classification layer after the last layer of the backbone LLM to predict different states. These states will determine whether or not the user interrupts the dialogue to achieve a duplex dialogue for the user and the bot.

<p align="center">
<img src="./assets/overview.png" width="88%" height="88%">
Expand Down Expand Up @@ -59,19 +71,83 @@ Besides we implement a Model as a Server strategy. We first started several mode
</p>


## 📐 Inference
### Requirements and Installation
**Environment Requirements**:
```
git clone https://github.com/VITA-MLLM/Freeze-Omni
cd Freeze-Omni
conda create -n freeze-omni python=3.10 -y
conda activate freeze-omni
pip install --upgrade pip
pip install -r requirements.txt
```
**Required weights**:
- Download [Freeze-Omni checkpoint](https://huggingface.co/VITA-MLLM/Freeze-Omni) then move `checkpoints` into the root dir of this repo (`Freeze-Omni/checkpoints`)
- Download [Qwen2-7B-Instruct checkpoint](https://huggingface.co/Qwen/Qwen2-7B-Instruct) then move `Qwen2-7B-Instruct` into the root dir of this repo (`Freeze-Omni/Qwen2-7B-Instruct`)

### Quick Start
**From python command**
```
CUDA_VISIBLE_DEVICES=0 python3 bin/inference.py \
--model_path ./checkpoints \
--input_wav ./assets/question.wav \
--output_wav ./assets/answer.wav \
--llm_path ./Qwen2-7B-Instruct \
--top_p 0.8 \
--top_k 20 \
--temperature 0.8
```
**From script**
```
sh scripts/run_inference.sh
```

### Real-Time Interactive Demo

To have a good interactive experience, please pay attention to the following three points:

- **Ensure a high-speed network connection**.
- **Use high-performance GPUs for deployment**. In the demo video, we use 1 Nvidia A100 GPU. A800, H800, or H20 will be much better.
- **Maintain a quiet environment**.

**From python command**
```
CUDA_VISIBLE_DEVICES=0 python3 bin/server.py \
--ip your_server_ip \
--port your_server_port \
--max_users 3 \
--llm_exec_nums 1 \
--timeout 180 \
--model_path ./checkpoints \
--llm_path ./Qwen2-7B-Instruct \
--top_p 0.8 \
--top_k 20 \
--temperature 0.8
```
**From script**

Change the **ip** and **port** in `scripts/run_demo_server.sh` with yours and run:
```
sh scripts/run_demo_server.sh
```

## ✒️ Citation

If you find our work helpful for your research, please consider citing our work.

```bibtex
@article{xiong2024freeze,
title={Freeze-Omni: A Smart and Low Latency Speech-to-speech Dialogue Model with Frozen LLM},
author={Xiong Wang and Yangze Li and Chaoyou Fu and Lei Xie and Ke Li and Xing Sun and Long Ma},
author={Xiong Wang and Yangze Li and Chaoyou Fu and Yunhang Shen and Lei Xie and Ke Li and Xing Sun and Long Ma},
journal={arXiv preprint arXiv:2411.00774},
year={2024}
}
```

## &#x1F4E3; Statement

**Freeze-Omni is trained on large-scale corpus, and its output has randomness. Any content generated by Freeze-Omni does not represent the views of the model developers. We are not responsible for any problems arising from the use, misuse, and dissemination of Freeze-Omni, including but not limited to public opinion risks and data security issues.**

## 📜 Related Works

Expand All @@ -81,3 +157,5 @@ Explore our related researches:
- **[MME]** [MME: A Comprehensive Evaluation Benchmark for Multimodal Large Language Models](https://github.com/BradyFU/Awesome-Multimodal-Large-Language-Models/tree/Evaluation)
- **[Video-MME]** [Video-MME: The First-Ever Comprehensive Evaluation Benchmark of Multi-modal LLMs in Video Analysis](https://github.com/BradyFU/Video-MME)

## 👍 Acknowledgement
Freeze-Omni is built with reference to the following outstanding works: [Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct), [TiCodec](https://github.com/y-ren16/ticodec)
Binary file added assets/answer.wav
Binary file not shown.
Binary file modified assets/out_cer.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/question.wav
Binary file not shown.
190 changes: 190 additions & 0 deletions bin/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
from __future__ import print_function

import argparse
import os
import json
import queue
import torch
import yaml
import threading
import struct
import time
import torchaudio
import datetime
import builtins
import math

import soundfile as sf
import numpy as np
import torch.nn.functional as F
import torchaudio.compliance.kaldi as k

from torch.utils.data import DataLoader

from models.pipeline import inferencePipeline
from models.decoder.llm2tts import llm2TTS
from web.parms import GlobalParams
from web.pool import TTSObjectPool

def get_args():
parser = argparse.ArgumentParser(description='Freeze-Omni')
parser.add_argument('--model_path', required=True, help='model_path to load')
parser.add_argument('--llm_path', required=True, help='llm_path to load')
parser.add_argument('--top_k', type=int, default=5)
parser.add_argument('--top_p', type=float, default=0.8)
parser.add_argument('--temperature', type=float, default=0.7)
parser.add_argument('--input_wav', required=True, help='input wav')
parser.add_argument('--output_wav', required=True, help='output wav')

args = parser.parse_args()
print(args)
return args

class audioEncoderProcessor:
def __init__(self, chunk_size = 16):
self.chunk_size = 16
self.chunk_overlap = 3
self.feat_dim = 80
self.frame_size = 400
self.frame_shift = 160
self.frame_overlap = self.frame_size - self.frame_shift
self.CHUNK = self.frame_shift * self.chunk_size
self.reset()

def get_chunk_size(self):
return self.CHUNK

def reset(self):
self.input_chunk = torch.zeros([1, self.chunk_size + self.chunk_overlap, self.feat_dim])
self.input_sample = torch.zeros([1, self.CHUNK + self.frame_overlap , 1])

def fbank_shift(self, sample_data):
# fbank feature shift
self.input_sample[:, :self.frame_overlap , :] = self.input_sample[:, -self.frame_overlap:, :].clone()
self.input_sample[:, self.frame_overlap:, :] = sample_data

def chunk_data_shift(self, xs):
# chunk feature shift
self.input_chunk[:, :self.chunk_overlap, :] = self.input_chunk[:, -self.chunk_overlap:, :].clone()
self.input_chunk[:, self.chunk_overlap:, :] = xs.squeeze(0)

def process(self,
audio: torch.Tensor):
with torch.no_grad():
sample_data = torch.tensor(audio).reshape(1, -1, 1)[:, :, :1] * 32768
self.fbank_shift(sample_data)
# use kaldi api to compute fbank
xs = k.fbank(waveform = self.input_sample.squeeze(-1), dither=0,
frame_length=25, frame_shift=10, num_mel_bins=self.feat_dim)
self.chunk_data_shift(xs)
return self.input_chunk.clone()

def decoder(cur_hidden_state, pipeline, cur_text, tts, codec_chunk_size, codec_padding_size, decoder_topk, wav):
hidden_state_output = torch.cat(cur_hidden_state).squeeze(1)
cur_text_procced = pipeline.post_process(cur_text)
print("Synthesis: ", [cur_text_procced])
embeddings = pipeline.model.llm_decoder.model.embed_tokens(
torch.tensor(pipeline.model.tokenizer.encode(cur_text_procced)).cuda()
)
for seg in tts.run(embeddings.reshape(-1, 896).unsqueeze(0), decoder_topk,
hidden_state_output.reshape(-1, 896).unsqueeze(0),
codec_chunk_size, codec_padding_size):
wav.append(seg)

def inference(pipeline, audio_processor, tts, configs):
"""
Perform inference for a speech dialogue system.
Parameters:
- pipeline: Speech dialogue pipeline.
- audio_processor: Processes raw audio data into a format suitable for the pipeline.
- tts: The speech decoder moudule.
- configs: Input args.
Returns:
- None
"""
wav, fs = sf.read(configs.input_wav)
wav = torch.tensor(wav)
if fs != 16000:
wav = torchaudio.transforms.Resample(orig_freq=fs, new_freq=16000)(wav.float())
fs = 16000

codec_chunk_size = 40
codec_padding_size = 10
decoder_topk = 2

# Satge0: preprocess
# set system role, stat will be set to 'sl'
stat = 'pre'
outputs = pipeline.speech_dialogue(None, stat=stat, role="You are a helpful assistant.")
chunk_size = audio_processor.get_chunk_size()

# Satge1: start listen
# stat will be auto set to 'cl' after Stage1
wav_input = torch.zeros(math.ceil(wav.shape[0] / chunk_size) * chunk_size)
wav_input[:wav.shape[0]] = wav
for i in range(0, wav_input.shape[0], chunk_size):
fbank = audio_processor.process(wav_input[i:i+chunk_size])
outputs = pipeline.speech_dialogue(fbank, **outputs)
outputs['stat'] = 'cl'
audio_processor.reset()

outputs['adapter_cache'] = None
outputs['encoder_cache'] = None
outputs['pe_index'] = 0
outputs['stat'] = 'ss'

# Stage3: start speak
outputs = pipeline.speech_dialogue(None, **outputs)
cur_hidden_state = []
cur_hidden_state.append(outputs['hidden_state'])

whole_text = ''
last_text = ''
cur_text = ''
wav = []
# Stage4: contiune speak until stat is set to 'sl'
# use 'stop' to interrupt generation, stat need to be manually set as 'sl'
stop = False
while True:
if len(outputs['past_tokens']) > 128:
stop = True
if stop:
break
del outputs['text']
del outputs['hidden_state']
outputs = pipeline.speech_dialogue(None, **outputs)
if outputs['stat'] == 'cs':
cur_hidden_state.append(outputs['hidden_state'])
whole_text += outputs['text'][len(last_text):]
cur_text += outputs['text'][len(last_text):]
suffix_list = ["。", ":", "?", "!", ".", "?","!", "\n"]
if outputs['text'][len(last_text):].endswith(tuple(suffix_list)):
if outputs['text'][len(last_text):].endswith(".") and last_text[-1].isdigit():
pass
else:
if len(cur_hidden_state) > 0:
decoder(cur_hidden_state, pipeline, cur_text, tts,
codec_chunk_size, codec_padding_size, decoder_topk, wav)
cur_hidden_state = []
cur_text = ""
if outputs['stat'] == 'sl':
break
# print(outputs['text'])
last_text = outputs['text']
if len(cur_hidden_state) != 0:
decoder(cur_hidden_state, pipeline, cur_text, tts,
codec_chunk_size, codec_padding_size, decoder_topk, wav)

sf.write(configs.output_wav, torch.cat(wav, -1).squeeze().float().cpu().numpy(), 24000)
outputs['stat'] = 'sl'
outputs['last_id'] = None
print(whole_text)

if __name__ == '__main__':
configs = get_args()
pipeline = inferencePipeline(configs)
tts = llm2TTS(configs.model_path)
audio_processor = audioEncoderProcessor()
inference(pipeline, audio_processor, tts, configs)
Loading

0 comments on commit 9cb0c35

Please sign in to comment.