{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "3dda6a69", "metadata": {}, "outputs": [], "source": [ "from sklearn.neighbors import KNeighborsClassifier\n", "from sklearn.datasets import fetch_20newsgroups\n", "from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer \n", "import numpy as np\n", "import pandas as pd\n", "from sklearn.metrics import confusion_matrix, classification_report, accuracy_score, roc_auc_score\n", "from sklearn.pipeline import Pipeline" ] }, { "cell_type": "code", "execution_count": 2, "id": "7fd6636b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['fasttext-wiki-news-subwords-300', 'conceptnet-numberbatch-17-06-300', 'word2vec-ruscorpora-300', 'word2vec-google-news-300', 'glove-wiki-gigaword-50', 'glove-wiki-gigaword-100', 'glove-wiki-gigaword-200', 'glove-wiki-gigaword-300', 'glove-twitter-25', 'glove-twitter-50', 'glove-twitter-100', 'glove-twitter-200', '__testing_word2vec-matrix-synopsis']\n" ] } ], "source": [ "import gensim.downloader\n", "print(list(gensim.downloader.info()['models'].keys()))" ] }, { "cell_type": "markdown", "id": "3f93b5f6", "metadata": {}, "source": [ "# GloVe" ] }, { "cell_type": "code", "execution_count": 3, "id": "be870586", "metadata": {}, "outputs": [], "source": [ "glove_model = gensim.downloader.load(\"glove-twitter-25\") # load glove vectors\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "599d6406", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[-0.96419 -0.60978 0.67449 0.35113 0.41317 -0.21241 1.3796\n", " 0.12854 0.31567 0.66325 0.3391 -0.18934 -3.325 -1.1491\n", " -0.4129 0.2195 0.8706 -0.50616 -0.12781 -0.066965 0.065761\n", " 0.43927 0.1758 -0.56058 0.13529 ]\n" ] }, { "data": { "text/plain": [ "[('dog', 0.9590820074081421),\n", " ('monkey', 0.920357882976532),\n", " ('bear', 0.9143136739730835),\n", " ('pet', 0.9108031392097473),\n", " ('girl', 0.8880629539489746),\n", " ('horse', 0.8872726559638977),\n", " ('kitty', 0.8870542049407959),\n", " ('puppy', 0.886769711971283),\n", " ('hot', 0.886525571346283),\n", " ('lady', 0.8845519423484802)]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(glove_model['cat']) # word embedding for 'cat'\n", "glove_model.most_similar(\"cat\") # show words that similar to word 'cat'" ] }, { "cell_type": "code", "execution_count": 5, "id": "2db71cfb", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.60927683" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "glove_model.similarity('cat', 'bus')" ] }, { "cell_type": "code", "execution_count": 6, "id": "7788acf5", "metadata": {}, "outputs": [], "source": [ "categories = ['alt.atheism', 'comp.graphics', 'sci.space'] \n", "remove = ('headers', 'footers', 'quotes')\n", "twenty_train = fetch_20newsgroups(subset='train', shuffle=True, random_state=42, categories = categories, remove = remove )\n", "twenty_test = fetch_20newsgroups(subset='test', shuffle=True, random_state=42, categories = categories, remove = remove )\n" ] }, { "cell_type": "markdown", "id": "79dd1ac1", "metadata": {}, "source": [ "# Векторизуем обучающую выборку\n", "Получаем матрицу \"Документ-термин\"" ] }, { "cell_type": "code", "execution_count": 7, "id": "0565dd1a", "metadata": {}, "outputs": [], "source": [ "vectorizer = CountVectorizer(stop_words='english')" ] }, { "cell_type": "code", "execution_count": 8, "id": "a681a1d6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(1657, 23297)\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
00000000000000000000000005102000000062david42000100255pixel000410320004136...zurbrinzurichzuszvizwaartepuntenzwakzwakkezwarezwartezyxel
00000000000...0000000000
10000000000...0000000000
20000000000...0000000000
30000000000...0000000000
40000000000...0000000000
\n", "

5 rows × 23297 columns

\n", "
" ], "text/plain": [ " 00 000 0000 00000 000000 000005102000 000062david42 000100255pixel \\\n", "0 0 0 0 0 0 0 0 0 \n", "1 0 0 0 0 0 0 0 0 \n", "2 0 0 0 0 0 0 0 0 \n", "3 0 0 0 0 0 0 0 0 \n", "4 0 0 0 0 0 0 0 0 \n", "\n", " 00041032 0004136 ... zurbrin zurich zus zvi zwaartepunten zwak \\\n", "0 0 0 ... 0 0 0 0 0 0 \n", "1 0 0 ... 0 0 0 0 0 0 \n", "2 0 0 ... 0 0 0 0 0 0 \n", "3 0 0 ... 0 0 0 0 0 0 \n", "4 0 0 ... 0 0 0 0 0 0 \n", "\n", " zwakke zware zwarte zyxel \n", "0 0 0 0 0 \n", "1 0 0 0 0 \n", "2 0 0 0 0 \n", "3 0 0 0 0 \n", "4 0 0 0 0 \n", "\n", "[5 rows x 23297 columns]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_data = vectorizer.fit_transform(twenty_train['data'])\n", "CV_data=pd.DataFrame(train_data.toarray(), columns=vectorizer.get_feature_names_out())\n", "print(CV_data.shape)\n", "CV_data.head()" ] }, { "cell_type": "code", "execution_count": 9, "id": "b20aef46", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Index(['00', '000', '0000', '00000', '000000', '000005102000', '000062david42',\n", " '000100255pixel', '00041032', '0004136'],\n", " dtype='object')\n" ] } ], "source": [ "# Создадим список слов, присутствующих в словаре.\n", "words_vocab=CV_data.columns\n", "print(words_vocab[0:10])" ] }, { "cell_type": "markdown", "id": "d1893e86", "metadata": {}, "source": [ "## Векторизуем с помощью GloVe\n", "\n", "Нужно для каждого документа сложить glove-вектора слов, из которых он состоит.\n", "В результате получим вектор документа как сумму векторов слов, из него состоящих" ] }, { "cell_type": "markdown", "id": "bc36b98d", "metadata": {}, "source": [ "### Посмотрим на примере как будет работать векторизация" ] }, { "cell_type": "code", "execution_count": 10, "id": "0d6af65a", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
00000000000000000000000005102000000062david42000100255pixel000410320004136...zurbrinzurichzuszvizwaartepuntenzwakzwakkezwarezwartezyxel
00000000000...0000000000
11100000000...0000000001
\n", "

2 rows × 23297 columns

\n", "
" ], "text/plain": [ " 00 000 0000 00000 000000 000005102000 000062david42 000100255pixel \\\n", "0 0 0 0 0 0 0 0 0 \n", "1 1 1 0 0 0 0 0 0 \n", "\n", " 00041032 0004136 ... zurbrin zurich zus zvi zwaartepunten zwak \\\n", "0 0 0 ... 0 0 0 0 0 0 \n", "1 0 0 ... 0 0 0 0 0 0 \n", "\n", " zwakke zware zwarte zyxel \n", "0 0 0 0 0 \n", "1 0 0 0 1 \n", "\n", "[2 rows x 23297 columns]" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "text_data = ['Hello world I love python', 'This is a great computer game! 00 000 zyxel']\n", "# Векторизуем с помощью обученного CountVectorizer\n", "X = vectorizer.transform(text_data)\n", "CV_text_data=pd.DataFrame(X.toarray(), columns=vectorizer.get_feature_names_out())\n", "CV_text_data\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "11dda58a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "hello : [-0.77069 0.12827 0.33137 0.0050893 -0.47605 -0.50116\n", " 1.858 1.0624 -0.56511 0.13328 -0.41918 -0.14195\n", " -2.8555 -0.57131 -0.13418 -0.44922 0.48591 -0.6479\n", " -0.84238 0.61669 -0.19824 -0.57967 -0.65885 0.43928\n", " -0.50473 ]\n", "love : [-0.62645 -0.082389 0.070538 0.5782 -0.87199 -0.14816 2.2315\n", " 0.98573 -1.3154 -0.34921 -0.8847 0.14585 -4.97 -0.73369\n", " -0.94359 0.035859 -0.026733 -0.77538 -0.30014 0.48853 -0.16678\n", " -0.016651 -0.53164 0.64236 -0.10922 ]\n", "python : [-0.25645 -0.22323 0.025901 0.22901 0.49028 -0.060829 0.24563\n", " -0.84854 1.5882 -0.7274 0.60603 0.25205 -1.8064 -0.95526\n", " 0.44867 0.013614 0.60856 0.65423 0.82506 0.99459 -0.29403\n", " -0.27013 -0.348 -0.7293 0.2201 ]\n", "world : [ 0.10301 0.095666 -0.14789 -0.22383 -0.14775 -0.11599 1.8513\n", " 0.24886 -0.41877 -0.20384 -0.08509 0.33246 -4.6946 0.84096\n", " -0.46666 -0.031128 -0.19539 -0.037349 0.58949 0.13941 -0.57667\n", " -0.44426 -0.43085 -0.52875 0.25855 ]\n", "Hello world I love python : [ -1.55058002 -0.081683 0.27991899 0.58846928 -1.00551002\n", " -0.82613902 6.18642995 1.44844997 -0.71108004 -1.14717001\n", " -0.78294002 0.58841 -14.32649982 -1.41929996 -1.09575997\n", " -0.430875 0.87234702 -0.806399 0.27203003 2.23921998\n", " -1.23571999 -1.31071102 -1.96934 -0.17641005 -0.1353 ]\n", "computer : [ 0.64005 -0.019514 0.70148 -0.66123 1.1723 -0.58859 0.25917\n", " -0.81541 1.1708 1.1413 -0.15405 -0.11369 -3.8414 -0.87233\n", " 0.47489 1.1541 0.97678 1.1107 -0.14572 -0.52013 -0.52234\n", " -0.92349 0.34651 0.061939 -0.57375 ]\n", "game : [ 1.146 0.3291 0.26878 -1.3945 -0.30044 0.77901 1.3537\n", " 0.37393 0.50478 -0.44266 -0.048706 0.51396 -4.3136 0.39805\n", " 1.197 0.10287 -0.17618 -1.2881 -0.59801 0.26131 -1.2619\n", " 0.39202 0.59309 -0.55232 0.005087]\n", "great : [-8.4229e-01 3.6512e-01 -3.8841e-01 -4.6118e-01 2.4301e-01 3.2412e-01\n", " 1.9009e+00 -2.2630e-01 -3.1335e-01 -1.0970e+00 -4.1494e-03 6.2074e-01\n", " -5.0964e+00 6.7418e-01 5.0080e-01 -6.2119e-01 5.1765e-01 -4.4122e-01\n", " -1.4364e-01 1.9130e-01 -7.4608e-01 -2.5903e-01 -7.8010e-01 1.1030e-01\n", " -2.7928e-01]\n", "zyxel : [ 0.79234 0.067376 -0.22639 -2.2272 0.30057 -0.85676 -1.7268\n", " -0.78626 1.2042 -0.92348 -0.83987 -0.74233 0.29689 -1.208\n", " 0.98706 -1.1624 0.61415 -0.27825 0.27813 1.5838 -0.63593\n", " -0.10225 1.7102 -0.95599 -1.3867 ]\n", "This is a great computer game! 00 000 zyxel : [ 1.73610002 0.74208201 0.35545996 -4.74411008 1.41543998\n", " -0.34222007 1.78697008 -1.45404002 2.56643 -1.32184002\n", " -1.04677537 0.27867999 -12.95450976 -1.00809997 3.15975004\n", " -0.52662008 1.93239999 -0.89686999 -0.60924001 1.51628\n", " -3.16624993 -0.89275002 1.86969995 -1.33607102 -2.23464306]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\Андрей\\AppData\\Local\\Temp\\ipykernel_8524\\2010506005.py:17: FutureWarning: The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.\n", " glove_data=glove_data.append(pd.DataFrame([one_doc]))\n" ] } ], "source": [ "# Создадим датафрейм, в который будем сохранять вектор документа\n", "glove_data=pd.DataFrame()\n", "\n", "# Пробегаем по каждой строке датафрейма (по каждому документу)\n", "for i in range(CV_text_data.shape[0]):\n", " \n", " # Вектор одного документа с размерностью glove-модели:\n", " one_doc = np.zeros(25) \n", " \n", " # Пробегаемся по каждому документу, смотрим, какие слова документа присутствуют в нашем словаре\n", " # Суммируем glove-вектора каждого известного слова в one_doc\n", " for word in words_vocab[CV_text_data.iloc[i,:] >= 1]:\n", " if word in glove_model.key_to_index.keys(): \n", " print(word, ': ', glove_model[word])\n", " one_doc += glove_model[word]\n", " print(text_data[i], ': ', one_doc)\n", " glove_data=glove_data.append(pd.DataFrame([one_doc])) \n" ] }, { "cell_type": "code", "execution_count": 12, "id": "ff68d8dc", "metadata": {}, "outputs": [], "source": [ "def text2vec(text_data):\n", " \n", " # Векторизуем с помощью обученного CountVectorizer\n", " X = vectorizer.transform(text_data)\n", " CV_text_data=pd.DataFrame(X.toarray(), columns=vectorizer.get_feature_names_out())\n", " CV_text_data\n", " # Создадим датафрейм, в который будем сохранять вектор документа\n", " glove_data=pd.DataFrame()\n", "\n", " # Пробегаем по каждой строке (по каждому документу)\n", " for i in range(CV_text_data.shape[0]):\n", "\n", " # Вектор одного документа с размерностью glove-модели:\n", " one_doc = np.zeros(25) \n", "\n", " # Пробегаемся по каждому документу, смотрим, какие слова документа присутствуют в нашем словаре\n", " # Суммируем glove-вектора каждого известного слова в one_doc\n", " for word in words_vocab[CV_text_data.iloc[i,:] >= 1]:\n", " if word in glove_model.key_to_index.keys(): \n", " #print(word, ': ', glove_model[word])\n", " one_doc += glove_model[word]\n", " #print(text_data[i], ': ', one_doc)\n", " glove_data = pd.concat([glove_data, pd.DataFrame([one_doc])], axis = 0)\n", " #print('glove_data: ', glove_data)\n", " return glove_data" ] }, { "cell_type": "code", "execution_count": 13, "id": "b778776c", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
0123456789...15161718192021222324
0-1.55058-0.0816830.2799190.588469-1.00551-0.8261396.186431.44845-0.71108-1.14717...-0.4308750.872347-0.8063990.272032.23922-1.23572-1.310711-1.96934-0.176410-0.135300
01.736100.7420820.355460-4.7441101.41544-0.3422201.78697-1.454042.56643-1.32184...-0.5266201.932400-0.896870-0.609241.51628-3.16625-0.8927501.86970-1.336071-2.234643
\n", "

2 rows × 25 columns

\n", "
" ], "text/plain": [ " 0 1 2 3 4 5 6 7 \\\n", "0 -1.55058 -0.081683 0.279919 0.588469 -1.00551 -0.826139 6.18643 1.44845 \n", "0 1.73610 0.742082 0.355460 -4.744110 1.41544 -0.342220 1.78697 -1.45404 \n", "\n", " 8 9 ... 15 16 17 18 19 \\\n", "0 -0.71108 -1.14717 ... -0.430875 0.872347 -0.806399 0.27203 2.23922 \n", "0 2.56643 -1.32184 ... -0.526620 1.932400 -0.896870 -0.60924 1.51628 \n", "\n", " 20 21 22 23 24 \n", "0 -1.23572 -1.310711 -1.96934 -0.176410 -0.135300 \n", "0 -3.16625 -0.892750 1.86970 -1.336071 -2.234643 \n", "\n", "[2 rows x 25 columns]" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\n", "glove_data\n" ] }, { "cell_type": "code", "execution_count": 15, "id": "1bdb459e", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
0123456789...15161718192021222324
0-8.5211422.020376-10.8029213.1676360.25246915.54404817.631184-32.5811929.696540-11.103087...2.8104537.9002150.96212917.691130-1.252574-10.0980490.5001131.3486942.186150-16.556824
06.57622820.336350-32.675150-9.07387217.515655-6.48879459.458419-75.38429813.323775-14.443218...21.40773823.5251180.32568019.871444-27.585188-4.559155-7.417482-16.694553-0.197711-58.948193
01.3299143.060870-1.8684841.392735-1.335277-5.01495512.859476-9.978156-0.869613-2.031490...2.9251342.8729302.1844863.831770-0.877866-0.9277700.700101-9.855365-5.419429-2.279330
0-4.866150-0.2731763.515124-5.008165-1.236789-7.951168-11.015882-3.49624116.024286-9.388742...-0.4711413.5753786.1932220.34943015.040248-10.369132-0.848717-0.564796-1.114126-7.844431
0-3.115007-1.805252-5.419340-0.393406-0.406461-2.7243407.898330-15.6191130.231822-3.628156...5.9441518.309932-0.65608412.178709-6.118551-3.2863763.4509462.0553430.463787-12.644626
..................................................................
0-0.9309544.974043-8.147008-5.1471303.960455-1.3440227.818063-25.4274204.624732-7.218097...3.6230384.4531892.4053208.032963-8.0295390.838867-4.757457-5.755052-9.496197-21.542710
0-0.7706900.1282700.3313700.005089-0.476050-0.5011601.8580001.062400-0.5651100.133280...-0.4492200.485910-0.647900-0.8423800.616690-0.198240-0.579670-0.6588500.439280-0.504730
01.4911776.992638-7.921970-7.1575216.641657-2.95802012.820770-18.5029466.838083-2.717310...-1.3448734.170405-0.1780305.699992-7.295038-3.683306-2.718006-0.117608-7.205832-13.863438
02.5237705.8173942.184340-2.996497-0.267181-10.0596346.344402-2.0471272.679123-7.642505...-1.2302961.409746-3.322040-5.068259-0.6487180.753010-6.220990-5.012004-1.518542-10.156440
0-0.11869111.860546-2.567264-10.955913-4.239322-9.34055221.189778-10.8953752.659030-3.848115...0.72619111.634998-5.4472481.293007-7.882002-2.5274530.298939-6.1070623.365051-15.641826
\n", "

1657 rows × 25 columns

\n", "
" ], "text/plain": [ " 0 1 2 3 4 5 \\\n", "0 -8.521142 2.020376 -10.802921 3.167636 0.252469 15.544048 \n", "0 6.576228 20.336350 -32.675150 -9.073872 17.515655 -6.488794 \n", "0 1.329914 3.060870 -1.868484 1.392735 -1.335277 -5.014955 \n", "0 -4.866150 -0.273176 3.515124 -5.008165 -1.236789 -7.951168 \n", "0 -3.115007 -1.805252 -5.419340 -0.393406 -0.406461 -2.724340 \n", ".. ... ... ... ... ... ... \n", "0 -0.930954 4.974043 -8.147008 -5.147130 3.960455 -1.344022 \n", "0 -0.770690 0.128270 0.331370 0.005089 -0.476050 -0.501160 \n", "0 1.491177 6.992638 -7.921970 -7.157521 6.641657 -2.958020 \n", "0 2.523770 5.817394 2.184340 -2.996497 -0.267181 -10.059634 \n", "0 -0.118691 11.860546 -2.567264 -10.955913 -4.239322 -9.340552 \n", "\n", " 6 7 8 9 ... 15 16 \\\n", "0 17.631184 -32.581192 9.696540 -11.103087 ... 2.810453 7.900215 \n", "0 59.458419 -75.384298 13.323775 -14.443218 ... 21.407738 23.525118 \n", "0 12.859476 -9.978156 -0.869613 -2.031490 ... 2.925134 2.872930 \n", "0 -11.015882 -3.496241 16.024286 -9.388742 ... -0.471141 3.575378 \n", "0 7.898330 -15.619113 0.231822 -3.628156 ... 5.944151 8.309932 \n", ".. ... ... ... ... ... ... ... \n", "0 7.818063 -25.427420 4.624732 -7.218097 ... 3.623038 4.453189 \n", "0 1.858000 1.062400 -0.565110 0.133280 ... -0.449220 0.485910 \n", "0 12.820770 -18.502946 6.838083 -2.717310 ... -1.344873 4.170405 \n", "0 6.344402 -2.047127 2.679123 -7.642505 ... -1.230296 1.409746 \n", "0 21.189778 -10.895375 2.659030 -3.848115 ... 0.726191 11.634998 \n", "\n", " 17 18 19 20 21 22 23 \\\n", "0 0.962129 17.691130 -1.252574 -10.098049 0.500113 1.348694 2.186150 \n", "0 0.325680 19.871444 -27.585188 -4.559155 -7.417482 -16.694553 -0.197711 \n", "0 2.184486 3.831770 -0.877866 -0.927770 0.700101 -9.855365 -5.419429 \n", "0 6.193222 0.349430 15.040248 -10.369132 -0.848717 -0.564796 -1.114126 \n", "0 -0.656084 12.178709 -6.118551 -3.286376 3.450946 2.055343 0.463787 \n", ".. ... ... ... ... ... ... ... \n", "0 2.405320 8.032963 -8.029539 0.838867 -4.757457 -5.755052 -9.496197 \n", "0 -0.647900 -0.842380 0.616690 -0.198240 -0.579670 -0.658850 0.439280 \n", "0 -0.178030 5.699992 -7.295038 -3.683306 -2.718006 -0.117608 -7.205832 \n", "0 -3.322040 -5.068259 -0.648718 0.753010 -6.220990 -5.012004 -1.518542 \n", "0 -5.447248 1.293007 -7.882002 -2.527453 0.298939 -6.107062 3.365051 \n", "\n", " 24 \n", "0 -16.556824 \n", "0 -58.948193 \n", "0 -2.279330 \n", "0 -7.844431 \n", "0 -12.644626 \n", ".. ... \n", "0 -21.542710 \n", "0 -0.504730 \n", "0 -13.863438 \n", "0 -10.156440 \n", "0 -15.641826 \n", "\n", "[1657 rows x 25 columns]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_data_glove = text2vec(twenty_train['data']);\n", "train_data_glove" ] }, { "cell_type": "code", "execution_count": 17, "id": "5ac20e79", "metadata": {}, "outputs": [], "source": [ "clf = KNeighborsClassifier(n_neighbors = 5)" ] }, { "cell_type": "code", "execution_count": 18, "id": "08164a25", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
KNeighborsClassifier()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "KNeighborsClassifier()" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clf.fit(train_data_glove, twenty_train['target'])" ] }, { "cell_type": "code", "execution_count": 19, "id": "e459faaf", "metadata": {}, "outputs": [], "source": [ "test_data_glove = text2vec(twenty_test['data']);" ] }, { "cell_type": "code", "execution_count": 20, "id": "d8144e75", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
0123456789...15161718192021222324
0-6.7606355.063863-2.7790603.699120-2.8580860.13523020.811229-19.4255677.302950-5.826012...3.8333786.794452-0.92172012.187404-5.547615-4.1339993.588260-0.497106-2.542142-11.362855
01.6326162.512300-0.745513-3.0811542.182067-1.9888167.533100-1.015740-0.829598-2.764237...0.7918512.114150-2.249193-0.163590-1.177710-2.496928-5.074085-2.6669470.662050-3.590550
02.1157662.142060-0.445607-3.2290301.154580-2.8772786.399954-10.4457692.230760-3.299899...4.3888708.515056-0.7662603.549431-1.643443-0.825730-2.968016-0.808924-0.000160-7.468189
0-0.8027845.1994434.294071-7.3909662.747166-1.35995215.032628-1.6015901.4744062.570105...3.0434326.176236-6.193988-3.990476-2.345854-5.534376-8.9254221.5533000.905790-12.824533
029.92648965.324993-25.059592-64.08013077.565282-34.61460475.643770-115.60085990.847175-42.971146...40.95603150.322156-19.53709828.903925-34.643949-69.894146-94.992145-48.601895-29.098555-91.934770
..................................................................
01.8292354.5138072.9165202.237308-1.704831-1.81119222.196895-12.858912-4.054810-3.130457...6.0192468.949456-4.682214-5.648911-1.0268983.7190062.449941-6.4871971.340930-7.325196
0-0.9638155.4911643.567377-6.048021-5.059298-0.97795815.131499-0.9044702.185990-1.459807...0.9684994.725793-0.7269441.328612-3.1442091.643127-1.259245-0.880740-6.713165-3.115454
06.80132415.348126-17.0517185.0309989.332448-5.71669156.409175-56.250411-4.028209-11.687558...22.88442412.9405701.05866421.879058-20.8972532.5377553.774890-11.495336-2.609774-36.597559
01.0540900.7645241.958340-1.085245-0.441392-0.4219706.139770-0.612219-2.251460-0.465165...0.3779581.957450-1.705220-0.5097000.0161101.4616201.5890692.2673400.447919-0.469250
0-18.38728613.274879-7.895913-1.831442-10.424961-12.24844232.153890-40.16929313.089525-21.306493...6.4972798.3407294.99610923.442078-3.701088-11.6715059.209790-10.002501-0.815266-17.024052
\n", "

1102 rows × 25 columns

\n", "
" ], "text/plain": [ " 0 1 2 3 4 5 \\\n", "0 -6.760635 5.063863 -2.779060 3.699120 -2.858086 0.135230 \n", "0 1.632616 2.512300 -0.745513 -3.081154 2.182067 -1.988816 \n", "0 2.115766 2.142060 -0.445607 -3.229030 1.154580 -2.877278 \n", "0 -0.802784 5.199443 4.294071 -7.390966 2.747166 -1.359952 \n", "0 29.926489 65.324993 -25.059592 -64.080130 77.565282 -34.614604 \n", ".. ... ... ... ... ... ... \n", "0 1.829235 4.513807 2.916520 2.237308 -1.704831 -1.811192 \n", "0 -0.963815 5.491164 3.567377 -6.048021 -5.059298 -0.977958 \n", "0 6.801324 15.348126 -17.051718 5.030998 9.332448 -5.716691 \n", "0 1.054090 0.764524 1.958340 -1.085245 -0.441392 -0.421970 \n", "0 -18.387286 13.274879 -7.895913 -1.831442 -10.424961 -12.248442 \n", "\n", " 6 7 8 9 ... 15 16 \\\n", "0 20.811229 -19.425567 7.302950 -5.826012 ... 3.833378 6.794452 \n", "0 7.533100 -1.015740 -0.829598 -2.764237 ... 0.791851 2.114150 \n", "0 6.399954 -10.445769 2.230760 -3.299899 ... 4.388870 8.515056 \n", "0 15.032628 -1.601590 1.474406 2.570105 ... 3.043432 6.176236 \n", "0 75.643770 -115.600859 90.847175 -42.971146 ... 40.956031 50.322156 \n", ".. ... ... ... ... ... ... ... \n", "0 22.196895 -12.858912 -4.054810 -3.130457 ... 6.019246 8.949456 \n", "0 15.131499 -0.904470 2.185990 -1.459807 ... 0.968499 4.725793 \n", "0 56.409175 -56.250411 -4.028209 -11.687558 ... 22.884424 12.940570 \n", "0 6.139770 -0.612219 -2.251460 -0.465165 ... 0.377958 1.957450 \n", "0 32.153890 -40.169293 13.089525 -21.306493 ... 6.497279 8.340729 \n", "\n", " 17 18 19 20 21 22 \\\n", "0 -0.921720 12.187404 -5.547615 -4.133999 3.588260 -0.497106 \n", "0 -2.249193 -0.163590 -1.177710 -2.496928 -5.074085 -2.666947 \n", "0 -0.766260 3.549431 -1.643443 -0.825730 -2.968016 -0.808924 \n", "0 -6.193988 -3.990476 -2.345854 -5.534376 -8.925422 1.553300 \n", "0 -19.537098 28.903925 -34.643949 -69.894146 -94.992145 -48.601895 \n", ".. ... ... ... ... ... ... \n", "0 -4.682214 -5.648911 -1.026898 3.719006 2.449941 -6.487197 \n", "0 -0.726944 1.328612 -3.144209 1.643127 -1.259245 -0.880740 \n", "0 1.058664 21.879058 -20.897253 2.537755 3.774890 -11.495336 \n", "0 -1.705220 -0.509700 0.016110 1.461620 1.589069 2.267340 \n", "0 4.996109 23.442078 -3.701088 -11.671505 9.209790 -10.002501 \n", "\n", " 23 24 \n", "0 -2.542142 -11.362855 \n", "0 0.662050 -3.590550 \n", "0 -0.000160 -7.468189 \n", "0 0.905790 -12.824533 \n", "0 -29.098555 -91.934770 \n", ".. ... ... \n", "0 1.340930 -7.325196 \n", "0 -6.713165 -3.115454 \n", "0 -2.609774 -36.597559 \n", "0 0.447919 -0.469250 \n", "0 -0.815266 -17.024052 \n", "\n", "[1102 rows x 25 columns]" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_data_glove" ] }, { "cell_type": "code", "execution_count": 21, "id": "a69830f0", "metadata": {}, "outputs": [], "source": [ "predict = clf.predict(test_data_glove )" ] }, { "cell_type": "code", "execution_count": 22, "id": "9ac5cf20", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[225 35 59]\n", " [ 26 313 50]\n", " [ 56 98 240]]\n", " precision recall f1-score support\n", "\n", " 0 0.73 0.71 0.72 319\n", " 1 0.70 0.80 0.75 389\n", " 2 0.69 0.61 0.65 394\n", "\n", " accuracy 0.71 1102\n", " macro avg 0.71 0.71 0.70 1102\n", "weighted avg 0.71 0.71 0.70 1102\n", "\n" ] } ], "source": [ "print (confusion_matrix(twenty_test['target'], predict))\n", "print(classification_report(twenty_test['target'], predict))" ] }, { "cell_type": "code", "execution_count": null, "id": "b8cce5a9", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "1b9bff90", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 5 }