233 KiB
Методические указания к лабораторной работе №2
В данной работе мы продолжаем работать с библиотекой scikit-learn (http://scikit-learn.org), и хотим выяснить ее возможности при работе с текстовыми документами.
Ниже приведены новые модули, которые будут использованы в данной работе:
- fetch_20newsgroups - http://scikit-learn.org/stable/modules/generated/sklearn.datasets.fetch_20newsgroups.html - загружает датасет «20 news groups», состоящий приблизительно из 18000 сообщений на английском языке по 20 тематикам, разбитым на обучающую и тестовую выборки. Векторизаторы текста:
- CountVectorizer - http://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html
- TfidfTransformer - http://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.TfidfVectorizer.html#sklearn.feature_extraction.text.TfidfVectorizer
- Pipeline - http://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html - конвейерный классификатор
- MultinominalNB - http://scikit-learn.org/stable/modules/generated/sklearn.naive_bayes.MultinomialNB.html - Полиномиальный (Мультиномиальный) Наивный Байесовский метод – разновидность Наивного Байесовского метода, которая хорошо работает с текстами, длины которых сильно варьируются.
Для проведения стемминга предлагается использовать библиотеку NLTK и стеммер Портера
Импорт библиотек
from sklearn.neighbors import KNeighborsClassifier
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score, roc_auc_scoreЗагрузка выборки
Выборка 20 news groups представляет собой сообщения, котороые состоят из заголовка (header), основной части, подписи или сноски (footer), а также могут содержать в себе цитирование предыдущего сообщения (quotes). Модуль fetch_20newsgroups позволяет выбирать интересующие тематики и удалять ненужные части сообщений. Для того чтобы выбрать сообщения по интересующим тематикам, необходимо передать список тематик в параметр categories.
Для того чтобы удалить ненужные части сообщений, нужно передать их в параметр remove. Кроме того, важным параметром fetch_20newsgroups является subset - тип выборки – обучающая или тестовая.
Выберем сообщения по тематикам Атеизм и Компьютерная графика, а также укажем, что нас не интересуют заголовки, цитаты и подписи:
categories = ['alt.atheism', 'comp.graphics']
remove = ('headers', 'footers', 'quotes')
twenty_train = fetch_20newsgroups(subset='train', shuffle=True, random_state=42, categories = categories, remove = remove )
twenty_test = fetch_20newsgroups(subset='test', shuffle=True, random_state=42, categories = categories, remove = remove )Возвращаемый набор данных — это scikit-learn совокупность: одномерный контейнер с полями, которые могут интерпретироваться как признаки объекта (object attributes). Например, target_names содержит список названий запрошенных категорий, target - тематику сообщения, а data – непосредственно текст сообщения:
print (twenty_train.data[2])Does anyone know of any good shareware animation or paint software for an SGI
machine? I've exhausted everyplace on the net I can find and still don't hava
a nice piece of software.
Thanks alot!
Chad
Векторизация
Чтобы использовать машинное обучение на текстовых документах, первым делом, нужно перевести текстовое содержимое в числовой вектор признаков. Предобработка текста, токенизация и отбрасывание стоп-слов включены в состав модуля CountVectorizer, который позволяет создать словарь характерных признаков и перевести документы в векторы признаков. Создадим объект-векторизатор vect со следующими параметрами:
max_features = 10000- количество наиболее частотных терминов, из которых будет состоять словарь. По умолчанию используются все слова.stop_words = 'english'– на данный момент модулем поддерживается отсечение английских стоп-слов. Кроме того, здесь можно указать список стоп-слов вручную. Если параметр не указывать, будут использованы все термины словаря.
vect = CountVectorizer(max_features = 10000, stop_words = 'english')Также, в работе может потребоваться настройка следующих параметров:
max_df-floatв диапазоне [0.0, 1.0] илиint, по умолчанию = 1.0. При построении словаря игнорирует термины, частота которых в документе строго превышает заданный порог (стоп-слова для конкретного корпуса). Еслиfloat, параметр обозначает долю документов, если целое число – то абсолютное значение.min_df-floatв диапазоне [0.0, 1.0] илиint, по умолчанию = 1.0. При построении словаря игнорируйте термины, частота которых в документе строго ниже заданного порога. В литературе это значение также называется порогом. Еслиfloat, параметр обозначает долю документов, если целое число – то абсолютное значение.
После того как объект-векторизатор создан, необходимо создать словарь характерных признаков с помощью метода fit() и перевести документы в векторы признаков c помощью метода transform(), подав на него обучающую выборку:
vect.fit(twenty_train.data)
train_data = vect.transform(twenty_train.data)
test_data = vect.transform(twenty_test.data)Также, можно отметить что эти два действия могут быть объединены одним методом fit_transform(). Однако, в этом случае нужно учесть, что для перевода тестовой выборки в вектор признаков, по-прежнему нужно использовать метод transform():
train_data = vect.fit_transform(twenty_train.data)
test_data = vect.transform(twenty_test.data)Если для тестовых данных также воспользоваться методом fit_transform(), это приведет к перестроению словаря признаков и неправильным результатам классификации.
Следующий блок кода позволит вывести первые 10 терминов, упорядоченных по частоте встречаемости:
x = list(zip(vect.get_feature_names_out(), np.ravel(train_data.sum(axis=0))))
def SortbyTF(inputStr):
return inputStr[1]
x.sort(key=SortbyTF, reverse = True)
print (x[:10])[('image', np.int64(489)), ('don', np.int64(417)), ('graphics', np.int64(410)), ('god', np.int64(409)), ('people', np.int64(384)), ('does', np.int64(364)), ('edu', np.int64(349)), ('like', np.int64(329)), ('just', np.int64(327)), ('know', np.int64(319))]
Стемминг
Существует целый ряд алгоритмов стемминга. Один из популярных - алгоритм Портера, реализация которого приведена в библиотеке nltk Для проведения стемминга нужно создать объект PorterStemmer(). Стеммер работает таким образом: у созданного объекта PorterStemmer есть метод stem, производящий стемминга. Таким образом, необходимо каждую из частей выборки (обучающую и тестовую) разбить на отдельные документы, затем, проходя в цикле по каждому слову в документе, произвести стемминг и объединить эти слова в новый документ.
В ЛР проводить стемминг не требуется.
import nltk
from nltk.stem import *
from nltk import word_tokenize
nltk.download('punkt_tab')
porter_stemmer = PorterStemmer()
stem_train = []
for text in twenty_train.data:
nltk_tokens = word_tokenize(text)
line = ''
for word in nltk_tokens:
line += ' ' + porter_stemmer.stem(word)
stem_train.append(line)
print (stem_train[0])TF- и TF-IDF взвешивание
CountVectorizer позволяет лишь определять частоту встречаемости термина во всей выборке, но такой подход к выявлению информативных терминов не всегда дает качественный результат. На практике используют более продвинутые способы, наибольшее распространение из которых получили TF- и TF-IDF взвешивания. Воспользуемся методом fit() класса TfidfTransformer(), который переводит матрицу частот встречаемости в TF- и TF-IDF веса.
tfidf = TfidfTransformer(use_idf = True).fit(train_data)
train_data_tfidf = tfidf.transform(train_data)Отметим, что в метод fit() нужно передавать не исходные текстовые данные, а вектор слов и их частот, полученный с помощью метода transform() класса CountVectorizer. Для того, чтобы получить tf-idf значения, необходимо установить параметр use_idf = True, в противном случае на выходе мы получим значения tf
Классификация
После того как мы провели векторизацию текста, обучение модели и классификация для текстовых данных выглядит абсолютно идентично классификации объектов в первой лабораторной работе.
Задача обучения модели заключается не только в выборе подходящих данных обучающей выборки, способных качественно охарактеризовать объекты, но и в настройке многочисленных параметров метода классификации, предварительной обработке данных и т.д. Рассмотрим, какие возможности предлагаются в библиотеке scikit-learn для автоматизации и упрощения данной задачи.
Pipeline
Чтобы с цепочкой vectorizer => transformer => classifier было проще работать, в scikit-learn есть класс Pipeline (конвейер), который функционирует как составной (конвейерный) классификатор.
from sklearn.pipeline import PipelineПромежуточными шагами конвейера должны быть преобразования, то есть должны выполняться методы fit() и transform(), а последний шаг – только fit(). При этом, pipeline позволяет устанавливать различные параметры на каждом своем шаге. Таким образом, проделанные нами действия по векторизации данных, взвешиванию с помощью TF-IDF и классификации методом К-БС с использованием pipeline будут выглядеть следующим образом:
text_clf = Pipeline([('vect', CountVectorizer(max_features= 1000, stop_words = 'english')),
('tfidf', TfidfTransformer(use_idf = True)),
('clf', KNeighborsClassifier (n_neighbors=1)),]) Названия vect, tfidf и clf выбраны нами произвольно. Мы рассмотрим их использование в следующей лабораторной работе. Теперь обучим модель с помощью всего 1 команды:
text_clf = text_clf.fit(twenty_train.data, twenty_train.target)И проведем классификацию на тестовой выборке:
prediction = text_clf.predict(twenty_test.data)Настройка параметров с использованием grid search
from sklearn.model_selection import GridSearchCVС помощью конвейерной обработки (pipelines) стало гораздо проще указывать параметры обучения модели, однако перебирать все возможные варианты вручную даже в нашем простом случае выйдет затратно. В нашем случае имеется четыре настраиваемых параметра: max_features, stop_words, use_idf и n_neighbors.
Вместо поиска лучших параметров в конвейере вручную, можно запустить поиск (методом полного перебора) лучших параметров в сетке возможных значений. Сделать это можно с помощью объекта класса GridSearchCV.
Для того чтобы задать сетку параметров необходимо создать переменную-словарь, ключами которого являются конструкции вида: «НазваниеШагаКонвейера__НазваниеПараметра», а значениями – кортеж из значений параметра
parameters = {'vect__max_features': (100,500,1000,5000,10000),
'vect__stop_words': ('english', None),
'tfidf__use_idf': (True, False),
'clf__n_neighbors': (1,3,5,7)} Далее необходимо создать объект класса GridSearchCV, передав в него объект pipeline или классификатор, список параметров сетки, а также при необходимости, задав прочие параметры, такие так количество задействованых ядер процессора n_jobs, количество фолдов кросс-валидации cv, метрику, по которой будем судить о качестве модели scoring, и другие
gs_clf = GridSearchCV(text_clf, parameters, n_jobs=-1, cv=3, scoring = 'f1_weighted')Теперь объект gs_clf можно обучить по всем параметрам, как обычный классификатор, методом fit(). После того как прошло обучение, узнать лучшую совокупность параметров можно, обратившись к атрибуту best_params_. Для необученной модели атрибуты будут отсутствовать.
gs_clf.fit(twenty_train.data, twenty_train.target)GridSearchCV(cv=3,
estimator=Pipeline(steps=[('vect',
CountVectorizer(max_features=1000,
stop_words='english')),
('tfidf', TfidfTransformer()),
('clf',
KNeighborsClassifier(n_neighbors=1))]),
n_jobs=-1,
param_grid={'clf__n_neighbors': (1, 3, 5, 7),
'tfidf__use_idf': (True, False),
'vect__max_features': (100, 500, 1000, 5000, 10000),
'vect__stop_words': ('english', None)},
scoring='f1_weighted')In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
Parameters
Parameters
Parameters
403466d1d5/skrub/_reporting/_data/templates/report.js (L789)
* @returns "light" or "dark"
*/
function detectTheme(element) {
const body = document.querySelector('body');
// Check VSCode theme
const themeKindAttr = body.getAttribute('data-vscode-theme-kind');
const themeNameAttr = body.getAttribute('data-vscode-theme-name');
if (themeKindAttr && themeNameAttr) {
const themeKind = themeKindAttr.toLowerCase();
const themeName = themeNameAttr.toLowerCase();
if (themeKind.includes("dark") || themeName.includes("dark")) {
return "dark";
}
if (themeKind.includes("light") || themeName.includes("light")) {
return "light";
}
}
// Check Jupyter theme
if (body.getAttribute('data-jp-theme-light') === 'false') {
return 'dark';
} else if (body.getAttribute('data-jp-theme-light') === 'true') {
return 'light';
}
// Guess based on a parent element's color
const color = window.getComputedStyle(element.parentNode, null).getPropertyValue('color');
const match = color.match(/^rgb\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)\s*$/i);
if (match) {
const [r, g, b] = [
parseFloat(match[1]),
parseFloat(match[2]),
parseFloat(match[3])
];
// https://en.wikipedia.org/wiki/HSL_and_HSV#Lightness
const luma = 0.299 * r + 0.587 * g + 0.114 * b;
if (luma > 180) {
// If the text is very bright we have a dark theme
return 'dark';
}
if (luma < 75) {
// If the text is very dark we have a light theme
return 'light';
}
// Otherwise fall back to the next heuristic.
}
// Fallback to system preference
return window.matchMedia('(prefers-color-scheme: dark)').matches ? 'dark' : 'light';
}
function forceTheme(elementId) {
const estimatorElement = document.querySelector(`#${elementId}`);
if (estimatorElement === null) {
console.error(`Element with id ${elementId} not found.`);
} else {
const theme = detectTheme(estimatorElement);
estimatorElement.classList.add(theme);
}
}
forceTheme('sk-container-id-1');</script>
gs_clf.best_params_{'clf__n_neighbors': 3,
'tfidf__use_idf': True,
'vect__max_features': 100,
'vect__stop_words': 'english'}
gs_clf.cv_results_{'mean_fit_time': array([0.14711912, 0.1838789 , 0.16524474, 0.13485821, 0.17077629,
0.171736 , 0.13478573, 0.17135119, 0.1455044 , 0.16278481,
0.14029678, 0.17597159, 0.14752992, 0.1515433 , 0.2238245 ,
0.16685406, 0.19492237, 0.18400868, 0.1637164 , 0.23180572,
0.13470483, 0.19008255, 0.168998 , 0.14402525, 0.14959741,
0.21908172, 0.12458984, 0.17676258, 0.181101 , 0.15396865,
0.13980309, 0.17127609, 0.12496074, 0.21157686, 0.14910388,
0.16302236, 0.20699684, 0.14392487, 0.13715124, 0.19204704,
0.12343788, 0.202528 , 0.1306417 , 0.14448158, 0.16433088,
0.17828568, 0.15127762, 0.16031146, 0.15140502, 0.17451549,
0.13383325, 0.19785619, 0.14872464, 0.19805606, 0.17880678,
0.14933427, 0.18098283, 0.17271209, 0.19098576, 0.15773439,
0.14277951, 0.14034979, 0.17362897, 0.14588912, 0.15714224,
0.19136397, 0.14665333, 0.18524861, 0.1639146 , 0.15949496,
0.14914314, 0.19116807, 0.16394456, 0.17290258, 0.14597694,
0.18869273, 0.17521318, 0.18450022, 0.15120141, 0.14013394]),
'std_fit_time': array([0.0426388 , 0.00865779, 0.03047002, 0.01959866, 0.01996265,
0.03991559, 0.02089307, 0.0115939 , 0.01577824, 0.01259437,
0.02861695, 0.02803293, 0.02632727, 0.03513436, 0.10327593,
0.06408524, 0.04651852, 0.03295957, 0.0471307 , 0.11679622,
0.0110411 , 0.02803024, 0.05824908, 0.01165112, 0.01597652,
0.06360679, 0.0225909 , 0.03841467, 0.06331909, 0.02901401,
0.02955105, 0.00180821, 0.01985555, 0.0523942 , 0.03268085,
0.03894983, 0.06517167, 0.02243237, 0.02200533, 0.0133168 ,
0.02014168, 0.05538047, 0.00691204, 0.02019455, 0.04787927,
0.02202076, 0.03367995, 0.04032929, 0.04035338, 0.05086764,
0.03783577, 0.02842441, 0.01027949, 0.02261441, 0.07413437,
0.03057069, 0.07406925, 0.06374613, 0.07499723, 0.00702891,
0.01449466, 0.01535535, 0.06266348, 0.01594578, 0.03962651,
0.03872874, 0.02924308, 0.03921288, 0.03093211, 0.02737551,
0.03913621, 0.03922343, 0.06958397, 0.0321589 , 0.03550322,
0.06330303, 0.04521662, 0.01877706, 0.01843914, 0.02003679]),
'mean_score_time': array([0.07024606, 0.09829744, 0.06123559, 0.0934546 , 0.06724668,
0.08956035, 0.06278229, 0.08930341, 0.07657051, 0.08451796,
0.06686974, 0.06430848, 0.06159838, 0.06787952, 0.0685486 ,
0.08175969, 0.07736111, 0.08548617, 0.10391927, 0.11483145,
0.06140908, 0.08571712, 0.05673496, 0.07094908, 0.07356191,
0.08419561, 0.06646315, 0.07401323, 0.0662907 , 0.07332158,
0.06078386, 0.06057103, 0.05959137, 0.07239731, 0.06452703,
0.06926155, 0.07553283, 0.07713358, 0.06963356, 0.09399152,
0.05499752, 0.06101354, 0.06951396, 0.06801232, 0.06193964,
0.08509183, 0.06340122, 0.09640757, 0.09270819, 0.07950346,
0.06135122, 0.07454085, 0.06276194, 0.08477139, 0.06570975,
0.06936661, 0.07554563, 0.08220371, 0.06743741, 0.10737284,
0.06237586, 0.06374478, 0.07077082, 0.06676245, 0.06271029,
0.10701195, 0.07044252, 0.07724508, 0.0787247 , 0.07522114,
0.06066545, 0.08716504, 0.06308579, 0.08094724, 0.08140651,
0.07310184, 0.07593838, 0.07488211, 0.06320556, 0.04453659]),
'std_score_time': array([0.0307069 , 0.0182529 , 0.02101209, 0.03205303, 0.0257132 ,
0.02687809, 0.0173987 , 0.02640571, 0.01539859, 0.01465485,
0.02122433, 0.01356409, 0.00834315, 0.01751938, 0.01171655,
0.03144068, 0.02179221, 0.01969381, 0.04494128, 0.00622263,
0.01579416, 0.02897917, 0.01304806, 0.0202541 , 0.02151299,
0.03211671, 0.01358637, 0.01938508, 0.01422302, 0.01688134,
0.00869969, 0.0141093 , 0.01194079, 0.01273514, 0.01648673,
0.01545342, 0.01899566, 0.0178712 , 0.02332053, 0.03043117,
0.01200954, 0.01238742, 0.02325767, 0.01629294, 0.01778484,
0.03084458, 0.0157806 , 0.04616991, 0.02142426, 0.01231431,
0.0059662 , 0.01826184, 0.02046084, 0.02321078, 0.00858559,
0.01727098, 0.01823882, 0.02170296, 0.0142323 , 0.06476885,
0.02283514, 0.01249623, 0.00957304, 0.01621181, 0.00941332,
0.04093458, 0.01894285, 0.0226533 , 0.02176937, 0.01821806,
0.01055958, 0.03213191, 0.00837758, 0.02668343, 0.03007748,
0.01731415, 0.01637027, 0.01601808, 0.02890825, 0.01039945]),
'param_clf__n_neighbors': masked_array(data=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
5, 5, 5, 5, 5, 5, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
7, 7, 7, 7, 7, 7, 7, 7],
mask=[False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False],
fill_value=999999),
'param_tfidf__use_idf': masked_array(data=[True, True, True, True, True, True, True, True, True,
True, False, False, False, False, False, False, False,
False, False, False, True, True, True, True, True,
True, True, True, True, True, False, False, False,
False, False, False, False, False, False, False, True,
True, True, True, True, True, True, True, True, True,
False, False, False, False, False, False, False, False,
False, False, True, True, True, True, True, True, True,
True, True, True, False, False, False, False, False,
False, False, False, False, False],
mask=[False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False],
fill_value=True),
'param_vect__max_features': masked_array(data=[100, 100, 500, 500, 1000, 1000, 5000, 5000, 10000,
10000, 100, 100, 500, 500, 1000, 1000, 5000, 5000,
10000, 10000, 100, 100, 500, 500, 1000, 1000, 5000,
5000, 10000, 10000, 100, 100, 500, 500, 1000, 1000,
5000, 5000, 10000, 10000, 100, 100, 500, 500, 1000,
1000, 5000, 5000, 10000, 10000, 100, 100, 500, 500,
1000, 1000, 5000, 5000, 10000, 10000, 100, 100, 500,
500, 1000, 1000, 5000, 5000, 10000, 10000, 100, 100,
500, 500, 1000, 1000, 5000, 5000, 10000, 10000],
mask=[False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False],
fill_value=999999),
'param_vect__stop_words': masked_array(data=['english', None, 'english', None, 'english', None,
'english', None, 'english', None, 'english', None,
'english', None, 'english', None, 'english', None,
'english', None, 'english', None, 'english', None,
'english', None, 'english', None, 'english', None,
'english', None, 'english', None, 'english', None,
'english', None, 'english', None, 'english', None,
'english', None, 'english', None, 'english', None,
'english', None, 'english', None, 'english', None,
'english', None, 'english', None, 'english', None,
'english', None, 'english', None, 'english', None,
'english', None, 'english', None, 'english', None,
'english', None, 'english', None, 'english', None,
'english', None],
mask=[False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False],
fill_value=np.str_('?'),
dtype=object),
'params': [{'clf__n_neighbors': 1,
'tfidf__use_idf': True,
'vect__max_features': 100,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 1,
'tfidf__use_idf': True,
'vect__max_features': 100,
'vect__stop_words': None},
{'clf__n_neighbors': 1,
'tfidf__use_idf': True,
'vect__max_features': 500,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 1,
'tfidf__use_idf': True,
'vect__max_features': 500,
'vect__stop_words': None},
{'clf__n_neighbors': 1,
'tfidf__use_idf': True,
'vect__max_features': 1000,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 1,
'tfidf__use_idf': True,
'vect__max_features': 1000,
'vect__stop_words': None},
{'clf__n_neighbors': 1,
'tfidf__use_idf': True,
'vect__max_features': 5000,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 1,
'tfidf__use_idf': True,
'vect__max_features': 5000,
'vect__stop_words': None},
{'clf__n_neighbors': 1,
'tfidf__use_idf': True,
'vect__max_features': 10000,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 1,
'tfidf__use_idf': True,
'vect__max_features': 10000,
'vect__stop_words': None},
{'clf__n_neighbors': 1,
'tfidf__use_idf': False,
'vect__max_features': 100,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 1,
'tfidf__use_idf': False,
'vect__max_features': 100,
'vect__stop_words': None},
{'clf__n_neighbors': 1,
'tfidf__use_idf': False,
'vect__max_features': 500,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 1,
'tfidf__use_idf': False,
'vect__max_features': 500,
'vect__stop_words': None},
{'clf__n_neighbors': 1,
'tfidf__use_idf': False,
'vect__max_features': 1000,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 1,
'tfidf__use_idf': False,
'vect__max_features': 1000,
'vect__stop_words': None},
{'clf__n_neighbors': 1,
'tfidf__use_idf': False,
'vect__max_features': 5000,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 1,
'tfidf__use_idf': False,
'vect__max_features': 5000,
'vect__stop_words': None},
{'clf__n_neighbors': 1,
'tfidf__use_idf': False,
'vect__max_features': 10000,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 1,
'tfidf__use_idf': False,
'vect__max_features': 10000,
'vect__stop_words': None},
{'clf__n_neighbors': 3,
'tfidf__use_idf': True,
'vect__max_features': 100,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 3,
'tfidf__use_idf': True,
'vect__max_features': 100,
'vect__stop_words': None},
{'clf__n_neighbors': 3,
'tfidf__use_idf': True,
'vect__max_features': 500,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 3,
'tfidf__use_idf': True,
'vect__max_features': 500,
'vect__stop_words': None},
{'clf__n_neighbors': 3,
'tfidf__use_idf': True,
'vect__max_features': 1000,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 3,
'tfidf__use_idf': True,
'vect__max_features': 1000,
'vect__stop_words': None},
{'clf__n_neighbors': 3,
'tfidf__use_idf': True,
'vect__max_features': 5000,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 3,
'tfidf__use_idf': True,
'vect__max_features': 5000,
'vect__stop_words': None},
{'clf__n_neighbors': 3,
'tfidf__use_idf': True,
'vect__max_features': 10000,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 3,
'tfidf__use_idf': True,
'vect__max_features': 10000,
'vect__stop_words': None},
{'clf__n_neighbors': 3,
'tfidf__use_idf': False,
'vect__max_features': 100,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 3,
'tfidf__use_idf': False,
'vect__max_features': 100,
'vect__stop_words': None},
{'clf__n_neighbors': 3,
'tfidf__use_idf': False,
'vect__max_features': 500,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 3,
'tfidf__use_idf': False,
'vect__max_features': 500,
'vect__stop_words': None},
{'clf__n_neighbors': 3,
'tfidf__use_idf': False,
'vect__max_features': 1000,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 3,
'tfidf__use_idf': False,
'vect__max_features': 1000,
'vect__stop_words': None},
{'clf__n_neighbors': 3,
'tfidf__use_idf': False,
'vect__max_features': 5000,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 3,
'tfidf__use_idf': False,
'vect__max_features': 5000,
'vect__stop_words': None},
{'clf__n_neighbors': 3,
'tfidf__use_idf': False,
'vect__max_features': 10000,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 3,
'tfidf__use_idf': False,
'vect__max_features': 10000,
'vect__stop_words': None},
{'clf__n_neighbors': 5,
'tfidf__use_idf': True,
'vect__max_features': 100,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 5,
'tfidf__use_idf': True,
'vect__max_features': 100,
'vect__stop_words': None},
{'clf__n_neighbors': 5,
'tfidf__use_idf': True,
'vect__max_features': 500,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 5,
'tfidf__use_idf': True,
'vect__max_features': 500,
'vect__stop_words': None},
{'clf__n_neighbors': 5,
'tfidf__use_idf': True,
'vect__max_features': 1000,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 5,
'tfidf__use_idf': True,
'vect__max_features': 1000,
'vect__stop_words': None},
{'clf__n_neighbors': 5,
'tfidf__use_idf': True,
'vect__max_features': 5000,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 5,
'tfidf__use_idf': True,
'vect__max_features': 5000,
'vect__stop_words': None},
{'clf__n_neighbors': 5,
'tfidf__use_idf': True,
'vect__max_features': 10000,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 5,
'tfidf__use_idf': True,
'vect__max_features': 10000,
'vect__stop_words': None},
{'clf__n_neighbors': 5,
'tfidf__use_idf': False,
'vect__max_features': 100,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 5,
'tfidf__use_idf': False,
'vect__max_features': 100,
'vect__stop_words': None},
{'clf__n_neighbors': 5,
'tfidf__use_idf': False,
'vect__max_features': 500,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 5,
'tfidf__use_idf': False,
'vect__max_features': 500,
'vect__stop_words': None},
{'clf__n_neighbors': 5,
'tfidf__use_idf': False,
'vect__max_features': 1000,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 5,
'tfidf__use_idf': False,
'vect__max_features': 1000,
'vect__stop_words': None},
{'clf__n_neighbors': 5,
'tfidf__use_idf': False,
'vect__max_features': 5000,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 5,
'tfidf__use_idf': False,
'vect__max_features': 5000,
'vect__stop_words': None},
{'clf__n_neighbors': 5,
'tfidf__use_idf': False,
'vect__max_features': 10000,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 5,
'tfidf__use_idf': False,
'vect__max_features': 10000,
'vect__stop_words': None},
{'clf__n_neighbors': 7,
'tfidf__use_idf': True,
'vect__max_features': 100,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 7,
'tfidf__use_idf': True,
'vect__max_features': 100,
'vect__stop_words': None},
{'clf__n_neighbors': 7,
'tfidf__use_idf': True,
'vect__max_features': 500,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 7,
'tfidf__use_idf': True,
'vect__max_features': 500,
'vect__stop_words': None},
{'clf__n_neighbors': 7,
'tfidf__use_idf': True,
'vect__max_features': 1000,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 7,
'tfidf__use_idf': True,
'vect__max_features': 1000,
'vect__stop_words': None},
{'clf__n_neighbors': 7,
'tfidf__use_idf': True,
'vect__max_features': 5000,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 7,
'tfidf__use_idf': True,
'vect__max_features': 5000,
'vect__stop_words': None},
{'clf__n_neighbors': 7,
'tfidf__use_idf': True,
'vect__max_features': 10000,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 7,
'tfidf__use_idf': True,
'vect__max_features': 10000,
'vect__stop_words': None},
{'clf__n_neighbors': 7,
'tfidf__use_idf': False,
'vect__max_features': 100,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 7,
'tfidf__use_idf': False,
'vect__max_features': 100,
'vect__stop_words': None},
{'clf__n_neighbors': 7,
'tfidf__use_idf': False,
'vect__max_features': 500,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 7,
'tfidf__use_idf': False,
'vect__max_features': 500,
'vect__stop_words': None},
{'clf__n_neighbors': 7,
'tfidf__use_idf': False,
'vect__max_features': 1000,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 7,
'tfidf__use_idf': False,
'vect__max_features': 1000,
'vect__stop_words': None},
{'clf__n_neighbors': 7,
'tfidf__use_idf': False,
'vect__max_features': 5000,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 7,
'tfidf__use_idf': False,
'vect__max_features': 5000,
'vect__stop_words': None},
{'clf__n_neighbors': 7,
'tfidf__use_idf': False,
'vect__max_features': 10000,
'vect__stop_words': 'english'},
{'clf__n_neighbors': 7,
'tfidf__use_idf': False,
'vect__max_features': 10000,
'vect__stop_words': None}],
'split0_test_score': array([0.76817583, 0.72894636, 0.53358157, 0.6374826 , 0.52798315,
0.56519764, 0.51309075, 0.49836392, 0.49549296, 0.48037932,
0.74678068, 0.6781238 , 0.53815493, 0.67256907, 0.59194787,
0.61495561, 0.50607973, 0.58348819, 0.48282372, 0.56647032,
0.78847084, 0.7252709 , 0.48067615, 0.52118859, 0.51721694,
0.48760057, 0.51577841, 0.46346843, 0.43209944, 0.42022184,
0.79654539, 0.72168041, 0.50540586, 0.62161402, 0.53994855,
0.57362945, 0.53619645, 0.54414205, 0.43884702, 0.4762766 ,
0.75153416, 0.72862518, 0.45804881, 0.49005168, 0.52767212,
0.49479873, 0.48059474, 0.45872108, 0.43320814, 0.46831129,
0.77799622, 0.71691323, 0.468479 , 0.58540742, 0.53125944,
0.55188163, 0.53165214, 0.49549296, 0.46225076, 0.46334516,
0.73731272, 0.722074 , 0.42768101, 0.49765146, 0.48090102,
0.59104363, 0.51802701, 0.48304986, 0.44659485, 0.52984156,
0.76674346, 0.71632941, 0.42110969, 0.55322776, 0.49049555,
0.55175372, 0.55288072, 0.52752419, 0.47634782, 0.48746189]),
'split1_test_score': array([0.79662124, 0.74654727, 0.55840557, 0.65284461, 0.56848052,
0.58959475, 0.50824228, 0.52616761, 0.55281361, 0.53900106,
0.79198638, 0.74396927, 0.60547906, 0.71852445, 0.55601273,
0.70906678, 0.55288072, 0.70132316, 0.54064496, 0.72369393,
0.8200291 , 0.76394456, 0.50142662, 0.65973942, 0.57835078,
0.58411319, 0.56231765, 0.54929577, 0.53755372, 0.54631601,
0.76828781, 0.73851721, 0.57218138, 0.70341759, 0.54466558,
0.70088681, 0.52052675, 0.68262612, 0.56159728, 0.69962829,
0.78081641, 0.76676231, 0.47784768, 0.65979122, 0.55877981,
0.58351195, 0.55934603, 0.5526496 , 0.53805724, 0.54275916,
0.7590099 , 0.74428157, 0.50962843, 0.69741916, 0.56721274,
0.6845754 , 0.53674839, 0.65880645, 0.56642635, 0.66528626,
0.75249835, 0.76114625, 0.5033107 , 0.61728125, 0.54473029,
0.54902822, 0.55716104, 0.55466537, 0.51280439, 0.54540585,
0.72900844, 0.73304581, 0.49250319, 0.69558085, 0.57262651,
0.66997019, 0.54339107, 0.66489176, 0.56953075, 0.65745565]),
'split2_test_score': array([0.82526209, 0.7519849 , 0.59510748, 0.68219353, 0.5764876 ,
0.6305322 , 0.54587929, 0.56161317, 0.52387661, 0.56501726,
0.82242212, 0.71247211, 0.60517653, 0.7037331 , 0.54624902,
0.68817364, 0.5608424 , 0.65146866, 0.53526413, 0.66897713,
0.81950927, 0.79423613, 0.52825543, 0.66907312, 0.56248767,
0.60389162, 0.49652131, 0.55180864, 0.50857805, 0.52741727,
0.80244926, 0.69231892, 0.53887943, 0.71969825, 0.53158497,
0.6792325 , 0.49764733, 0.65713584, 0.49078794, 0.66897713,
0.80172943, 0.75452358, 0.54031316, 0.63620099, 0.52529445,
0.59626871, 0.51327659, 0.53187959, 0.50624241, 0.51743203,
0.76711749, 0.70123506, 0.49774614, 0.71460958, 0.50316727,
0.65013293, 0.49732848, 0.64249816, 0.52623393, 0.65444272,
0.76791783, 0.73768667, 0.53419234, 0.63714302, 0.50165182,
0.5932877 , 0.49723701, 0.51859739, 0.49993588, 0.52478814,
0.76415908, 0.69550001, 0.48572333, 0.68222059, 0.50972411,
0.66549536, 0.47022833, 0.64386329, 0.50648577, 0.6415092 ]),
'mean_test_score': array([0.79668639, 0.74249284, 0.56236487, 0.65750691, 0.55765042,
0.5951082 , 0.52240411, 0.5287149 , 0.52406106, 0.52813255,
0.78706306, 0.71152173, 0.58293684, 0.69827554, 0.56473654,
0.67073201, 0.53993428, 0.64542667, 0.5195776 , 0.65304713,
0.8093364 , 0.76115053, 0.50345273, 0.61666704, 0.55268513,
0.55853513, 0.52487246, 0.52152428, 0.49274373, 0.49798504,
0.78909415, 0.71750551, 0.53882222, 0.68157662, 0.53873303,
0.65124959, 0.51812351, 0.627968 , 0.49707742, 0.61496068,
0.77802667, 0.74997036, 0.49206988, 0.59534796, 0.53724879,
0.55819313, 0.51773912, 0.51441676, 0.4925026 , 0.50950083,
0.7680412 , 0.72080995, 0.49195119, 0.66581205, 0.53387982,
0.62886332, 0.52190967, 0.59893252, 0.51830368, 0.59435805,
0.7525763 , 0.74030231, 0.48839468, 0.58402524, 0.50909438,
0.57778652, 0.52414169, 0.51877087, 0.48644504, 0.53334518,
0.75330366, 0.71495841, 0.4664454 , 0.6436764 , 0.52428206,
0.62907309, 0.52216671, 0.61209308, 0.51745478, 0.59547558]),
'std_test_score': array([0.02330541, 0.00983268, 0.02527339, 0.01854849, 0.02123109,
0.02695613, 0.01671706, 0.02588415, 0.02340142, 0.03539763,
0.0310761 , 0.0268897 , 0.03166583, 0.01915399, 0.01964985,
0.04035167, 0.02415844, 0.04829527, 0.02608159, 0.06516717,
0.01475571, 0.02822417, 0.01947692, 0.06762091, 0.02590243,
0.05080407, 0.02762023, 0.0410645 , 0.04448367, 0.05552553,
0.01490843, 0.01909001, 0.02726102, 0.04291775, 0.00540887,
0.05559311, 0.0158291 , 0.06018046, 0.05030954, 0.09885959,
0.02058686, 0.01589883, 0.03505766, 0.07507599, 0.01525564,
0.04512812, 0.03230456, 0.04028527, 0.04389321, 0.0309063 ,
0.0077786 , 0.01778837, 0.01729171, 0.05728616, 0.02621202,
0.05622103, 0.0175056 , 0.07344521, 0.04289759, 0.09274581,
0.0124946 , 0.01605805, 0.04474395, 0.06161139, 0.0265843 ,
0.02035581, 0.02484303, 0.02923717, 0.02866389, 0.00877417,
0.01721168, 0.01535864, 0.03217646, 0.064189 , 0.03507444,
0.05470356, 0.03692975, 0.06041232, 0.03882443, 0.07665416]),
'rank_test_score': array([ 2, 11, 41, 21, 44, 35, 58, 52, 57, 53, 4, 16, 38, 17, 40, 19, 46,
24, 62, 22, 1, 7, 71, 29, 45, 42, 54, 61, 74, 72, 3, 14, 47, 18,
48, 23, 65, 28, 73, 30, 5, 10, 76, 34, 49, 43, 66, 68, 75, 69, 6,
13, 77, 20, 50, 27, 60, 32, 64, 36, 9, 12, 78, 37, 70, 39, 56, 63,
79, 51, 8, 15, 80, 25, 55, 26, 59, 31, 67, 33], dtype=int32)}
gs_clf.best_score_np.float64(0.8093364049132378)
Word embedding
# Импортируем библиотеку gensim и посмотрим какие обученные модели в ней есть.
import gensim.downloader
print(list(gensim.downloader.info()['models'].keys()))['fasttext-wiki-news-subwords-300', 'conceptnet-numberbatch-17-06-300', 'word2vec-ruscorpora-300', 'word2vec-google-news-300', 'glove-wiki-gigaword-50', 'glove-wiki-gigaword-100', 'glove-wiki-gigaword-200', 'glove-wiki-gigaword-300', 'glove-twitter-25', 'glove-twitter-50', 'glove-twitter-100', 'glove-twitter-200', '__testing_word2vec-matrix-synopsis']
GloVe
Загрузим "glove-twitter-25" и будем работать с ней
glove_model = gensim.downloader.load("glove-twitter-25") # load glove vectors[==================================================] 100.0% 104.8/104.8MB downloaded
# Можем посмотреть, что за векторы у слова "cat"
print(glove_model['cat']) # word embedding for 'cat'
glove_model.most_similar("cat") # show words that similar to word 'cat'[-0.96419 -0.60978 0.67449 0.35113 0.41317 -0.21241 1.3796
0.12854 0.31567 0.66325 0.3391 -0.18934 -3.325 -1.1491
-0.4129 0.2195 0.8706 -0.50616 -0.12781 -0.066965 0.065761
0.43927 0.1758 -0.56058 0.13529 ]
[('dog', 0.9590820074081421),
('monkey', 0.920357882976532),
('bear', 0.9143136739730835),
('pet', 0.9108031392097473),
('girl', 0.8880629539489746),
('horse', 0.8872726559638977),
('kitty', 0.8870542049407959),
('puppy', 0.886769711971283),
('hot', 0.886525571346283),
('lady', 0.8845519423484802)]
Векторизуем обучающую выборку
Получаем матрицу "Документ-термин". В столбцах - все слова выборки, в строках - каждый документ. Значения в ячейках - количество встречаний слова в документе. В отображенной части видим только нули - наглядная демонстрация проблем разреженности матрицы
vectorizer = CountVectorizer(stop_words='english')
train_data = vectorizer.fit_transform(twenty_train['data'])
CV_data=pd.DataFrame(train_data.toarray(), columns=vectorizer.get_feature_names_out())
print(CV_data.shape)
CV_data.head()(1064, 16170)
| 00 | 000 | 000005102000 | 000100255pixel | 0007 | 000usd | 001200201pixel | 00196 | 0028 | 0038 | ... | zorg | zorn | zsoft | zues | zug | zur | zurich | zus | zvi | zyxel | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 2 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 3 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 4 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
5 rows × 16170 columns
# Создадим список уникальных слов, присутствующих в словаре.
words_vocab=CV_data.columns
print(words_vocab[0:10])Index(['00', '000', '000005102000', '000100255pixel', '0007', '000usd',
'001200201pixel', '00196', '0028', '0038'],
dtype='str')
Векторизуем с помощью GloVe
Нужно для каждого документа сложить glove-вектора слов, из которых он состоит. В результате получим вектор документа как сумму векторов слов, из него состоящих
Посмотрим на примере как будет работать векторизация
# Пусть выборка состоит из двух документов:
text_data = ['Hello world I love python', 'This is a great computer game! 00 000 zyxel']
# Векторизуем с помощью обученного CountVectorizer
X = vectorizer.transform(text_data)
CV_text_data=pd.DataFrame(X.toarray(), columns=vectorizer.get_feature_names_out())
CV_text_data| 00 | 000 | 000005102000 | 000100255pixel | 0007 | 000usd | 001200201pixel | 00196 | 0028 | 0038 | ... | zorg | zorn | zsoft | zues | zug | zur | zurich | zus | zvi | zyxel | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 1 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 |
2 rows × 16170 columns
# Создадим датафрейм, в который будем сохранять вектор документа
glove_data=pd.DataFrame()
# Пробегаем по каждой строке датафрейма (по каждому документу)
for i in range(CV_text_data.shape[0]):
# Вектор одного документа с размерностью glove-модели:
one_doc = np.zeros(25)
# Пробегаемся по каждому документу, смотрим, какие слова документа присутствуют в нашем словаре
# Суммируем glove-вектора каждого известного слова в one_doc
for word in words_vocab[CV_text_data.iloc[i,:] >= 1]:
if word in glove_model.key_to_index.keys():
print(word, ': ', glove_model[word])
one_doc += glove_model[word]
print(text_data[i], ': ', one_doc)
glove_data= pd.concat([glove_data, pd.DataFrame([one_doc])], ignore_index=True) hello : [-0.77069 0.12827 0.33137 0.0050893 -0.47605 -0.50116
1.858 1.0624 -0.56511 0.13328 -0.41918 -0.14195
-2.8555 -0.57131 -0.13418 -0.44922 0.48591 -0.6479
-0.84238 0.61669 -0.19824 -0.57967 -0.65885 0.43928
-0.50473 ]
love : [-0.62645 -0.082389 0.070538 0.5782 -0.87199 -0.14816 2.2315
0.98573 -1.3154 -0.34921 -0.8847 0.14585 -4.97 -0.73369
-0.94359 0.035859 -0.026733 -0.77538 -0.30014 0.48853 -0.16678
-0.016651 -0.53164 0.64236 -0.10922 ]
python : [-0.25645 -0.22323 0.025901 0.22901 0.49028 -0.060829 0.24563
-0.84854 1.5882 -0.7274 0.60603 0.25205 -1.8064 -0.95526
0.44867 0.013614 0.60856 0.65423 0.82506 0.99459 -0.29403
-0.27013 -0.348 -0.7293 0.2201 ]
world : [ 0.10301 0.095666 -0.14789 -0.22383 -0.14775 -0.11599 1.8513
0.24886 -0.41877 -0.20384 -0.08509 0.33246 -4.6946 0.84096
-0.46666 -0.031128 -0.19539 -0.037349 0.58949 0.13941 -0.57667
-0.44426 -0.43085 -0.52875 0.25855 ]
Hello world I love python : [ -1.55058002 -0.081683 0.27991899 0.58846928 -1.00551002
-0.82613902 6.18642995 1.44844997 -0.71108004 -1.14717001
-0.78294002 0.58841 -14.32649982 -1.41929996 -1.09575997
-0.430875 0.87234702 -0.806399 0.27203003 2.23921998
-1.23571999 -1.31071102 -1.96934 -0.17641005 -0.1353 ]
computer : [ 0.64005 -0.019514 0.70148 -0.66123 1.1723 -0.58859 0.25917
-0.81541 1.1708 1.1413 -0.15405 -0.11369 -3.8414 -0.87233
0.47489 1.1541 0.97678 1.1107 -0.14572 -0.52013 -0.52234
-0.92349 0.34651 0.061939 -0.57375 ]
game : [ 1.146 0.3291 0.26878 -1.3945 -0.30044 0.77901 1.3537
0.37393 0.50478 -0.44266 -0.048706 0.51396 -4.3136 0.39805
1.197 0.10287 -0.17618 -1.2881 -0.59801 0.26131 -1.2619
0.39202 0.59309 -0.55232 0.005087]
great : [-8.4229e-01 3.6512e-01 -3.8841e-01 -4.6118e-01 2.4301e-01 3.2412e-01
1.9009e+00 -2.2630e-01 -3.1335e-01 -1.0970e+00 -4.1494e-03 6.2074e-01
-5.0964e+00 6.7418e-01 5.0080e-01 -6.2119e-01 5.1765e-01 -4.4122e-01
-1.4364e-01 1.9130e-01 -7.4608e-01 -2.5903e-01 -7.8010e-01 1.1030e-01
-2.7928e-01]
zyxel : [ 0.79234 0.067376 -0.22639 -2.2272 0.30057 -0.85676 -1.7268
-0.78626 1.2042 -0.92348 -0.83987 -0.74233 0.29689 -1.208
0.98706 -1.1624 0.61415 -0.27825 0.27813 1.5838 -0.63593
-0.10225 1.7102 -0.95599 -1.3867 ]
This is a great computer game! 00 000 zyxel : [ 1.73610002 0.74208201 0.35545996 -4.74411008 1.41543998
-0.34222007 1.78697008 -1.45404002 2.56643 -1.32184002
-1.04677537 0.27867999 -12.95450976 -1.00809997 3.15975004
-0.52662008 1.93239999 -0.89686999 -0.60924001 1.51628
-3.16624993 -0.89275002 1.86969995 -1.33607102 -2.23464306]
На основе ячейки выше, напишем функцию, которая для каждого слова в тексте ищет это слово в glove_model
def text2vec(text_data):
# Векторизуем с помощью обученного CountVectorizer
X = vectorizer.transform(text_data)
CV_text_data=pd.DataFrame(X.toarray(), columns=vectorizer.get_feature_names_out())
CV_text_data
# Создадим датафрейм, в который будем сохранять вектор документа
glove_data=pd.DataFrame()
# Пробегаем по каждой строке (по каждому документу)
for i in range(CV_text_data.shape[0]):
# Вектор одного документа с размерностью glove-модели:
one_doc = np.zeros(25)
# Пробегаемся по каждому документу, смотрим, какие слова документа присутствуют в нашем словаре
# Суммируем glove-вектора каждого известного слова в one_doc
for word in words_vocab[CV_text_data.iloc[i,:] >= 1]:
if word in glove_model.key_to_index.keys():
#print(word, ': ', glove_model[word])
one_doc += glove_model[word]
#print(text_data[i], ': ', one_doc)
glove_data = pd.concat([glove_data, pd.DataFrame([one_doc])], axis = 0)
#print('glove_data: ', glove_data)
return glove_data# Наша выборка, векторизованная через glove выглядит так.
# Попытаться понять, почему такая размерность матрицы и что за значения в ячейках
glove_data| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | -1.55058 | -0.081683 | 0.279919 | 0.588469 | -1.00551 | -0.826139 | 6.18643 | 1.44845 | -0.71108 | -1.14717 | ... | -0.430875 | 0.872347 | -0.806399 | 0.27203 | 2.23922 | -1.23572 | -1.310711 | -1.96934 | -0.176410 | -0.135300 |
| 1 | 1.73610 | 0.742082 | 0.355460 | -4.744110 | 1.41544 | -0.342220 | 1.78697 | -1.45404 | 2.56643 | -1.32184 | ... | -0.526620 | 1.932400 | -0.896870 | -0.60924 | 1.51628 | -3.16625 | -0.892750 | 1.86970 | -1.336071 | -2.234643 |
2 rows × 25 columns
Можем использовать написанную функцию на нашей реальной выборке
train_data_glove = text2vec(twenty_train['data'])
train_data_glove| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0.499258 | 3.090499 | -17.267230 | -0.997832 | -2.746523 | 1.586072 | 25.890665 | -27.914878 | -1.656527 | -6.599336 | ... | 6.840056 | 5.475214 | 4.488208 | 17.680875 | -8.254198 | 1.772313 | 6.184074 | -8.879014 | -0.216338 | -13.408250 |
| 0 | 2.579560 | -1.498130 | 1.752350 | 0.407907 | 1.528970 | 0.152390 | 0.283640 | 1.937990 | -0.659960 | -2.136370 | ... | 0.434040 | -0.274430 | 0.445033 | -0.645750 | 1.055500 | -0.850250 | -0.433200 | 0.212190 | -1.164330 | 0.905880 |
| 0 | -3.838998 | 3.339424 | 3.943270 | -1.428108 | 4.171174 | 1.918481 | 9.177921 | -3.241092 | 1.280560 | -1.781829 | ... | -3.719443 | 9.627485 | -1.198027 | -4.010974 | 1.464305 | 0.398107 | -1.013318 | -1.591893 | 4.290867 | -3.888199 |
| 0 | -1.182812 | 51.118736 | -28.117878 | -28.586500 | 24.244726 | -9.567325 | 111.896700 | -114.387355 | 12.647457 | -24.559720 | ... | 24.085172 | 39.112016 | -23.011944 | 41.767239 | -42.362471 | -16.060568 | -20.090755 | -4.457053 | -3.181086 | -90.876348 |
| 0 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | ... | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 0 | -5.839209 | 12.729176 | -6.201106 | -11.255287 | 2.878065 | -11.993908 | 11.418405 | -15.479517 | 16.634253 | -10.534109 | ... | 11.665809 | 13.761738 | -2.579413 | -7.877240 | 3.299712 | -9.241404 | -16.102834 | 5.343221 | -3.646908 | -17.216441 |
| 0 | 1.286424 | 4.341360 | -29.110272 | 5.057644 | 5.610822 | -10.954865 | 29.974885 | -38.261129 | 10.369041 | -15.652271 | ... | 12.201855 | 5.913883 | 15.941801 | 26.251438 | -9.280780 | -6.656035 | 6.158630 | -14.878060 | -1.138172 | -25.532792 |
| 0 | 6.698275 | 8.135166 | -10.347349 | 0.165546 | 2.089456 | -7.695734 | 31.593479 | -38.582911 | -3.285587 | -6.087527 | ... | 16.575447 | 15.696628 | 0.785602 | 13.384657 | -13.329273 | 1.983042 | -4.041431 | -5.571411 | 0.581912 | -32.037549 |
| 0 | -0.461364 | 4.499502 | -10.446904 | 1.502338 | -10.415269 | -8.801252 | 18.272721 | -10.398248 | -6.686713 | -0.250122 | ... | 5.758498 | 4.540718 | -0.372422 | 10.668546 | -6.377995 | 9.305774 | 11.568183 | -18.814704 | -1.204665 | -14.551387 |
| 0 | 8.348875 | 10.837893 | -1.528737 | -18.246933 | 33.950209 | -0.376254 | 48.017779 | -57.636160 | 25.231878 | -13.371779 | ... | 29.297609 | 34.342018 | -8.527635 | 13.555851 | -12.649073 | -28.706501 | -27.042654 | 4.292338 | 2.614328 | -48.834758 |
1064 rows × 25 columns
Мы получили матрицу, где каждый документ представлен вектором-строкой. Можем подавать эти вектора на вход для обучения классифкатора
clf = KNeighborsClassifier(n_neighbors = 5)
clf.fit(train_data_glove, twenty_train['target'])KNeighborsClassifier()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
403466d1d5/skrub/_reporting/_data/templates/report.js (L789)
* @returns "light" or "dark"
*/
function detectTheme(element) {
const body = document.querySelector('body');
// Check VSCode theme
const themeKindAttr = body.getAttribute('data-vscode-theme-kind');
const themeNameAttr = body.getAttribute('data-vscode-theme-name');
if (themeKindAttr && themeNameAttr) {
const themeKind = themeKindAttr.toLowerCase();
const themeName = themeNameAttr.toLowerCase();
if (themeKind.includes("dark") || themeName.includes("dark")) {
return "dark";
}
if (themeKind.includes("light") || themeName.includes("light")) {
return "light";
}
}
// Check Jupyter theme
if (body.getAttribute('data-jp-theme-light') === 'false') {
return 'dark';
} else if (body.getAttribute('data-jp-theme-light') === 'true') {
return 'light';
}
// Guess based on a parent element's color
const color = window.getComputedStyle(element.parentNode, null).getPropertyValue('color');
const match = color.match(/^rgb\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)\s*$/i);
if (match) {
const [r, g, b] = [
parseFloat(match[1]),
parseFloat(match[2]),
parseFloat(match[3])
];
// https://en.wikipedia.org/wiki/HSL_and_HSV#Lightness
const luma = 0.299 * r + 0.587 * g + 0.114 * b;
if (luma > 180) {
// If the text is very bright we have a dark theme
return 'dark';
}
if (luma < 75) {
// If the text is very dark we have a light theme
return 'light';
}
// Otherwise fall back to the next heuristic.
}
// Fallback to system preference
return window.matchMedia('(prefers-color-scheme: dark)').matches ? 'dark' : 'light';
}
function forceTheme(elementId) {
const estimatorElement = document.querySelector(`#${elementId}`);
if (estimatorElement === null) {
console.error(`Element with id ${elementId} not found.`);
} else {
const theme = detectTheme(estimatorElement);
estimatorElement.classList.add(theme);
}
}
forceTheme('sk-container-id-2');</script>
Аналогичную операци проводим с тестовой частью выборки и оцениваем качество модели
test_data_glove = text2vec(twenty_test['data'])
test_data_glove| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 2.118167 | 10.172361 | -12.929062 | -10.235129 | -5.895256 | -2.960585 | 13.365803 | -26.779260 | 0.811607 | -3.139421 | ... | 4.893207 | 4.652184 | 2.903990 | 13.077540 | -12.438726 | 2.602789 | -7.450777 | -5.202635 | -20.983743 | -14.107781 |
| 0 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | ... | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
| 0 | 4.873320 | 27.474242 | -10.773876 | 8.112889 | 0.623212 | -2.199475 | 59.094818 | -59.831612 | -2.963759 | -18.048856 | ... | 23.890453 | 12.469234 | -3.087457 | 12.981009 | -29.858967 | 13.105756 | 10.670630 | -17.378340 | -6.079635 | -41.408141 |
| 0 | -3.543091 | 1.986200 | -2.191232 | 0.006059 | -2.351208 | 0.354323 | 5.127530 | -4.232705 | 0.855633 | -1.511475 | ... | -1.034241 | 2.211705 | -3.255103 | 3.145205 | 0.115450 | -0.696187 | 0.647662 | 2.188560 | 1.237610 | -3.731906 |
| 0 | -0.810280 | -0.041986 | 0.309050 | 0.268330 | -0.735400 | -0.585930 | -0.207310 | 0.084744 | -0.024049 | -1.729000 | ... | -0.981390 | -0.348830 | -0.012710 | -0.399160 | 0.237930 | 0.608960 | 0.257370 | -0.611980 | -0.569530 | -0.627590 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 0 | -2.713570 | 17.157689 | 3.372045 | -5.867599 | -6.276993 | -12.745005 | 25.654686 | -28.636649 | 0.390143 | -9.184376 | ... | 1.313231 | 11.260476 | -3.974922 | 13.881956 | -5.040806 | 4.269640 | 9.495090 | -6.392856 | -17.443242 | -8.663642 |
| 0 | 0.330970 | 4.718122 | -1.751891 | -2.035309 | -0.343631 | -4.570000 | 6.742370 | -4.367192 | -2.187082 | -1.016773 | ... | -0.032756 | 3.431770 | -4.601260 | -0.623544 | -1.379208 | -2.430931 | -7.340540 | -4.909410 | 1.812950 | -7.518150 |
| 0 | 2.374983 | 5.077650 | -3.002268 | -3.933630 | 2.078387 | -5.470335 | 7.654599 | -12.212619 | 7.815100 | -2.849330 | ... | -0.234639 | 5.634821 | -3.027221 | 1.661357 | 0.393759 | -6.279475 | -4.927666 | -2.592987 | 1.981002 | -12.255251 |
| 0 | 1.410159 | 0.770070 | -3.955777 | -2.791190 | 5.210700 | 1.621952 | 7.780780 | -9.859718 | 8.394491 | -5.347944 | ... | 1.896947 | 4.079237 | -1.207350 | 5.772389 | 1.186500 | -7.286901 | -2.485993 | -0.947197 | 0.788222 | -9.518580 |
| 0 | 5.710371 | 6.638114 | 0.178107 | -6.406337 | 9.719197 | -8.779200 | 8.866295 | -11.437220 | 4.871641 | -11.415244 | ... | 0.851029 | 4.182781 | -3.302557 | 1.213709 | -4.294990 | -7.751608 | -12.556040 | -10.339437 | 1.448986 | -10.394629 |
708 rows × 25 columns
predict = clf.predict(test_data_glove )print (confusion_matrix(twenty_test['target'], predict))
print(classification_report(twenty_test['target'], predict))[[258 61]
[ 31 358]]
precision recall f1-score support
0 0.89 0.81 0.85 319
1 0.85 0.92 0.89 389
accuracy 0.87 708
macro avg 0.87 0.86 0.87 708
weighted avg 0.87 0.87 0.87 708