-
Notifications
You must be signed in to change notification settings - Fork 121
/
Copy pathworkflow.py
147 lines (128 loc) · 5.56 KB
/
workflow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import logging
import os
from typing import List
import httpx
from pydantic import BaseModel, Field
from tensorlake import RemoteGraph
from tensorlake.functions_sdk.data_objects import File
from tensorlake.functions_sdk.graph import Graph
from tensorlake.functions_sdk.functions import tensorlake_function
from tensorlake.functions_sdk.image import Image
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Data Models
class WebsiteData(BaseModel):
url: str = Field(..., description="URL of the website")
content: str = Field(..., description="Content of the website")
class SummarizeWebsite(BaseModel):
summary: str = Field(..., description="Summary of the website content")
class Audio(BaseModel):
file: File = Field(..., description="Audio file of the generated TTS")
# Define custom images
scraper_image = (
Image()
.name("tensorlake/scraper-image")
.run("pip install httpx")
)
openai_image = (
Image()
.name("tensorlake/openai-image")
.run("pip install openai")
)
elevenlabs_image = (
Image()
.name("tensorlake/elevenlabs-image")
.run("pip install elevenlabs")
)
@tensorlake_function(image=scraper_image)
def scrape_website(url: str) -> WebsiteData:
"""Scrape the website content."""
try:
response = httpx.get(f"https://r.jina.ai/{url}")
response.raise_for_status()
return WebsiteData(url=url, content=response.text)
except Exception as e:
logging.error(f"Error scraping website: {e}")
raise
@tensorlake_function(image=openai_image)
def summarize_website(website_data: WebsiteData) -> SummarizeWebsite:
"""Summarize the website content."""
try:
import openai
client = openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
response = client.chat.completions.create(
model="gpt-4",
messages=[
{
"role": "system",
"content": "You are a helpful assistant that summarizes website content in a form which is hearable as a podcast. Remove all the marketing content and only keep the core content in the summary. Call it Tensorlake Daily, don't add words such as host, pause, etc. The transcript of the summary is going to be fed into a TTS model which will add the pauses.",
},
{
"role": "user",
"content": f"Summarize the following content: {website_data.content}",
},
],
max_tokens=3000,
temperature=0.7,
)
summary = response.choices[0].message.content.strip()
return SummarizeWebsite(summary=summary)
except Exception as e:
logging.error(f"Error summarizing website: {e}")
raise
@tensorlake_function(image=elevenlabs_image)
def generate_tts(summary: SummarizeWebsite) -> Audio:
"""Generate TTS for the summary using elevenlabs."""
try:
import elevenlabs
from elevenlabs import save
voice = "Rachel" # You can choose a different voice if needed
client = elevenlabs.ElevenLabs(api_key=os.environ.get("ELEVENLABS_API_KEY"))
audio = client.generate(text=summary.summary, voice=voice)
save(audio, "tensorlake-daily.mp3")
with open("tensorlake-daily.mp3", "rb") as f:
return Audio(file=File(data=f.read()))
except Exception as e:
logging.error(f"Error generating TTS: {e}")
raise
def create_graph() -> Graph:
g = Graph(name="tensorlake-daily-website-summarizer", start_node=scrape_website)
g.add_edge(scrape_website, summarize_website)
g.add_edge(summarize_website, generate_tts)
return g
def deploy_graphs(server_url: str):
graph = create_graph()
RemoteGraph.deploy(graph, server_url=server_url)
logging.info("Graph deployed successfully")
def run_workflow(mode: str, server_url: str = 'http://localhost:8900'):
if mode == 'in-process-run':
graph = create_graph()
elif mode == 'remote-run':
graph = RemoteGraph.by_name("tensorlake-daily-website-summarizer", server_url=server_url)
else:
raise ValueError("Invalid mode. Choose 'in-process-run' or 'remote-run'.")
url = "https://www.cidrap.umn.edu/avian-influenza-bird-flu/h5n1-avian-flu-virus-detected-wastewater-10-texas-cities"
logging.info(f"Processing URL: {url}")
invocation_id = graph.run(block_until_done=True, url=url)
output: List[Audio] = graph.output(invocation_id, "generate_tts")
if output:
with open("tensorlake-daily-saved.mp3", "wb") as f:
f.write(output[0].file.data)
logging.info("Audio file saved as tensorlake-daily-saved.mp3")
else:
logging.warning("No output found")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Run Tensorlake Daily Website Summarizer")
parser.add_argument('--mode', choices=['in-process-run', 'remote-deploy', 'remote-run'], required=True,
help='Mode of operation: in-process-run, remote-deploy, or remote-run')
parser.add_argument('--server-url', default='http://localhost:8900', help='Indexify server URL for remote mode or deployment')
args = parser.parse_args()
try:
if args.mode == 'remote-deploy':
deploy_graphs(args.server_url)
elif args.mode in ['in-process-run', 'remote-run']:
run_workflow(args.mode, args.server_url)
logging.info("Operation completed successfully!")
except Exception as e:
logging.error(f"An error occurred during execution: {str(e)}")