{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "3dda6a69", "metadata": {}, "outputs": [], "source": [ "from sklearn.neighbors import KNeighborsClassifier\n", "from sklearn.datasets import fetch_20newsgroups\n", "from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer \n", "import numpy as np\n", "import pandas as pd\n", "from sklearn.metrics import confusion_matrix, classification_report, accuracy_score, roc_auc_score\n", "from sklearn.pipeline import Pipeline" ] }, { "cell_type": "code", "execution_count": 2, "id": "7fd6636b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['fasttext-wiki-news-subwords-300', 'conceptnet-numberbatch-17-06-300', 'word2vec-ruscorpora-300', 'word2vec-google-news-300', 'glove-wiki-gigaword-50', 'glove-wiki-gigaword-100', 'glove-wiki-gigaword-200', 'glove-wiki-gigaword-300', 'glove-twitter-25', 'glove-twitter-50', 'glove-twitter-100', 'glove-twitter-200', '__testing_word2vec-matrix-synopsis']\n" ] } ], "source": [ "import gensim.downloader\n", "print(list(gensim.downloader.info()['models'].keys()))" ] }, { "cell_type": "markdown", "id": "3f93b5f6", "metadata": {}, "source": [ "# GloVe" ] }, { "cell_type": "code", "execution_count": 3, "id": "be870586", "metadata": {}, "outputs": [], "source": [ "glove_model = gensim.downloader.load(\"glove-twitter-25\") # load glove vectors\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "599d6406", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[-0.96419 -0.60978 0.67449 0.35113 0.41317 -0.21241 1.3796\n", " 0.12854 0.31567 0.66325 0.3391 -0.18934 -3.325 -1.1491\n", " -0.4129 0.2195 0.8706 -0.50616 -0.12781 -0.066965 0.065761\n", " 0.43927 0.1758 -0.56058 0.13529 ]\n" ] }, { "data": { "text/plain": [ "[('dog', 0.9590820074081421),\n", " ('monkey', 0.920357882976532),\n", " ('bear', 0.9143136739730835),\n", " ('pet', 0.9108031392097473),\n", " ('girl', 0.8880629539489746),\n", " ('horse', 0.8872726559638977),\n", " ('kitty', 0.8870542049407959),\n", " ('puppy', 0.886769711971283),\n", " ('hot', 0.886525571346283),\n", " ('lady', 0.8845519423484802)]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(glove_model['cat']) # word embedding for 'cat'\n", "glove_model.most_similar(\"cat\") # show words that similar to word 'cat'" ] }, { "cell_type": "code", "execution_count": 5, "id": "2db71cfb", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.60927683" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "glove_model.similarity('cat', 'bus')" ] }, { "cell_type": "code", "execution_count": 6, "id": "7788acf5", "metadata": {}, "outputs": [], "source": [ "categories = ['alt.atheism', 'comp.graphics', 'sci.space'] \n", "remove = ('headers', 'footers', 'quotes')\n", "twenty_train = fetch_20newsgroups(subset='train', shuffle=True, random_state=42, categories = categories, remove = remove )\n", "twenty_test = fetch_20newsgroups(subset='test', shuffle=True, random_state=42, categories = categories, remove = remove )\n" ] }, { "cell_type": "markdown", "id": "79dd1ac1", "metadata": {}, "source": [ "# Векторизуем обучающую выборку\n", "Получаем матрицу \"Документ-термин\"" ] }, { "cell_type": "code", "execution_count": 7, "id": "0565dd1a", "metadata": {}, "outputs": [], "source": [ "vectorizer = CountVectorizer(stop_words='english')" ] }, { "cell_type": "code", "execution_count": 8, "id": "a681a1d6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(1657, 23297)\n" ] }, { "data": { "text/html": [ "
\n", " | 00 | \n", "000 | \n", "0000 | \n", "00000 | \n", "000000 | \n", "000005102000 | \n", "000062david42 | \n", "000100255pixel | \n", "00041032 | \n", "0004136 | \n", "... | \n", "zurbrin | \n", "zurich | \n", "zus | \n", "zvi | \n", "zwaartepunten | \n", "zwak | \n", "zwakke | \n", "zware | \n", "zwarte | \n", "zyxel | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
1 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
2 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
3 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
4 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
5 rows × 23297 columns
\n", "\n", " | 00 | \n", "000 | \n", "0000 | \n", "00000 | \n", "000000 | \n", "000005102000 | \n", "000062david42 | \n", "000100255pixel | \n", "00041032 | \n", "0004136 | \n", "... | \n", "zurbrin | \n", "zurich | \n", "zus | \n", "zvi | \n", "zwaartepunten | \n", "zwak | \n", "zwakke | \n", "zware | \n", "zwarte | \n", "zyxel | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
1 | \n", "1 | \n", "1 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "1 | \n", "
2 rows × 23297 columns
\n", "\n", " | 0 | \n", "1 | \n", "2 | \n", "3 | \n", "4 | \n", "5 | \n", "6 | \n", "7 | \n", "8 | \n", "9 | \n", "... | \n", "15 | \n", "16 | \n", "17 | \n", "18 | \n", "19 | \n", "20 | \n", "21 | \n", "22 | \n", "23 | \n", "24 | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "-1.55058 | \n", "-0.081683 | \n", "0.279919 | \n", "0.588469 | \n", "-1.00551 | \n", "-0.826139 | \n", "6.18643 | \n", "1.44845 | \n", "-0.71108 | \n", "-1.14717 | \n", "... | \n", "-0.430875 | \n", "0.872347 | \n", "-0.806399 | \n", "0.27203 | \n", "2.23922 | \n", "-1.23572 | \n", "-1.310711 | \n", "-1.96934 | \n", "-0.176410 | \n", "-0.135300 | \n", "
0 | \n", "1.73610 | \n", "0.742082 | \n", "0.355460 | \n", "-4.744110 | \n", "1.41544 | \n", "-0.342220 | \n", "1.78697 | \n", "-1.45404 | \n", "2.56643 | \n", "-1.32184 | \n", "... | \n", "-0.526620 | \n", "1.932400 | \n", "-0.896870 | \n", "-0.60924 | \n", "1.51628 | \n", "-3.16625 | \n", "-0.892750 | \n", "1.86970 | \n", "-1.336071 | \n", "-2.234643 | \n", "
2 rows × 25 columns
\n", "\n", " | 0 | \n", "1 | \n", "2 | \n", "3 | \n", "4 | \n", "5 | \n", "6 | \n", "7 | \n", "8 | \n", "9 | \n", "... | \n", "15 | \n", "16 | \n", "17 | \n", "18 | \n", "19 | \n", "20 | \n", "21 | \n", "22 | \n", "23 | \n", "24 | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "-8.521142 | \n", "2.020376 | \n", "-10.802921 | \n", "3.167636 | \n", "0.252469 | \n", "15.544048 | \n", "17.631184 | \n", "-32.581192 | \n", "9.696540 | \n", "-11.103087 | \n", "... | \n", "2.810453 | \n", "7.900215 | \n", "0.962129 | \n", "17.691130 | \n", "-1.252574 | \n", "-10.098049 | \n", "0.500113 | \n", "1.348694 | \n", "2.186150 | \n", "-16.556824 | \n", "
0 | \n", "6.576228 | \n", "20.336350 | \n", "-32.675150 | \n", "-9.073872 | \n", "17.515655 | \n", "-6.488794 | \n", "59.458419 | \n", "-75.384298 | \n", "13.323775 | \n", "-14.443218 | \n", "... | \n", "21.407738 | \n", "23.525118 | \n", "0.325680 | \n", "19.871444 | \n", "-27.585188 | \n", "-4.559155 | \n", "-7.417482 | \n", "-16.694553 | \n", "-0.197711 | \n", "-58.948193 | \n", "
0 | \n", "1.329914 | \n", "3.060870 | \n", "-1.868484 | \n", "1.392735 | \n", "-1.335277 | \n", "-5.014955 | \n", "12.859476 | \n", "-9.978156 | \n", "-0.869613 | \n", "-2.031490 | \n", "... | \n", "2.925134 | \n", "2.872930 | \n", "2.184486 | \n", "3.831770 | \n", "-0.877866 | \n", "-0.927770 | \n", "0.700101 | \n", "-9.855365 | \n", "-5.419429 | \n", "-2.279330 | \n", "
0 | \n", "-4.866150 | \n", "-0.273176 | \n", "3.515124 | \n", "-5.008165 | \n", "-1.236789 | \n", "-7.951168 | \n", "-11.015882 | \n", "-3.496241 | \n", "16.024286 | \n", "-9.388742 | \n", "... | \n", "-0.471141 | \n", "3.575378 | \n", "6.193222 | \n", "0.349430 | \n", "15.040248 | \n", "-10.369132 | \n", "-0.848717 | \n", "-0.564796 | \n", "-1.114126 | \n", "-7.844431 | \n", "
0 | \n", "-3.115007 | \n", "-1.805252 | \n", "-5.419340 | \n", "-0.393406 | \n", "-0.406461 | \n", "-2.724340 | \n", "7.898330 | \n", "-15.619113 | \n", "0.231822 | \n", "-3.628156 | \n", "... | \n", "5.944151 | \n", "8.309932 | \n", "-0.656084 | \n", "12.178709 | \n", "-6.118551 | \n", "-3.286376 | \n", "3.450946 | \n", "2.055343 | \n", "0.463787 | \n", "-12.644626 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
0 | \n", "-0.930954 | \n", "4.974043 | \n", "-8.147008 | \n", "-5.147130 | \n", "3.960455 | \n", "-1.344022 | \n", "7.818063 | \n", "-25.427420 | \n", "4.624732 | \n", "-7.218097 | \n", "... | \n", "3.623038 | \n", "4.453189 | \n", "2.405320 | \n", "8.032963 | \n", "-8.029539 | \n", "0.838867 | \n", "-4.757457 | \n", "-5.755052 | \n", "-9.496197 | \n", "-21.542710 | \n", "
0 | \n", "-0.770690 | \n", "0.128270 | \n", "0.331370 | \n", "0.005089 | \n", "-0.476050 | \n", "-0.501160 | \n", "1.858000 | \n", "1.062400 | \n", "-0.565110 | \n", "0.133280 | \n", "... | \n", "-0.449220 | \n", "0.485910 | \n", "-0.647900 | \n", "-0.842380 | \n", "0.616690 | \n", "-0.198240 | \n", "-0.579670 | \n", "-0.658850 | \n", "0.439280 | \n", "-0.504730 | \n", "
0 | \n", "1.491177 | \n", "6.992638 | \n", "-7.921970 | \n", "-7.157521 | \n", "6.641657 | \n", "-2.958020 | \n", "12.820770 | \n", "-18.502946 | \n", "6.838083 | \n", "-2.717310 | \n", "... | \n", "-1.344873 | \n", "4.170405 | \n", "-0.178030 | \n", "5.699992 | \n", "-7.295038 | \n", "-3.683306 | \n", "-2.718006 | \n", "-0.117608 | \n", "-7.205832 | \n", "-13.863438 | \n", "
0 | \n", "2.523770 | \n", "5.817394 | \n", "2.184340 | \n", "-2.996497 | \n", "-0.267181 | \n", "-10.059634 | \n", "6.344402 | \n", "-2.047127 | \n", "2.679123 | \n", "-7.642505 | \n", "... | \n", "-1.230296 | \n", "1.409746 | \n", "-3.322040 | \n", "-5.068259 | \n", "-0.648718 | \n", "0.753010 | \n", "-6.220990 | \n", "-5.012004 | \n", "-1.518542 | \n", "-10.156440 | \n", "
0 | \n", "-0.118691 | \n", "11.860546 | \n", "-2.567264 | \n", "-10.955913 | \n", "-4.239322 | \n", "-9.340552 | \n", "21.189778 | \n", "-10.895375 | \n", "2.659030 | \n", "-3.848115 | \n", "... | \n", "0.726191 | \n", "11.634998 | \n", "-5.447248 | \n", "1.293007 | \n", "-7.882002 | \n", "-2.527453 | \n", "0.298939 | \n", "-6.107062 | \n", "3.365051 | \n", "-15.641826 | \n", "
1657 rows × 25 columns
\n", "KNeighborsClassifier()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
KNeighborsClassifier()
\n", " | 0 | \n", "1 | \n", "2 | \n", "3 | \n", "4 | \n", "5 | \n", "6 | \n", "7 | \n", "8 | \n", "9 | \n", "... | \n", "15 | \n", "16 | \n", "17 | \n", "18 | \n", "19 | \n", "20 | \n", "21 | \n", "22 | \n", "23 | \n", "24 | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "-6.760635 | \n", "5.063863 | \n", "-2.779060 | \n", "3.699120 | \n", "-2.858086 | \n", "0.135230 | \n", "20.811229 | \n", "-19.425567 | \n", "7.302950 | \n", "-5.826012 | \n", "... | \n", "3.833378 | \n", "6.794452 | \n", "-0.921720 | \n", "12.187404 | \n", "-5.547615 | \n", "-4.133999 | \n", "3.588260 | \n", "-0.497106 | \n", "-2.542142 | \n", "-11.362855 | \n", "
0 | \n", "1.632616 | \n", "2.512300 | \n", "-0.745513 | \n", "-3.081154 | \n", "2.182067 | \n", "-1.988816 | \n", "7.533100 | \n", "-1.015740 | \n", "-0.829598 | \n", "-2.764237 | \n", "... | \n", "0.791851 | \n", "2.114150 | \n", "-2.249193 | \n", "-0.163590 | \n", "-1.177710 | \n", "-2.496928 | \n", "-5.074085 | \n", "-2.666947 | \n", "0.662050 | \n", "-3.590550 | \n", "
0 | \n", "2.115766 | \n", "2.142060 | \n", "-0.445607 | \n", "-3.229030 | \n", "1.154580 | \n", "-2.877278 | \n", "6.399954 | \n", "-10.445769 | \n", "2.230760 | \n", "-3.299899 | \n", "... | \n", "4.388870 | \n", "8.515056 | \n", "-0.766260 | \n", "3.549431 | \n", "-1.643443 | \n", "-0.825730 | \n", "-2.968016 | \n", "-0.808924 | \n", "-0.000160 | \n", "-7.468189 | \n", "
0 | \n", "-0.802784 | \n", "5.199443 | \n", "4.294071 | \n", "-7.390966 | \n", "2.747166 | \n", "-1.359952 | \n", "15.032628 | \n", "-1.601590 | \n", "1.474406 | \n", "2.570105 | \n", "... | \n", "3.043432 | \n", "6.176236 | \n", "-6.193988 | \n", "-3.990476 | \n", "-2.345854 | \n", "-5.534376 | \n", "-8.925422 | \n", "1.553300 | \n", "0.905790 | \n", "-12.824533 | \n", "
0 | \n", "29.926489 | \n", "65.324993 | \n", "-25.059592 | \n", "-64.080130 | \n", "77.565282 | \n", "-34.614604 | \n", "75.643770 | \n", "-115.600859 | \n", "90.847175 | \n", "-42.971146 | \n", "... | \n", "40.956031 | \n", "50.322156 | \n", "-19.537098 | \n", "28.903925 | \n", "-34.643949 | \n", "-69.894146 | \n", "-94.992145 | \n", "-48.601895 | \n", "-29.098555 | \n", "-91.934770 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
0 | \n", "1.829235 | \n", "4.513807 | \n", "2.916520 | \n", "2.237308 | \n", "-1.704831 | \n", "-1.811192 | \n", "22.196895 | \n", "-12.858912 | \n", "-4.054810 | \n", "-3.130457 | \n", "... | \n", "6.019246 | \n", "8.949456 | \n", "-4.682214 | \n", "-5.648911 | \n", "-1.026898 | \n", "3.719006 | \n", "2.449941 | \n", "-6.487197 | \n", "1.340930 | \n", "-7.325196 | \n", "
0 | \n", "-0.963815 | \n", "5.491164 | \n", "3.567377 | \n", "-6.048021 | \n", "-5.059298 | \n", "-0.977958 | \n", "15.131499 | \n", "-0.904470 | \n", "2.185990 | \n", "-1.459807 | \n", "... | \n", "0.968499 | \n", "4.725793 | \n", "-0.726944 | \n", "1.328612 | \n", "-3.144209 | \n", "1.643127 | \n", "-1.259245 | \n", "-0.880740 | \n", "-6.713165 | \n", "-3.115454 | \n", "
0 | \n", "6.801324 | \n", "15.348126 | \n", "-17.051718 | \n", "5.030998 | \n", "9.332448 | \n", "-5.716691 | \n", "56.409175 | \n", "-56.250411 | \n", "-4.028209 | \n", "-11.687558 | \n", "... | \n", "22.884424 | \n", "12.940570 | \n", "1.058664 | \n", "21.879058 | \n", "-20.897253 | \n", "2.537755 | \n", "3.774890 | \n", "-11.495336 | \n", "-2.609774 | \n", "-36.597559 | \n", "
0 | \n", "1.054090 | \n", "0.764524 | \n", "1.958340 | \n", "-1.085245 | \n", "-0.441392 | \n", "-0.421970 | \n", "6.139770 | \n", "-0.612219 | \n", "-2.251460 | \n", "-0.465165 | \n", "... | \n", "0.377958 | \n", "1.957450 | \n", "-1.705220 | \n", "-0.509700 | \n", "0.016110 | \n", "1.461620 | \n", "1.589069 | \n", "2.267340 | \n", "0.447919 | \n", "-0.469250 | \n", "
0 | \n", "-18.387286 | \n", "13.274879 | \n", "-7.895913 | \n", "-1.831442 | \n", "-10.424961 | \n", "-12.248442 | \n", "32.153890 | \n", "-40.169293 | \n", "13.089525 | \n", "-21.306493 | \n", "... | \n", "6.497279 | \n", "8.340729 | \n", "4.996109 | \n", "23.442078 | \n", "-3.701088 | \n", "-11.671505 | \n", "9.209790 | \n", "-10.002501 | \n", "-0.815266 | \n", "-17.024052 | \n", "
1102 rows × 25 columns
\n", "