форкнуто от main/is_dnn
Вы не можете выбрать более 25 тем
Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.
412 строки
12 KiB
Plaintext
412 строки
12 KiB
Plaintext
{
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0,
|
|
"metadata": {
|
|
"colab": {
|
|
"provenance": [],
|
|
"gpuType": "T4"
|
|
},
|
|
"kernelspec": {
|
|
"name": "python3",
|
|
"display_name": "Python 3"
|
|
},
|
|
"language_info": {
|
|
"name": "python"
|
|
},
|
|
"accelerator": "GPU"
|
|
},
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"1 пункт"
|
|
],
|
|
"metadata": {
|
|
"id": "6tuecB2YaGZd"
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "EexxrLenVcsK"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"os.chdir('/content/drive/MyDrive/Colab Notebooks/IS_LR4')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"import tensorflow as tf\n",
|
|
"device_name = tf.test.gpu_device_name()\n",
|
|
"if device_name != '/device:GPU:0':\n",
|
|
" raise SystemError('GPU device not found')\n",
|
|
"print('Found GPU at: {}'.format(device_name))"
|
|
],
|
|
"metadata": {
|
|
"id": "JZQBWhlnZ3Kz"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"2 пункт"
|
|
],
|
|
"metadata": {
|
|
"id": "BMVDcB8saN1d"
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# загрузка датасета\n",
|
|
"from keras.datasets import imdb\n",
|
|
"vocabulary_size = 5000\n",
|
|
"index_from = 3\n",
|
|
"(X_train, y_train), (X_test, y_test) = imdb.load_data(path=\"imdb.npz\",\n",
|
|
"num_words=vocabulary_size,\n",
|
|
"skip_top=0,\n",
|
|
"maxlen=None,\n",
|
|
"seed=19,\n",
|
|
"start_char=1,\n",
|
|
"oov_char=2,\n",
|
|
"index_from=index_from\n",
|
|
")"
|
|
],
|
|
"metadata": {
|
|
"id": "dVZOh4OjaNfT"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"print(\"Размер обучающего множества X_train:\", X_train.shape)\n",
|
|
"print(\"Размер обучающих меток y_train:\", y_train.shape)\n",
|
|
"print(\"Размер тестового множества X_test:\", X_test.shape)\n",
|
|
"print(\"Размер тестовых меток y_test:\", y_test.shape)"
|
|
],
|
|
"metadata": {
|
|
"id": "ONS-UYoqb7VM"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"3 пункт"
|
|
],
|
|
"metadata": {
|
|
"id": "vrPoEAUFcZ6P"
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# создание словаря для перевода индексов в слова\n",
|
|
"# заргузка словаря \"слово:индекс\"\n",
|
|
"word_to_id = imdb.get_word_index()\n",
|
|
"# уточнение словаря\n",
|
|
"word_to_id = {key:(value + index_from) for key,value in word_to_id.items()}\n",
|
|
"word_to_id[\"<PAD>\"] = 0\n",
|
|
"word_to_id[\"<START>\"] = 1\n",
|
|
"word_to_id[\"<UNK>\"] = 2\n",
|
|
"word_to_id[\"<UNUSED>\"] = 3\n",
|
|
"# создание обратного словаря \"индекс:слово\"\n",
|
|
"id_to_word = {value:key for key,value in word_to_id.items()}"
|
|
],
|
|
"metadata": {
|
|
"id": "ldO9EDteccDO"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"idx = 19\n",
|
|
"review_indices = X_train[idx]\n",
|
|
"print(\"Отзыв в виде индексов:\\n\", review_indices)\n",
|
|
"\n",
|
|
"review_text = \" \".join(id_to_word.get(i, \"?\") for i in review_indices)\n",
|
|
"print(\"\\nОтзыв в виде текста:\\n\", review_text)\n",
|
|
"\n",
|
|
"print(\"\\nДлина отзыва (количество индексов):\", len(review_indices))\n",
|
|
"\n",
|
|
"label = y_train[idx]\n",
|
|
"class_name = \"Positive\" if label == 1 else \"Negative\"\n",
|
|
"print(\"Метка класса:\", label, \"| Класс:\", class_name)\n"
|
|
],
|
|
"metadata": {
|
|
"id": "dgLjpdWVcaT1"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"4 пункт"
|
|
],
|
|
"metadata": {
|
|
"id": "sJRCzgSXdyaX"
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"print(\"Максимальная длина отзыва:\", len(max(X_train, key=len)))\n",
|
|
"print(\"Минимальная длина отзыва:\", len(min(X_train, key=len)))"
|
|
],
|
|
"metadata": {
|
|
"id": "NKuA8LSfd4nq"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"5 пункт"
|
|
],
|
|
"metadata": {
|
|
"id": "xRZqJEnkensA"
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# предобработка данных\n",
|
|
"from tensorflow.keras.utils import pad_sequences\n",
|
|
"max_words = 500\n",
|
|
"X_train = pad_sequences(X_train, maxlen=max_words, value=0, padding='pre', truncating='post')\n",
|
|
"X_test = pad_sequences(X_test, maxlen=max_words, value=0, padding='pre', truncating='post')\n",
|
|
"\n"
|
|
],
|
|
"metadata": {
|
|
"id": "FcSsuqWGeqDE"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"6 пункт"
|
|
],
|
|
"metadata": {
|
|
"id": "Mrh00gk2gHaS"
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"print(\"Максимальная длина отзыва после предобработки:\", len(max(X_train, key=len)))\n",
|
|
"print(\"Минимальная длина отзыва после предобработки:\", len(min(X_train, key=len)))"
|
|
],
|
|
"metadata": {
|
|
"id": "woU0e9UMeqQi"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"7 пункт\n"
|
|
],
|
|
"metadata": {
|
|
"id": "jvOmAItEgJsq"
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"idx = 19\n",
|
|
"review_indices = X_train[idx]\n",
|
|
"print(\"Отзыв в виде индексов:\\n\", review_indices)\n",
|
|
"\n",
|
|
"review_text = \" \".join(id_to_word.get(i, \"?\") for i in review_indices)\n",
|
|
"print(\"\\nОтзыв в виде текста:\\n\", review_text)\n",
|
|
"\n",
|
|
"print(\"\\nДлина отзыва (количество индексов):\", len(review_indices))\n",
|
|
"\n",
|
|
"label = y_train[idx]\n",
|
|
"class_name = \"Positive\" if label == 1 else \"Negative\"\n",
|
|
"print(\"Метка класса:\", label, \"| Класс:\", class_name)"
|
|
],
|
|
"metadata": {
|
|
"id": "LGoRw4AHgJ9s"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"8 пункт"
|
|
],
|
|
"metadata": {
|
|
"id": "5CAdHfC_hVfo"
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"\n",
|
|
"print(\"Предобработанное обучающее множество X_train (первые 5 примеров):\")\n",
|
|
"print(X_train[:5])\n",
|
|
"\n",
|
|
"print(\"\\nПредобработанное тестовое множество X_test (первые 5 примеров):\")\n",
|
|
"print(X_test[:5])\n",
|
|
"\n",
|
|
"\n",
|
|
"print(\"Размер обучающего множества X_train:\", X_train.shape)\n",
|
|
"print(\"Размер обучающих меток y_train:\", y_train.shape)\n",
|
|
"print(\"Размер тестового множества X_test:\", X_test.shape)\n",
|
|
"print(\"Размер тестовых меток y_test:\", y_test.shape)"
|
|
],
|
|
"metadata": {
|
|
"id": "OAsrX4WdhYAx"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"9 пункт"
|
|
],
|
|
"metadata": {
|
|
"id": "GTQnxs9AjEaz"
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"from tensorflow.keras.models import Sequential\n",
|
|
"from tensorflow.keras.layers import Embedding, LSTM, Dropout, Dense\n",
|
|
"\n",
|
|
"\n",
|
|
"model = Sequential()\n",
|
|
"model.add(Embedding(input_dim=vocabulary_size, output_dim=32, input_length=max_words))\n",
|
|
"model.add(LSTM(64))\n",
|
|
"model.add(Dropout(0.5))\n",
|
|
"model.add(Dense(1, activation='sigmoid'))\n",
|
|
"\n",
|
|
"\n",
|
|
"model.compile(\n",
|
|
" optimizer='adam',\n",
|
|
" loss='binary_crossentropy',\n",
|
|
" metrics=['accuracy']\n",
|
|
")\n",
|
|
"\n",
|
|
"model.build(input_shape=(None, max_words))\n",
|
|
"model.summary()"
|
|
],
|
|
"metadata": {
|
|
"id": "R_-5hCfGjibD"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"# Обучение модели\n",
|
|
"history = model.fit(\n",
|
|
" X_train,\n",
|
|
" y_train,\n",
|
|
" epochs=5,\n",
|
|
" batch_size=64,\n",
|
|
" validation_split=0.2,\n",
|
|
" verbose=1\n",
|
|
")"
|
|
],
|
|
"metadata": {
|
|
"id": "JyegJgdBlU4P"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"print(\"\\nКачество обучения по эпохам\")\n",
|
|
"for i in range(5):\n",
|
|
" train_acc = history.history['accuracy'][i]\n",
|
|
" val_acc = history.history['val_accuracy'][i]\n",
|
|
" print(f\"Эпоха {i+1}: accuracy = {train_acc:.4f}, val_accuracy = {val_acc:.4f}\")"
|
|
],
|
|
"metadata": {
|
|
"id": "9etKZpeVmeNj"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"10 пункт"
|
|
],
|
|
"metadata": {
|
|
"id": "sulUG0iDmukX"
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"test_loss, test_accuracy = model.evaluate(X_test, y_test, verbose=0)\n",
|
|
"\n",
|
|
"print(\"Качество классификации на тестовой выборке\")\n",
|
|
"print(f\"Test accuracy: {test_accuracy:.4f}\")"
|
|
],
|
|
"metadata": {
|
|
"id": "hVdh7SIAnWrx"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"y_score = model.predict(X_test)\n",
|
|
"y_pred = [1 if y_score[i,0]>=0.5 else 0 for i in range(len(y_score))]\n",
|
|
"from sklearn.metrics import classification_report\n",
|
|
"print(classification_report(y_test, y_pred, labels = [0, 1], target_names=['Negative', 'Positive']))"
|
|
],
|
|
"metadata": {
|
|
"id": "-p7pfGE7mwZi"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"source": [
|
|
"from sklearn.metrics import roc_curve, auc\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"fpr, tpr, thresholds = roc_curve(y_test, y_score)\n",
|
|
"plt.plot(fpr, tpr)\n",
|
|
"plt.grid()\n",
|
|
"plt.xlabel('False Positive Rate')\n",
|
|
"plt.ylabel('True Positive Rate')\n",
|
|
"plt.title('ROC')\n",
|
|
"plt.show()\n",
|
|
"print('Area under ROC is', auc(fpr, tpr))"
|
|
],
|
|
"metadata": {
|
|
"id": "A_ZMEN_YpmAq"
|
|
},
|
|
"execution_count": null,
|
|
"outputs": []
|
|
}
|
|
]
|
|
} |