Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 106 additions & 125 deletions embeddings_converter.py
Original file line number Diff line number Diff line change
@@ -1,142 +1,123 @@
import os
import sys
import logging
import google.generativeai as genai
from PIL import Image
from dotenv import load_dotenv
from pathlib import Path
from typing import Optional, Union
from google.api_core import exceptions as google_exceptions
from PIL import UnidentifiedImageError

logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
from typing import Optional, Union, List
from sentence_transformers import SentenceTransformer
from transformers import CLIPProcessor, CLIPModel
import pdfplumber
import torch
from PIL import Image
from io import BytesIO

# Logging setup
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class EmbeddingError(Exception):
"""Base exception for embedding-related errors"""
pass

class APIKeyError(EmbeddingError):
"""Raised when there are issues with the API key"""
pass

class FileProcessingError(EmbeddingError):
"""Raised when there are issues processing files"""
pass

def validate_api_key() -> None:
"""Validate that the API key is present and properly configured"""
api_key = os.getenv('GOOGLE_API_KEY')
if not api_key:
raise APIKeyError("API key not found. Please set GOOGLE_API_KEY in your .env file")
if api_key == 'your_api_key_here':
raise APIKeyError("Please replace the default API key with your actual Gemini API key")

def get_text_embedding(text_content: str) -> Optional[dict]:
"""Convert text content to embeddings using Gemini API."""
model = 'models/embedding-001'
# Custom exceptions
class EmbeddingError(Exception): pass
class FileProcessingError(EmbeddingError): pass

# Text embedding using SentenceTransformer
def get_hf_text_embedding(text: str, model_name: str = "sentence-transformers/all-MiniLM-L6-v2") -> list:
try:
embedding = genai.embed_content(
model=model,
content=text_content,
task_type="retrieval_document"
)
return embedding
except google_exceptions.InvalidArgument as e:
logger.error(f"Invalid argument for text embedding: {e}")
raise EmbeddingError(f"Invalid input for text embedding: {e}")
except google_exceptions.PermissionDenied as e:
logger.error(f"API permission denied: {e}")
raise APIKeyError(f"API authentication failed: {e}")
model = SentenceTransformer(model_name)
embedding = model.encode(text, convert_to_tensor=False)
return embedding.tolist()
except Exception as e:
logger.error(f"Unexpected error in text embedding: {e}")
raise EmbeddingError(f"Failed to generate text embedding: {e}")
logger.error(f"Hugging Face text embedding failed: {e}")
raise EmbeddingError(f"Failed to generate HF text embedding: {e}")

def get_image_embedding(image_path: Union[str, Path]) -> Optional[dict]:
"""Convert image to embeddings using Gemini API."""
# Image embedding using CLIP
def get_clip_image_embedding(pil_image: Image.Image) -> list:
try:
model = genai.GenerativeModel('gemini-pro-vision')
image = Image.open(image_path)
response = model.get_response(image)
return response.candidates[0]
except UnidentifiedImageError as e:
logger.error(f"Failed to open image {image_path}: {e}")
raise FileProcessingError(f"Invalid or corrupted image file: {e}")
except google_exceptions.PermissionDenied as e:
logger.error(f"API permission denied: {e}")
raise APIKeyError(f"API authentication failed: {e}")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
inputs = processor(images=pil_image, return_tensors="pt")
with torch.no_grad():
outputs = model.get_image_features(**inputs)
return outputs.squeeze().tolist()
except Exception as e:
logger.error(f"Unexpected error in image embedding: {e}")
logger.error(f"CLIP image embedding failed: {e}")
raise EmbeddingError(f"Failed to generate image embedding: {e}")

def process_file(file_path: Union[str, Path]) -> Optional[dict]:
"""Process a file and return its embedding based on file type."""
file_path = Path(file_path)

if not file_path.exists():
raise FileProcessingError(f"File not found: {file_path}")

if not file_path.is_file():
raise FileProcessingError(f"Not a file: {file_path}")

# Handle text files
if file_path.suffix.lower() in ['.txt', '.md', '.py', '.json']:
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
return get_text_embedding(content)
except UnicodeDecodeError as e:
logger.error(f"Failed to read text file {file_path}: {e}")
raise FileProcessingError(f"Error reading text file (encoding issue): {e}")
except IOError as e:
logger.error(f"IO error reading file {file_path}: {e}")
raise FileProcessingError(f"Error reading file: {e}")

# Handle image files
elif file_path.suffix.lower() in ['.jpg', '.jpeg', '.png', '.gif']:
return get_image_embedding(file_path)

else:
msg = f"Unsupported file type: {file_path.suffix}"
logger.warning(msg)
raise FileProcessingError(msg)
# Extract text from all PDF pages
def extract_text_from_pdf(pdf_path: Union[str, Path]) -> str:
try:
with pdfplumber.open(pdf_path) as pdf:
full_text = "\n".join([page.extract_text() or "" for page in pdf.pages])
return full_text.strip()
except Exception as e:
logger.error(f"Failed to extract text from PDF: {e}")
raise FileProcessingError(f"PDF processing error: {e}")

def main():
# Extract tables as flattened text
def extract_tables_from_pdf(pdf_path: Union[str, Path]) -> List[str]:
tables_text = []
try:
with pdfplumber.open(pdf_path) as pdf:
for page in pdf.pages:
tables = page.extract_tables()
for table in tables:
table_text = "\n".join([", ".join(row) for row in table if row])
tables_text.append(table_text)
except Exception as e:
logger.warning(f"Failed to extract tables: {e}")
return tables_text

# Extract images as PIL images
def extract_images_from_pdf(pdf_path: Union[str, Path]) -> List[Image.Image]:
images = []
try:
# Validate API key first
validate_api_key()

# Example usage
text_file = "example.txt"
image_file = "example.jpg"

# Process text file
if os.path.exists(text_file):
logger.info("Processing text file...")
try:
embedding = process_file(text_file)
logger.info(f"Text embedding generated successfully. Shape: {len(embedding.values)}")
except EmbeddingError as e:
logger.error(f"Failed to process text file: {e}")

# Process image file
if os.path.exists(image_file):
logger.info("Processing image file...")
try:
embedding = process_file(image_file)
logger.info("Image embedding generated successfully")
except EmbeddingError as e:
logger.error(f"Failed to process image file: {e}")

except APIKeyError as e:
logger.critical(f"API Key Error: {e}")
sys.exit(1)
with pdfplumber.open(pdf_path) as pdf:
for page in pdf.pages:
for image_dict in page.images:
x0, top, x1, bottom = image_dict["x0"], image_dict["top"], image_dict["x1"], image_dict["bottom"]
bbox = (x0, top, x1, bottom)
cropped = page.crop(bbox)
img = cropped.to_image(resolution=300).original
pil_img = Image.open(BytesIO(img.tobytes()))
images.append(pil_img)
except Exception as e:
logger.critical(f"Critical error: {e}")
sys.exit(1)
logger.warning(f"Image extraction failed: {e}")
return images

# Full processor
def process_pdf_file(file_path: Union[str, Path]) -> dict:
file_path = Path(file_path)
if not file_path.exists() or file_path.suffix.lower() != '.pdf':
raise FileProcessingError(f"Invalid file: {file_path}")

result = {}

# Text Embedding
text = extract_text_from_pdf(file_path)
if text:
result["text_embedding"] = get_hf_text_embedding(text)

# Table Embeddings
table_texts = extract_tables_from_pdf(file_path)
result["table_embeddings"] = [get_hf_text_embedding(t) for t in table_texts if t]

# Image Embeddings
images = extract_images_from_pdf(file_path)
result["image_embeddings"] = [get_clip_image_embedding(img) for img in images]

return result

def main():
pdf_file = "ltimindtree_annual_report.pdf"

if os.path.exists(pdf_file):
logger.info(f"Processing PDF: {pdf_file}")
try:
embeddings = process_pdf_file(pdf_file)
logger.info(f"✅ Text embedding vector length: {len(embeddings.get('text_embedding', []))}")
logger.info(f"📊 Table embeddings: {len(embeddings.get('table_embeddings', []))} tables processed")
logger.info(f"🖼️ Image embeddings: {len(embeddings.get('image_embeddings', []))} images processed")
except EmbeddingError as e:
logger.error(f"Processing failed: {e}")
else:
logger.error(f"PDF not found: {pdf_file}")

if __name__ == "__main__":
main()
main()
106 changes: 106 additions & 0 deletions hf_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import os
import logging
from pathlib import Path
from typing import Optional, Union
from PIL import Image
import torch
from transformers import AutoModel, AutoTokenizer, AutoImageProcessor

# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Custom error classes
class EmbeddingError(Exception):
pass

class FileProcessingError(EmbeddingError):
pass

# Initialize the BGE-M3 model and processor
MODEL_NAME = "BAAI/bge-m3"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

try:
# Load text components
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
text_model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)

# Load image components (assuming the same model can handle both)
image_processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
logger.info(f"Successfully loaded {MODEL_NAME} model on {DEVICE}")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise EmbeddingError(f"Model loading failed: {e}")

def get_text_embedding(text: str) -> list:
"""Generate text embeddings using BAAI/bge-m3 model."""
try:
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(DEVICE)
with torch.no_grad():
outputs = text_model(**inputs)
# Use the [CLS] token embedding as the sentence embedding
embedding = outputs.last_hidden_state[:, 0, :].cpu().squeeze().numpy()
return embedding.tolist()
except Exception as e:
logger.error(f"Text embedding failed: {e}")
raise EmbeddingError(f"Failed to generate text embedding: {e}")

def get_image_embedding(image_path: Union[str, Path]) -> list:
"""Generate image embeddings using BAAI/bge-m3 model."""
try:
image = Image.open(image_path)
inputs = image_processor(images=image, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = text_model(**inputs)
# Use the [CLS] token embedding as the image embedding
embedding = outputs.last_hidden_state[:, 0, :].cpu().squeeze().numpy()
return embedding.tolist()
except Exception as e:
logger.error(f"Image embedding failed: {e}")
raise EmbeddingError(f"Failed to generate image embedding: {e}")

def process_file(file_path: Union[str, Path]) -> Optional[dict]:
file_path = Path(file_path)

if not file_path.exists():
raise FileProcessingError(f"File not found: {file_path}")

if not file_path.is_file():
raise FileProcessingError(f"Not a file: {file_path}")

if file_path.suffix.lower() in ['.txt', '.md', '.py', '.json']:
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
return {"embedding": get_text_embedding(content)}
except Exception as e:
logger.error(f"Text processing error: {e}")
raise FileProcessingError(f"Text file processing error: {e}")
elif file_path.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']:
try:
return {"embedding": get_image_embedding(file_path)}
except Exception as e:
logger.error(f"Image processing error: {e}")
raise FileProcessingError(f"Image file processing error: {e}")
else:
raise FileProcessingError(f"Unsupported file type: {file_path.suffix}")

# Entry point
def main():
test_file = "test.txt" # Change this to test with different files

if os.path.exists(test_file):
logger.info(f"Processing file: {test_file}")
try:
result = process_file(test_file)
logger.info("Embedding generated successfully")
logger.info(f"Embedding length: {len(result['embedding'])}")
except (EmbeddingError, FileProcessingError) as e:
logger.error(f"Failed to process file: {e}")

if __name__ == "__main__":
main()