Andrey 1 год назад
Родитель ade1f01236
Сommit 13bfecb611

@ -466,20 +466,18 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 19, "execution_count": null,
"metadata": { "metadata": {},
"id": "YYGc7vo2FoTw"
},
"outputs": [], "outputs": [],
"source": [ "source": [
"def to_matrix(data, char2id, max_len=None, dtype='int64', batch_first = True):\n", "def to_matrix(data, char2id, max_len=None, dtype='int64', batch_first = True):\n",
"\n", "\n",
" max_len = max_len if max_len else max(map(len, data))\n", " max_len = max_len if max_len else max(map(len, data))\n",
" data = [text[:max_len] + \" EOS\" for text in data]\n", " data = [text[:max_len] for text in data]\n",
" data_ix = np.zeros([len(data), max_len], dtype)\n", " data_ix = np.zeros([len(data), max_len+1], dtype)\n",
"\n", "\n",
" for i in range(len(data)):\n", " for i in range(len(data)):\n",
" line_ix = [char2id[c] for c in data[i][:max_len]]\n", " line_ix = [char2id[c] for c in data[i][:max_len]] + [char2id[\"EOS\"]]\n",
" data_ix[i, :len(line_ix)] = line_ix\n", " data_ix[i, :len(line_ix)] = line_ix\n",
"\n", "\n",
" if not batch_first: # convert [batch, time] into [time, batch]\n", " if not batch_first: # convert [batch, time] into [time, batch]\n",

Загрузка…
Отмена
Сохранить