From de9f98f44ee7e6caafe36b7be46b72f9894b63b3 Mon Sep 17 00:00:00 2001 From: Andrey Date: Tue, 17 Mar 2026 08:45:31 +0300 Subject: [PATCH] =?UTF-8?q?LR2=20=D0=BC=D0=B5=D1=82=D0=BE=D0=B4=D0=B8?= =?UTF-8?q?=D1=87=D0=B5=D1=81=D0=BA=D0=B8=D0=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- labs/OATD_LR2_metod.ipynb | 4508 ++++++++++++++++++++++++++++++++++++- 1 file changed, 4460 insertions(+), 48 deletions(-) diff --git a/labs/OATD_LR2_metod.ipynb b/labs/OATD_LR2_metod.ipynb index 98e1023..b6b20b2 100644 --- a/labs/OATD_LR2_metod.ipynb +++ b/labs/OATD_LR2_metod.ipynb @@ -36,14 +36,16 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from sklearn.neighbors import KNeighborsClassifier\n", "from sklearn.datasets import fetch_20newsgroups\n", "from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer \n", - "import numpy as np" + "import numpy as np\n", + "import pandas as pd\n", + "from sklearn.metrics import confusion_matrix, classification_report, accuracy_score, roc_auc_score\n" ] }, { @@ -63,7 +65,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -82,7 +84,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -125,7 +127,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -150,7 +152,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -169,7 +171,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -193,14 +195,14 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "[('image', 489), ('don', 417), ('graphics', 410), ('god', 409), ('people', 384), ('does', 364), ('edu', 349), ('like', 329), ('just', 327), ('know', 319)]\n" + "[('image', np.int64(489)), ('don', np.int64(417)), ('graphics', np.int64(410)), ('god', np.int64(409)), ('people', np.int64(384)), ('does', np.int64(364)), ('edu', np.int64(349)), ('like', np.int64(329)), ('just', np.int64(327)), ('know', np.int64(319))]\n" ] } ], @@ -212,26 +214,6 @@ "print (x[:10])\n" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Если для каждого класса по-отдельности получить подобные списки наиболее частотных слов, то можно оценить пересекаемость терминов двух классов, например, с помощью [меры сходства Жаккара](https://ru.wikipedia.org/wiki/Коэффициент_Жаккара).\n", - "Ниже приведена функция, которая на вход принимает два списка слов и возвращает значение коэффициента Жаккара:" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], - "source": [ - "def jaccard_similarity(list1, list2):\n", - " intersection = len(list(set(list1).intersection(list2)))\n", - " union = (len(set(list1)) + len(set(list2))) - intersection\n", - " return float(intersection) / union\n" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -243,27 +225,22 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Существует целый ряд алгоритмов стемминга. В работе предлагается использовать алгоритм Портера, реализация которого приведена в библиотеке [nltk](https://www.nltk.org/)\n", + "Существует целый ряд алгоритмов стемминга. Один из популярных - алгоритм Портера, реализация которого приведена в библиотеке [nltk](https://www.nltk.org/)\n", "Для проведения стемминга нужно создать объект `PorterStemmer()`. Стеммер работает таким образом: у созданного объекта `PorterStemmer` есть метод `stem`, производящий стемминга. Таким образом, необходимо каждую из частей выборки (обучающую и тестовую) разбить на отдельные документы, затем, проходя в цикле по каждому слову в документе, произвести стемминг и объединить эти слова в новый документ.\n", - "\n" + "\n", + "В ЛР проводить стемминг не требуется.\n" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " i 'll take a wild guess and say freedom is object valuabl . i base thi on the assumpt that if everyon in the world were depriv utterli of their freedom ( so that their everi act wa contrari to their volit ) , almost all would want to complain . therefor i take it that to assert or believ that `` freedom is not veri valuabl '' , when almost everyon can see that it is , is everi bit as absurd as to assert `` it is not rain '' on a raini day . i take thi to be a candid for an object valu , and it it is a necessari condit for object moral that object valu such as thi exist .\n" - ] - } - ], + "outputs": [], "source": [ + "import nltk\n", "from nltk.stem import *\n", "from nltk import word_tokenize\n", + "nltk.download('punkt_tab')\n", "\n", "porter_stemmer = PorterStemmer()\n", "stem_train = []\n", @@ -293,7 +270,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -342,7 +319,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -359,7 +336,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -377,7 +354,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -393,13 +370,4448 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "prediction = text_clf.predict(twenty_test.data)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Настройка параметров с использованием grid search" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.model_selection import GridSearchCV" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "С помощью конвейерной обработки (`pipelines`) стало гораздо проще указывать параметры обучения модели, однако перебирать все возможные варианты вручную даже в нашем простом случае выйдет затратно. В нашем случае имеется четыре настраиваемых параметра: `max_features`, `stop_words`, `use_idf` и `n_neighbors`.\n", + "\n", + "Вместо поиска лучших параметров в конвейере вручную, можно запустить поиск (методом полного перебора) лучших параметров в сетке возможных значений. Сделать это можно с помощью объекта класса `GridSearchCV`.\n", + " \n", + "Для того чтобы задать сетку параметров необходимо создать переменную-словарь, ключами которого являются конструкции вида: «НазваниеШагаКонвейера__НазваниеПараметра», а значениями – кортеж из значений параметра\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "parameters = {'vect__max_features': (100,500,1000,5000,10000),\n", + " 'vect__stop_words': ('english', None),\n", + " 'tfidf__use_idf': (True, False), \n", + " 'clf__n_neighbors': (1,3,5,7)} " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Далее необходимо создать объект класса `GridSearchCV`, передав в него объект `pipeline` или классификатор, список параметров сетки, а также при необходимости, задав прочие параметры, такие так количество задействованых ядер процессора `n_jobs`, количество фолдов кросс-валидации `cv`, метрику, по которой будем судить о качестве модели `scoring`, и другие" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gs_clf = GridSearchCV(text_clf, parameters, n_jobs=-1, cv=3, scoring = 'f1_weighted')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Теперь объект `gs_clf` можно обучить по всем параметрам, как обычный классификатор, методом `fit()`.\n", + "После того как прошло обучение, узнать лучшую совокупность параметров можно, обратившись к атрибуту `best_params_`. Для необученной модели атрибуты будут отсутствовать.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
GridSearchCV(cv=3,\n",
+       "             estimator=Pipeline(steps=[('vect',\n",
+       "                                        CountVectorizer(max_features=1000,\n",
+       "                                                        stop_words='english')),\n",
+       "                                       ('tfidf', TfidfTransformer()),\n",
+       "                                       ('clf',\n",
+       "                                        KNeighborsClassifier(n_neighbors=1))]),\n",
+       "             n_jobs=-1,\n",
+       "             param_grid={'clf__n_neighbors': (1, 3, 5, 7),\n",
+       "                         'tfidf__use_idf': (True, False),\n",
+       "                         'vect__max_features': (100, 500, 1000, 5000, 10000),\n",
+       "                         'vect__stop_words': ('english', None)},\n",
+       "             scoring='f1_weighted')
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" + ], + "text/plain": [ + "GridSearchCV(cv=3,\n", + " estimator=Pipeline(steps=[('vect',\n", + " CountVectorizer(max_features=1000,\n", + " stop_words='english')),\n", + " ('tfidf', TfidfTransformer()),\n", + " ('clf',\n", + " KNeighborsClassifier(n_neighbors=1))]),\n", + " n_jobs=-1,\n", + " param_grid={'clf__n_neighbors': (1, 3, 5, 7),\n", + " 'tfidf__use_idf': (True, False),\n", + " 'vect__max_features': (100, 500, 1000, 5000, 10000),\n", + " 'vect__stop_words': ('english', None)},\n", + " scoring='f1_weighted')" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gs_clf.fit(twenty_train.data, twenty_train.target)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'clf__n_neighbors': 3,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 100,\n", + " 'vect__stop_words': 'english'}" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gs_clf.best_params_" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'mean_fit_time': array([0.14711912, 0.1838789 , 0.16524474, 0.13485821, 0.17077629,\n", + " 0.171736 , 0.13478573, 0.17135119, 0.1455044 , 0.16278481,\n", + " 0.14029678, 0.17597159, 0.14752992, 0.1515433 , 0.2238245 ,\n", + " 0.16685406, 0.19492237, 0.18400868, 0.1637164 , 0.23180572,\n", + " 0.13470483, 0.19008255, 0.168998 , 0.14402525, 0.14959741,\n", + " 0.21908172, 0.12458984, 0.17676258, 0.181101 , 0.15396865,\n", + " 0.13980309, 0.17127609, 0.12496074, 0.21157686, 0.14910388,\n", + " 0.16302236, 0.20699684, 0.14392487, 0.13715124, 0.19204704,\n", + " 0.12343788, 0.202528 , 0.1306417 , 0.14448158, 0.16433088,\n", + " 0.17828568, 0.15127762, 0.16031146, 0.15140502, 0.17451549,\n", + " 0.13383325, 0.19785619, 0.14872464, 0.19805606, 0.17880678,\n", + " 0.14933427, 0.18098283, 0.17271209, 0.19098576, 0.15773439,\n", + " 0.14277951, 0.14034979, 0.17362897, 0.14588912, 0.15714224,\n", + " 0.19136397, 0.14665333, 0.18524861, 0.1639146 , 0.15949496,\n", + " 0.14914314, 0.19116807, 0.16394456, 0.17290258, 0.14597694,\n", + " 0.18869273, 0.17521318, 0.18450022, 0.15120141, 0.14013394]),\n", + " 'std_fit_time': array([0.0426388 , 0.00865779, 0.03047002, 0.01959866, 0.01996265,\n", + " 0.03991559, 0.02089307, 0.0115939 , 0.01577824, 0.01259437,\n", + " 0.02861695, 0.02803293, 0.02632727, 0.03513436, 0.10327593,\n", + " 0.06408524, 0.04651852, 0.03295957, 0.0471307 , 0.11679622,\n", + " 0.0110411 , 0.02803024, 0.05824908, 0.01165112, 0.01597652,\n", + " 0.06360679, 0.0225909 , 0.03841467, 0.06331909, 0.02901401,\n", + " 0.02955105, 0.00180821, 0.01985555, 0.0523942 , 0.03268085,\n", + " 0.03894983, 0.06517167, 0.02243237, 0.02200533, 0.0133168 ,\n", + " 0.02014168, 0.05538047, 0.00691204, 0.02019455, 0.04787927,\n", + " 0.02202076, 0.03367995, 0.04032929, 0.04035338, 0.05086764,\n", + " 0.03783577, 0.02842441, 0.01027949, 0.02261441, 0.07413437,\n", + " 0.03057069, 0.07406925, 0.06374613, 0.07499723, 0.00702891,\n", + " 0.01449466, 0.01535535, 0.06266348, 0.01594578, 0.03962651,\n", + " 0.03872874, 0.02924308, 0.03921288, 0.03093211, 0.02737551,\n", + " 0.03913621, 0.03922343, 0.06958397, 0.0321589 , 0.03550322,\n", + " 0.06330303, 0.04521662, 0.01877706, 0.01843914, 0.02003679]),\n", + " 'mean_score_time': array([0.07024606, 0.09829744, 0.06123559, 0.0934546 , 0.06724668,\n", + " 0.08956035, 0.06278229, 0.08930341, 0.07657051, 0.08451796,\n", + " 0.06686974, 0.06430848, 0.06159838, 0.06787952, 0.0685486 ,\n", + " 0.08175969, 0.07736111, 0.08548617, 0.10391927, 0.11483145,\n", + " 0.06140908, 0.08571712, 0.05673496, 0.07094908, 0.07356191,\n", + " 0.08419561, 0.06646315, 0.07401323, 0.0662907 , 0.07332158,\n", + " 0.06078386, 0.06057103, 0.05959137, 0.07239731, 0.06452703,\n", + " 0.06926155, 0.07553283, 0.07713358, 0.06963356, 0.09399152,\n", + " 0.05499752, 0.06101354, 0.06951396, 0.06801232, 0.06193964,\n", + " 0.08509183, 0.06340122, 0.09640757, 0.09270819, 0.07950346,\n", + " 0.06135122, 0.07454085, 0.06276194, 0.08477139, 0.06570975,\n", + " 0.06936661, 0.07554563, 0.08220371, 0.06743741, 0.10737284,\n", + " 0.06237586, 0.06374478, 0.07077082, 0.06676245, 0.06271029,\n", + " 0.10701195, 0.07044252, 0.07724508, 0.0787247 , 0.07522114,\n", + " 0.06066545, 0.08716504, 0.06308579, 0.08094724, 0.08140651,\n", + " 0.07310184, 0.07593838, 0.07488211, 0.06320556, 0.04453659]),\n", + " 'std_score_time': array([0.0307069 , 0.0182529 , 0.02101209, 0.03205303, 0.0257132 ,\n", + " 0.02687809, 0.0173987 , 0.02640571, 0.01539859, 0.01465485,\n", + " 0.02122433, 0.01356409, 0.00834315, 0.01751938, 0.01171655,\n", + " 0.03144068, 0.02179221, 0.01969381, 0.04494128, 0.00622263,\n", + " 0.01579416, 0.02897917, 0.01304806, 0.0202541 , 0.02151299,\n", + " 0.03211671, 0.01358637, 0.01938508, 0.01422302, 0.01688134,\n", + " 0.00869969, 0.0141093 , 0.01194079, 0.01273514, 0.01648673,\n", + " 0.01545342, 0.01899566, 0.0178712 , 0.02332053, 0.03043117,\n", + " 0.01200954, 0.01238742, 0.02325767, 0.01629294, 0.01778484,\n", + " 0.03084458, 0.0157806 , 0.04616991, 0.02142426, 0.01231431,\n", + " 0.0059662 , 0.01826184, 0.02046084, 0.02321078, 0.00858559,\n", + " 0.01727098, 0.01823882, 0.02170296, 0.0142323 , 0.06476885,\n", + " 0.02283514, 0.01249623, 0.00957304, 0.01621181, 0.00941332,\n", + " 0.04093458, 0.01894285, 0.0226533 , 0.02176937, 0.01821806,\n", + " 0.01055958, 0.03213191, 0.00837758, 0.02668343, 0.03007748,\n", + " 0.01731415, 0.01637027, 0.01601808, 0.02890825, 0.01039945]),\n", + " 'param_clf__n_neighbors': masked_array(data=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n", + " 3, 3, 3, 3, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,\n", + " 5, 5, 5, 5, 5, 5, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,\n", + " 7, 7, 7, 7, 7, 7, 7, 7],\n", + " mask=[False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False],\n", + " fill_value=999999),\n", + " 'param_tfidf__use_idf': masked_array(data=[True, True, True, True, True, True, True, True, True,\n", + " True, False, False, False, False, False, False, False,\n", + " False, False, False, True, True, True, True, True,\n", + " True, True, True, True, True, False, False, False,\n", + " False, False, False, False, False, False, False, True,\n", + " True, True, True, True, True, True, True, True, True,\n", + " False, False, False, False, False, False, False, False,\n", + " False, False, True, True, True, True, True, True, True,\n", + " True, True, True, False, False, False, False, False,\n", + " False, False, False, False, False],\n", + " mask=[False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False],\n", + " fill_value=True),\n", + " 'param_vect__max_features': masked_array(data=[100, 100, 500, 500, 1000, 1000, 5000, 5000, 10000,\n", + " 10000, 100, 100, 500, 500, 1000, 1000, 5000, 5000,\n", + " 10000, 10000, 100, 100, 500, 500, 1000, 1000, 5000,\n", + " 5000, 10000, 10000, 100, 100, 500, 500, 1000, 1000,\n", + " 5000, 5000, 10000, 10000, 100, 100, 500, 500, 1000,\n", + " 1000, 5000, 5000, 10000, 10000, 100, 100, 500, 500,\n", + " 1000, 1000, 5000, 5000, 10000, 10000, 100, 100, 500,\n", + " 500, 1000, 1000, 5000, 5000, 10000, 10000, 100, 100,\n", + " 500, 500, 1000, 1000, 5000, 5000, 10000, 10000],\n", + " mask=[False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False],\n", + " fill_value=999999),\n", + " 'param_vect__stop_words': masked_array(data=['english', None, 'english', None, 'english', None,\n", + " 'english', None, 'english', None, 'english', None,\n", + " 'english', None, 'english', None, 'english', None,\n", + " 'english', None, 'english', None, 'english', None,\n", + " 'english', None, 'english', None, 'english', None,\n", + " 'english', None, 'english', None, 'english', None,\n", + " 'english', None, 'english', None, 'english', None,\n", + " 'english', None, 'english', None, 'english', None,\n", + " 'english', None, 'english', None, 'english', None,\n", + " 'english', None, 'english', None, 'english', None,\n", + " 'english', None, 'english', None, 'english', None,\n", + " 'english', None, 'english', None, 'english', None,\n", + " 'english', None, 'english', None, 'english', None,\n", + " 'english', None],\n", + " mask=[False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False],\n", + " fill_value=np.str_('?'),\n", + " dtype=object),\n", + " 'params': [{'clf__n_neighbors': 1,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 100,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 1,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 100,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 1,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 500,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 1,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 500,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 1,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 1000,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 1,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 1000,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 1,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 5000,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 1,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 5000,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 1,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 10000,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 1,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 10000,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 1,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 100,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 1,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 100,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 1,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 500,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 1,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 500,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 1,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 1000,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 1,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 1000,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 1,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 5000,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 1,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 5000,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 1,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 10000,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 1,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 10000,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 3,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 100,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 3,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 100,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 3,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 500,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 3,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 500,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 3,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 1000,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 3,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 1000,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 3,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 5000,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 3,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 5000,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 3,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 10000,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 3,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 10000,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 3,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 100,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 3,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 100,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 3,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 500,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 3,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 500,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 3,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 1000,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 3,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 1000,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 3,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 5000,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 3,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 5000,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 3,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 10000,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 3,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 10000,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 5,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 100,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 5,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 100,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 5,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 500,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 5,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 500,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 5,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 1000,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 5,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 1000,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 5,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 5000,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 5,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 5000,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 5,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 10000,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 5,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 10000,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 5,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 100,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 5,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 100,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 5,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 500,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 5,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 500,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 5,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 1000,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 5,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 1000,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 5,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 5000,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 5,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 5000,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 5,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 10000,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 5,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 10000,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 7,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 100,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 7,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 100,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 7,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 500,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 7,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 500,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 7,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 1000,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 7,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 1000,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 7,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 5000,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 7,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 5000,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 7,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 10000,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 7,\n", + " 'tfidf__use_idf': True,\n", + " 'vect__max_features': 10000,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 7,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 100,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 7,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 100,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 7,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 500,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 7,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 500,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 7,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 1000,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 7,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 1000,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 7,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 5000,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 7,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 5000,\n", + " 'vect__stop_words': None},\n", + " {'clf__n_neighbors': 7,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 10000,\n", + " 'vect__stop_words': 'english'},\n", + " {'clf__n_neighbors': 7,\n", + " 'tfidf__use_idf': False,\n", + " 'vect__max_features': 10000,\n", + " 'vect__stop_words': None}],\n", + " 'split0_test_score': array([0.76817583, 0.72894636, 0.53358157, 0.6374826 , 0.52798315,\n", + " 0.56519764, 0.51309075, 0.49836392, 0.49549296, 0.48037932,\n", + " 0.74678068, 0.6781238 , 0.53815493, 0.67256907, 0.59194787,\n", + " 0.61495561, 0.50607973, 0.58348819, 0.48282372, 0.56647032,\n", + " 0.78847084, 0.7252709 , 0.48067615, 0.52118859, 0.51721694,\n", + " 0.48760057, 0.51577841, 0.46346843, 0.43209944, 0.42022184,\n", + " 0.79654539, 0.72168041, 0.50540586, 0.62161402, 0.53994855,\n", + " 0.57362945, 0.53619645, 0.54414205, 0.43884702, 0.4762766 ,\n", + " 0.75153416, 0.72862518, 0.45804881, 0.49005168, 0.52767212,\n", + " 0.49479873, 0.48059474, 0.45872108, 0.43320814, 0.46831129,\n", + " 0.77799622, 0.71691323, 0.468479 , 0.58540742, 0.53125944,\n", + " 0.55188163, 0.53165214, 0.49549296, 0.46225076, 0.46334516,\n", + " 0.73731272, 0.722074 , 0.42768101, 0.49765146, 0.48090102,\n", + " 0.59104363, 0.51802701, 0.48304986, 0.44659485, 0.52984156,\n", + " 0.76674346, 0.71632941, 0.42110969, 0.55322776, 0.49049555,\n", + " 0.55175372, 0.55288072, 0.52752419, 0.47634782, 0.48746189]),\n", + " 'split1_test_score': array([0.79662124, 0.74654727, 0.55840557, 0.65284461, 0.56848052,\n", + " 0.58959475, 0.50824228, 0.52616761, 0.55281361, 0.53900106,\n", + " 0.79198638, 0.74396927, 0.60547906, 0.71852445, 0.55601273,\n", + " 0.70906678, 0.55288072, 0.70132316, 0.54064496, 0.72369393,\n", + " 0.8200291 , 0.76394456, 0.50142662, 0.65973942, 0.57835078,\n", + " 0.58411319, 0.56231765, 0.54929577, 0.53755372, 0.54631601,\n", + " 0.76828781, 0.73851721, 0.57218138, 0.70341759, 0.54466558,\n", + " 0.70088681, 0.52052675, 0.68262612, 0.56159728, 0.69962829,\n", + " 0.78081641, 0.76676231, 0.47784768, 0.65979122, 0.55877981,\n", + " 0.58351195, 0.55934603, 0.5526496 , 0.53805724, 0.54275916,\n", + " 0.7590099 , 0.74428157, 0.50962843, 0.69741916, 0.56721274,\n", + " 0.6845754 , 0.53674839, 0.65880645, 0.56642635, 0.66528626,\n", + " 0.75249835, 0.76114625, 0.5033107 , 0.61728125, 0.54473029,\n", + " 0.54902822, 0.55716104, 0.55466537, 0.51280439, 0.54540585,\n", + " 0.72900844, 0.73304581, 0.49250319, 0.69558085, 0.57262651,\n", + " 0.66997019, 0.54339107, 0.66489176, 0.56953075, 0.65745565]),\n", + " 'split2_test_score': array([0.82526209, 0.7519849 , 0.59510748, 0.68219353, 0.5764876 ,\n", + " 0.6305322 , 0.54587929, 0.56161317, 0.52387661, 0.56501726,\n", + " 0.82242212, 0.71247211, 0.60517653, 0.7037331 , 0.54624902,\n", + " 0.68817364, 0.5608424 , 0.65146866, 0.53526413, 0.66897713,\n", + " 0.81950927, 0.79423613, 0.52825543, 0.66907312, 0.56248767,\n", + " 0.60389162, 0.49652131, 0.55180864, 0.50857805, 0.52741727,\n", + " 0.80244926, 0.69231892, 0.53887943, 0.71969825, 0.53158497,\n", + " 0.6792325 , 0.49764733, 0.65713584, 0.49078794, 0.66897713,\n", + " 0.80172943, 0.75452358, 0.54031316, 0.63620099, 0.52529445,\n", + " 0.59626871, 0.51327659, 0.53187959, 0.50624241, 0.51743203,\n", + " 0.76711749, 0.70123506, 0.49774614, 0.71460958, 0.50316727,\n", + " 0.65013293, 0.49732848, 0.64249816, 0.52623393, 0.65444272,\n", + " 0.76791783, 0.73768667, 0.53419234, 0.63714302, 0.50165182,\n", + " 0.5932877 , 0.49723701, 0.51859739, 0.49993588, 0.52478814,\n", + " 0.76415908, 0.69550001, 0.48572333, 0.68222059, 0.50972411,\n", + " 0.66549536, 0.47022833, 0.64386329, 0.50648577, 0.6415092 ]),\n", + " 'mean_test_score': array([0.79668639, 0.74249284, 0.56236487, 0.65750691, 0.55765042,\n", + " 0.5951082 , 0.52240411, 0.5287149 , 0.52406106, 0.52813255,\n", + " 0.78706306, 0.71152173, 0.58293684, 0.69827554, 0.56473654,\n", + " 0.67073201, 0.53993428, 0.64542667, 0.5195776 , 0.65304713,\n", + " 0.8093364 , 0.76115053, 0.50345273, 0.61666704, 0.55268513,\n", + " 0.55853513, 0.52487246, 0.52152428, 0.49274373, 0.49798504,\n", + " 0.78909415, 0.71750551, 0.53882222, 0.68157662, 0.53873303,\n", + " 0.65124959, 0.51812351, 0.627968 , 0.49707742, 0.61496068,\n", + " 0.77802667, 0.74997036, 0.49206988, 0.59534796, 0.53724879,\n", + " 0.55819313, 0.51773912, 0.51441676, 0.4925026 , 0.50950083,\n", + " 0.7680412 , 0.72080995, 0.49195119, 0.66581205, 0.53387982,\n", + " 0.62886332, 0.52190967, 0.59893252, 0.51830368, 0.59435805,\n", + " 0.7525763 , 0.74030231, 0.48839468, 0.58402524, 0.50909438,\n", + " 0.57778652, 0.52414169, 0.51877087, 0.48644504, 0.53334518,\n", + " 0.75330366, 0.71495841, 0.4664454 , 0.6436764 , 0.52428206,\n", + " 0.62907309, 0.52216671, 0.61209308, 0.51745478, 0.59547558]),\n", + " 'std_test_score': array([0.02330541, 0.00983268, 0.02527339, 0.01854849, 0.02123109,\n", + " 0.02695613, 0.01671706, 0.02588415, 0.02340142, 0.03539763,\n", + " 0.0310761 , 0.0268897 , 0.03166583, 0.01915399, 0.01964985,\n", + " 0.04035167, 0.02415844, 0.04829527, 0.02608159, 0.06516717,\n", + " 0.01475571, 0.02822417, 0.01947692, 0.06762091, 0.02590243,\n", + " 0.05080407, 0.02762023, 0.0410645 , 0.04448367, 0.05552553,\n", + " 0.01490843, 0.01909001, 0.02726102, 0.04291775, 0.00540887,\n", + " 0.05559311, 0.0158291 , 0.06018046, 0.05030954, 0.09885959,\n", + " 0.02058686, 0.01589883, 0.03505766, 0.07507599, 0.01525564,\n", + " 0.04512812, 0.03230456, 0.04028527, 0.04389321, 0.0309063 ,\n", + " 0.0077786 , 0.01778837, 0.01729171, 0.05728616, 0.02621202,\n", + " 0.05622103, 0.0175056 , 0.07344521, 0.04289759, 0.09274581,\n", + " 0.0124946 , 0.01605805, 0.04474395, 0.06161139, 0.0265843 ,\n", + " 0.02035581, 0.02484303, 0.02923717, 0.02866389, 0.00877417,\n", + " 0.01721168, 0.01535864, 0.03217646, 0.064189 , 0.03507444,\n", + " 0.05470356, 0.03692975, 0.06041232, 0.03882443, 0.07665416]),\n", + " 'rank_test_score': array([ 2, 11, 41, 21, 44, 35, 58, 52, 57, 53, 4, 16, 38, 17, 40, 19, 46,\n", + " 24, 62, 22, 1, 7, 71, 29, 45, 42, 54, 61, 74, 72, 3, 14, 47, 18,\n", + " 48, 23, 65, 28, 73, 30, 5, 10, 76, 34, 49, 43, 66, 68, 75, 69, 6,\n", + " 13, 77, 20, 50, 27, 60, 32, 64, 36, 9, 12, 78, 37, 70, 39, 56, 63,\n", + " 79, 51, 8, 15, 80, 25, 55, 26, 59, 31, 67, 33], dtype=int32)}" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gs_clf.cv_results_" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "np.float64(0.8093364049132378)" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gs_clf.best_score_" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Word embedding" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['fasttext-wiki-news-subwords-300', 'conceptnet-numberbatch-17-06-300', 'word2vec-ruscorpora-300', 'word2vec-google-news-300', 'glove-wiki-gigaword-50', 'glove-wiki-gigaword-100', 'glove-wiki-gigaword-200', 'glove-wiki-gigaword-300', 'glove-twitter-25', 'glove-twitter-50', 'glove-twitter-100', 'glove-twitter-200', '__testing_word2vec-matrix-synopsis']\n" + ] + } + ], + "source": [ + "# Импортируем библиотеку gensim и посмотрим какие обученные модели в ней есть.\n", + "import gensim.downloader\n", + "print(list(gensim.downloader.info()['models'].keys()))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## GloVe\n", + "Загрузим \"glove-twitter-25\" и будем работать с ней" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[==================================================] 100.0% 104.8/104.8MB downloaded\n" + ] + } + ], + "source": [ + "glove_model = gensim.downloader.load(\"glove-twitter-25\") # load glove vectors\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[-0.96419 -0.60978 0.67449 0.35113 0.41317 -0.21241 1.3796\n", + " 0.12854 0.31567 0.66325 0.3391 -0.18934 -3.325 -1.1491\n", + " -0.4129 0.2195 0.8706 -0.50616 -0.12781 -0.066965 0.065761\n", + " 0.43927 0.1758 -0.56058 0.13529 ]\n" + ] + }, + { + "data": { + "text/plain": [ + "[('dog', 0.9590820074081421),\n", + " ('monkey', 0.920357882976532),\n", + " ('bear', 0.9143136739730835),\n", + " ('pet', 0.9108031392097473),\n", + " ('girl', 0.8880629539489746),\n", + " ('horse', 0.8872726559638977),\n", + " ('kitty', 0.8870542049407959),\n", + " ('puppy', 0.886769711971283),\n", + " ('hot', 0.886525571346283),\n", + " ('lady', 0.8845519423484802)]" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Можем посмотреть, что за векторы у слова \"cat\"\n", + "print(glove_model['cat']) # word embedding for 'cat'\n", + "glove_model.most_similar(\"cat\") # show words that similar to word 'cat'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Векторизуем обучающую выборку\n", + "Получаем матрицу \"Документ-термин\". В столбцах - все слова выборки, в строках - каждый документ. Значения в ячейках - количество встречаний слова в документе. \n", + "В отображенной части видим только нули - наглядная демонстрация проблем разреженности матрицы" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(1064, 16170)\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
00000000005102000000100255pixel0007000usd001200201pixel0019600280038...zorgzornzsoftzueszugzurzurichzuszvizyxel
00000000000...0000000000
10000000000...0000000000
20000000000...0000000000
30000000000...0000000000
40000000000...0000000000
\n", + "

5 rows × 16170 columns

\n", + "
" + ], + "text/plain": [ + " 00 000 000005102000 000100255pixel 0007 000usd 001200201pixel 00196 \\\n", + "0 0 0 0 0 0 0 0 0 \n", + "1 0 0 0 0 0 0 0 0 \n", + "2 0 0 0 0 0 0 0 0 \n", + "3 0 0 0 0 0 0 0 0 \n", + "4 0 0 0 0 0 0 0 0 \n", + "\n", + " 0028 0038 ... zorg zorn zsoft zues zug zur zurich zus zvi zyxel \n", + "0 0 0 ... 0 0 0 0 0 0 0 0 0 0 \n", + "1 0 0 ... 0 0 0 0 0 0 0 0 0 0 \n", + "2 0 0 ... 0 0 0 0 0 0 0 0 0 0 \n", + "3 0 0 ... 0 0 0 0 0 0 0 0 0 0 \n", + "4 0 0 ... 0 0 0 0 0 0 0 0 0 0 \n", + "\n", + "[5 rows x 16170 columns]" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "vectorizer = CountVectorizer(stop_words='english')\n", + "train_data = vectorizer.fit_transform(twenty_train['data'])\n", + "CV_data=pd.DataFrame(train_data.toarray(), columns=vectorizer.get_feature_names_out())\n", + "print(CV_data.shape)\n", + "CV_data.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Index(['00', '000', '000005102000', '000100255pixel', '0007', '000usd',\n", + " '001200201pixel', '00196', '0028', '0038'],\n", + " dtype='str')\n" + ] + } + ], + "source": [ + "# Создадим список уникальных слов, присутствующих в словаре.\n", + "words_vocab=CV_data.columns\n", + "print(words_vocab[0:10])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Векторизуем с помощью GloVe\n", + "\n", + "Нужно для каждого документа сложить glove-вектора слов, из которых он состоит.\n", + "В результате получим вектор документа как сумму векторов слов, из него состоящих" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Посмотрим на примере как будет работать векторизация" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
00000000005102000000100255pixel0007000usd001200201pixel0019600280038...zorgzornzsoftzueszugzurzurichzuszvizyxel
00000000000...0000000000
11100000000...0000000001
\n", + "

2 rows × 16170 columns

\n", + "
" + ], + "text/plain": [ + " 00 000 000005102000 000100255pixel 0007 000usd 001200201pixel 00196 \\\n", + "0 0 0 0 0 0 0 0 0 \n", + "1 1 1 0 0 0 0 0 0 \n", + "\n", + " 0028 0038 ... zorg zorn zsoft zues zug zur zurich zus zvi zyxel \n", + "0 0 0 ... 0 0 0 0 0 0 0 0 0 0 \n", + "1 0 0 ... 0 0 0 0 0 0 0 0 0 1 \n", + "\n", + "[2 rows x 16170 columns]" + ] + }, + "execution_count": 42, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Пусть выборка состоит из двух документов: \n", + "text_data = ['Hello world I love python', 'This is a great computer game! 00 000 zyxel']\n", + "# Векторизуем с помощью обученного CountVectorizer\n", + "X = vectorizer.transform(text_data)\n", + "CV_text_data=pd.DataFrame(X.toarray(), columns=vectorizer.get_feature_names_out())\n", + "CV_text_data\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "hello : [-0.77069 0.12827 0.33137 0.0050893 -0.47605 -0.50116\n", + " 1.858 1.0624 -0.56511 0.13328 -0.41918 -0.14195\n", + " -2.8555 -0.57131 -0.13418 -0.44922 0.48591 -0.6479\n", + " -0.84238 0.61669 -0.19824 -0.57967 -0.65885 0.43928\n", + " -0.50473 ]\n", + "love : [-0.62645 -0.082389 0.070538 0.5782 -0.87199 -0.14816 2.2315\n", + " 0.98573 -1.3154 -0.34921 -0.8847 0.14585 -4.97 -0.73369\n", + " -0.94359 0.035859 -0.026733 -0.77538 -0.30014 0.48853 -0.16678\n", + " -0.016651 -0.53164 0.64236 -0.10922 ]\n", + "python : [-0.25645 -0.22323 0.025901 0.22901 0.49028 -0.060829 0.24563\n", + " -0.84854 1.5882 -0.7274 0.60603 0.25205 -1.8064 -0.95526\n", + " 0.44867 0.013614 0.60856 0.65423 0.82506 0.99459 -0.29403\n", + " -0.27013 -0.348 -0.7293 0.2201 ]\n", + "world : [ 0.10301 0.095666 -0.14789 -0.22383 -0.14775 -0.11599 1.8513\n", + " 0.24886 -0.41877 -0.20384 -0.08509 0.33246 -4.6946 0.84096\n", + " -0.46666 -0.031128 -0.19539 -0.037349 0.58949 0.13941 -0.57667\n", + " -0.44426 -0.43085 -0.52875 0.25855 ]\n", + "Hello world I love python : [ -1.55058002 -0.081683 0.27991899 0.58846928 -1.00551002\n", + " -0.82613902 6.18642995 1.44844997 -0.71108004 -1.14717001\n", + " -0.78294002 0.58841 -14.32649982 -1.41929996 -1.09575997\n", + " -0.430875 0.87234702 -0.806399 0.27203003 2.23921998\n", + " -1.23571999 -1.31071102 -1.96934 -0.17641005 -0.1353 ]\n", + "computer : [ 0.64005 -0.019514 0.70148 -0.66123 1.1723 -0.58859 0.25917\n", + " -0.81541 1.1708 1.1413 -0.15405 -0.11369 -3.8414 -0.87233\n", + " 0.47489 1.1541 0.97678 1.1107 -0.14572 -0.52013 -0.52234\n", + " -0.92349 0.34651 0.061939 -0.57375 ]\n", + "game : [ 1.146 0.3291 0.26878 -1.3945 -0.30044 0.77901 1.3537\n", + " 0.37393 0.50478 -0.44266 -0.048706 0.51396 -4.3136 0.39805\n", + " 1.197 0.10287 -0.17618 -1.2881 -0.59801 0.26131 -1.2619\n", + " 0.39202 0.59309 -0.55232 0.005087]\n", + "great : [-8.4229e-01 3.6512e-01 -3.8841e-01 -4.6118e-01 2.4301e-01 3.2412e-01\n", + " 1.9009e+00 -2.2630e-01 -3.1335e-01 -1.0970e+00 -4.1494e-03 6.2074e-01\n", + " -5.0964e+00 6.7418e-01 5.0080e-01 -6.2119e-01 5.1765e-01 -4.4122e-01\n", + " -1.4364e-01 1.9130e-01 -7.4608e-01 -2.5903e-01 -7.8010e-01 1.1030e-01\n", + " -2.7928e-01]\n", + "zyxel : [ 0.79234 0.067376 -0.22639 -2.2272 0.30057 -0.85676 -1.7268\n", + " -0.78626 1.2042 -0.92348 -0.83987 -0.74233 0.29689 -1.208\n", + " 0.98706 -1.1624 0.61415 -0.27825 0.27813 1.5838 -0.63593\n", + " -0.10225 1.7102 -0.95599 -1.3867 ]\n", + "This is a great computer game! 00 000 zyxel : [ 1.73610002 0.74208201 0.35545996 -4.74411008 1.41543998\n", + " -0.34222007 1.78697008 -1.45404002 2.56643 -1.32184002\n", + " -1.04677537 0.27867999 -12.95450976 -1.00809997 3.15975004\n", + " -0.52662008 1.93239999 -0.89686999 -0.60924001 1.51628\n", + " -3.16624993 -0.89275002 1.86969995 -1.33607102 -2.23464306]\n" + ] + } + ], + "source": [ + "# Создадим датафрейм, в который будем сохранять вектор документа\n", + "glove_data=pd.DataFrame()\n", + "\n", + "# Пробегаем по каждой строке датафрейма (по каждому документу)\n", + "for i in range(CV_text_data.shape[0]):\n", + " \n", + " # Вектор одного документа с размерностью glove-модели:\n", + " one_doc = np.zeros(25) \n", + " \n", + " # Пробегаемся по каждому документу, смотрим, какие слова документа присутствуют в нашем словаре\n", + " # Суммируем glove-вектора каждого известного слова в one_doc\n", + " for word in words_vocab[CV_text_data.iloc[i,:] >= 1]:\n", + " if word in glove_model.key_to_index.keys(): \n", + " print(word, ': ', glove_model[word])\n", + " one_doc += glove_model[word]\n", + " print(text_data[i], ': ', one_doc)\n", + " glove_data= pd.concat([glove_data, pd.DataFrame([one_doc])], ignore_index=True) \n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "На основе ячейки выше, напишем функцию, которая для каждого слова в тексте ищет это слово в glove_model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def text2vec(text_data):\n", + " \n", + " # Векторизуем с помощью обученного CountVectorizer\n", + " X = vectorizer.transform(text_data)\n", + " CV_text_data=pd.DataFrame(X.toarray(), columns=vectorizer.get_feature_names_out())\n", + " CV_text_data\n", + " # Создадим датафрейм, в который будем сохранять вектор документа\n", + " glove_data=pd.DataFrame()\n", + "\n", + " # Пробегаем по каждой строке (по каждому документу)\n", + " for i in range(CV_text_data.shape[0]):\n", + "\n", + " # Вектор одного документа с размерностью glove-модели:\n", + " one_doc = np.zeros(25) \n", + "\n", + " # Пробегаемся по каждому документу, смотрим, какие слова документа присутствуют в нашем словаре\n", + " # Суммируем glove-вектора каждого известного слова в one_doc\n", + " for word in words_vocab[CV_text_data.iloc[i,:] >= 1]:\n", + " if word in glove_model.key_to_index.keys(): \n", + " #print(word, ': ', glove_model[word])\n", + " one_doc += glove_model[word]\n", + " #print(text_data[i], ': ', one_doc)\n", + " glove_data = pd.concat([glove_data, pd.DataFrame([one_doc])], axis = 0)\n", + " #print('glove_data: ', glove_data)\n", + " return glove_data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
0123456789...15161718192021222324
0-1.55058-0.0816830.2799190.588469-1.00551-0.8261396.186431.44845-0.71108-1.14717...-0.4308750.872347-0.8063990.272032.23922-1.23572-1.310711-1.96934-0.176410-0.135300
11.736100.7420820.355460-4.7441101.41544-0.3422201.78697-1.454042.56643-1.32184...-0.5266201.932400-0.896870-0.609241.51628-3.16625-0.8927501.86970-1.336071-2.234643
\n", + "

2 rows × 25 columns

\n", + "
" + ], + "text/plain": [ + " 0 1 2 3 4 5 6 7 \\\n", + "0 -1.55058 -0.081683 0.279919 0.588469 -1.00551 -0.826139 6.18643 1.44845 \n", + "1 1.73610 0.742082 0.355460 -4.744110 1.41544 -0.342220 1.78697 -1.45404 \n", + "\n", + " 8 9 ... 15 16 17 18 19 \\\n", + "0 -0.71108 -1.14717 ... -0.430875 0.872347 -0.806399 0.27203 2.23922 \n", + "1 2.56643 -1.32184 ... -0.526620 1.932400 -0.896870 -0.60924 1.51628 \n", + "\n", + " 20 21 22 23 24 \n", + "0 -1.23572 -1.310711 -1.96934 -0.176410 -0.135300 \n", + "1 -3.16625 -0.892750 1.86970 -1.336071 -2.234643 \n", + "\n", + "[2 rows x 25 columns]" + ] + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Наша выборка, векторизованная через glove выглядит так. \n", + "# Попытаться понять, почему такая размерность матрицы и что за значения в ячейках\n", + "glove_data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Можем использовать написанную функцию на нашей реальной выборке\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
0123456789...15161718192021222324
00.4992583.090499-17.267230-0.997832-2.7465231.58607225.890665-27.914878-1.656527-6.599336...6.8400565.4752144.48820817.680875-8.2541981.7723136.184074-8.879014-0.216338-13.408250
02.579560-1.4981301.7523500.4079071.5289700.1523900.2836401.937990-0.659960-2.136370...0.434040-0.2744300.445033-0.6457501.055500-0.850250-0.4332000.212190-1.1643300.905880
0-3.8389983.3394243.943270-1.4281084.1711741.9184819.177921-3.2410921.280560-1.781829...-3.7194439.627485-1.198027-4.0109741.4643050.398107-1.013318-1.5918934.290867-3.888199
0-1.18281251.118736-28.117878-28.58650024.244726-9.567325111.896700-114.38735512.647457-24.559720...24.08517239.112016-23.01194441.767239-42.362471-16.060568-20.090755-4.457053-3.181086-90.876348
00.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000
..................................................................
0-5.83920912.729176-6.201106-11.2552872.878065-11.99390811.418405-15.47951716.634253-10.534109...11.66580913.761738-2.579413-7.8772403.299712-9.241404-16.1028345.343221-3.646908-17.216441
01.2864244.341360-29.1102725.0576445.610822-10.95486529.974885-38.26112910.369041-15.652271...12.2018555.91388315.94180126.251438-9.280780-6.6560356.158630-14.878060-1.138172-25.532792
06.6982758.135166-10.3473490.1655462.089456-7.69573431.593479-38.582911-3.285587-6.087527...16.57544715.6966280.78560213.384657-13.3292731.983042-4.041431-5.5714110.581912-32.037549
0-0.4613644.499502-10.4469041.502338-10.415269-8.80125218.272721-10.398248-6.686713-0.250122...5.7584984.540718-0.37242210.668546-6.3779959.30577411.568183-18.814704-1.204665-14.551387
08.34887510.837893-1.528737-18.24693333.950209-0.37625448.017779-57.63616025.231878-13.371779...29.29760934.342018-8.52763513.555851-12.649073-28.706501-27.0426544.2923382.614328-48.834758
\n", + "

1064 rows × 25 columns

\n", + "
" + ], + "text/plain": [ + " 0 1 2 3 4 5 \\\n", + "0 0.499258 3.090499 -17.267230 -0.997832 -2.746523 1.586072 \n", + "0 2.579560 -1.498130 1.752350 0.407907 1.528970 0.152390 \n", + "0 -3.838998 3.339424 3.943270 -1.428108 4.171174 1.918481 \n", + "0 -1.182812 51.118736 -28.117878 -28.586500 24.244726 -9.567325 \n", + "0 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + ".. ... ... ... ... ... ... \n", + "0 -5.839209 12.729176 -6.201106 -11.255287 2.878065 -11.993908 \n", + "0 1.286424 4.341360 -29.110272 5.057644 5.610822 -10.954865 \n", + "0 6.698275 8.135166 -10.347349 0.165546 2.089456 -7.695734 \n", + "0 -0.461364 4.499502 -10.446904 1.502338 -10.415269 -8.801252 \n", + "0 8.348875 10.837893 -1.528737 -18.246933 33.950209 -0.376254 \n", + "\n", + " 6 7 8 9 ... 15 16 \\\n", + "0 25.890665 -27.914878 -1.656527 -6.599336 ... 6.840056 5.475214 \n", + "0 0.283640 1.937990 -0.659960 -2.136370 ... 0.434040 -0.274430 \n", + "0 9.177921 -3.241092 1.280560 -1.781829 ... -3.719443 9.627485 \n", + "0 111.896700 -114.387355 12.647457 -24.559720 ... 24.085172 39.112016 \n", + "0 0.000000 0.000000 0.000000 0.000000 ... 0.000000 0.000000 \n", + ".. ... ... ... ... ... ... ... \n", + "0 11.418405 -15.479517 16.634253 -10.534109 ... 11.665809 13.761738 \n", + "0 29.974885 -38.261129 10.369041 -15.652271 ... 12.201855 5.913883 \n", + "0 31.593479 -38.582911 -3.285587 -6.087527 ... 16.575447 15.696628 \n", + "0 18.272721 -10.398248 -6.686713 -0.250122 ... 5.758498 4.540718 \n", + "0 48.017779 -57.636160 25.231878 -13.371779 ... 29.297609 34.342018 \n", + "\n", + " 17 18 19 20 21 22 \\\n", + "0 4.488208 17.680875 -8.254198 1.772313 6.184074 -8.879014 \n", + "0 0.445033 -0.645750 1.055500 -0.850250 -0.433200 0.212190 \n", + "0 -1.198027 -4.010974 1.464305 0.398107 -1.013318 -1.591893 \n", + "0 -23.011944 41.767239 -42.362471 -16.060568 -20.090755 -4.457053 \n", + "0 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + ".. ... ... ... ... ... ... \n", + "0 -2.579413 -7.877240 3.299712 -9.241404 -16.102834 5.343221 \n", + "0 15.941801 26.251438 -9.280780 -6.656035 6.158630 -14.878060 \n", + "0 0.785602 13.384657 -13.329273 1.983042 -4.041431 -5.571411 \n", + "0 -0.372422 10.668546 -6.377995 9.305774 11.568183 -18.814704 \n", + "0 -8.527635 13.555851 -12.649073 -28.706501 -27.042654 4.292338 \n", + "\n", + " 23 24 \n", + "0 -0.216338 -13.408250 \n", + "0 -1.164330 0.905880 \n", + "0 4.290867 -3.888199 \n", + "0 -3.181086 -90.876348 \n", + "0 0.000000 0.000000 \n", + ".. ... ... \n", + "0 -3.646908 -17.216441 \n", + "0 -1.138172 -25.532792 \n", + "0 0.581912 -32.037549 \n", + "0 -1.204665 -14.551387 \n", + "0 2.614328 -48.834758 \n", + "\n", + "[1064 rows x 25 columns]" + ] + }, + "execution_count": 50, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_data_glove = text2vec(twenty_train['data'])\n", + "train_data_glove" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Мы получили матрицу, где каждый документ представлен вектором-строкой. Можем подавать эти вектора на вход для обучения классифкатора" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
KNeighborsClassifier()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" + ], + "text/plain": [ + "KNeighborsClassifier()" + ] + }, + "execution_count": 51, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "clf = KNeighborsClassifier(n_neighbors = 5)\n", + "clf.fit(train_data_glove, twenty_train['target'])\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Аналогичную операци проводим с тестовой частью выборки и оцениваем качество модели" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
0123456789...15161718192021222324
02.11816710.172361-12.929062-10.235129-5.895256-2.96058513.365803-26.7792600.811607-3.139421...4.8932074.6521842.90399013.077540-12.4387262.602789-7.450777-5.202635-20.983743-14.107781
00.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000
04.87332027.474242-10.7738768.1128890.623212-2.19947559.094818-59.831612-2.963759-18.048856...23.89045312.469234-3.08745712.981009-29.85896713.10575610.670630-17.378340-6.079635-41.408141
0-3.5430911.986200-2.1912320.006059-2.3512080.3543235.127530-4.2327050.855633-1.511475...-1.0342412.211705-3.2551033.1452050.115450-0.6961870.6476622.1885601.237610-3.731906
0-0.810280-0.0419860.3090500.268330-0.735400-0.585930-0.2073100.084744-0.024049-1.729000...-0.981390-0.348830-0.012710-0.3991600.2379300.6089600.257370-0.611980-0.569530-0.627590
..................................................................
0-2.71357017.1576893.372045-5.867599-6.276993-12.74500525.654686-28.6366490.390143-9.184376...1.31323111.260476-3.97492213.881956-5.0408064.2696409.495090-6.392856-17.443242-8.663642
00.3309704.718122-1.751891-2.035309-0.343631-4.5700006.742370-4.367192-2.187082-1.016773...-0.0327563.431770-4.601260-0.623544-1.379208-2.430931-7.340540-4.9094101.812950-7.518150
02.3749835.077650-3.002268-3.9336302.078387-5.4703357.654599-12.2126197.815100-2.849330...-0.2346395.634821-3.0272211.6613570.393759-6.279475-4.927666-2.5929871.981002-12.255251
01.4101590.770070-3.955777-2.7911905.2107001.6219527.780780-9.8597188.394491-5.347944...1.8969474.079237-1.2073505.7723891.186500-7.286901-2.485993-0.9471970.788222-9.518580
05.7103716.6381140.178107-6.4063379.719197-8.7792008.866295-11.4372204.871641-11.415244...0.8510294.182781-3.3025571.213709-4.294990-7.751608-12.556040-10.3394371.448986-10.394629
\n", + "

708 rows × 25 columns

\n", + "
" + ], + "text/plain": [ + " 0 1 2 3 4 5 6 \\\n", + "0 2.118167 10.172361 -12.929062 -10.235129 -5.895256 -2.960585 13.365803 \n", + "0 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "0 4.873320 27.474242 -10.773876 8.112889 0.623212 -2.199475 59.094818 \n", + "0 -3.543091 1.986200 -2.191232 0.006059 -2.351208 0.354323 5.127530 \n", + "0 -0.810280 -0.041986 0.309050 0.268330 -0.735400 -0.585930 -0.207310 \n", + ".. ... ... ... ... ... ... ... \n", + "0 -2.713570 17.157689 3.372045 -5.867599 -6.276993 -12.745005 25.654686 \n", + "0 0.330970 4.718122 -1.751891 -2.035309 -0.343631 -4.570000 6.742370 \n", + "0 2.374983 5.077650 -3.002268 -3.933630 2.078387 -5.470335 7.654599 \n", + "0 1.410159 0.770070 -3.955777 -2.791190 5.210700 1.621952 7.780780 \n", + "0 5.710371 6.638114 0.178107 -6.406337 9.719197 -8.779200 8.866295 \n", + "\n", + " 7 8 9 ... 15 16 17 \\\n", + "0 -26.779260 0.811607 -3.139421 ... 4.893207 4.652184 2.903990 \n", + "0 0.000000 0.000000 0.000000 ... 0.000000 0.000000 0.000000 \n", + "0 -59.831612 -2.963759 -18.048856 ... 23.890453 12.469234 -3.087457 \n", + "0 -4.232705 0.855633 -1.511475 ... -1.034241 2.211705 -3.255103 \n", + "0 0.084744 -0.024049 -1.729000 ... -0.981390 -0.348830 -0.012710 \n", + ".. ... ... ... ... ... ... ... \n", + "0 -28.636649 0.390143 -9.184376 ... 1.313231 11.260476 -3.974922 \n", + "0 -4.367192 -2.187082 -1.016773 ... -0.032756 3.431770 -4.601260 \n", + "0 -12.212619 7.815100 -2.849330 ... -0.234639 5.634821 -3.027221 \n", + "0 -9.859718 8.394491 -5.347944 ... 1.896947 4.079237 -1.207350 \n", + "0 -11.437220 4.871641 -11.415244 ... 0.851029 4.182781 -3.302557 \n", + "\n", + " 18 19 20 21 22 23 \\\n", + "0 13.077540 -12.438726 2.602789 -7.450777 -5.202635 -20.983743 \n", + "0 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "0 12.981009 -29.858967 13.105756 10.670630 -17.378340 -6.079635 \n", + "0 3.145205 0.115450 -0.696187 0.647662 2.188560 1.237610 \n", + "0 -0.399160 0.237930 0.608960 0.257370 -0.611980 -0.569530 \n", + ".. ... ... ... ... ... ... \n", + "0 13.881956 -5.040806 4.269640 9.495090 -6.392856 -17.443242 \n", + "0 -0.623544 -1.379208 -2.430931 -7.340540 -4.909410 1.812950 \n", + "0 1.661357 0.393759 -6.279475 -4.927666 -2.592987 1.981002 \n", + "0 5.772389 1.186500 -7.286901 -2.485993 -0.947197 0.788222 \n", + "0 1.213709 -4.294990 -7.751608 -12.556040 -10.339437 1.448986 \n", + "\n", + " 24 \n", + "0 -14.107781 \n", + "0 0.000000 \n", + "0 -41.408141 \n", + "0 -3.731906 \n", + "0 -0.627590 \n", + ".. ... \n", + "0 -8.663642 \n", + "0 -7.518150 \n", + "0 -12.255251 \n", + "0 -9.518580 \n", + "0 -10.394629 \n", + "\n", + "[708 rows x 25 columns]" + ] + }, + "execution_count": 53, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_data_glove = text2vec(twenty_test['data'])\n", + "test_data_glove" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "predict = clf.predict(test_data_glove )\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[258 61]\n", + " [ 31 358]]\n", + " precision recall f1-score support\n", + "\n", + " 0 0.89 0.81 0.85 319\n", + " 1 0.85 0.92 0.89 389\n", + "\n", + " accuracy 0.87 708\n", + " macro avg 0.87 0.86 0.87 708\n", + "weighted avg 0.87 0.87 0.87 708\n", + "\n" + ] + } + ], + "source": [ + "print (confusion_matrix(twenty_test['target'], predict))\n", + "print(classification_report(twenty_test['target'], predict))" + ] + }, { "cell_type": "code", "execution_count": null, @@ -410,7 +4822,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": ".venv", "language": "python", "name": "python3" }, @@ -424,7 +4836,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.13" + "version": "3.12.3" } }, "nbformat": 4,