Fixed Deep RNN Multi-gpu implementation to use tf
Fixed Deep RNN Multi-gpu implementation to use tf.python.ops.nn.rnn_cell DeviceWrapper class instead of implementing custom class to achieve the same result. It would appear that the custom class shown with the code does not function.
This commit is contained in:
committed by
GitHub
parent
a2fc9671ed
commit
00c6cf9d43
@@ -107,11 +107,13 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"deletable": true,
|
||||
"editable": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Basic RNNs"
|
||||
]
|
||||
@@ -431,7 +433,7 @@
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
@@ -711,7 +713,7 @@
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
@@ -2210,6 +2212,29 @@
|
||||
"outputs, states = tf.nn.dynamic_rnn(multi_layer_cell, X, dtype=tf.float32)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Alternatively, you can use the Tensorflow class DeviceWrapper - note you can define more than one layer per gpu"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"devices = [\"/gpu:0\", \"/gpu:1\", \"/gpu:2\"] \n",
|
||||
"cells = []\n",
|
||||
"for dev in devices:\n",
|
||||
" cell = DeviceWrapper(rnn_cell.BasicRNNCell(num_units=n_neurons), dev)\n",
|
||||
" cells.append(cell)\n",
|
||||
"\n",
|
||||
"self.multiple_lstm_cells = tf.contrib.rnn.MultiRNNCell(cells, state_is_tuple=True)\n",
|
||||
"outputs, states = tf.nn.dynamic_rnn(multi_layer_cell, X, dtype=tf.float32)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 76,
|
||||
@@ -3524,7 +3549,7 @@
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
"version": 3.0
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
@@ -3538,7 +3563,7 @@
|
||||
"navigate_menu": true,
|
||||
"number_sections": true,
|
||||
"sideBar": true,
|
||||
"threshold": 6,
|
||||
"threshold": 6.0,
|
||||
"toc_cell": false,
|
||||
"toc_section_display": "block",
|
||||
"toc_window_display": false
|
||||
@@ -3546,4 +3571,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user