-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8687d44
commit 7b48a0c
Showing
8 changed files
with
5,246 additions
and
657 deletions.
There are no files selected for viewing
Binary file not shown.
2,473 changes: 2,473 additions & 0 deletions
2,473
notebooks/.ipynb_checkpoints/lookback_fasttext-checkpoint.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
91 changes: 91 additions & 0 deletions
91
notebooks/.ipynb_checkpoints/make_bert_embeddings-checkpoint.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "36ccfe91-a69d-4077-8c03-32f97e3ac02a", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"os.environ['TRANSFORMERS_CACHE'] = '/Users/baga_nuhkadiev/.cache/huggingface'" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "d50e667e-3209-42f9-8612-6e64405c5107", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from transformers import BertTokenizer, BertModel\n", | ||
"\n", | ||
"tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n", | ||
"model = BertModel.from_pretrained('bert-base-uncased')\n", | ||
"\n", | ||
"def get_bert_embeddings(text):\n", | ||
" inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)\n", | ||
" outputs = model(**inputs)\n", | ||
" # Get the embeddings from the last hidden state\n", | ||
" embeddings = outputs.last_hidden_state\n", | ||
" # Pool the embeddings (use mean pooling for simplicity)\n", | ||
" pooled_embeddings = torch.mean(embeddings, dim=1)\n", | ||
" return pooled_embeddings.detach().numpy()\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "3df9f25b-1dbb-496b-b037-ffe692ef3831", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"headlines = data['Text'].values\n", | ||
"embeddings = [get_bert_embeddings(headline) for headline in headlines]\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "7b42119a-45a4-4bfd-9013-e7005ebd18f1", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"bert_embeddings = np.array(embeddings)\n", | ||
"bert_embeddings.shape" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "c0489af6-64dc-4386-a1e9-049f2dcee9ac", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"with open('../bert_embeddings.npy', 'wb') as f:\n", | ||
" np.save(f, bert_embeddings)\n" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.8.3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
104 changes: 104 additions & 0 deletions
104
notebooks/.ipynb_checkpoints/make_fasttext_embeddings-checkpoint.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "da379302-b109-494a-a5b5-b97b52ae6e77", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"import torch.nn as nn\n", | ||
"import numpy as np\n", | ||
"import pandas as pd\n", | ||
"from torch.utils.data import TensorDataset, DataLoader\n", | ||
"from nltk import WordPunctTokenizer\n", | ||
"from nltk.corpus import stopwords\n", | ||
"import nltk\n", | ||
"import math" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "23567fa5-f680-4180-973a-416f7fa9e8cd", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"t_data = pd.DataFrame()\n", | ||
"\n", | ||
"combined_news_djia = pd.read_csv('../data/Combined_News_DJIA.csv')\n", | ||
"combined_news_djia['Top1'] = combined_news_djia['Top1'].apply(lambda x: x[2:-1] if x[0]=='b' else x)\n", | ||
"combined_news_djia['Top2'] = combined_news_djia['Top2'].apply(lambda x: x[2:-1] if x[0]=='b' else x)\n", | ||
"\n", | ||
"t_data['Text'] = combined_news_djia['Top1'] + \" \" + combined_news_djia['Top2']\n", | ||
"t_data['Date'] = combined_news_djia['Date']\n", | ||
"t_data.set_index('Date')\n", | ||
"\n", | ||
"nltk.download('stopwords')\n", | ||
"tokenizer = WordPunctTokenizer()\n", | ||
"stop_words = set(stopwords.words('english'))\n", | ||
"\n", | ||
"def process_headline(x):\n", | ||
" return \" \".join([w.lower() for w in tokenizer.tokenize(x) if not w.lower() in stop_words])\n", | ||
"\n", | ||
"t_data['Text'] = t_data['Text'].apply(process_headline)\n", | ||
"t_data" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "35019d74-fd71-44b2-aa59-311f4cc1b3dc", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import fasttext.util\n", | ||
"\n", | ||
"# Download FastText word vectors\n", | ||
"# fasttext.util.download_model('en', if_exists='ignore') # Download English language embeddings\n", | ||
"ft = fasttext.load_model('../../NLPstockPredictions/cc.en.300.bin') # Load the downloaded model\n", | ||
"\n", | ||
"def get_embeddings(data):\n", | ||
" combo = []\n", | ||
" for row in data.values:\n", | ||
" news_embedding = np.mean([ft.get_word_vector(word) for word in row[0].split()], axis=0)\n", | ||
" combo.append(news_embedding)\n", | ||
" return np.array(combo), data.values[:, 1]\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "2a214078-569d-4ef4-9fd2-ad8a7b114557", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"t_data = get_embeddings(t_data)\n", | ||
"with open('../fasttext_embeddings.npy', 'wb') as f:\n", | ||
" np.save(f, t_data[0])" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.8.3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Oops, something went wrong.