lab04: EOS in to_matrix
Этот коммит содержится в:
@@ -466,20 +466,18 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {
|
||||
"id": "YYGc7vo2FoTw"
|
||||
},
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def to_matrix(data, char2id, max_len=None, dtype='int64', batch_first = True):\n",
|
||||
"\n",
|
||||
" max_len = max_len if max_len else max(map(len, data))\n",
|
||||
" data = [text[:max_len] + \" EOS\" for text in data]\n",
|
||||
" data_ix = np.zeros([len(data), max_len], dtype)\n",
|
||||
" data = [text[:max_len] for text in data]\n",
|
||||
" data_ix = np.zeros([len(data), max_len+1], dtype)\n",
|
||||
"\n",
|
||||
" for i in range(len(data)):\n",
|
||||
" line_ix = [char2id[c] for c in data[i][:max_len]]\n",
|
||||
" line_ix = [char2id[c] for c in data[i][:max_len]] + [char2id[\"EOS\"]]\n",
|
||||
" data_ix[i, :len(line_ix)] = line_ix\n",
|
||||
"\n",
|
||||
" if not batch_first: # convert [batch, time] into [time, batch]\n",
|
||||
|
||||
Ссылка в новой задаче
Block a user