Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

628 строки
18 KiB
Plaintext

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

{
"cells": [
{
"cell_type": "code",
"execution_count": 4,
"id": "3dda6a69",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.neighbors import KNeighborsClassifier\n",
"from sklearn.datasets import fetch_20newsgroups\n",
"from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer \n",
"import numpy as np\n",
"import pandas as pd\n",
"from sklearn.metrics import confusion_matrix, classification_report, accuracy_score, roc_auc_score\n",
"from sklearn.pipeline import Pipeline"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "7fd6636b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['fasttext-wiki-news-subwords-300', 'conceptnet-numberbatch-17-06-300', 'word2vec-ruscorpora-300', 'word2vec-google-news-300', 'glove-wiki-gigaword-50', 'glove-wiki-gigaword-100', 'glove-wiki-gigaword-200', 'glove-wiki-gigaword-300', 'glove-twitter-25', 'glove-twitter-50', 'glove-twitter-100', 'glove-twitter-200', '__testing_word2vec-matrix-synopsis']\n"
]
}
],
"source": [
"import gensim.downloader\n",
"print(list(gensim.downloader.info()['models'].keys()))"
]
},
{
"cell_type": "markdown",
"id": "3f93b5f6",
"metadata": {},
"source": [
"# GloVe"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "be870586",
"metadata": {},
"outputs": [],
"source": [
"glove_model = gensim.downloader.load(\"glove-twitter-25\") # load glove vectors\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "599d6406",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[-0.96419 -0.60978 0.67449 0.35113 0.41317 -0.21241 1.3796\n",
" 0.12854 0.31567 0.66325 0.3391 -0.18934 -3.325 -1.1491\n",
" -0.4129 0.2195 0.8706 -0.50616 -0.12781 -0.066965 0.065761\n",
" 0.43927 0.1758 -0.56058 0.13529 ]\n"
]
},
{
"data": {
"text/plain": [
"[('dog', 0.9590820074081421),\n",
" ('monkey', 0.920357882976532),\n",
" ('bear', 0.9143136739730835),\n",
" ('pet', 0.9108031392097473),\n",
" ('girl', 0.8880629539489746),\n",
" ('horse', 0.8872726559638977),\n",
" ('kitty', 0.8870542049407959),\n",
" ('puppy', 0.886769711971283),\n",
" ('hot', 0.886525571346283),\n",
" ('lady', 0.8845519423484802)]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"print(glove_model['cat']) # word embedding for 'cat'\n",
"glove_model.most_similar(\"cat\") # show words that similar to word 'cat'"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "2db71cfb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.60927683"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"glove_model.similarity('cat', 'bus')"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "7788acf5",
"metadata": {},
"outputs": [],
"source": [
"categories = ['alt.atheism', 'comp.graphics', 'sci.space'] \n",
"remove = ('headers', 'footers', 'quotes')\n",
"twenty_train = fetch_20newsgroups(subset='train', shuffle=True, random_state=42, categories = categories, remove = remove )\n",
"twenty_test = fetch_20newsgroups(subset='test', shuffle=True, random_state=42, categories = categories, remove = remove )\n"
]
},
{
"cell_type": "markdown",
"id": "79dd1ac1",
"metadata": {},
"source": [
"# Векторизуем обучающую выборку\n",
"Получаем матрицу \"Документ-термин\""
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "0565dd1a",
"metadata": {},
"outputs": [],
"source": [
"vectorizer = CountVectorizer(stop_words='english')"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "a681a1d6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(1657, 23297)\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>00</th>\n",
" <th>000</th>\n",
" <th>0000</th>\n",
" <th>00000</th>\n",
" <th>000000</th>\n",
" <th>000005102000</th>\n",
" <th>000062david42</th>\n",
" <th>000100255pixel</th>\n",
" <th>00041032</th>\n",
" <th>0004136</th>\n",
" <th>...</th>\n",
" <th>zurbrin</th>\n",
" <th>zurich</th>\n",
" <th>zus</th>\n",
" <th>zvi</th>\n",
" <th>zwaartepunten</th>\n",
" <th>zwak</th>\n",
" <th>zwakke</th>\n",
" <th>zware</th>\n",
" <th>zwarte</th>\n",
" <th>zyxel</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 23297 columns</p>\n",
"</div>"
],
"text/plain": [
" 00 000 0000 00000 000000 000005102000 000062david42 000100255pixel \\\n",
"0 0 0 0 0 0 0 0 0 \n",
"1 0 0 0 0 0 0 0 0 \n",
"2 0 0 0 0 0 0 0 0 \n",
"3 0 0 0 0 0 0 0 0 \n",
"4 0 0 0 0 0 0 0 0 \n",
"\n",
" 00041032 0004136 ... zurbrin zurich zus zvi zwaartepunten zwak \\\n",
"0 0 0 ... 0 0 0 0 0 0 \n",
"1 0 0 ... 0 0 0 0 0 0 \n",
"2 0 0 ... 0 0 0 0 0 0 \n",
"3 0 0 ... 0 0 0 0 0 0 \n",
"4 0 0 ... 0 0 0 0 0 0 \n",
"\n",
" zwakke zware zwarte zyxel \n",
"0 0 0 0 0 \n",
"1 0 0 0 0 \n",
"2 0 0 0 0 \n",
"3 0 0 0 0 \n",
"4 0 0 0 0 \n",
"\n",
"[5 rows x 23297 columns]"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_data = vectorizer.fit_transform(twenty_train['data'])\n",
"CV_data=pd.DataFrame(train_data.toarray(), columns=vectorizer.get_feature_names_out())\n",
"print(CV_data.shape)\n",
"CV_data.head()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "b20aef46",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['00', '000', '0000', '00000', '000000', '000005102000', '000062david42',\n",
" '000100255pixel', '00041032', '0004136'],\n",
" dtype='object')\n"
]
}
],
"source": [
"# Создадим список слов, присутствующих в словаре.\n",
"words_vocab=CV_data.columns\n",
"print(words_vocab[0:10])"
]
},
{
"cell_type": "markdown",
"id": "d1893e86",
"metadata": {},
"source": [
"## Векторизуем с помощью GloVe\n",
"\n",
"Нужно для каждого документа сложить glove-вектора слов, из которых он состоит.\n",
"В результате получим вектор документа как сумму векторов слов, из него состоящих"
]
},
{
"cell_type": "markdown",
"id": "bc36b98d",
"metadata": {},
"source": [
"### Посмотрим на примере как будет работать векторизация"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0d6af65a",
"metadata": {},
"outputs": [],
"source": [
"text_data = ['Hello world I love python', 'This is a great computer game! 00 000 zyxel']\n",
"# Векторизуем с помощью обученного CountVectorizer\n",
"X = vectorizer.transform(text_data)\n",
"CV_text_data=pd.DataFrame(X.toarray(), columns=vectorizer.get_feature_names_out())\n",
"CV_text_data\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "11dda58a",
"metadata": {},
"outputs": [],
"source": [
"# Создадим датафрейм, в который будем сохранять вектор документа\n",
"glove_data=pd.DataFrame()\n",
"\n",
"# Пробегаем по каждой строке (по каждому документу)\n",
"for i in range(CV_text_data.shape[0]):\n",
" \n",
" # Вектор одного документа с размерностью glove-модели:\n",
" one_doc = np.zeros(25) \n",
" \n",
" # Пробегаемся по каждому документу, смотрим, какие слова документа присутствуют в нашем словаре\n",
" # Суммируем glove-вектора каждого известного слова в one_doc\n",
" for word in words_vocab[CV_text_data.iloc[i,:] >= 1]:\n",
" if word in glove_model.key_to_index.keys(): \n",
" print(word, ': ', glove_model[word])\n",
" one_doc += glove_model[word]\n",
" print(text_data[i], ': ', one_doc)\n",
" glove_data=glove_data.append(pd.DataFrame([one_doc])) \n",
"print('glove_data: ', glove_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ff68d8dc",
"metadata": {},
"outputs": [],
"source": [
"def text2vec(text_data):\n",
" \n",
" # Векторизуем с помощью обученного CountVectorizer\n",
" X = vectorizer.transform(text_data)\n",
" CV_text_data=pd.DataFrame(X.toarray(), columns=vectorizer.get_feature_names_out())\n",
" CV_text_data\n",
" # Создадим датафрейм, в который будем сохранять вектор документа\n",
" glove_data=pd.DataFrame()\n",
"\n",
" # Пробегаем по каждой строке (по каждому документу)\n",
" for i in range(CV_text_data.shape[0]):\n",
"\n",
" # Вектор одного документа с размерностью glove-модели:\n",
" one_doc = np.zeros(25) \n",
"\n",
" # Пробегаемся по каждому документу, смотрим, какие слова документа присутствуют в нашем словаре\n",
" # Суммируем glove-вектора каждого известного слова в one_doc\n",
" for word in words_vocab[CV_text_data.iloc[i,:] >= 1]:\n",
" if word in glove_model.key_to_index.keys(): \n",
" #print(word, ': ', glove_model[word])\n",
" one_doc += glove_model[word]\n",
" #print(text_data[i], ': ', one_doc)\n",
" glove_data = pd.concat([glove_data, pd.DataFrame([one_doc])], axis = 0)\n",
" #print('glove_data: ', glove_data)\n",
" return glove_data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b778776c",
"metadata": {},
"outputs": [],
"source": [
"\n",
"glove_data\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cb6edbdf",
"metadata": {},
"outputs": [],
"source": [
"one_doc"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1bdb459e",
"metadata": {},
"outputs": [],
"source": [
"train_data_glove = text2vec(twenty_train['data']);\n",
"train_data_glove"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3a7ea7c6",
"metadata": {},
"outputs": [],
"source": [
"train_data\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5ac20e79",
"metadata": {},
"outputs": [],
"source": [
"clf = KNeighborsClassifier(n_neighbors = 5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "08164a25",
"metadata": {},
"outputs": [],
"source": [
"clf.fit(train_data_glove, twenty_train['target'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e459faaf",
"metadata": {},
"outputs": [],
"source": [
"test_data_glove = text2vec(twenty_test['data']);"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d8144e75",
"metadata": {},
"outputs": [],
"source": [
"test_data_glove"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a69830f0",
"metadata": {},
"outputs": [],
"source": [
"predict = clf.predict(test_data_glove )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9ac5cf20",
"metadata": {},
"outputs": [],
"source": [
"print (confusion_matrix(twenty_test['target'], predict))\n",
"print(classification_report(twenty_test['target'], predict))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b8cce5a9",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "1b9bff90",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}