Skip to content

Commit 2defd56

Browse files
authored
Merge pull request #157 from xming521/dev
fix some bugs
2 parents bca3240 + 45a5797 commit 2defd56

13 files changed

Lines changed: 104 additions & 173 deletions

File tree

tests/test_full_pipe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010
import pytest
1111

12+
from weclone.utils.config import load_config
1213
from weclone.utils.config_models import DataModality, WCMakeDatasetConfig
13-
from weclone.utils.configV2 import load_config
1414
from weclone.utils.log import logger
1515

1616
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

tests/test_full_pipeV2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010
import pytest
1111

12+
from weclone.utils.config import load_config
1213
from weclone.utils.config_models import DataModality, WCMakeDatasetConfig
13-
from weclone.utils.configV2 import load_config
1414
from weclone.utils.log import logger
1515

1616
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

weclone/cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
import click
88
import commentjson
99

10+
from weclone.utils.config import load_config
1011
from weclone.utils.config_models import CliArgs
11-
from weclone.utils.configV2 import load_config
1212
from weclone.utils.log import capture_output, logger
1313

1414
cli_config: CliArgs | None = None
@@ -70,7 +70,7 @@ def cli():
7070
@apply_common_decorators()
7171
def qa_generator():
7272
"""处理聊天记录CSV文件,生成问答对数据集。"""
73-
from weclone.data.qa_generatorV2 import DataProcessor
73+
from weclone.data.qa_generator import DataProcessor
7474

7575
processor = DataProcessor()
7676
processor.main()

weclone/core/inference/offline_infer.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -42,22 +42,22 @@ def vllm_infer(
4242
raise ValueError("Pipeline parallel size should be smaller than the number of gpus.")
4343

4444
model_args, data_args, _, generating_args = get_infer_args(
45-
dict(
46-
model_name_or_path=model_name_or_path,
47-
adapter_name_or_path=adapter_name_or_path,
48-
dataset=dataset,
49-
dataset_dir=dataset_dir,
50-
template=template,
51-
cutoff_len=cutoff_len,
52-
max_samples=max_samples,
53-
preprocessing_num_workers=16,
54-
vllm_config=vllm_config,
55-
temperature=temperature,
56-
top_p=top_p,
57-
top_k=top_k,
58-
max_new_tokens=max_new_tokens,
59-
repetition_penalty=repetition_penalty,
60-
)
45+
{
46+
"model_name_or_path": model_name_or_path,
47+
"adapter_name_or_path": adapter_name_or_path,
48+
"dataset": dataset,
49+
"dataset_dir": dataset_dir,
50+
"template": template,
51+
"cutoff_len": cutoff_len,
52+
"max_samples": max_samples,
53+
"preprocessing_num_workers": 16,
54+
"vllm_config": vllm_config,
55+
"temperature": temperature,
56+
"top_p": top_p,
57+
"top_k": top_k,
58+
"max_new_tokens": max_new_tokens,
59+
"repetition_penalty": repetition_penalty,
60+
}
6161
)
6262

6363
tokenizer_module = load_tokenizer(model_args)

weclone/data/chat_parsers/wechat_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from tqdm import tqdm
88

99
from weclone.data.models import ChatMessage
10-
from weclone.data.qa_generatorV2 import DataProcessor
10+
from weclone.data.qa_generator import DataProcessor
1111
from weclone.utils.log import logger
1212

1313
data_dir = "./dataset/wechat/dat"

weclone/data/clean/strategies.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ def judge(self, data: List[QaPairV2]) -> None:
113113
inputs.append(prompt_value.to_string())
114114
outputs = vllm_infer(
115115
inputs,
116-
self.make_dataset_config["model_name_or_path"],
117-
template=self.make_dataset_config["template"],
116+
self.make_dataset_config.model_name_or_path,
117+
template=self.make_dataset_config.template,
118118
temperature=0,
119119
guided_decoding_class=QaPairScore,
120120
repetition_penalty=1.2,
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
)
2323
from weclone.data.strategies import LLMStrategy, TimeWindowStrategy
2424
from weclone.data.utils import ImageToTextProcessor, check_image_file_exists
25+
from weclone.utils.config import load_config
2526
from weclone.utils.config_models import DataModality, PlatformType, WCMakeDatasetConfig
26-
from weclone.utils.configV2 import load_config
2727
from weclone.utils.log import logger
2828

2929

weclone/eval/test_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
from tqdm import tqdm
88

99
from weclone.utils.config import load_config
10+
from weclone.utils.config_models import WCInferConfig
1011

11-
config = load_config("web_demo")
12+
config = cast(WCInferConfig, load_config("web_demo"))
1213

1314
config = {
14-
"default_prompt": config["default_system"],
15+
"default_prompt": config.default_system,
1516
"model": "gpt-3.5-turbo",
1617
"history_len": 15,
1718
}

weclone/eval/web_demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55

66
def main():
7-
config = load_config("web_demo")
7+
load_config("web_demo")
88
demo = create_web_demo()
99
demo.queue()
1010
demo.launch(server_name="0.0.0.0", share=True, inbrowser=True)

weclone/train/train_sft.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from llamafactory.train.tuner import run_exp
88

99
from weclone.data.clean.strategies import LLMCleaningStrategy
10+
from weclone.utils.config import load_config
1011
from weclone.utils.config_models import WCMakeDatasetConfig, WCTrainSftConfig
11-
from weclone.utils.configV2 import load_config
1212
from weclone.utils.log import logger
1313

1414

0 commit comments

Comments
 (0)