Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

1688 строки
62 KiB
Plaintext

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "3dda6a69",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.neighbors import KNeighborsClassifier\n",
"from sklearn.datasets import fetch_20newsgroups\n",
"from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer \n",
"import numpy as np\n",
"import pandas as pd\n",
"from sklearn.metrics import confusion_matrix, classification_report, accuracy_score, roc_auc_score\n",
"from sklearn.pipeline import Pipeline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "7fd6636b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['fasttext-wiki-news-subwords-300', 'conceptnet-numberbatch-17-06-300', 'word2vec-ruscorpora-300', 'word2vec-google-news-300', 'glove-wiki-gigaword-50', 'glove-wiki-gigaword-100', 'glove-wiki-gigaword-200', 'glove-wiki-gigaword-300', 'glove-twitter-25', 'glove-twitter-50', 'glove-twitter-100', 'glove-twitter-200', '__testing_word2vec-matrix-synopsis']\n"
]
}
],
"source": [
"import gensim.downloader\n",
"print(list(gensim.downloader.info()['models'].keys()))"
]
},
{
"cell_type": "markdown",
"id": "3f93b5f6",
"metadata": {},
"source": [
"# GloVe"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "be870586",
"metadata": {},
"outputs": [],
"source": [
"glove_model = gensim.downloader.load(\"glove-twitter-25\") # load glove vectors\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "599d6406",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[-0.96419 -0.60978 0.67449 0.35113 0.41317 -0.21241 1.3796\n",
" 0.12854 0.31567 0.66325 0.3391 -0.18934 -3.325 -1.1491\n",
" -0.4129 0.2195 0.8706 -0.50616 -0.12781 -0.066965 0.065761\n",
" 0.43927 0.1758 -0.56058 0.13529 ]\n"
]
},
{
"data": {
"text/plain": [
"[('dog', 0.9590820074081421),\n",
" ('monkey', 0.920357882976532),\n",
" ('bear', 0.9143136739730835),\n",
" ('pet', 0.9108031392097473),\n",
" ('girl', 0.8880629539489746),\n",
" ('horse', 0.8872726559638977),\n",
" ('kitty', 0.8870542049407959),\n",
" ('puppy', 0.886769711971283),\n",
" ('hot', 0.886525571346283),\n",
" ('lady', 0.8845519423484802)]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"print(glove_model['cat']) # word embedding for 'cat'\n",
"glove_model.most_similar(\"cat\") # show words that similar to word 'cat'"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "2db71cfb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.60927683"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"glove_model.similarity('cat', 'bus')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "7788acf5",
"metadata": {},
"outputs": [],
"source": [
"categories = ['alt.atheism', 'comp.graphics', 'sci.space'] \n",
"remove = ('headers', 'footers', 'quotes')\n",
"twenty_train = fetch_20newsgroups(subset='train', shuffle=True, random_state=42, categories = categories, remove = remove )\n",
"twenty_test = fetch_20newsgroups(subset='test', shuffle=True, random_state=42, categories = categories, remove = remove )\n"
]
},
{
"cell_type": "markdown",
"id": "79dd1ac1",
"metadata": {},
"source": [
"# Векторизуем обучающую выборку\n",
"Получаем матрицу \"Документ-термин\""
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "0565dd1a",
"metadata": {},
"outputs": [],
"source": [
"vectorizer = CountVectorizer(stop_words='english')"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "a681a1d6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(1657, 23297)\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>00</th>\n",
" <th>000</th>\n",
" <th>0000</th>\n",
" <th>00000</th>\n",
" <th>000000</th>\n",
" <th>000005102000</th>\n",
" <th>000062david42</th>\n",
" <th>000100255pixel</th>\n",
" <th>00041032</th>\n",
" <th>0004136</th>\n",
" <th>...</th>\n",
" <th>zurbrin</th>\n",
" <th>zurich</th>\n",
" <th>zus</th>\n",
" <th>zvi</th>\n",
" <th>zwaartepunten</th>\n",
" <th>zwak</th>\n",
" <th>zwakke</th>\n",
" <th>zware</th>\n",
" <th>zwarte</th>\n",
" <th>zyxel</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 23297 columns</p>\n",
"</div>"
],
"text/plain": [
" 00 000 0000 00000 000000 000005102000 000062david42 000100255pixel \\\n",
"0 0 0 0 0 0 0 0 0 \n",
"1 0 0 0 0 0 0 0 0 \n",
"2 0 0 0 0 0 0 0 0 \n",
"3 0 0 0 0 0 0 0 0 \n",
"4 0 0 0 0 0 0 0 0 \n",
"\n",
" 00041032 0004136 ... zurbrin zurich zus zvi zwaartepunten zwak \\\n",
"0 0 0 ... 0 0 0 0 0 0 \n",
"1 0 0 ... 0 0 0 0 0 0 \n",
"2 0 0 ... 0 0 0 0 0 0 \n",
"3 0 0 ... 0 0 0 0 0 0 \n",
"4 0 0 ... 0 0 0 0 0 0 \n",
"\n",
" zwakke zware zwarte zyxel \n",
"0 0 0 0 0 \n",
"1 0 0 0 0 \n",
"2 0 0 0 0 \n",
"3 0 0 0 0 \n",
"4 0 0 0 0 \n",
"\n",
"[5 rows x 23297 columns]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_data = vectorizer.fit_transform(twenty_train['data'])\n",
"CV_data=pd.DataFrame(train_data.toarray(), columns=vectorizer.get_feature_names_out())\n",
"print(CV_data.shape)\n",
"CV_data.head()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "b20aef46",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['00', '000', '0000', '00000', '000000', '000005102000', '000062david42',\n",
" '000100255pixel', '00041032', '0004136'],\n",
" dtype='object')\n"
]
}
],
"source": [
"# Создадим список слов, присутствующих в словаре.\n",
"words_vocab=CV_data.columns\n",
"print(words_vocab[0:10])"
]
},
{
"cell_type": "markdown",
"id": "d1893e86",
"metadata": {},
"source": [
"## Векторизуем с помощью GloVe\n",
"\n",
"Нужно для каждого документа сложить glove-вектора слов, из которых он состоит.\n",
"В результате получим вектор документа как сумму векторов слов, из него состоящих"
]
},
{
"cell_type": "markdown",
"id": "bc36b98d",
"metadata": {},
"source": [
"### Посмотрим на примере как будет работать векторизация"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "0d6af65a",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>00</th>\n",
" <th>000</th>\n",
" <th>0000</th>\n",
" <th>00000</th>\n",
" <th>000000</th>\n",
" <th>000005102000</th>\n",
" <th>000062david42</th>\n",
" <th>000100255pixel</th>\n",
" <th>00041032</th>\n",
" <th>0004136</th>\n",
" <th>...</th>\n",
" <th>zurbrin</th>\n",
" <th>zurich</th>\n",
" <th>zus</th>\n",
" <th>zvi</th>\n",
" <th>zwaartepunten</th>\n",
" <th>zwak</th>\n",
" <th>zwakke</th>\n",
" <th>zware</th>\n",
" <th>zwarte</th>\n",
" <th>zyxel</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>2 rows × 23297 columns</p>\n",
"</div>"
],
"text/plain": [
" 00 000 0000 00000 000000 000005102000 000062david42 000100255pixel \\\n",
"0 0 0 0 0 0 0 0 0 \n",
"1 1 1 0 0 0 0 0 0 \n",
"\n",
" 00041032 0004136 ... zurbrin zurich zus zvi zwaartepunten zwak \\\n",
"0 0 0 ... 0 0 0 0 0 0 \n",
"1 0 0 ... 0 0 0 0 0 0 \n",
"\n",
" zwakke zware zwarte zyxel \n",
"0 0 0 0 0 \n",
"1 0 0 0 1 \n",
"\n",
"[2 rows x 23297 columns]"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"text_data = ['Hello world I love python', 'This is a great computer game! 00 000 zyxel']\n",
"# Векторизуем с помощью обученного CountVectorizer\n",
"X = vectorizer.transform(text_data)\n",
"CV_text_data=pd.DataFrame(X.toarray(), columns=vectorizer.get_feature_names_out())\n",
"CV_text_data\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "11dda58a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"hello : [-0.77069 0.12827 0.33137 0.0050893 -0.47605 -0.50116\n",
" 1.858 1.0624 -0.56511 0.13328 -0.41918 -0.14195\n",
" -2.8555 -0.57131 -0.13418 -0.44922 0.48591 -0.6479\n",
" -0.84238 0.61669 -0.19824 -0.57967 -0.65885 0.43928\n",
" -0.50473 ]\n",
"love : [-0.62645 -0.082389 0.070538 0.5782 -0.87199 -0.14816 2.2315\n",
" 0.98573 -1.3154 -0.34921 -0.8847 0.14585 -4.97 -0.73369\n",
" -0.94359 0.035859 -0.026733 -0.77538 -0.30014 0.48853 -0.16678\n",
" -0.016651 -0.53164 0.64236 -0.10922 ]\n",
"python : [-0.25645 -0.22323 0.025901 0.22901 0.49028 -0.060829 0.24563\n",
" -0.84854 1.5882 -0.7274 0.60603 0.25205 -1.8064 -0.95526\n",
" 0.44867 0.013614 0.60856 0.65423 0.82506 0.99459 -0.29403\n",
" -0.27013 -0.348 -0.7293 0.2201 ]\n",
"world : [ 0.10301 0.095666 -0.14789 -0.22383 -0.14775 -0.11599 1.8513\n",
" 0.24886 -0.41877 -0.20384 -0.08509 0.33246 -4.6946 0.84096\n",
" -0.46666 -0.031128 -0.19539 -0.037349 0.58949 0.13941 -0.57667\n",
" -0.44426 -0.43085 -0.52875 0.25855 ]\n",
"Hello world I love python : [ -1.55058002 -0.081683 0.27991899 0.58846928 -1.00551002\n",
" -0.82613902 6.18642995 1.44844997 -0.71108004 -1.14717001\n",
" -0.78294002 0.58841 -14.32649982 -1.41929996 -1.09575997\n",
" -0.430875 0.87234702 -0.806399 0.27203003 2.23921998\n",
" -1.23571999 -1.31071102 -1.96934 -0.17641005 -0.1353 ]\n",
"computer : [ 0.64005 -0.019514 0.70148 -0.66123 1.1723 -0.58859 0.25917\n",
" -0.81541 1.1708 1.1413 -0.15405 -0.11369 -3.8414 -0.87233\n",
" 0.47489 1.1541 0.97678 1.1107 -0.14572 -0.52013 -0.52234\n",
" -0.92349 0.34651 0.061939 -0.57375 ]\n",
"game : [ 1.146 0.3291 0.26878 -1.3945 -0.30044 0.77901 1.3537\n",
" 0.37393 0.50478 -0.44266 -0.048706 0.51396 -4.3136 0.39805\n",
" 1.197 0.10287 -0.17618 -1.2881 -0.59801 0.26131 -1.2619\n",
" 0.39202 0.59309 -0.55232 0.005087]\n",
"great : [-8.4229e-01 3.6512e-01 -3.8841e-01 -4.6118e-01 2.4301e-01 3.2412e-01\n",
" 1.9009e+00 -2.2630e-01 -3.1335e-01 -1.0970e+00 -4.1494e-03 6.2074e-01\n",
" -5.0964e+00 6.7418e-01 5.0080e-01 -6.2119e-01 5.1765e-01 -4.4122e-01\n",
" -1.4364e-01 1.9130e-01 -7.4608e-01 -2.5903e-01 -7.8010e-01 1.1030e-01\n",
" -2.7928e-01]\n",
"zyxel : [ 0.79234 0.067376 -0.22639 -2.2272 0.30057 -0.85676 -1.7268\n",
" -0.78626 1.2042 -0.92348 -0.83987 -0.74233 0.29689 -1.208\n",
" 0.98706 -1.1624 0.61415 -0.27825 0.27813 1.5838 -0.63593\n",
" -0.10225 1.7102 -0.95599 -1.3867 ]\n",
"This is a great computer game! 00 000 zyxel : [ 1.73610002 0.74208201 0.35545996 -4.74411008 1.41543998\n",
" -0.34222007 1.78697008 -1.45404002 2.56643 -1.32184002\n",
" -1.04677537 0.27867999 -12.95450976 -1.00809997 3.15975004\n",
" -0.52662008 1.93239999 -0.89686999 -0.60924001 1.51628\n",
" -3.16624993 -0.89275002 1.86969995 -1.33607102 -2.23464306]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\Андрей\\AppData\\Local\\Temp\\ipykernel_8524\\2010506005.py:17: FutureWarning: The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.\n",
" glove_data=glove_data.append(pd.DataFrame([one_doc]))\n"
]
}
],
"source": [
"# Создадим датафрейм, в который будем сохранять вектор документа\n",
"glove_data=pd.DataFrame()\n",
"\n",
"# Пробегаем по каждой строке датафрейма (по каждому документу)\n",
"for i in range(CV_text_data.shape[0]):\n",
" \n",
" # Вектор одного документа с размерностью glove-модели:\n",
" one_doc = np.zeros(25) \n",
" \n",
" # Пробегаемся по каждому документу, смотрим, какие слова документа присутствуют в нашем словаре\n",
" # Суммируем glove-вектора каждого известного слова в one_doc\n",
" for word in words_vocab[CV_text_data.iloc[i,:] >= 1]:\n",
" if word in glove_model.key_to_index.keys(): \n",
" print(word, ': ', glove_model[word])\n",
" one_doc += glove_model[word]\n",
" print(text_data[i], ': ', one_doc)\n",
" glove_data=glove_data.append(pd.DataFrame([one_doc])) \n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "ff68d8dc",
"metadata": {},
"outputs": [],
"source": [
"def text2vec(text_data):\n",
" \n",
" # Векторизуем с помощью обученного CountVectorizer\n",
" X = vectorizer.transform(text_data)\n",
" CV_text_data=pd.DataFrame(X.toarray(), columns=vectorizer.get_feature_names_out())\n",
" CV_text_data\n",
" # Создадим датафрейм, в который будем сохранять вектор документа\n",
" glove_data=pd.DataFrame()\n",
"\n",
" # Пробегаем по каждой строке (по каждому документу)\n",
" for i in range(CV_text_data.shape[0]):\n",
"\n",
" # Вектор одного документа с размерностью glove-модели:\n",
" one_doc = np.zeros(25) \n",
"\n",
" # Пробегаемся по каждому документу, смотрим, какие слова документа присутствуют в нашем словаре\n",
" # Суммируем glove-вектора каждого известного слова в one_doc\n",
" for word in words_vocab[CV_text_data.iloc[i,:] >= 1]:\n",
" if word in glove_model.key_to_index.keys(): \n",
" #print(word, ': ', glove_model[word])\n",
" one_doc += glove_model[word]\n",
" #print(text_data[i], ': ', one_doc)\n",
" glove_data = pd.concat([glove_data, pd.DataFrame([one_doc])], axis = 0)\n",
" #print('glove_data: ', glove_data)\n",
" return glove_data"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "b778776c",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>0</th>\n",
" <th>1</th>\n",
" <th>2</th>\n",
" <th>3</th>\n",
" <th>4</th>\n",
" <th>5</th>\n",
" <th>6</th>\n",
" <th>7</th>\n",
" <th>8</th>\n",
" <th>9</th>\n",
" <th>...</th>\n",
" <th>15</th>\n",
" <th>16</th>\n",
" <th>17</th>\n",
" <th>18</th>\n",
" <th>19</th>\n",
" <th>20</th>\n",
" <th>21</th>\n",
" <th>22</th>\n",
" <th>23</th>\n",
" <th>24</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>-1.55058</td>\n",
" <td>-0.081683</td>\n",
" <td>0.279919</td>\n",
" <td>0.588469</td>\n",
" <td>-1.00551</td>\n",
" <td>-0.826139</td>\n",
" <td>6.18643</td>\n",
" <td>1.44845</td>\n",
" <td>-0.71108</td>\n",
" <td>-1.14717</td>\n",
" <td>...</td>\n",
" <td>-0.430875</td>\n",
" <td>0.872347</td>\n",
" <td>-0.806399</td>\n",
" <td>0.27203</td>\n",
" <td>2.23922</td>\n",
" <td>-1.23572</td>\n",
" <td>-1.310711</td>\n",
" <td>-1.96934</td>\n",
" <td>-0.176410</td>\n",
" <td>-0.135300</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1.73610</td>\n",
" <td>0.742082</td>\n",
" <td>0.355460</td>\n",
" <td>-4.744110</td>\n",
" <td>1.41544</td>\n",
" <td>-0.342220</td>\n",
" <td>1.78697</td>\n",
" <td>-1.45404</td>\n",
" <td>2.56643</td>\n",
" <td>-1.32184</td>\n",
" <td>...</td>\n",
" <td>-0.526620</td>\n",
" <td>1.932400</td>\n",
" <td>-0.896870</td>\n",
" <td>-0.60924</td>\n",
" <td>1.51628</td>\n",
" <td>-3.16625</td>\n",
" <td>-0.892750</td>\n",
" <td>1.86970</td>\n",
" <td>-1.336071</td>\n",
" <td>-2.234643</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>2 rows × 25 columns</p>\n",
"</div>"
],
"text/plain": [
" 0 1 2 3 4 5 6 7 \\\n",
"0 -1.55058 -0.081683 0.279919 0.588469 -1.00551 -0.826139 6.18643 1.44845 \n",
"0 1.73610 0.742082 0.355460 -4.744110 1.41544 -0.342220 1.78697 -1.45404 \n",
"\n",
" 8 9 ... 15 16 17 18 19 \\\n",
"0 -0.71108 -1.14717 ... -0.430875 0.872347 -0.806399 0.27203 2.23922 \n",
"0 2.56643 -1.32184 ... -0.526620 1.932400 -0.896870 -0.60924 1.51628 \n",
"\n",
" 20 21 22 23 24 \n",
"0 -1.23572 -1.310711 -1.96934 -0.176410 -0.135300 \n",
"0 -3.16625 -0.892750 1.86970 -1.336071 -2.234643 \n",
"\n",
"[2 rows x 25 columns]"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\n",
"glove_data\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "1bdb459e",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>0</th>\n",
" <th>1</th>\n",
" <th>2</th>\n",
" <th>3</th>\n",
" <th>4</th>\n",
" <th>5</th>\n",
" <th>6</th>\n",
" <th>7</th>\n",
" <th>8</th>\n",
" <th>9</th>\n",
" <th>...</th>\n",
" <th>15</th>\n",
" <th>16</th>\n",
" <th>17</th>\n",
" <th>18</th>\n",
" <th>19</th>\n",
" <th>20</th>\n",
" <th>21</th>\n",
" <th>22</th>\n",
" <th>23</th>\n",
" <th>24</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>-8.521142</td>\n",
" <td>2.020376</td>\n",
" <td>-10.802921</td>\n",
" <td>3.167636</td>\n",
" <td>0.252469</td>\n",
" <td>15.544048</td>\n",
" <td>17.631184</td>\n",
" <td>-32.581192</td>\n",
" <td>9.696540</td>\n",
" <td>-11.103087</td>\n",
" <td>...</td>\n",
" <td>2.810453</td>\n",
" <td>7.900215</td>\n",
" <td>0.962129</td>\n",
" <td>17.691130</td>\n",
" <td>-1.252574</td>\n",
" <td>-10.098049</td>\n",
" <td>0.500113</td>\n",
" <td>1.348694</td>\n",
" <td>2.186150</td>\n",
" <td>-16.556824</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>6.576228</td>\n",
" <td>20.336350</td>\n",
" <td>-32.675150</td>\n",
" <td>-9.073872</td>\n",
" <td>17.515655</td>\n",
" <td>-6.488794</td>\n",
" <td>59.458419</td>\n",
" <td>-75.384298</td>\n",
" <td>13.323775</td>\n",
" <td>-14.443218</td>\n",
" <td>...</td>\n",
" <td>21.407738</td>\n",
" <td>23.525118</td>\n",
" <td>0.325680</td>\n",
" <td>19.871444</td>\n",
" <td>-27.585188</td>\n",
" <td>-4.559155</td>\n",
" <td>-7.417482</td>\n",
" <td>-16.694553</td>\n",
" <td>-0.197711</td>\n",
" <td>-58.948193</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1.329914</td>\n",
" <td>3.060870</td>\n",
" <td>-1.868484</td>\n",
" <td>1.392735</td>\n",
" <td>-1.335277</td>\n",
" <td>-5.014955</td>\n",
" <td>12.859476</td>\n",
" <td>-9.978156</td>\n",
" <td>-0.869613</td>\n",
" <td>-2.031490</td>\n",
" <td>...</td>\n",
" <td>2.925134</td>\n",
" <td>2.872930</td>\n",
" <td>2.184486</td>\n",
" <td>3.831770</td>\n",
" <td>-0.877866</td>\n",
" <td>-0.927770</td>\n",
" <td>0.700101</td>\n",
" <td>-9.855365</td>\n",
" <td>-5.419429</td>\n",
" <td>-2.279330</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>-4.866150</td>\n",
" <td>-0.273176</td>\n",
" <td>3.515124</td>\n",
" <td>-5.008165</td>\n",
" <td>-1.236789</td>\n",
" <td>-7.951168</td>\n",
" <td>-11.015882</td>\n",
" <td>-3.496241</td>\n",
" <td>16.024286</td>\n",
" <td>-9.388742</td>\n",
" <td>...</td>\n",
" <td>-0.471141</td>\n",
" <td>3.575378</td>\n",
" <td>6.193222</td>\n",
" <td>0.349430</td>\n",
" <td>15.040248</td>\n",
" <td>-10.369132</td>\n",
" <td>-0.848717</td>\n",
" <td>-0.564796</td>\n",
" <td>-1.114126</td>\n",
" <td>-7.844431</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>-3.115007</td>\n",
" <td>-1.805252</td>\n",
" <td>-5.419340</td>\n",
" <td>-0.393406</td>\n",
" <td>-0.406461</td>\n",
" <td>-2.724340</td>\n",
" <td>7.898330</td>\n",
" <td>-15.619113</td>\n",
" <td>0.231822</td>\n",
" <td>-3.628156</td>\n",
" <td>...</td>\n",
" <td>5.944151</td>\n",
" <td>8.309932</td>\n",
" <td>-0.656084</td>\n",
" <td>12.178709</td>\n",
" <td>-6.118551</td>\n",
" <td>-3.286376</td>\n",
" <td>3.450946</td>\n",
" <td>2.055343</td>\n",
" <td>0.463787</td>\n",
" <td>-12.644626</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>-0.930954</td>\n",
" <td>4.974043</td>\n",
" <td>-8.147008</td>\n",
" <td>-5.147130</td>\n",
" <td>3.960455</td>\n",
" <td>-1.344022</td>\n",
" <td>7.818063</td>\n",
" <td>-25.427420</td>\n",
" <td>4.624732</td>\n",
" <td>-7.218097</td>\n",
" <td>...</td>\n",
" <td>3.623038</td>\n",
" <td>4.453189</td>\n",
" <td>2.405320</td>\n",
" <td>8.032963</td>\n",
" <td>-8.029539</td>\n",
" <td>0.838867</td>\n",
" <td>-4.757457</td>\n",
" <td>-5.755052</td>\n",
" <td>-9.496197</td>\n",
" <td>-21.542710</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>-0.770690</td>\n",
" <td>0.128270</td>\n",
" <td>0.331370</td>\n",
" <td>0.005089</td>\n",
" <td>-0.476050</td>\n",
" <td>-0.501160</td>\n",
" <td>1.858000</td>\n",
" <td>1.062400</td>\n",
" <td>-0.565110</td>\n",
" <td>0.133280</td>\n",
" <td>...</td>\n",
" <td>-0.449220</td>\n",
" <td>0.485910</td>\n",
" <td>-0.647900</td>\n",
" <td>-0.842380</td>\n",
" <td>0.616690</td>\n",
" <td>-0.198240</td>\n",
" <td>-0.579670</td>\n",
" <td>-0.658850</td>\n",
" <td>0.439280</td>\n",
" <td>-0.504730</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1.491177</td>\n",
" <td>6.992638</td>\n",
" <td>-7.921970</td>\n",
" <td>-7.157521</td>\n",
" <td>6.641657</td>\n",
" <td>-2.958020</td>\n",
" <td>12.820770</td>\n",
" <td>-18.502946</td>\n",
" <td>6.838083</td>\n",
" <td>-2.717310</td>\n",
" <td>...</td>\n",
" <td>-1.344873</td>\n",
" <td>4.170405</td>\n",
" <td>-0.178030</td>\n",
" <td>5.699992</td>\n",
" <td>-7.295038</td>\n",
" <td>-3.683306</td>\n",
" <td>-2.718006</td>\n",
" <td>-0.117608</td>\n",
" <td>-7.205832</td>\n",
" <td>-13.863438</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>2.523770</td>\n",
" <td>5.817394</td>\n",
" <td>2.184340</td>\n",
" <td>-2.996497</td>\n",
" <td>-0.267181</td>\n",
" <td>-10.059634</td>\n",
" <td>6.344402</td>\n",
" <td>-2.047127</td>\n",
" <td>2.679123</td>\n",
" <td>-7.642505</td>\n",
" <td>...</td>\n",
" <td>-1.230296</td>\n",
" <td>1.409746</td>\n",
" <td>-3.322040</td>\n",
" <td>-5.068259</td>\n",
" <td>-0.648718</td>\n",
" <td>0.753010</td>\n",
" <td>-6.220990</td>\n",
" <td>-5.012004</td>\n",
" <td>-1.518542</td>\n",
" <td>-10.156440</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>-0.118691</td>\n",
" <td>11.860546</td>\n",
" <td>-2.567264</td>\n",
" <td>-10.955913</td>\n",
" <td>-4.239322</td>\n",
" <td>-9.340552</td>\n",
" <td>21.189778</td>\n",
" <td>-10.895375</td>\n",
" <td>2.659030</td>\n",
" <td>-3.848115</td>\n",
" <td>...</td>\n",
" <td>0.726191</td>\n",
" <td>11.634998</td>\n",
" <td>-5.447248</td>\n",
" <td>1.293007</td>\n",
" <td>-7.882002</td>\n",
" <td>-2.527453</td>\n",
" <td>0.298939</td>\n",
" <td>-6.107062</td>\n",
" <td>3.365051</td>\n",
" <td>-15.641826</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1657 rows × 25 columns</p>\n",
"</div>"
],
"text/plain": [
" 0 1 2 3 4 5 \\\n",
"0 -8.521142 2.020376 -10.802921 3.167636 0.252469 15.544048 \n",
"0 6.576228 20.336350 -32.675150 -9.073872 17.515655 -6.488794 \n",
"0 1.329914 3.060870 -1.868484 1.392735 -1.335277 -5.014955 \n",
"0 -4.866150 -0.273176 3.515124 -5.008165 -1.236789 -7.951168 \n",
"0 -3.115007 -1.805252 -5.419340 -0.393406 -0.406461 -2.724340 \n",
".. ... ... ... ... ... ... \n",
"0 -0.930954 4.974043 -8.147008 -5.147130 3.960455 -1.344022 \n",
"0 -0.770690 0.128270 0.331370 0.005089 -0.476050 -0.501160 \n",
"0 1.491177 6.992638 -7.921970 -7.157521 6.641657 -2.958020 \n",
"0 2.523770 5.817394 2.184340 -2.996497 -0.267181 -10.059634 \n",
"0 -0.118691 11.860546 -2.567264 -10.955913 -4.239322 -9.340552 \n",
"\n",
" 6 7 8 9 ... 15 16 \\\n",
"0 17.631184 -32.581192 9.696540 -11.103087 ... 2.810453 7.900215 \n",
"0 59.458419 -75.384298 13.323775 -14.443218 ... 21.407738 23.525118 \n",
"0 12.859476 -9.978156 -0.869613 -2.031490 ... 2.925134 2.872930 \n",
"0 -11.015882 -3.496241 16.024286 -9.388742 ... -0.471141 3.575378 \n",
"0 7.898330 -15.619113 0.231822 -3.628156 ... 5.944151 8.309932 \n",
".. ... ... ... ... ... ... ... \n",
"0 7.818063 -25.427420 4.624732 -7.218097 ... 3.623038 4.453189 \n",
"0 1.858000 1.062400 -0.565110 0.133280 ... -0.449220 0.485910 \n",
"0 12.820770 -18.502946 6.838083 -2.717310 ... -1.344873 4.170405 \n",
"0 6.344402 -2.047127 2.679123 -7.642505 ... -1.230296 1.409746 \n",
"0 21.189778 -10.895375 2.659030 -3.848115 ... 0.726191 11.634998 \n",
"\n",
" 17 18 19 20 21 22 23 \\\n",
"0 0.962129 17.691130 -1.252574 -10.098049 0.500113 1.348694 2.186150 \n",
"0 0.325680 19.871444 -27.585188 -4.559155 -7.417482 -16.694553 -0.197711 \n",
"0 2.184486 3.831770 -0.877866 -0.927770 0.700101 -9.855365 -5.419429 \n",
"0 6.193222 0.349430 15.040248 -10.369132 -0.848717 -0.564796 -1.114126 \n",
"0 -0.656084 12.178709 -6.118551 -3.286376 3.450946 2.055343 0.463787 \n",
".. ... ... ... ... ... ... ... \n",
"0 2.405320 8.032963 -8.029539 0.838867 -4.757457 -5.755052 -9.496197 \n",
"0 -0.647900 -0.842380 0.616690 -0.198240 -0.579670 -0.658850 0.439280 \n",
"0 -0.178030 5.699992 -7.295038 -3.683306 -2.718006 -0.117608 -7.205832 \n",
"0 -3.322040 -5.068259 -0.648718 0.753010 -6.220990 -5.012004 -1.518542 \n",
"0 -5.447248 1.293007 -7.882002 -2.527453 0.298939 -6.107062 3.365051 \n",
"\n",
" 24 \n",
"0 -16.556824 \n",
"0 -58.948193 \n",
"0 -2.279330 \n",
"0 -7.844431 \n",
"0 -12.644626 \n",
".. ... \n",
"0 -21.542710 \n",
"0 -0.504730 \n",
"0 -13.863438 \n",
"0 -10.156440 \n",
"0 -15.641826 \n",
"\n",
"[1657 rows x 25 columns]"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_data_glove = text2vec(twenty_train['data']);\n",
"train_data_glove"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "5ac20e79",
"metadata": {},
"outputs": [],
"source": [
"clf = KNeighborsClassifier(n_neighbors = 5)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "08164a25",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style>#sk-container-id-1 {color: black;background-color: white;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>KNeighborsClassifier()</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">KNeighborsClassifier</label><div class=\"sk-toggleable__content\"><pre>KNeighborsClassifier()</pre></div></div></div></div></div>"
],
"text/plain": [
"KNeighborsClassifier()"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"clf.fit(train_data_glove, twenty_train['target'])"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "e459faaf",
"metadata": {},
"outputs": [],
"source": [
"test_data_glove = text2vec(twenty_test['data']);"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "d8144e75",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>0</th>\n",
" <th>1</th>\n",
" <th>2</th>\n",
" <th>3</th>\n",
" <th>4</th>\n",
" <th>5</th>\n",
" <th>6</th>\n",
" <th>7</th>\n",
" <th>8</th>\n",
" <th>9</th>\n",
" <th>...</th>\n",
" <th>15</th>\n",
" <th>16</th>\n",
" <th>17</th>\n",
" <th>18</th>\n",
" <th>19</th>\n",
" <th>20</th>\n",
" <th>21</th>\n",
" <th>22</th>\n",
" <th>23</th>\n",
" <th>24</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>-6.760635</td>\n",
" <td>5.063863</td>\n",
" <td>-2.779060</td>\n",
" <td>3.699120</td>\n",
" <td>-2.858086</td>\n",
" <td>0.135230</td>\n",
" <td>20.811229</td>\n",
" <td>-19.425567</td>\n",
" <td>7.302950</td>\n",
" <td>-5.826012</td>\n",
" <td>...</td>\n",
" <td>3.833378</td>\n",
" <td>6.794452</td>\n",
" <td>-0.921720</td>\n",
" <td>12.187404</td>\n",
" <td>-5.547615</td>\n",
" <td>-4.133999</td>\n",
" <td>3.588260</td>\n",
" <td>-0.497106</td>\n",
" <td>-2.542142</td>\n",
" <td>-11.362855</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1.632616</td>\n",
" <td>2.512300</td>\n",
" <td>-0.745513</td>\n",
" <td>-3.081154</td>\n",
" <td>2.182067</td>\n",
" <td>-1.988816</td>\n",
" <td>7.533100</td>\n",
" <td>-1.015740</td>\n",
" <td>-0.829598</td>\n",
" <td>-2.764237</td>\n",
" <td>...</td>\n",
" <td>0.791851</td>\n",
" <td>2.114150</td>\n",
" <td>-2.249193</td>\n",
" <td>-0.163590</td>\n",
" <td>-1.177710</td>\n",
" <td>-2.496928</td>\n",
" <td>-5.074085</td>\n",
" <td>-2.666947</td>\n",
" <td>0.662050</td>\n",
" <td>-3.590550</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>2.115766</td>\n",
" <td>2.142060</td>\n",
" <td>-0.445607</td>\n",
" <td>-3.229030</td>\n",
" <td>1.154580</td>\n",
" <td>-2.877278</td>\n",
" <td>6.399954</td>\n",
" <td>-10.445769</td>\n",
" <td>2.230760</td>\n",
" <td>-3.299899</td>\n",
" <td>...</td>\n",
" <td>4.388870</td>\n",
" <td>8.515056</td>\n",
" <td>-0.766260</td>\n",
" <td>3.549431</td>\n",
" <td>-1.643443</td>\n",
" <td>-0.825730</td>\n",
" <td>-2.968016</td>\n",
" <td>-0.808924</td>\n",
" <td>-0.000160</td>\n",
" <td>-7.468189</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>-0.802784</td>\n",
" <td>5.199443</td>\n",
" <td>4.294071</td>\n",
" <td>-7.390966</td>\n",
" <td>2.747166</td>\n",
" <td>-1.359952</td>\n",
" <td>15.032628</td>\n",
" <td>-1.601590</td>\n",
" <td>1.474406</td>\n",
" <td>2.570105</td>\n",
" <td>...</td>\n",
" <td>3.043432</td>\n",
" <td>6.176236</td>\n",
" <td>-6.193988</td>\n",
" <td>-3.990476</td>\n",
" <td>-2.345854</td>\n",
" <td>-5.534376</td>\n",
" <td>-8.925422</td>\n",
" <td>1.553300</td>\n",
" <td>0.905790</td>\n",
" <td>-12.824533</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>29.926489</td>\n",
" <td>65.324993</td>\n",
" <td>-25.059592</td>\n",
" <td>-64.080130</td>\n",
" <td>77.565282</td>\n",
" <td>-34.614604</td>\n",
" <td>75.643770</td>\n",
" <td>-115.600859</td>\n",
" <td>90.847175</td>\n",
" <td>-42.971146</td>\n",
" <td>...</td>\n",
" <td>40.956031</td>\n",
" <td>50.322156</td>\n",
" <td>-19.537098</td>\n",
" <td>28.903925</td>\n",
" <td>-34.643949</td>\n",
" <td>-69.894146</td>\n",
" <td>-94.992145</td>\n",
" <td>-48.601895</td>\n",
" <td>-29.098555</td>\n",
" <td>-91.934770</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1.829235</td>\n",
" <td>4.513807</td>\n",
" <td>2.916520</td>\n",
" <td>2.237308</td>\n",
" <td>-1.704831</td>\n",
" <td>-1.811192</td>\n",
" <td>22.196895</td>\n",
" <td>-12.858912</td>\n",
" <td>-4.054810</td>\n",
" <td>-3.130457</td>\n",
" <td>...</td>\n",
" <td>6.019246</td>\n",
" <td>8.949456</td>\n",
" <td>-4.682214</td>\n",
" <td>-5.648911</td>\n",
" <td>-1.026898</td>\n",
" <td>3.719006</td>\n",
" <td>2.449941</td>\n",
" <td>-6.487197</td>\n",
" <td>1.340930</td>\n",
" <td>-7.325196</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>-0.963815</td>\n",
" <td>5.491164</td>\n",
" <td>3.567377</td>\n",
" <td>-6.048021</td>\n",
" <td>-5.059298</td>\n",
" <td>-0.977958</td>\n",
" <td>15.131499</td>\n",
" <td>-0.904470</td>\n",
" <td>2.185990</td>\n",
" <td>-1.459807</td>\n",
" <td>...</td>\n",
" <td>0.968499</td>\n",
" <td>4.725793</td>\n",
" <td>-0.726944</td>\n",
" <td>1.328612</td>\n",
" <td>-3.144209</td>\n",
" <td>1.643127</td>\n",
" <td>-1.259245</td>\n",
" <td>-0.880740</td>\n",
" <td>-6.713165</td>\n",
" <td>-3.115454</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>6.801324</td>\n",
" <td>15.348126</td>\n",
" <td>-17.051718</td>\n",
" <td>5.030998</td>\n",
" <td>9.332448</td>\n",
" <td>-5.716691</td>\n",
" <td>56.409175</td>\n",
" <td>-56.250411</td>\n",
" <td>-4.028209</td>\n",
" <td>-11.687558</td>\n",
" <td>...</td>\n",
" <td>22.884424</td>\n",
" <td>12.940570</td>\n",
" <td>1.058664</td>\n",
" <td>21.879058</td>\n",
" <td>-20.897253</td>\n",
" <td>2.537755</td>\n",
" <td>3.774890</td>\n",
" <td>-11.495336</td>\n",
" <td>-2.609774</td>\n",
" <td>-36.597559</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1.054090</td>\n",
" <td>0.764524</td>\n",
" <td>1.958340</td>\n",
" <td>-1.085245</td>\n",
" <td>-0.441392</td>\n",
" <td>-0.421970</td>\n",
" <td>6.139770</td>\n",
" <td>-0.612219</td>\n",
" <td>-2.251460</td>\n",
" <td>-0.465165</td>\n",
" <td>...</td>\n",
" <td>0.377958</td>\n",
" <td>1.957450</td>\n",
" <td>-1.705220</td>\n",
" <td>-0.509700</td>\n",
" <td>0.016110</td>\n",
" <td>1.461620</td>\n",
" <td>1.589069</td>\n",
" <td>2.267340</td>\n",
" <td>0.447919</td>\n",
" <td>-0.469250</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>-18.387286</td>\n",
" <td>13.274879</td>\n",
" <td>-7.895913</td>\n",
" <td>-1.831442</td>\n",
" <td>-10.424961</td>\n",
" <td>-12.248442</td>\n",
" <td>32.153890</td>\n",
" <td>-40.169293</td>\n",
" <td>13.089525</td>\n",
" <td>-21.306493</td>\n",
" <td>...</td>\n",
" <td>6.497279</td>\n",
" <td>8.340729</td>\n",
" <td>4.996109</td>\n",
" <td>23.442078</td>\n",
" <td>-3.701088</td>\n",
" <td>-11.671505</td>\n",
" <td>9.209790</td>\n",
" <td>-10.002501</td>\n",
" <td>-0.815266</td>\n",
" <td>-17.024052</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1102 rows × 25 columns</p>\n",
"</div>"
],
"text/plain": [
" 0 1 2 3 4 5 \\\n",
"0 -6.760635 5.063863 -2.779060 3.699120 -2.858086 0.135230 \n",
"0 1.632616 2.512300 -0.745513 -3.081154 2.182067 -1.988816 \n",
"0 2.115766 2.142060 -0.445607 -3.229030 1.154580 -2.877278 \n",
"0 -0.802784 5.199443 4.294071 -7.390966 2.747166 -1.359952 \n",
"0 29.926489 65.324993 -25.059592 -64.080130 77.565282 -34.614604 \n",
".. ... ... ... ... ... ... \n",
"0 1.829235 4.513807 2.916520 2.237308 -1.704831 -1.811192 \n",
"0 -0.963815 5.491164 3.567377 -6.048021 -5.059298 -0.977958 \n",
"0 6.801324 15.348126 -17.051718 5.030998 9.332448 -5.716691 \n",
"0 1.054090 0.764524 1.958340 -1.085245 -0.441392 -0.421970 \n",
"0 -18.387286 13.274879 -7.895913 -1.831442 -10.424961 -12.248442 \n",
"\n",
" 6 7 8 9 ... 15 16 \\\n",
"0 20.811229 -19.425567 7.302950 -5.826012 ... 3.833378 6.794452 \n",
"0 7.533100 -1.015740 -0.829598 -2.764237 ... 0.791851 2.114150 \n",
"0 6.399954 -10.445769 2.230760 -3.299899 ... 4.388870 8.515056 \n",
"0 15.032628 -1.601590 1.474406 2.570105 ... 3.043432 6.176236 \n",
"0 75.643770 -115.600859 90.847175 -42.971146 ... 40.956031 50.322156 \n",
".. ... ... ... ... ... ... ... \n",
"0 22.196895 -12.858912 -4.054810 -3.130457 ... 6.019246 8.949456 \n",
"0 15.131499 -0.904470 2.185990 -1.459807 ... 0.968499 4.725793 \n",
"0 56.409175 -56.250411 -4.028209 -11.687558 ... 22.884424 12.940570 \n",
"0 6.139770 -0.612219 -2.251460 -0.465165 ... 0.377958 1.957450 \n",
"0 32.153890 -40.169293 13.089525 -21.306493 ... 6.497279 8.340729 \n",
"\n",
" 17 18 19 20 21 22 \\\n",
"0 -0.921720 12.187404 -5.547615 -4.133999 3.588260 -0.497106 \n",
"0 -2.249193 -0.163590 -1.177710 -2.496928 -5.074085 -2.666947 \n",
"0 -0.766260 3.549431 -1.643443 -0.825730 -2.968016 -0.808924 \n",
"0 -6.193988 -3.990476 -2.345854 -5.534376 -8.925422 1.553300 \n",
"0 -19.537098 28.903925 -34.643949 -69.894146 -94.992145 -48.601895 \n",
".. ... ... ... ... ... ... \n",
"0 -4.682214 -5.648911 -1.026898 3.719006 2.449941 -6.487197 \n",
"0 -0.726944 1.328612 -3.144209 1.643127 -1.259245 -0.880740 \n",
"0 1.058664 21.879058 -20.897253 2.537755 3.774890 -11.495336 \n",
"0 -1.705220 -0.509700 0.016110 1.461620 1.589069 2.267340 \n",
"0 4.996109 23.442078 -3.701088 -11.671505 9.209790 -10.002501 \n",
"\n",
" 23 24 \n",
"0 -2.542142 -11.362855 \n",
"0 0.662050 -3.590550 \n",
"0 -0.000160 -7.468189 \n",
"0 0.905790 -12.824533 \n",
"0 -29.098555 -91.934770 \n",
".. ... ... \n",
"0 1.340930 -7.325196 \n",
"0 -6.713165 -3.115454 \n",
"0 -2.609774 -36.597559 \n",
"0 0.447919 -0.469250 \n",
"0 -0.815266 -17.024052 \n",
"\n",
"[1102 rows x 25 columns]"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_data_glove"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "a69830f0",
"metadata": {},
"outputs": [],
"source": [
"predict = clf.predict(test_data_glove )"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "9ac5cf20",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[225 35 59]\n",
" [ 26 313 50]\n",
" [ 56 98 240]]\n",
" precision recall f1-score support\n",
"\n",
" 0 0.73 0.71 0.72 319\n",
" 1 0.70 0.80 0.75 389\n",
" 2 0.69 0.61 0.65 394\n",
"\n",
" accuracy 0.71 1102\n",
" macro avg 0.71 0.71 0.70 1102\n",
"weighted avg 0.71 0.71 0.70 1102\n",
"\n"
]
}
],
"source": [
"print (confusion_matrix(twenty_test['target'], predict))\n",
"print(classification_report(twenty_test['target'], predict))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b8cce5a9",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "1b9bff90",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}