This project demonstrates the implementation of a text summarization model using the BART (Bidirectional and Auto-Regressive Transformers) architecture from the transformers
library. The model is trained on the CNN/DailyMail dataset to generate concise and coherent summaries of input articles.
- Introduction
- Features
- Installation
- Usage
- Training the Model
- Generating Summaries
- Memory Management
- Files in the Repository
- Contributing
- License
This project showcases a text summarization system built using the BART model from Hugging Face's transformers
library. The system is trained on the CNN/DailyMail dataset and uses beam search to generate high-quality summaries.
- Text Summarization: Generates concise summaries of input articles.
- Beam Search: Uses beam search decoding to improve the quality of generated summaries.
- Memory Management: Implements memory usage monitoring to ensure efficient use of resources.
- Python 3.6+
transformers
librarydatasets
librarytorch
psutil
(for memory management)
- Clone the repository:
git clone https://github.com/prempraneethkota/BriefBot.git cd BriefBot
- Install the required libraries:
pip install transformers datasets torch psutil
- Download the
model.safetensors
file from the provided link and place it in thesummarization_model
directory.
- Run the interactive script:
python Run.py
- Enter an article when prompted to receive a summary. Type
exit
to quit.
Please enter an article to summarize (or type 'exit' to quit):
[Your article here]
Generated Summary:
[Summary]
The training script (brief3.py) sets up the model, tokenizer, and training loop.
import numpy as np
import torch
from datasets import load_dataset
from transformers import (AutoTokenizer, AutoModelForSeq2SeqLM,
Seq2SeqTrainingArguments, Seq2SeqTrainer, default_data_collator)
import os
import psutil
# Function to monitor memory usage
def limit_memory_usage(max_usage_percent=85):
process = psutil.Process(os.getpid())
max_usage = (max_usage_percent / 100) * psutil.virtual_memory().total
if process.memory_info().rss > max_usage:
raise MemoryError("Memory usage exceeded limit.")
# Set device to CPU
device = torch.device('cpu')
# Load the dataset
dataset = load_dataset("cnn_dailymail", "3.0.0")
train_data = dataset['train'].select(range(250)) # Using a smaller subset for demonstration
valid_data = dataset['validation'].select(range(150))
# Load pre-trained tokenizer and model
model_name = "facebook/bart-base" # Using a smaller model for faster training
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# Move the model to CPU
model.to(device)
# Preprocess function for tokenizing data
def preprocess_function(examples):
inputs = tokenizer(examples["article"], max_length=512, truncation=True, padding="max_length")
targets = tokenizer(examples["highlights"], max_length=128, truncation=True, padding="max_length")
inputs["labels"] = targets["input_ids"]
return inputs
# Apply preprocessing to datasets
train_dataset = train_data.map(preprocess_function, batched=True)
valid_dataset = valid_data.map(preprocess_function, batched=True)
# Set dataset format for PyTorch
train_dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
valid_dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
# Custom data collator function to handle labels efficiently
def collate_fn(batch):
batch = default_data_collator(batch)
if "labels" in batch:
batch["labels"] = torch.tensor(np.array(batch["labels"]), dtype=torch.int64)
return batch
# Training arguments with increased logging verbosity
training_args = Seq2SeqTrainingArguments(
output_dir="./results",
eval_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=3, # Adjust batch size for CPU usage
per_device_eval_batch_size=3, # Adjust eval batch size for CPU usage
weight_decay=0.01,
save_total_limit=4,
num_train_epochs=4,
predict_with_generate=True,
logging_dir='./logs',
logging_steps=10, # Log every 10 steps
report_to="all", # Report to all available logging integrations (e.g., console, TensorBoard)
logging_first_step=True, # Log the very first step as well
)
# Trainer with memory monitoring
class MemoryLimitedSeq2SeqTrainer(Seq2SeqTrainer):
def training_step(self, *args, **kwargs):
limit_memory_usage(max_usage_percent=85)
return super().training_step(*args, **kwargs)
trainer = MemoryLimitedSeq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=valid_dataset,
data_collator=collate_fn,
)
# Train the model
trainer.train()
# Save the model and tokenizer
trainer.save_model("./summarization_model")
tokenizer.save_pretrained("./summarization_model")
The generate_summary function takes an article as input and returns its summary using the trained model.
def generate_summary(text):
inputs = tokenizer([text], max_length=512, truncation=True, return_tensors="pt").to(device)
generation_kwargs = {
"num_beams": 4,
"max_length": 150,
"early_stopping": True,
"no_repeat_ngram_size": 3,
"forced_bos_token_id": model.config.bos_token_id,
}
summary_ids = model.generate(
inputs["input_ids"],
**generation_kwargs
)
return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
# Example usage
input_article = """
Lung cancer is a type of cancer that begins in the lungs. Your lungs are two spongy organs in your chest that take in oxygen when you inhale and release carbon dioxide when you exhale.
Lung cancer is the leading cause of cancer deaths worldwide. People who smoke have the greatest risk of lung cancer, though lung cancer can also occur in people who have never smoked.
The risk of lung cancer increases with the length of time and number of cigarettes you've smoked. If you quit smoking, even after smoking for many years, you can significantly reduce your chances of developing lung cancer.
"""
print("Generated Summary:")
print(generate_summary(input_article))
The project includes a function to monitor and limit memory usage to ensure efficient use of resources.
def limit_memory_usage(max_usage_percent=85):
process = psutil.Process(os.getpid())
max_usage = (max_usage_percent / 100) * psutil.virtual_memory().total
if process.memory_info().rss > max_usage,
raise MemoryError("Memory usage exceeded limit.")
- Run.py
- Brief3.py
- Readme.md
- summarization_model
Feel free to fork this repository and contribute by submitting pull requests. For major changes, please open an issue first to discuss what you would like to change.
Note: To use the model, ensure that you download the "model.safetensors" file from the link provided and place it in the "summarization_model directory".
The announcement by Jaguar that it's repositioning itself as an ultra-luxury EV brand will also see the British carmaker's classic iconography undergo a complete redesign. The famous leaping tiger will remain, but in a more angular appearance. Also set for a makeover are logos, typefaces, and color palette. The radical revamp follows the disappearance of the famous roaring cat logo and equally celebrated leaping jaguar hood ornament from the Jaguar lineup, and traditionalists are wondering how these latest logos will appear on cars.