-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlangchain_custom_llm.py
More file actions
92 lines (78 loc) · 2.94 KB
/
langchain_custom_llm.py
File metadata and controls
92 lines (78 loc) · 2.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from typing import Optional
from langchain.llms.base import LLM
import torch
from typing import List, Mapping, Optional, Any
from pydantic import Field
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, pipeline, AutoModelForSeq2SeqLM
device = torch.device("cuda:0" if torch.cuda.is_available() else torch.device('cpu'))
print('device:', device)
class CustomChain(LLM):
model_folder_path: str = Field(None, alias='model_folder_path')
model_name: str = Field(None, alias='model_name')
model: Any = None
tokenizer: Any = None
# # all the optional arguments
backend: Optional[str] = 't5'
temp: Optional[float] = 0.7
top_p: Optional[float] = 0.1
top_k: Optional[int] = 40
n_batch: Optional[int] = 8
n_threads: Optional[int] = 4
n_predict: Optional[int] = 256
max_tokens: Optional[int] = 200
repeat_last_n: Optional[int] = 64
repeat_penalty: Optional[float] = 1.18
def __init__(self):
super(CustomChain, self).__init__()
self.model_name = 'google/flan-t5-base'
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, cache_dir='cache_dir')
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name, cache_dir='cache_dir', torch_dtype=torch.bfloat16)
self.model.to(device)
@property
def _get_model_default_parameters(self):
return {
"max_tokens": 400,
"n_predict": 400,
"top_k": 40,
"top_p": 0.1,
"temp": 0.0001,
"n_batch": 1,
"repeat_penalty": 1.18,
"repeat_last_n": 64,
}
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""
Get all the identifying parameters
"""
return {
'model_name' : self.model_name,
'model_path' : self.model_folder_path,
'model_parameters': self._get_model_default_parameters
}
@property
def _llm_type(self) -> str:
return 't5'
def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs) -> str:
"""
Args:
prompt: The prompt to pass into the model.
stop: A list of strings to stop generation when encountered
Returns:
The string generated by the model
"""
inputs = self.tokenizer(prompt, return_tensors='pt')
resposne = self.tokenizer.decode(
self.model.generate(
inputs['input_ids'].to(device),
max_new_tokens=200,
)[0],
skip_special_tokens=True
)
return resposne
# db = SQLDatabase.from_uri("sqlite:///assets/Chinook.db")
llm_chain = CustomChain()
output = llm_chain('Who is the president of US?')
print(output)
# db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)
# db_chain.run("How many employees are there?")