{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPython 3.5.5\n", "IPython 6.3.0\n", "\n", "numpy 1.14.3\n", "sklearn 0.19.1\n", "scipy 1.0.1\n", "matplotlib 2.2.2\n", "tensorflow 1.8.0\n", "gym 0.10.5\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -v -p numpy,sklearn,scipy,matplotlib,tensorflow,gym" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**16장 – 강화 학습**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "_이 노트북은 15장에 있는 모든 샘플 코드와 연습문제 해답을 가지고 있습니다._" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 설정" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "파이썬 2와 3을 모두 지원합니다. 공통 모듈을 임포트하고 맷플롯립 그림이 노트북 안에 포함되도록 설정하고 생성한 그림을 저장하기 위한 함수를 준비합니다:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# 파이썬 2와 파이썬 3 지원\n", "from __future__ import division, print_function, unicode_literals\n", "\n", "# 공통\n", "import numpy as np\n", "import os\n", "import sys\n", "\n", "# 일관된 출력을 위해 유사난수 초기화\n", "def reset_graph(seed=42):\n", " tf.reset_default_graph()\n", " tf.set_random_seed(seed)\n", " np.random.seed(seed)\n", "\n", "# 맷플롯립 설정\n", "from IPython.display import HTML\n", "import matplotlib\n", "import matplotlib.animation as animation\n", "import matplotlib.pyplot as plt\n", "plt.rcParams['axes.labelsize'] = 14\n", "plt.rcParams['xtick.labelsize'] = 12\n", "plt.rcParams['ytick.labelsize'] = 12\n", "\n", "# 한글출력\n", "plt.rcParams['font.family'] = 'NanumBarunGothic'\n", "plt.rcParams['axes.unicode_minus'] = False\n", "\n", "# 그림을 저장할 폴더\n", "PROJECT_ROOT_DIR = \".\"\n", "CHAPTER_ID = \"rl\"\n", "\n", "def save_fig(fig_id, tight_layout=True):\n", " path = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID, fig_id + \".png\")\n", " if tight_layout:\n", " plt.tight_layout()\n", " plt.savefig(path, format='png', dpi=300)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# OpenAI 짐(gym)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이 노트북에서는 강화 학습 알고리즘을 개발하고 비교할 수 있는 훌륭한 도구인 [OpenAI 짐(gym)](https://gym.openai.com/)을 사용합니다. 짐은 *에이전트*가 학습할 수 있는 많은 환경을 제공합니다. `gym`을 임포트해 보죠:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import gym" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "그다음 MsPacman 환경 버전 0을 로드합니다." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "env = gym.make('MsPacman-v0')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`reset()` 메서드를 호출하여 환경을 초기화합니다. 이 메서드는 하나의 관측을 반환합니다:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "obs = env.reset()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "관측은 환경마다 다릅니다. 여기에서는 [width, height, channels] 크기의 3D 넘파이 배열로 저장되어 있는 RGB 이미지입니다(채널은 3개로 빨강, 초록, 파랑입니다). 잠시 후에 보겠지만 다른 환경에서는 다른 오브젝트가 반환될 수 있습니다." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(210, 160, 3)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "obs.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "환경은 `render()` 메서드를 사용하여 화면에 나타낼 수 있고 렌더링 모드를 고를 수 있습니다(렌더링 옵션은 환경마다 다릅니다). 이 경우에는 `mode=\"rgb_array\"`로 지정해서 넘파이 배열로 환경에 대한 이미지를 받겠습니다:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "img = env.render(mode=\"rgb_array\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이미지를 그려보죠:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAVUAAAGoCAYAAAD2AHbpAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADBpJREFUeJzt3b+u29YdB3Cx6EsUrt/A6BCgQJYUWQJ0yNYlk18hd3Mfod7uM2Tyks1DgSxGsxgokKHI0tk2+hjsYNxrWZEoSvxS5w8/n01Xonl4JH3949EhzzCO4w6AjN+VbgBAT4QqQJBQBQgSqgBBQhUgSKgCBAlVgCChChAkVAGCfl+6AfuGYXB5F1ClcRyHOa9TqQIEVVWpvv/++9JNAFhEpQoQVFWlOuWPP/6hdBOq8f5v/zv5nH7iGJ+Zeab6aS6VKkCQUAUIEqoAQUIVIEioAgQJVYAgoQoQJFQBgpqZ/H/OksnN1257bqLwWtsuoZ/KtqlE/y+hny6nUgUIEqoAQUIVIGgYx3ruC/3h7u5kY9z04ZPSY0a0x2dmnql+enJ/7ybVALcmVAGChCpAkFAFCBKqAEFCFSBIqAIECVWAoG5uqLJEYgXFY3qbVL1WP/GRz8s8tfeTShUgqJtK1WV4QA1UqgBBQhUgSKgCBAlVgCChChAkVAGCuplS1ZoWV1OlrNZWU90qlSpAkFAFCBKqAEHGVAspNRZlDKxdJd47n5fLqVQBgrqpVP2PCtRApQoQJFQBgoQqQJBQBQgSqgBBQhUgqJspVUu0Nh3LhQNcyoUDt6NSBQgSqgBBQhUgSKgCBAlVgCChChAkVAGChCpAUDeT/5esNHnttktWmlxr27WOdcm2pfppivdu3rZb66cElSpAkFAFCBKqAEHDOI6l2/Dow93dycZs9eYMwO1Mjcc+ub8f5vwbKlWAIKEKECRUAYKEKkCQUAUIEqoAQUIVIEioAgQ1c0OVczdnmLLWhQNr3exj6X6v1Vp7a9VaP/p+ZKlUAYKEKkCQUAUIEqoAQUIVIEioAgQJVYAgoQoQ1Mzk/zWVWH2x1MToJUqvUtmD1vqwxgn8NfbTPpUqQJBQBQgSqgBBxlR3ZcZoah8XOqbFNtemtT4s1d7W+mmfShUgSKgCBAlVgCChChAkVAGChCpAkFAFCBKqAEFVTf5vecLvLbXWT621t1b6cZ61+mm8n/c6lSpAUFWVKizx9tmbyee//PXrm7SDbVOpAgSpVGneuQr18HUqVtakUgUIUqnStLfP3vym8jxVkT78/dg2kKJSBQgSqjTv7bM3R8dVT/0d1iRUAYK6GVNdsqRta0tUlzrWa7dd61h3u93u/X9P/5o/NW66Zpu8d8u19n3e102osl2nTvGd+lOC03+AoGEcx9JteDQ8fVlPY2YodWrEJ9dWo6ZUra+378f47sUw53UqVYAgoQoQJFQBgoQqQJBQBQgSqjRt/1f8L3/9+qLHsAahChDkiiqad1h9XvoYklSqAEFVVarnrsC4VqkrN9Y6nrWUunlMb/RjWaX7SaUKECRUAYKEKkCQUAUIEqoAQUIVIEioAgQJVYCgqib/l1J69cVWrNVPrfXxksnlPmvztNxPKlWAIKEKEOT0f1f/6UQt9NNy+nCelvtJpQoQJFQBgoQqQJBQBQgSqgBBQhUgSKgCBAlVgCChChBU1RVVS66iKLGCYstXfdAen7fllvTheD/vdSpVgCChChAkVAGChCpAkFAFCBKqAEFCFSBIqAIEVTX5v5RrV248d8FBiW3PTW4utS0ftfbetfgZL02lChAkVAGChCpA0DCOY+k2PBqevry6MS2PwfRuydjaJV4/fzX5/Lc/fBfZz62Oh+uslQXjuxfDnNepVAGC/PpP885VqIevS1WscIxQpWmvn7/6TUieCs+Hvx/bBlKc/gMECVWa9/r5q6NDAKf+DmsSqgBBxlRp3qnxUeOmlKBSBQiqqlItsSLqEq2195xWJ62fGjftbTzV5225W/ShShUgSKgCBAlVgCChChAkVAGChCpAkFClafsT/L/94buLHsMahCpAUFWT/0uxasA8tfbTYfV56eNbqrUPa9NyP6lUAYKEKkCQUAUIMqa6q3+MphZr9VNvNwqZ4rM2T8v9pFIFCBKqAEFCFSBIqAIECVWAIKEKECRUAYKEKkDQMI5j6TY8Gp6+rKcxM5ybtN7yBGZYqrfvx/juxTDndSpVgCCXqdKN//zjX0f//qe//+Wz5x8ewxpUqgBBKlWad6pCPXx+v2JVrbIWlSpAkEqVbh2OpRpT5RaEKs07DMnDEIVbcvoPENRNpbpk9cUSKzcumRhd6liv3XatYz217anKtWSb5mzb23u3RGvf530qVYAgoQoQJFQBgtxQZYHebhjRi3O/+ptSdRu9fT/cUAWggG5+/We7zEelJipVgCCVKpthLJVbUKkCBKlUad7hjVNOPQ+3IFTphvCkBk7/AYKqqlTPTRa+VmuTjM8pdcOIGm88s5YabxTS4n5LKPWZeaBSBQgSqgBBQhUgSKgCBAlVgCChChAkVAGChCpAUFWT/0upcVJ7T5PPe7sDvPeu7H5r/7yoVAGChCpAkFAFCDKmuiszRlNqXGhLx7oW712/+01QqQIECVWAIKEKECRUAYKEKkCQUAUIEqoAQUIVIKiZyf81TgausU290cftavHCgcRKrCpVgCChChAkVAGChCpAkFAFCBKqAEFCFSBIqAIECVWAoGauqDpnyZK21267ZPneLS39W2M/1dimc7b03pX4PqeoVAGChCpA0DCOY+k2PPpwd3eyMW6s8Ump05vSp1U98N7VbaqfntzfD3P+DZUqQJBQBQgSqgBBQhUgSKgCBAlVgCChChAkVAGChCpAUDc3VFkisSztMb1dqVLieNZ6b85Z61h7+0xM2er3SqUKECRUAYKEKkCQUAUIEqoAQUIVIEioAgQJVYCgbib/t7ZcRIsrcpZYdbZGW1oltMX3rnQWqFQBgoQqQJBQBQjqZky1xrGdKaXau2S/127b2ntzTok+XLptS/tcqnSbVaoAQUIVIEioAgQJVYAgoQoQJFQBgoQqQJBQBQjqZvI/H/W2+mhvSr0/3I5KFSCom0q19O2+AHY7lSpAlFAFCBKqAEFCFSBIqAIECVWAoG6mVLWmxVUqr9XbsfZ2PFO2dKwpKlWAIKEKECRUAYKMqRaypbGo3o61t+OZsqVjTVGpAgR1U6n6HxWogUoVIEioAgQJVYAgoQoQJFQBgoQqQFA3U6qWaG06VmvtXaK3Y+3teKZs6Vj3qVQBgoQqQJBQBQgSqgBBQhUgSKgCBAlVgCChChDUzeT/qVUfz01CvnbbJStNrrXtWse6ZNtS/TTFezdv2631U4JKFSBIqAIECVWAoGEcx9JtePTh7u5kY7Z6cwbgdqbGY5/c3w9z/g2VKkCQUAUIEqoAQUIVIEioAgQJVYAgoQoQJFQBgpq5ocq5mzMA1EClChAkVAGChCpAUFU3VBmGoZ7GAOwZx9ENVQBurZlf/xN++unPu91ut/vmm39/9njfw3Pp/U7tc639wlr++cUXnz3+6y+/FGpJfTZx+j8nTA8lQm5/v3P2mdovrOUhTA9DdD9kew1Yp/8ABWzi9P9UhXpJ5Zra79r7hDVNVaEPz52qZrdCpQoQtIlK9ZRjFeQt93nL/cKtbL1iVakCBG26Ui1RJapM6cmxavRwutXWqFQBgjZdqQLXuaRC3drYqkoVIGjTlapf/2GZrVWhc6hUAYI2ce3/g0uuYkpWkKX2C2lTv+wfzk89/Hvr5l77v6nT/1KXiF5yQxVoXS8hei2n/wBBmzr9P3Sr+6me2+ct9gtJvZ7iT3HrP4ACNl2pAsylUgUoQKgCBAlVgCChChAkVAGChCpAkFAFCNrUtf9wKz/ff3X071/d/XzjlnBrJv9D2H6gPoToYcgK1/aY/A9QgEoVQo5VqFOvmXod9VGpAhQgVAGChCpAkFAFCBKqAEF+/Ycw81T75Nd/gAJUqrACl6n2Z26lKlQBZnD6D1CAUAUIEqoAQUIVIEioAgQJVYAgoQoQJFQBgoQqQJBQBQgSqgBBQhUgSKgCBAlVgCChChAkVAGChCpAkFAFCBKqAEFCFSBIqAIECVWAIKEKECRUAYKEKkCQUAUIEqoAQUIVIEioAgQN4ziWbgNAN1SqAEFCFSBIqAIECVWAIKEKECRUAYKEKkCQUAUIEqoAQUIVIEioAgQJVYAgoQoQJFQBgoQqQJBQBQgSqgBBQhUgSKgCBAlVgCChChAkVAGChCpAkFAFCBKqAEH/Bx83TBg0+8FmAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(5,6))\n", "plt.imshow(img)\n", "plt.axis(\"off\")\n", "save_fig(\"MsPacman\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "1980년대로 돌아오신 걸 환영합니다! :)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이 환경에서는 렌더링된 이미지가 관측과 동일합니다(하지만 많은 경우에 그렇지 않습니다):" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(img == obs).all()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "환경을 그리기 위한 유틸리티 함수를 만들겠습니다:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "def plot_environment(env, figsize=(5,6)):\n", " plt.figure(figsize=figsize)\n", " img = env.render(mode=\"rgb_array\")\n", " plt.imshow(img)\n", " plt.axis(\"off\")\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "환경을 어떻게 다루는지 보겠습니다. 에이전트는 \"행동 공간\"(가능한 행동의 모음)에서 하나의 행동을 선택합니다. 이 환경의 액션 공간을 다음과 같습니다:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Discrete(9)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "env.action_space" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`Discrete(9)`는 가능한 행동이 정수 0에서부터 8까지있다는 의미입니다. 이는 조이스틱의 9개의 위치(0=중앙, 1=위, 2=오른쪽, 3=왼쪽, 4=아래, 5=오른쪽위, 6=왼쪽위, 7=오른쪽아래, 8=왼쪽아래)에 해당합니다." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "그다음 환경에게 플레이할 행동을 알려주고 게임의 다음 단계를 진행시킵니다. 왼쪽으로 110번을 진행하고 왼쪽아래로 40번을 진행해 보겠습니다:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "env.reset()\n", "for step in range(110):\n", " env.step(3) #왼쪽\n", "for step in range(40):\n", " env.step(8) #왼쪽아래" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "어디에 있을까요?" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAASYAAAFrCAYAAAB8AiRoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAACq1JREFUeJzt3UGO20YWBmBpMJcYeHIFLwxkE8BeBDCQfYBcI72Lj5Ds+hwNZB/AQC9ioDcBvMgVEmOOwdmkG7QssimSYv1V/L6VopZcRUr58Up6LB27rjsAJPlX6QkAnBJMQBzBBMQRTEAcwQTEEUxAHMEExBFMQBzBBMQRTECcf5eeQN/xeHR9DDSs67rjlMdFBdPfP/5YegpAgKhgGvLfX/9TegpX9ff3/zt7f+vHvQd7fW2HjnsqnzEBcQQTEEcwAXEEExBHMAFxBBMQRzABcaroYxoy1isx1CdyaV/JFmPMsdZxjM1pizFKjl3y+K49pznPWXOMpVRMQBzBBMQ5Jv2u3Kebm7OT2Wv7fuvHvQd7fW2HjvvF7e2ki3hVTEAcwQTEEUxAHMEExBFMQBzBBMQRTECcqi9JmWPplp99NfWirHncrdvr6+qSFIARggmII5iAOIIJiCOYgDiCCYgjmIA4u+tj2kLiFqws53XdjooJiCOYgDiWclewRdm9t9I+wV5f1xJzUjEBcQQTEEcwAXEEExBHMAFxBBMQZ3ftAnv9OjbxuPfg2ue91ddVxQTEEUxAHMEExBFMQBzBBMQRTEAcwQTEqbqPac4vhF66Q2DJMcZ6VFoZo+TYLYyR+v/AUiomII5gAuIcu64rPYcnn25uzk6m1bZ7aNXQ8u/F7e1xyvNVTEAcwQTEEUxAHMEExBFMQBzBBMQRTECcKi5JGWuJH7JW79MW7fhzjm/Imj1fa86rlNTzsdf351QqJiCOYALiCCYgjmAC4ggmII5gAuIIJiBOFX1Ma1pz+9eSY1wqcU6lJZ6Tvb4/T6mYgDiCCYizu6XcFuVqUkn8KHFOpSWek72+P0+pmIA4ggmII5iAOIIJiCOYgDiCCYgjmIA4UX1MNfRXXEPqcafOq5S9no81j7u7nfY4FRMQRzABcQQTEEcwAXEEExBHMAFxBBMQJ6qP6VJzfh45cevSOXNqZYySY7cyxqXWmtPYc5ZSMQFxBBMQ59h1Xek5PDl+9UvOZP5RooyFqWp7f3Z//XSc8jgVExBHMAFxBBMQRzABcQQTEEcwAXEEExAn6pKUsZ6MS23Rw7HmfNeyxaULNXE+litx3ComII5gAuIIJiCOYALiCCYgjmAC4kS1C2xhix0CE6153Innas5X2t4Ln0s6bhUTEEcwAXF2t5RLKle3tNfjHrPXc1LDcauYgDiCCYgjmIA4ggmII5iAOIIJiCOYgDhRfUyX9ldssbNeDT0fbMN74UuXnpPudtrjVExAnKiKCR59+8M3Z++/v3vYeCaUIJiIMhRIp38XUG2zlAPiCCZi9Kul04ro/u7hs/ueq6yom6UcUc4t0Szb9kfFBMTZXcV06baiY71Slz5nrcdvNUYJ55Zo11q2tfA6bfH+LEHFBMQRTECc3S3lLi1Xt/gVkdQxSrq/ezj7Ld2ay7oWXqdWfuXmlIoJiLO7iols/croXN/SNSon8ggmovQDZ8tv6MhiKQfEEUzE6C/dxm7rBG+fpRxRpoTTuf+mLSomIE5UxbTFVrmXSpzTmBp6VGrlvfCla50TFRMQRzABcQQTEEcwAXEEExBHMAFxBBMQJ6qPaQs1bCt6DXs97jF7PSc1HLeKCYgjmIA4u1vKJZWrW1rzuGu7NGOI90IuFRMQRzABcQQTEEcwAXEEExBHMAFxBBMQ59h1Xek5PDl+9UvOZP4x1rNTQz8Ibavt/dn99dNxyuNUTEAcwQTEEUxAHMEExBFMQBzBBMSpetuTOV+V1rB7H5e/TnNe11bGuNRacxp7zlIqJiCOYALi6Px+Rm2dtexLbe9Pnd9AtQQTEEcwAXEEExBHMAFxBBMQRzABcaIuSVnzF14TeziGJF6GsOYYa2rlnLTw/rwmFRMQRzABcQQTEEcwAXEEExBHMAFxBBMQJ6qPaQst9JUcDnpwTrVyPloZYykVExBHMAFxdreUSypXl9jiOGo6V62cj1bGWErFBMQRTEAcwQTEEUxAHMEExBFMQBzBBMSpoo+pZN9FDT0fKZyr7aX2PS3djlfFBMSpomJK9u0P35y9//7uYeOZQDtUTAsMhdLj38b+DgwTTEAcwTRDvxq6v3v4YtnW/29VE1zOZ0wLjAUSMJ+KCYijYlrgdJm25bItcQvWNedUcuy15rTFGGP9QmudqzljLKViAuIIJiCOpdxCjx9495dx93cPV1/WJV6KsOacSo5d0xhz5pR4bk8JpgX6AXTaInAusIBpLOWAOCqmBfrVUMlv6KA1KiYgjmCaoX8ZypTbwGUE0wL94JlyG5hGMAFxdvfh99ItP/tq2kp2i7mueW6HlOyVStTq+1nFBMQRTEAcwQTEEUxAHMEExBFMQJzdtQtsoaadDsfmtMVxrGWL40t8XbdgB0uAg2ACAgkmII7PmK5gr1uwltTKFrOJ57zEnFRMQBzBBMQRTEAcwQTEEUxAHMEExNEuEKy2HSFrs8X5bYFLUgAOggkIJJiAOIIJiCOYgDiCCYhTdbvAnz///sV9L9+9efrb422gLtUG07lQ6t//8t2bz25vKXF71DlqOo6a5jqmleNYylIOiFNdxfTc8q3/OEs5qFN1wfSccyG1tVbK7pqOo6a5jkk8DjtYAhwarJhKVkrAOqqumF6+e/PZ50g+U4I2VB1MQJsEExCn2mAaWradLu+A+lQbTEC7jl3XlZ7Dk083N2cnk9jbAQwburTmxe3tccrzVUxAHMEExBFMQBzBBMQRTEAcwQTEEUxAnKp3F5jz08WXbl1acoyx/q1Wxig5dgtjpP4/sJSKCYgjmIA4LkkBVueSFKA5ggmII5iAOIIJiCOYgDiCCYgjmIA4VVySMtYSD7RHxQTEEUxAnKhLUo7HY85kgNV1XTfpkpQqPmOa6/37rw+Hw+Hw9u0fT7f73r79o4ox2LffXr16uv3dx48FZ7IdSzkgTrNLuffvv36qVs5VMn1zq5qpY6iamOOxUupXSbVXT7tfyp0uraaGVNoY7NdQ8Dze/9urV1WG0xSWckCcZiumU+cqmxrHgEffffx4drnXAhUTEGc3FdMWFYwqiWt57oPw1uwmmKBWpx9ynwuk1pZ0lnJAnN1UTD78pmatVUTPabbB8nCY3k+0JESmjCGkmGPoM6R+H9PpfemmNlhaygFxml7KbdGJrdubUmqpkuZoeinXt8WV/3YX4BpqXLINsZQDqrWbigkoT8UEVEswAXEEExBHMAFxBBMQp+kGS+ry4fb10+3XNx8G//ac0+cOPf/c48igYgLiqJgobko19Fx1M/RvnN7/+ubD030fbl+rmkJpsCTG2FJuzvOe+/fmjsd8GiyBalnKUS0VT7tUTEAcwQTEsZSjSlOWcaffwJ3eN/ZcylIxAXG0C1DM1G7uNb7q1/mdYWq7gGACNqOPCaiWYALiCCYgjmAC4ggmII5gAuIIJiCOYALiCCYgjmAC4ggmII5gAuIIJiCOYALiCCYgjmAC4ggmII5gAuIIJiCOYALiCCYgjmAC4ggmII5gAuJE/eAlwOGgYgICCSYgjmAC4ggmII5gAuIIJiCOYALiCCYgjmAC4ggmII5gAuIIJiCOYALiCCYgjmAC4ggmII5gAuIIJiCOYALiCCYgjmAC4ggmII5gAuIIJiCOYALi/B9BBT4hIRRVYwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_environment(env)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "사실 `step()` 함수는 여러 개의 중요한 객체를 반환해 줍니다:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "obs, reward, done, info = env.step(0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "앞서 본 것처럼 관측은 보이는 환경을 설명합니다. 여기서는 210x160 RGB 이미지입니다:" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(210, 160, 3)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "obs.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "환경은 마지막 스텝에서 받을 수 있는 보상을 알려 줍니다:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.0" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reward" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "게임이 종료되면 환경은 `done=True`를 반환합니다:" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "False" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "done" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "마지막으로 `info`는 환경의 내부 상태에 관한 추가 정보를 제공하는 딕셔너리입니다. 디버깅에는 유용하지만 에이전트는 학습을 위해서 이 정보를 사용하면 안됩니다(학습이 아니고 속이는 셈이므로)." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'ale.lives': 3}" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "info" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "10번의 스텝마다 랜덤한 방향을 선택하는 식으로 전체 게임(3개의 팩맨)을 플레이하고 각 프레임을 저장해 보겠습니다:" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "frames = []\n", "\n", "n_max_steps = 1000\n", "n_change_steps = 10\n", "\n", "obs = env.reset()\n", "for step in range(n_max_steps):\n", " img = env.render(mode=\"rgb_array\")\n", " frames.append(img)\n", " if step % n_change_steps == 0:\n", " action = env.action_space.sample() # play randomly\n", " obs, reward, done, info = env.step(action)\n", " if done:\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이제 애니메이션으로 한번 보죠:" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "def update_scene(num, frames, patch):\n", " plt.close() # 이전 그래프를 닫지 않으면 두 개의 그래프가 출력되는 matplotlib의 버그로 보입니다.\n", " patch.set_data(frames[num])\n", " return patch,\n", "\n", "def plot_animation(frames, figsize=(5,6), repeat=False, interval=40):\n", " fig = plt.figure(figsize=figsize)\n", " patch = plt.imshow(frames[0])\n", " plt.axis('off')\n", " return animation.FuncAnimation(fig, update_scene, fargs=(frames, patch), \n", " frames=len(frames), repeat=repeat, interval=interval)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "video = plot_animation(frames)\n", "HTML(video.to_html5_video()) # HTML5 동영상으로 만들어 줍니다" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "환경을 더 이상 사용하지 않으면 환경을 종료하여 자원을 반납합니다:" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "env.close()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "첫 번째 에이전트를 학습시키기 위해 간단한 Cart-Pole 환경을 사용하겠습니다." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 간단한 Cart-Pole 환경" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Cart-Pole은 아주 간단한 환경으로 왼쪽이나 오른쪽으로 움직일 수 있는 카트와 카트 위에 수직으로 서 있는 막대로 구성되어 있습니다. 에이전트는 카트를 왼쪽이나 오른쪽으로 움직여서 막대가 넘어지지 않도록 유지시켜야 합니다." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n" ] } ], "source": [ "env = gym.make(\"CartPole-v0\")" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "obs = env.reset()" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([-0.00326381, -0.00892555, 0.02935513, -0.00128701])" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "obs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "관측은 4개의 부동소수로 구성된 1D 넘파이 배열입니다. 각각 카트의 수평 위치, 속도, 막대의 각도(0=수직), 각속도를 나타냅니다. 이 환경을 렌더링하려면 먼저 몇 가지 이슈를 해결해야 합니다." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 렌더링 이슈 해결하기" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "일부 환경(Cart-Pole을 포함하여)은 `rgb_array` 모드를 설정하더라도 별도의 창을 띄우기 위해 디스플레이 접근이 필수적입니다. 일반적으로 이 창을 무시하면 됩니다. 주피터가 헤드리스(headless) 서버로 (즉 스크린이 없이) 실행중이면 예외가 발생합니다. 이를 피하는 한가지 방법은 Xvfb 같은 가짜 X 서버를 설치하는 것입니다. `xvfb-run` 명령을 사용해 주피터를 실행합니다:\n", "\n", " $ xvfb-run -s \"-screen 0 1400x900x24\" jupyter notebook\n", " \n", "주피터가 헤드리스 서버로 실행 중이지만 Xvfb를 설치하기 번거롭다면 Cart-Pole에 대해서는 다음 렌더링 함수를 사용할 수 있습니다:" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "from PIL import Image, ImageDraw\n", "\n", "try:\n", " from pyglet.gl import gl_info\n", " openai_cart_pole_rendering = True # 문제없음, OpenAI 짐의 렌더링 함수를 사용합니다\n", "except Exception:\n", " openai_cart_pole_rendering = False # 가능한 X 서버가 없다면, 자체 렌더링 함수를 사용합니다\n", "\n", "def render_cart_pole(env, obs):\n", " if openai_cart_pole_rendering:\n", " # OpenAI 짐의 렌더링 함수를 사용합니다\n", " return env.render(mode=\"rgb_array\")\n", " else:\n", " # Cart-Pole 환경을 위한 렌더링 (OpenAI 짐이 처리할 수 없는 경우)\n", " img_w = 600\n", " img_h = 400\n", " cart_w = img_w // 12\n", " cart_h = img_h // 15\n", " pole_len = img_h // 3.5\n", " pole_w = img_w // 80 + 1\n", " x_width = 2\n", " max_ang = 0.2\n", " bg_col = (255, 255, 255)\n", " cart_col = 0x000000 # 파랑 초록 빨강\n", " pole_col = 0x669acc # 파랑 초록 빨강\n", "\n", " pos, vel, ang, ang_vel = obs\n", " img = Image.new('RGB', (img_w, img_h), bg_col)\n", " draw = ImageDraw.Draw(img)\n", " cart_x = pos * img_w // x_width + img_w // x_width\n", " cart_y = img_h * 95 // 100\n", " top_pole_x = cart_x + pole_len * np.sin(ang)\n", " top_pole_y = cart_y - cart_h // 2 - pole_len * np.cos(ang)\n", " draw.line((0, cart_y, img_w, cart_y), fill=0)\n", " draw.rectangle((cart_x - cart_w // 2, cart_y - cart_h // 2, cart_x + cart_w // 2, cart_y + cart_h // 2), fill=cart_col) # draw cart\n", " draw.line((cart_x, cart_y - cart_h // 2, top_pole_x, top_pole_y), fill=pole_col, width=pole_w) # draw pole\n", " return np.array(img)\n", "\n", "def plot_cart_pole(env, obs):\n", " img = render_cart_pole(env, obs)\n", " plt.imshow(img)\n", " plt.axis(\"off\")\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXsAAAD/CAYAAAAQaHZxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAABDlJREFUeJzt3FFKAmEYQNEmXEXbaB1tQ9ek22gbtY22Mb1ESAUFOv3j3HNAUEH5HsbLx/DjNM/zHQDbdj96AACWJ/YAAWIPECD2AAFiDxAg9gABYg8QIPYAAWIPECD2AAG70QN84b8bAL6bLv0Cmz1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QMBu9ACwNq+nw4/vP+6P/zwJXI/NHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDs4Q/8Lw63TuwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB7OvJ4Oo0eARYg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9/OJxfxw9AlxM7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEns2bpunPjyU+D2sg9gABu9EDwNo8v+0/nz89nAZOAtdjs4cz56H/6TXcKrEHCBB7gACxhw8vx/23e/Tu2bMV0zzPo2c4t6ph2Ib/PBK5st8T23HxRbyq0zjOKXPrXMMs4RpLxKpibytiCTZ7cM8eIEHsAQLEHiBA7AECxB4gQOwBAsQeIGBV5+xhCc6+g80eIEHsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBgN3qAL6bRAwBskc0eIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIOAdBBArQqzRDE4AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_cart_pole(env, obs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "행동 공간을 확인해 보겠습니다:" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Discrete(2)" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "env.action_space" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "네 딱 두 개의 행동이 있네요. 왼쪽이나 오른쪽 방향으로 가속합니다. 막대가 넘어지기 전까지 카트를 왼쪽으로 밀어보죠:" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "obs = env.reset()\n", "while True:\n", " obs, reward, done, info = env.step(0)\n", " if done:\n", " break" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaIAAAEYCAYAAAAeWvJ8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAABS9JREFUeJzt3cFx2lAUQFHIUEXaiNtIG6YOl2HaSBtxG2mDLLKxjbGRkHQhOmfGC3vB/IWZy5feF9vj8bgBgMq3egEArJsQAZASIgBSQgRASogASAkRACkhAiAlRACkhAiAlBABkNrVC3jH84YA7sd2ihexIwIgJUQApIQIgJQQAZASIgBSQgRASogASAkRACkhAiAlRACkhAiAlBABkBIiAFJCBEBKiABICREAKSECICVEAKSECICUEAGQEiIAUkIEQEqIAEgJEQApIQIgJUQApIQIgJQQAZASIgBSQgRASogASAkRACkhAiAlRACkhAiAlBABkBIiAFJCBEBKiABICREAKSECICVEAKSECICUEAGQEiIAUkIEQEqIAEgJEQApIQIgJUQApIQIgJQQAZASIgBSQgRASogASAkRACkhAiAlRACkhAiAlBABkBIiAFJCBEBKiABICREAKSECICVEAKSECICUEAGQEiIAUkIEQEqIAEgJEQApIQIgJUQApIQIgJQQAZASIgBSQgRASogASAkRACkhAiAlRACkhAiAlBABkBIiAFJCBEBKiABICREAKSECICVEAKSECICUEAGQEiIAUkIEQEqIAEgJEQApIQIgJUQApIQIgJQQAZASIgBSQgRASogASAkRACkhAiC1qxcAa/Ry2L/5/cfjc7QS6NkRwQ14OexP4gRrIUQApIQIgJQQAZASIljYuXtBBhZYKyECICVEAKSECICUEAGQEiJYkEEFOCVEAKSECBbiET7wMSGCmMtyrJ0QAZASIgBSQgQLcH8IzhMiCLk/BEIEQEyIYGYuy8HnhAiAlBABkBIiiBhUgH+ECICUEMGMPG0bviZEAKSECICUEAGQEiKYiYOscBkhgoUZVIC3hAiAlBDBDFyWg8sJEQApIQIgJUSwIIMKcEqIAEgJEUzM8+VgGCECICVEAKSECICUEMGEHGSF4YQIFmBQAc4TIpiI3RCMI0QApIQIZuayHHxOiABICREAKSGCCRhUgPGECGbk/hB8TYgASAkRACkhAiAlRHAlgwpwHSGCmRhUgMsIEQApIYIr+FpwuJ4QAZASIgBSQgRASogASAkRjGRQAaYhRACkhAiAlBABkBIiGMHz5WA6QgQTMqgAwwkRDGQ3BNMSIlZvu90O+jnnYX+Y5HVgbYQIgNSuXgDcm19/Hk/+9vP7IVgJ/B/siABICREM8Pv5dDe02Xy8SwIuI0Qwgaenh3oJcLeECIDU9ng81mt47aYWwzqMHaV+fZnuYT98WOHG3nswxiTnEISI1avO9NzYew/GmOTNc1Pj2w75sSb+37l3U32YuqkQ+YRIwY4IWoYVAEgJEQApIQIgJUQApIQIgJQQAZASIgBSN3WOCArO80DLjgiAlBABkBIiAFJCBEBKiABICREAKSECICVEAKSECICUEAGQEiIAUkIEQEqIAEgJEQApIQIgJUQApIQIgJQQAZASIgBSQgRASogASAkRACkhAiAlRACkhAiAlBABkBIiAFJCBEBKiABICREAKSECICVEAKSECICUEAGQEiIAUkIEQEqIAEgJEQApIQIgJUQApIQIgJQQAZASIgBSQgRASogASAkRACkhAiAlRACkhAiAlBABkBIiAFK7egHvbOsFALAsOyIAUkIEQEqIAEgJEQApIQIgJUQApIQIgJQQAZASIgBSQgRASogASAkRACkhAiAlRACkhAiAlBABkBIiAFJCBEBKiABICREAKSECICVEAKSECICUEAGQEiIAUn8Bm29lFSLelIwAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "img = render_cart_pole(env, obs)\n", "plt.imshow(img)\n", "plt.axis(\"off\")\n", "save_fig(\"cart_pole_plot\")" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(400, 600, 3)" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "img.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "막대가 실제로 넘어지지 않더라도 너무 기울어지면 게임이 끝납니다. 환경을 다시 초기화하고 이번에는 오른쪽으로 밀어보겠습니다:" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "obs = env.reset()\n", "while True:\n", " obs, reward, done, info = env.step(1)\n", " if done:\n", " break" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "scrolled": true }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXsAAAD/CAYAAAAQaHZxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAABLBJREFUeJzt3NtpG1EUQFFPUBVpI24jbdhlGJcRt5E2kjbSxuQnBCG/ZI2SuTN7LTAYgcX5kLcv9wye5nm+AWDfPq09AAD/ntgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwQc1h7ghP/dAPDctPQNnOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiDgsPYAMJqfT/d/v/9y923FSeB6pnme157h2FDD0HIc+WOCzwCmpW/gGgcgQOwBAsQe/nBdw56JPbzjtbt82BKxBwgQe4AAsQcIEHs4YknLXok9nMGSlq0Te4AAsQcIEHs44d6ePRJ7gACxhzNZ0rJlYg8QIPYAAWIPL7CkZW/EHiBA7OEDLGnZKrGHV7jKYU/EHj7I6Z4tEnuAALEHCBB7eIN7e/ZC7AECxB4uYEnL1og9QIDYAwSIPbzDkpY9EHuAALGHC1nSsiViDxAg9gABYg9nsKRl68QeFnBvz1aIPUCA2AMEiD1AgNjDmSxp2TKxh4UsadkCsQcIEHuAALEHCBB7+ABLWrZK7OEKLGkZndjDBznds0ViD1fidM/IxB4gQOwBAsQeIEDs4QKWtGyN2MMVWdIyKrEHCBB7gACxBwgQe7iQJS1bIvZwZZa0jEjsAQLEHl4wTdNZX0t//r33gWsRe1jg9v5p7RHgLIe1B4A9+P7r7uQVfwQYi5M9LPQ89Dc3Dw8/VpgEXif2AAFiDxAg9rDQ4+Pts9e+fnZnz1imeZ7XnuHYUMPQ9b8fhxzs95DxLP5ADvU0jueNqfLZ5y3XOAwMFXunG0bhZM/euLMHCBB7gACxBwgQe4AAsQcIEHuAALEHCBjqOXsYhefe2Rsne4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAgMPaA5yY1h4AYI+c7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECfgMZeVQP9vfpjgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_cart_pole(env, obs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "아까 말했던 것과 같은 상황인 것 같습니다. 어떻게 막대가 똑 바로 서있게 만들 수 있을까요? 이를 위한 *정책*을 만들어야 합니다. 이 정책은 에이전트가 각 스텝에서 행동을 선택하기 위해 사용할 전략입니다. 어떤 행동을 할지 결정하기 위해 지난 행동이나 관측을 사용할 수 있습니다." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 하드 코딩 정책" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "간단한 정책을 하드 코딩해 보겠습니다. 막대가 왼쪽으로 기울어지면 카트를 왼쪽으로 밀고 반대의 경우는 오른쪽으로 밉니다. 작동이 되는지 확인해 보죠:" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "frames = []\n", "\n", "n_max_steps = 1000\n", "n_change_steps = 10\n", "\n", "obs = env.reset()\n", "for step in range(n_max_steps):\n", " img = render_cart_pole(env, obs)\n", " frames.append(img)\n", "\n", " # hard-coded policy\n", " position, velocity, angle, angular_velocity = obs\n", " if angle < 0:\n", " action = 0\n", " else:\n", " action = 1\n", "\n", " obs, reward, done, info = env.step(action)\n", " if done:\n", " break" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "video = plot_animation(frames, figsize=(6,4))\n", "HTML(video.to_html5_video()) # HTML5 동영상으로 만들어 줍니다" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "아니네요, 불안정해서 몇 번 움직이고 막대가 너무 기울어져 게임이 끝났습니다. 더 똑똑한 정책이 필요합니다!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 신경망 정책" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "관측을 입력으로 받고 각 관측에 대해 선택할 행동을 출력하는 신경망을 만들어 보겠습니다. 행동을 선택하기 위해 네트워크는 먼저 각 행동에 대한 확률을 추정하고 그다음 추정된 확률을 기반으로 랜덤하게 행동을 선택합니다. Cart-Pole 환경의 경우에는 두 개의 행동(왼쪽과 오른쪽)이 있으므로 하나의 출력 뉴런만 있으면 됩니다. 행동 0(왼쪽)에 대한 확률 `p`를 출력할 것입니다. 행동 1(오른쪽)에 대한 확률은 `1 - p`가 됩니다." ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "\n", "# 1. 네트워크 구조를 설정합니다\n", "n_inputs = 4 # == env.observation_space.shape[0]\n", "n_hidden = 4 # 간단한 작업이므로 너무 많은 뉴런이 필요하지 않습니다\n", "n_outputs = 1 # 왼쪽으로 가속할 확률을 출력합니다\n", "initializer = tf.variance_scaling_initializer()\n", "\n", "# 2. 네트워크를 만듭니다\n", "X = tf.placeholder(tf.float32, shape=[None, n_inputs])\n", "hidden = tf.layers.dense(X, n_hidden, activation=tf.nn.elu,\n", " kernel_initializer=initializer)\n", "outputs = tf.layers.dense(hidden, n_outputs, activation=tf.nn.sigmoid,\n", " kernel_initializer=initializer)\n", "\n", "# 3. 추정된 확률을 기반으로 랜덤하게 행동을 선택합니다\n", "p_left_and_right = tf.concat(axis=1, values=[outputs, 1 - outputs])\n", "action = tf.multinomial(tf.log(p_left_and_right), num_samples=1)\n", "\n", "init = tf.global_variables_initializer()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이 환경은 각 관측이 환경의 모든 상태를 포함하고 있기 때문에 지난 행동과 관측은 무시해도 괜찮습니다. 숨겨진 상태가 있다면 이 정보를 추측하기 위해 이전 행동과 상태를 고려해야 합니다. 예를 들어, 속도가 없고 카트의 위치만 있다면 현재 속도를 예측하기 위해 현재의 관측뿐만 아니라 이전 관측도 고려해야 합니다. 관측에 잡음이 있을 때도 같은 경우입니다. 현재 상태를 근사하게 추정하기 위해 과거 몇 개의 관측을 사용하는 것이 좋을 것입니다. 이 문제는 아주 간단해서 현재 관측에 잡음이 없고 환경의 모든 상태가 담겨 있습니다." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "정책 네트워크에서 만든 확률을 기반으로 가장 높은 확률을 가진 행동을 고르지 않고 왜 랜덤하게 행동을 선택하는지 궁금할 수 있습니다. 이런 방식이 에이전트가 새 행동을 *탐험*하는 것과 잘 동작하는 행동을 *이용*하는 것 사이에 균형을 맞추게 합니다. 만약 어떤 레스토랑에 처음 방문했다고 가정합시다. 모든 메뉴에 대한 선호도가 동일하므로 랜덤하게 하나를 고릅니다. 이 메뉴가 맛이 좋았다면 다음에 이를 주문할 가능성을 높일 것입니다. 하지만 100% 확률이 되어서는 안됩니다. 그렇지 않으면 다른 메뉴를 전혀 선택하지 않게 되고 더 좋을 수 있는 메뉴를 시도해 보지 못하게 됩니다." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "정책 신경망을 랜덤하게 초기화하고 게임 하나를 플레이해 보겠습니다:" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "n_max_steps = 1000\n", "frames = []\n", "\n", "with tf.Session() as sess:\n", " init.run()\n", " obs = env.reset()\n", " for step in range(n_max_steps):\n", " img = render_cart_pole(env, obs)\n", " frames.append(img)\n", " action_val = action.eval(feed_dict={X: obs.reshape(1, n_inputs)})\n", " obs, reward, done, info = env.step(action_val[0][0])\n", " if done:\n", " break\n", "\n", "env.close()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "랜덤하게 초기화한 정책 네트워크가 얼마나 잘 동작하는지 확인해 보겠습니다:" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "video = plot_animation(frames, figsize=(6,4))\n", "HTML(video.to_html5_video()) # HTML5 동영상으로 만들어 줍니다" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "음.. 별로 좋지 않네요. 신경망이 더 잘 학습되어야 합니다. 먼저 앞서 사용한 기본 정책을 학습할 수 있는지 확인해 보겠습니다. 막대가 왼쪽으로 기울어지면 왼쪽으로 움직이고 오른쪽으로 기울어지면 오른쪽으로 이동해야 합니다. 다음 코드는 같은 신경망이지만 타깃 확률 `y`와 훈련 연산(`cross_entropy`, `optimizer`, `training_op`)을 추가했습니다:" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "\n", "reset_graph()\n", "\n", "n_inputs = 4\n", "n_hidden = 4\n", "n_outputs = 1\n", "\n", "learning_rate = 0.01\n", "\n", "initializer = tf.variance_scaling_initializer()\n", "\n", "X = tf.placeholder(tf.float32, shape=[None, n_inputs])\n", "y = tf.placeholder(tf.float32, shape=[None, n_outputs])\n", "\n", "hidden = tf.layers.dense(X, n_hidden, activation=tf.nn.elu, kernel_initializer=initializer)\n", "logits = tf.layers.dense(hidden, n_outputs)\n", "outputs = tf.nn.sigmoid(logits) # 행동 0(왼쪽)에 대한 확률\n", "p_left_and_right = tf.concat(axis=1, values=[outputs, 1 - outputs])\n", "action = tf.multinomial(tf.log(p_left_and_right), num_samples=1)\n", "\n", "cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(labels=y, logits=logits)\n", "optimizer = tf.train.AdamOptimizer(learning_rate)\n", "training_op = optimizer.minimize(cross_entropy)\n", "\n", "init = tf.global_variables_initializer()\n", "saver = tf.train.Saver()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "동일한 네트워크를 동시에 10개의 다른 환경에서 플레이하고 1,000번 반복동안 훈련시키겠습니다. 완료되면 환경을 리셋합니다." ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n" ] } ], "source": [ "n_environments = 10\n", "n_iterations = 1000\n", "\n", "envs = [gym.make(\"CartPole-v0\") for _ in range(n_environments)]\n", "observations = [env.reset() for env in envs]\n", "\n", "with tf.Session() as sess:\n", " init.run()\n", " for iteration in range(n_iterations):\n", " target_probas = np.array([([1.] if obs[2] < 0 else [0.]) for obs in observations]) # angle<0 이면 proba(left)=1. 이 되어야 하고 그렇지 않으면 proba(left)=0. 이 되어야 합니다\n", " action_val, _ = sess.run([action, training_op], feed_dict={X: np.array(observations), y: target_probas})\n", " for env_index, env in enumerate(envs):\n", " obs, reward, done, info = env.step(action_val[env_index][0])\n", " observations[env_index] = obs if not done else env.reset()\n", " saver.save(sess, \"./my_policy_net_basic.ckpt\")\n", "\n", "for env in envs:\n", " env.close()" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [], "source": [ "def render_policy_net(model_path, action, X, n_max_steps = 1000):\n", " frames = []\n", " env = gym.make(\"CartPole-v0\")\n", " obs = env.reset()\n", " with tf.Session() as sess:\n", " saver.restore(sess, model_path)\n", " for step in range(n_max_steps):\n", " img = render_cart_pole(env, obs)\n", " frames.append(img)\n", " action_val = action.eval(feed_dict={X: obs.reshape(1, n_inputs)})\n", " obs, reward, done, info = env.step(action_val[0][0])\n", " if done:\n", " break\n", " env.close()\n", " return frames " ] }, { "cell_type": "code", "execution_count": 42, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", "INFO:tensorflow:Restoring parameters from ./my_policy_net_basic.ckpt\n" ] }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "frames = render_policy_net(\"./my_policy_net_basic.ckpt\", action, X)\n", "video = plot_animation(frames, figsize=(6,4))\n", "HTML(video.to_html5_video()) # HTML5 동영상으로 만들어 줍니다" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "정책을 잘 학습한 것 같네요. 이제 스스로 더 나은 정책을 학습할 수 있는지 알아 보겠습니다." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 정책 그래디언트" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "신경망을 훈련하기 위해 타깃 확률 `y`를 정의할 필요가 있습니다. 행동이 좋다면 이 확률을 증가시켜야 하고 반대로 나쁘면 이를 감소시켜야 합니다. 하지만 행동이 좋은지 나쁜지 어떻게 알 수 있을까요? 대부분의 행동으로 인한 영향은 뒤늦게 나타나는 것이 문제입니다. 게임에서 이기거나 질 때 어떤 행동이 이런 결과에 영향을 미쳤는지 명확하지 않습니다. 마지막 행동일까요? 아니면 마지막 10개의 행동일까요? 아니면 50번 스텝 앞의 행동일까요? 이를 *신용 할당 문제*라고 합니다.\n", "\n", "*정책 그래디언트* 알고리즘은 먼저 여러번 게임을 플레이하고 성공한 게임에서의 행동을 조금 더 높게 실패한 게임에서는 조금 더 낮게 되도록 하여 이 문제를 해결합니다. 먼저 게임을 진행해 보고 다시 어떻게 한 것인지 살펴 보겠습니다." ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "\n", "reset_graph()\n", "\n", "n_inputs = 4\n", "n_hidden = 4\n", "n_outputs = 1\n", "\n", "learning_rate = 0.01\n", "\n", "initializer = tf.variance_scaling_initializer()\n", "\n", "X = tf.placeholder(tf.float32, shape=[None, n_inputs])\n", "\n", "hidden = tf.layers.dense(X, n_hidden, activation=tf.nn.elu, kernel_initializer=initializer)\n", "logits = tf.layers.dense(hidden, n_outputs)\n", "outputs = tf.nn.sigmoid(logits) # 행동 0(왼쪽)에 대한 확률\n", "p_left_and_right = tf.concat(axis=1, values=[outputs, 1 - outputs])\n", "action = tf.multinomial(tf.log(p_left_and_right), num_samples=1)\n", "\n", "y = 1. - tf.to_float(action)\n", "cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(labels=y, logits=logits)\n", "optimizer = tf.train.AdamOptimizer(learning_rate)\n", "grads_and_vars = optimizer.compute_gradients(cross_entropy)\n", "gradients = [grad for grad, variable in grads_and_vars]\n", "gradient_placeholders = []\n", "grads_and_vars_feed = []\n", "for grad, variable in grads_and_vars:\n", " gradient_placeholder = tf.placeholder(tf.float32, shape=grad.get_shape())\n", " gradient_placeholders.append(gradient_placeholder)\n", " grads_and_vars_feed.append((gradient_placeholder, variable))\n", "training_op = optimizer.apply_gradients(grads_and_vars_feed)\n", "\n", "init = tf.global_variables_initializer()\n", "saver = tf.train.Saver()" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [], "source": [ "def discount_rewards(rewards, discount_rate):\n", " discounted_rewards = np.zeros(len(rewards))\n", " cumulative_rewards = 0\n", " for step in reversed(range(len(rewards))):\n", " cumulative_rewards = rewards[step] + cumulative_rewards * discount_rate\n", " discounted_rewards[step] = cumulative_rewards\n", " return discounted_rewards\n", "\n", "def discount_and_normalize_rewards(all_rewards, discount_rate):\n", " all_discounted_rewards = [discount_rewards(rewards, discount_rate) for rewards in all_rewards]\n", " flat_rewards = np.concatenate(all_discounted_rewards)\n", " reward_mean = flat_rewards.mean()\n", " reward_std = flat_rewards.std()\n", " return [(discounted_rewards - reward_mean)/reward_std for discounted_rewards in all_discounted_rewards]" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([-22., -40., -50.])" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "discount_rewards([10, 0, -50], discount_rate=0.8)" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[array([-0.28435071, -0.86597718, -1.18910299]),\n", " array([1.26665318, 1.0727777 ])]" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "discount_and_normalize_rewards([[10, 0, -50], [10, 20]], discount_rate=0.8)" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", "반복: 249" ] } ], "source": [ "env = gym.make(\"CartPole-v0\")\n", "\n", "n_games_per_update = 10\n", "n_max_steps = 1000\n", "n_iterations = 250\n", "save_iterations = 10\n", "discount_rate = 0.95\n", "\n", "with tf.Session() as sess:\n", " init.run()\n", " for iteration in range(n_iterations):\n", " print(\"\\r반복: {}\".format(iteration), end=\"\")\n", " all_rewards = []\n", " all_gradients = []\n", " for game in range(n_games_per_update):\n", " current_rewards = []\n", " current_gradients = []\n", " obs = env.reset()\n", " for step in range(n_max_steps):\n", " action_val, gradients_val = sess.run([action, gradients], feed_dict={X: obs.reshape(1, n_inputs)})\n", " obs, reward, done, info = env.step(action_val[0][0])\n", " current_rewards.append(reward)\n", " current_gradients.append(gradients_val)\n", " if done:\n", " break\n", " all_rewards.append(current_rewards)\n", " all_gradients.append(current_gradients)\n", "\n", " all_rewards = discount_and_normalize_rewards(all_rewards, discount_rate=discount_rate)\n", " feed_dict = {}\n", " for var_index, gradient_placeholder in enumerate(gradient_placeholders):\n", " mean_gradients = np.mean([reward * all_gradients[game_index][step][var_index]\n", " for game_index, rewards in enumerate(all_rewards)\n", " for step, reward in enumerate(rewards)], axis=0)\n", " feed_dict[gradient_placeholder] = mean_gradients\n", " sess.run(training_op, feed_dict=feed_dict)\n", " if iteration % save_iterations == 0:\n", " saver.save(sess, \"./my_policy_net_pg.ckpt\")" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [], "source": [ "env.close()" ] }, { "cell_type": "code", "execution_count": 49, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", "INFO:tensorflow:Restoring parameters from ./my_policy_net_pg.ckpt\n" ] }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "frames = render_policy_net(\"./my_policy_net_pg.ckpt\", action, X, n_max_steps=1000)\n", "video = plot_animation(frames, figsize=(6,4))\n", "HTML(video.to_html5_video()) # HTML5 동영상으로 만들어 줍니다" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 마르코프 연쇄" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "상태: 0 0 3 \n", "상태: 0 1 2 1 2 1 2 1 2 1 3 \n", "상태: 0 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 3 \n", "상태: 0 3 \n", "상태: 0 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 3 \n", "상태: 0 1 3 \n", "상태: 0 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 ...\n", "상태: 0 0 3 \n", "상태: 0 0 0 1 2 1 2 1 3 \n", "상태: 0 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 3 \n" ] } ], "source": [ "transition_probabilities = [\n", " [0.7, 0.2, 0.0, 0.1], # s0에서 s0, s1, s2, s3으로\n", " [0.0, 0.0, 0.9, 0.1], # s1에서 ...\n", " [0.0, 1.0, 0.0, 0.0], # s2에서 ...\n", " [0.0, 0.0, 0.0, 1.0], # s3에서 ...\n", " ]\n", "\n", "n_max_steps = 50\n", "\n", "def print_sequence(start_state=0):\n", " current_state = start_state\n", " print(\"상태:\", end=\" \")\n", " for step in range(n_max_steps):\n", " print(current_state, end=\" \")\n", " if current_state == 3:\n", " break\n", " current_state = np.random.choice(range(4), p=transition_probabilities[current_state])\n", " else:\n", " print(\"...\", end=\"\")\n", " print()\n", "\n", "for _ in range(10):\n", " print_sequence()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 마르코프 결정 과정" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "policy_fire\n", "상태 (+보상): 0 (10) 0 (10) 0 1 (-50) 2 2 2 (40) 0 (10) 0 (10) 0 (10) ... 전체 보상 = 210\n", "상태 (+보상): 0 1 (-50) 2 (40) 0 (10) 0 (10) 0 1 (-50) 2 2 (40) 0 (10) ... 전체 보상 = 70\n", "상태 (+보상): 0 (10) 0 1 (-50) 2 (40) 0 (10) 0 (10) 0 (10) 0 (10) 0 (10) 0 (10) ... 전체 보상 = 70\n", "상태 (+보상): 0 1 (-50) 2 1 (-50) 2 (40) 0 (10) 0 1 (-50) 2 (40) 0 ... 전체 보상 = -10\n", "상태 (+보상): 0 1 (-50) 2 (40) 0 (10) 0 (10) 0 1 (-50) 2 (40) 0 (10) 0 (10) ... 전체 보상 = 290\n", "요약: 평균=121.1, 표준 편차=129.333766, 최소=-330, 최대=470\n", "\n", "policy_random\n", "상태 (+보상): 0 1 (-50) 2 1 (-50) 2 (40) 0 1 (-50) 2 2 (40) 0 ... 전체 보상 = -60\n", "상태 (+보상): 0 (10) 0 0 0 0 0 (10) 0 0 0 (10) 0 ... 전체 보상 = -30\n", "상태 (+보상): 0 1 1 (-50) 2 (40) 0 0 1 1 1 1 ... 전체 보상 = 10\n", "상태 (+보상): 0 (10) 0 (10) 0 0 0 0 1 (-50) 2 (40) 0 0 ... 전체 보상 = 0\n", "상태 (+보상): 0 0 (10) 0 1 (-50) 2 (40) 0 0 0 0 (10) 0 (10) ... 전체 보상 = 40\n", "요약: 평균=-22.1, 표준 편차=88.152740, 최소=-380, 최대=200\n", "\n", "policy_safe\n", "상태 (+보상): 0 1 1 1 1 1 1 1 1 1 ... 전체 보상 = 0\n", "상태 (+보상): 0 1 1 1 1 1 1 1 1 1 ... 전체 보상 = 0\n", "상태 (+보상): 0 (10) 0 (10) 0 (10) 0 1 1 1 1 1 1 ... 전체 보상 = 30\n", "상태 (+보상): 0 (10) 0 1 1 1 1 1 1 1 1 ... 전체 보상 = 10\n", "상태 (+보상): 0 1 1 1 1 1 1 1 1 1 ... 전체 보상 = 0\n", "요약: 평균=22.3, 표준 편차=26.244312, 최소=0, 최대=170\n", "\n" ] } ], "source": [ "transition_probabilities = [\n", " [[0.7, 0.3, 0.0], [1.0, 0.0, 0.0], [0.8, 0.2, 0.0]], # s0에서, 행동 a0이 선택되면 0.7의 확률로 상태 s0로 가고 0.3의 확률로 상태 s1로 가는 식입니다.\n", " [[0.0, 1.0, 0.0], None, [0.0, 0.0, 1.0]],\n", " [None, [0.8, 0.1, 0.1], None],\n", " ]\n", "\n", "rewards = [\n", " [[+10, 0, 0], [0, 0, 0], [0, 0, 0]],\n", " [[0, 0, 0], [0, 0, 0], [0, 0, -50]],\n", " [[0, 0, 0], [+40, 0, 0], [0, 0, 0]],\n", " ]\n", "\n", "possible_actions = [[0, 1, 2], [0, 2], [1]]\n", "\n", "def policy_fire(state):\n", " return [0, 2, 1][state]\n", "\n", "def policy_random(state):\n", " return np.random.choice(possible_actions[state])\n", "\n", "def policy_safe(state):\n", " return [0, 0, 1][state]\n", "\n", "class MDPEnvironment(object):\n", " def __init__(self, start_state=0):\n", " self.start_state=start_state\n", " self.reset()\n", " def reset(self):\n", " self.total_rewards = 0\n", " self.state = self.start_state\n", " def step(self, action):\n", " next_state = np.random.choice(range(3), p=transition_probabilities[self.state][action])\n", " reward = rewards[self.state][action][next_state]\n", " self.state = next_state\n", " self.total_rewards += reward\n", " return self.state, reward\n", "\n", "def run_episode(policy, n_steps, start_state=0, display=True):\n", " env = MDPEnvironment()\n", " if display:\n", " print(\"상태 (+보상):\", end=\" \")\n", " for step in range(n_steps):\n", " if display:\n", " if step == 10:\n", " print(\"...\", end=\" \")\n", " elif step < 10:\n", " print(env.state, end=\" \")\n", " action = policy(env.state)\n", " state, reward = env.step(action)\n", " if display and step < 10:\n", " if reward:\n", " print(\"({})\".format(reward), end=\" \")\n", " if display:\n", " print(\"전체 보상 =\", env.total_rewards)\n", " return env.total_rewards\n", "\n", "for policy in (policy_fire, policy_random, policy_safe):\n", " all_totals = []\n", " print(policy.__name__)\n", " for episode in range(1000):\n", " all_totals.append(run_episode(policy, n_steps=100, display=(episode<5)))\n", " print(\"요약: 평균={:.1f}, 표준 편차={:1f}, 최소={}, 최대={}\".format(np.mean(all_totals), np.std(all_totals), np.min(all_totals), np.max(all_totals)))\n", " print()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Q-러닝" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Q-러닝은 에이전트가 플레이하는 것(가령, 랜덤하게)을 보고 점진적으로 Q-가치 추정을 향상시킵니다. 정확한 (또는 충분히 이에 가까운) Q-가치가 추정되면 최적의 정책은 가장 높은 Q-가치(즉, 그리디 정책)를 가진 행동을 선택하는 것이 됩니다." ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [], "source": [ "n_states = 3\n", "n_actions = 3\n", "n_steps = 20000\n", "alpha = 0.01\n", "gamma = 0.99\n", "exploration_policy = policy_random\n", "q_values = np.full((n_states, n_actions), -np.inf)\n", "for state, actions in enumerate(possible_actions):\n", " q_values[state][actions]=0\n", "\n", "env = MDPEnvironment()\n", "for step in range(n_steps):\n", " action = exploration_policy(env.state)\n", " state = env.state\n", " next_state, reward = env.step(action)\n", " next_value = np.max(q_values[next_state]) # 그리디한 정책\n", " q_values[state, action] = (1-alpha)*q_values[state, action] + alpha*(reward + gamma * next_value)" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "def optimal_policy(state):\n", " return np.argmax(q_values[state])" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[39.13508139, 38.88079412, 35.23025716],\n", " [18.9117071 , -inf, 20.54567816],\n", " [ -inf, 72.53192111, -inf]])" ] }, "execution_count": 54, "metadata": {}, "output_type": "execute_result" } ], "source": [ "q_values" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "상태 (+보상): 0 (10) 0 (10) 0 1 (-50) 2 (40) 0 (10) 0 1 (-50) 2 (40) 0 (10) ... 전체 보상 = 230\n", "상태 (+보상): 0 (10) 0 (10) 0 (10) 0 1 (-50) 2 2 1 (-50) 2 (40) 0 (10) ... 전체 보상 = 90\n", "상태 (+보상): 0 1 (-50) 2 (40) 0 (10) 0 (10) 0 (10) 0 (10) 0 (10) 0 (10) 0 (10) ... 전체 보상 = 170\n", "상태 (+보상): 0 1 (-50) 2 (40) 0 (10) 0 (10) 0 (10) 0 (10) 0 (10) 0 (10) 0 (10) ... 전체 보상 = 220\n", "상태 (+보상): 0 1 (-50) 2 (40) 0 (10) 0 1 (-50) 2 (40) 0 (10) 0 (10) 0 (10) ... 전체 보상 = -50\n", "요약: 평균=125.6, 표준 편차=127.363464, 최소=-290, 최대=500\n", "\n" ] } ], "source": [ "all_totals = []\n", "for episode in range(1000):\n", " all_totals.append(run_episode(optimal_policy, n_steps=100, display=(episode<5)))\n", "print(\"요약: 평균={:.1f}, 표준 편차={:1f}, 최소={}, 최대={}\".format(np.mean(all_totals), np.std(all_totals), np.min(all_totals), np.max(all_totals)))\n", "print()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# DQN 알고리즘으로 미스팩맨 게임 학습하기" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 미스팩맨 환경 만들기" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(210, 160, 3)" ] }, "execution_count": 56, "metadata": {}, "output_type": "execute_result" } ], "source": [ "env = gym.make(\"MsPacman-v0\")\n", "obs = env.reset()\n", "obs.shape" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Discrete(9)" ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ "env.action_space" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 전처리" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이미지 전처리는 선택 사항이지만 훈련 속도를 크게 높여 줍니다." ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [], "source": [ "mspacman_color = 210 + 164 + 74\n", "\n", "def preprocess_observation(obs):\n", " img = obs[1:176:2, ::2] # 자르고 크기를 줄입니다.\n", " img = img.sum(axis=2) # 흑백 스케일로 변환합니다.\n", " img[img==mspacman_color] = 0 # 대비를 높입니다.\n", " img = (img // 3 - 128).astype(np.int8) # -128~127 사이로 정규화합니다.\n", " return img.reshape(88, 80, 1)\n", "\n", "img = preprocess_observation(obs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "노트 `preprocess_observation()` 함수가 책에 있는 것과 조금 다릅니다. 64비트 부동소수를 -1.0~1.0 사이로 나타내지 않고 부호있는 바이트(-128~127 사이)로 표현합니다. 이렇게 하는 이유는 재생 메모리가 약 8배나 적게 소모되기 때문입니다(52GB에서 6.5GB로). 정밀도를 감소시켜도 눈에 띄이게 훈련에 미치는 영향은 없습니다." ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAArsAAAGoCAYAAABGyS0qAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xm8pFdZJ/DfAaOyCCPIAAY0AkLYJRhkJ4yEALKIcYGBZhhxUEChYVBcQNERRWRpVIwyiIEbEVBkhLAJDEhEIpuAYXNEA3YMhE1WxQBn/njfGyrVdatu3+q6b72nvt/P5366b73vec95T9W9/fRTT51Taq0BAIAWXWboAQAAwKoIdgEAaJZgFwCAZgl2AQBolmAXAIBmCXYBAGiWYBcAgGYJdkeslPJ1pZRnlFI+WEp5fynlL0spJ/XHfqKUcubEuT9WSjl7xjWeWEo5tMP1zyqlnD/x9fFSyh/1x84tpZwyo81rptpcaRf3cZlSyp+WUm6x+7tnWinlB0opvzH0OABgnQh2x+1hSU5MctNa6w2TPD3Ji+acf5dSykcnv5I8dqeTa60PrLWesP2V5AlJPjvr3FLKSaWUDyT59iT/nuSL/aGLSik3W3Afj0zyoVrr3/bXunwp5SmllFpKufVUP5ctpfxGKeW9pZS3lVJeP3HsB/vH31NKeVMp5foL+p11H1cvpfx+/5+Ht5ZSziml3HTi+FVKKVv92K4x1fYRpZQPlFLOK6WcXUq5+ozrP7GU8rH+Pwtv7r/uMnH8uFLK40op7yqlvLuU8tellFeWUu7QHz+zlPKRvv1bSylvL6V8X5LUWv8syQ1LKace7X0DQKsEu+NW0j2H28/j1y04/9W11mtMfiV56twOSrnidrY4ydWSXDTrvFrrO2utJ9ZaT0xyUpJnJfl8knvVWt8z5/qXS/K4qXH8TJL3J/nwjCZ/kORwkpvUWk9O8sP9da6X5PeT3KfWerMkz03ykqm+nl5K+eGJ769SSnldKeVqE6edlOQ1tdYb1lpvleT/JHnaxPFfSfLiGfdxSn8fd6i13iTJ2/uxzvLSWuuta623S/JTSV5aSvmW/tgLktwiyam11pvXWm+b5EFJLp5o/7t9+1sleXSS500c+/UkT9mhXwDYOIuCI9bbGUm+M8n7Sik1yUeT3H/O+aeVUg5PPXaldIHhTk5Mclb/5/WSvHHi2C+UUk6rtf5cn3m8SZJT0mV3/yrJZ5I8spRy8yR/sUPQe+8k76m1fnz7gVrrE5OklPJLkyf2mdprJvmOJH9TSvl0umzzJ5Ocni5I/Yf+9K0kTy2l3LTW+nf9Y09K8qpSyuWTvDLJq5I8barvV02N78JM/JzUWn+yH8v0ffxIkrMmrvXMJB8vpVy51vqZGfe9fb13llI+m+SEPmC+QZJb1lovnjjnE0k+scMlvj3JuyfOfUsp5VtKKTeutb53p34BYFMIdkeqlHKjJNdP8vokb0jyDUmukOTOpZTTM5UVrbU+J8lz+raPTXJirfXHjrLbX03y6VLKlfvvX53kzf3fr5cu+/hrtdZLgq8+43qndFnoWW6ZiWBtgTskuVWS36y1PqaUcrckryylXDfJdZJ8aPvEWutXSikf7h//u/6xT/Zv8b8yyROTPKZ/63+mvgzhV5I8ZBdju04mMr611k+XUj6T5IR591dK+f4k/5HkvCQPTvKy7UC3lPJTSR6Q5JvSBfKP6Zs9vG93tXQ/w/eduuy7082rYBeAjSfYHa8Tkty+//vxSe6Y5BnpMoDvSBcI71kp5ZfTZUu/Mcm1SynnpwtYv5Tkhf1p76i1nltKeU0/hu22O11zq9Y6/QGqKyb52C6H9Z+T/FWt9XVJUmt9dT+u2/dj+8rU+V/OkaU635wum31hugB9p7FeNX1QXGv9y12Mbbf9J8l9SynflS4re26SW9da/72ft8ttn1Rr/e0kv11K+dl0mfVtv1trfXI/zu9K8opSyl0nMrmfT3LlAACC3bGqtb4yXTC2XS/6nbXWp5ZS3pYu4/dNSV5eSrlWuoCq9F9f3b7GjJKGh9RaX9P//cnpgucrJflcks/UWifb3nWi3X2zu/rv/5jx2IfTlWLsxkU58gNyX00XZB5O9x+ASdfuH0+SlFJukOTPkvxEkrcm+ZNSyuW3yyYmzrtmuhKHp9RaX7DLsR1O8m0T17h8kqtO9j/hpbXWn+g/WPY76YLTpHueHlNKuczkXM9Ta31XKeUtSU7L1zK5107yT7scNwA0zQfURq6U8qIkN9r+vtZ6cr9ywi/03x+utV4r3QeXXllrvdacr9dMXOffaq3/muT5SW43L/iqtX4xyalJPrDD12trrZ+vtc4Kds9O8l92ebtnJ7nr9ioLpZTbpwsw35Lug13f1wf3KaX8ULpVId4x0f4+Sf57rfWcWuuX0mWur9pncdO3+/Z0pSH/6ygC3aSrEX7ARInHI5K8ebIeeFqt9RXpSiye3D/04nSB+7NKKVecOPVaO12jL7X4niR/039/5SQ3TlfaAgAbT2Z3/G6QbsmxP5x6/BNJPrLCfn8tyQe3v6m1vjTJS6dPKqWcmK62d6Za63tLKR8qpdx9xofDps/9WCnlAUleUkq5OF1Jxff3HwD7TF/j+vL+2OeS3H0ySK+1PmXqehenWw1h0tOSXD3JT5dSfrp/7Eu11jstGNsbSinPSvKXff//kuR+89r0HpXkvFLKi2ut55RS7pzk55OcW0r5SrqA/aJcekWI7Zrd7dU4Hl9r3a6d/h9Jnldr/cIu+gaA5pVa69BjYAmllHelW6Hg4hmH31trPa0/7yeT/EaST8847x9qrafscP03pltl4d9nHP6lWutOy2tttz8x3ZJnJ8w554ZJ/neSO0+uQsDR6TPUr0pyWq111vMMABtHsMtaKKXcKckna63nDT2WsSql3DHJJ2qt7xt6LACwLgS7AAA0ywfUAGhW2WktRGBz1FrX5itJ9eXLl6/9/hr6d5+vY/rvyPnplvz7SP/n55Ic1x97YZIHT5z72HQ7T340ydumrnM4yfX6v/9W//3k13X7YycmOX+i3QnpVlU5f+rrpunW9j48Y8zfPeP8i9Mty3i3JG+c0ebuU+f/Yvp3axfMz+2S/NHQz9OKnvsbTD1H9+8fPzPJj6VbdvLMifPvnG4zok+m22b9shPHPprkhB36eW6SR+xyTKXv+8x0O4z+1cSxk5Ock+5DyO9M93mL7WOXvFaT/M8kDx96fsf8ZTUGAJpRJz4MW0p5fKa2354696lJnjpxfql9dDF13iOTPHLivMNJjpszjI/XGR/K7XeUnDWOt2dqnfBSykdz5Lri28eem+S2+doHhy9O8nPpAu//utOgSilXSPIH6XajTCnle5L8ZrrNg45L8o9JHllrvaBf4vGZ6dZtL+lW+Hl0PcrPBCzo4+vT/UfiTn0fr0m3s+X0Bj0ppXwwszfLuXqSq9VaP1Fr/WD6pRpLKWel21V0p3F9S7qA8vuSvC/Jy5L8eJLf3cVtXT/JzM2G+n5vn24N+CskeUm6QHb6vCskeXmSH621vrKUcnK6DYJOrrV+eOr0ZyT5m1LKG2qt79/F+JiyVsHu4Uc9aughANCAfhvxg0m+t5TyK0kemi5YenV//DFJtrfg/nKS/5Tkt5M8Yf9H2+m3gb8oyafSjfWTs86rtf7oRJtbpMvqvitd5nKen0q33vr2+t8vSLd04R/35R6H0mU4H5DkjCSvq7X+et/PY9KtmnO7/vvLJHlFksfWfvfGfiy/XGu990Sf8/r4+XTB6Y3TlVW+KsmjM/EfkIl7vsH0Y6WUr0u3WdHnFtz3LPdL8hf9fzRSSvm1flxzg91Syi2T3Dzd8pQvqrVeaqWiWusDJ879xSSX3eFS10/y+dptEJVa69v61ZVOSrfZ0uQ1v1pKeXq6JT+nt4dnF9TsAtCUfrfEVyT5QpLb1lp/sdZ6jSR/vn1OrfXptd9QJ8l3pCt7mLvW91G4Winl/KmvH9xFu6ckuUu6rcQvqrV+uX/8hqWUs0opJ5dSrlFK+aFSypNKKX+dLqP7qXQB45NKKQ/sM6az/LckfzLx/QX9WEuSr08X8F8wcewqpZTLllIum25HyO1jqd0a5k9It+75LfoM7h+n39Bol338SJLfqrV+tb/XZyW5/y7madtV0wWMX9rphFLKrdJlp6ddL8lklvR9Sa47r7M+G/zsJD+d7vX1gj7gnjzntqWUU0op90hy18zI6vY+mORypZR7ls5t0gXRb9vh/JcnuUsp5ZvnjZHZ1iqzO8+1XnLNoYcwuMOnX7jjMfOD18d88+aHNpRSLpfk4elqcR+V5C+SnF1KuXX/2E4eneSjtda/7t8uv0KSa0xc98wk35+vbe2dzF7bPLXW89Nn80opn0jy3f1jO5YxzHBRkrv1OyIel65+9Mx0Gb9vTJcVfEu6Lc0/M3Hvt09XZ3zEbpX9FuYnJnn3xMM/kC5oe1ySy6fbpfLn+mMPTfcW/MfSZb7fk6msYq317f1ulS9JV/9+776UYNK8Pq6TrmZ224f6x3brqkku+cHu13U/L1/bpv3lSZ6Y5LsyY9OjdKUTs/5+6ZO6QP0+6coxXlhr/b0+s/28JOeUUn681vqe/vT7pLvPKyS5VbpSkCumu+9L1Fq/2G8Z/5vpsuj/nOR+tdZZW8yn1vr5UsqH09V+v2mnsTKbzC4ArbhfklskuU2t9cW12/L8lHRv8R9Ri5skfcb1YJKblFJuUmu9QZ/t/ejUqU+sl95e/UNHXm2m40opVy2l3DQzMoyllKuXUs4rpZyX5I7p3uZ/f5JXJvm/SY5P8ula6+uS3KZ//P7pthl/80Tbt6Wr7XxE/9g3TXV1xXR1pJNvu78gXZ3st6UL7v8jX9ut8bf6OTi+/zo3yR/NuL/j0+1meXGSb51xfF4fJd2H+bZ9OUcXl1w73YfzJh2utZ7Qf72o1nqPdP/pmfYP6YL/bTfMpQPvSbdOF6A/otb6hKTLbNdaD6TbvfRFpZRr948/Ll3m93rp6o9PSPKzsy5aa31XrfXUWuu1a623rbW+YeLwy9MF7pM+n9l1yywwmswuAMxTa/3DJH/YB5fXqbX+Y//2+DOSZHIVsj4T+vPpspW3S1cr+ZpSykNrra9Y1Fcp5bh0WchvmHjsFumClKQLLC9OVyP8qXSB46xg8aJ0Gdnjklwuyccm35Yvpdxt4txXJHn9orElXSZw6qGPpwtKvzXJ4f4t+buky8Z+JclXSilPSRc0H0z3Qbdbbo+lr2n9YinlKrXWT/WP3TddvfBp6eKJs0spP1Nr3a6LXtTH4XRB8Pn9GL8tX8vK7saN033gbaZSypXSlU3M8sIkT+hrcN+fLph97qwTa61vSfI9Oxx7drrShu0+r5/k+enqgX9n0Q2UUh6d5BEzDl09XY312yceu3aSf1p0TY4kswtAa+6Vrv51pv5t6ZelCy5vU2v951rrnyf5wVw627ftK0ke39fe/mNf6nBukl/KRLBba/3bPit8xyTvqrVes9Z63VrrybXWe+XSgct2m9pnoG+c5M/m1Z/WWr/cB7GvSvKBHb7uPiPQTb/KxCuSfG//0KfTrbAwWUv8Q+lqSZPk7/vvJ499PMm/JpfM4alJ7lpr/Zda60fSBb33mGizqI+tJA/ra1Yvk64E5U93uv8ZbpxL193WJNcspVzQv+X/6iT3ntWw1vqJdAH9VpL/l6684/ePou8jlFKu3l/vV2utv7ibNrXWZ9Rarzf9lYn68v7aN0r3n6f3LjPGTSWzC8CmeEySL9RaaynltP5DVpfoM3hvmW5Ua31IkofMumBfJzrt8ukyxdM+leTpRznm89LVdU6O5w47jOWFmbPcVroSiacleV6t9SullHsmeUop5VHp6owvTLdKQtKVhDyjlPLOdMH+F5Lca3vO+uD54VPjujATS7Ttoo8n92N6R//9m5I8aeqeDqbLAs/yjUlO61eKeH3/PH3D9El9zfURaq2vT3KjHa693fbW6bLAi3yh1nrj7JABnnP9X01XH33RjMOTHyZ8eJJn9vPOURLsAtCie5RuPdxpL0jyM9OB7gpcY4f+U0p5Tq111hq6N9uhzWfSZWWXUmt9cynlvaWUH+5rmt+arqZ51rl/n24N2mX7nNfHv6Vb23Ze+0PpAuJB1FrPzdQayCvw9elqqqfdJsmf96URN0/3QUr2QLALQFNqrWemW71gmWtca5fnfSBTwVCt9bwc5b+vtdY3pgt69qzWer9dnPawJA9app+xqbU+eOLb39tlm2ssPmvX/T8nyXP6b28/dezxSR6/4BInJvmRusPmKCwm2AWADdEvS/achSeyNmqtLxt6DGPnA2oAADSriczuosXi5y2ov9eF+FfRblHbvRpirPs9r4uM6Xne780hxvQ8D/HzA8C4NRHsAjB70wSAxu24+902ZQwAADSrrNOSbRccPLjjYLw9uf9vbzMuXh/zzZuf4w8dWpgZGIH1+WUOsH8W/v5WxgCwAQ4cODD0ENba1tbW3OPmj73y2lrOovnbDWUMAAA0S7ALAECzBLsAADRLsAsAQLMEuwAANEuwCwBAswS7AAA0a+PX2Z230PxetbKA/yrmho7XyM5amZuxWXYt0LG3X9bQ4x+6/SqvPfb2yxp6/EO3PxZkdgEAaFYTmd1F2SWZIgCAzSSzCwBAs0qtdegxXOKCgwd3HMy87Owymd0x1RzOG+sq+lSzuzr7/XwN8ZrcqyHGevyhQ2Ulne6vub/M96MubszWoa6QNnltLWfR/CVZ+PtbZhcAgGYJdgEAaJZgFwCAZgl2AQBoVhNLj43JEMukDfHBJfZu3T6I6DUCwJjJ7AIA0CzBLgAAzRLsAgDQLDW7+2yI+kc1l+Oy38+X1wcALWsi2PWPNQAAsyhjAACgWU1kdgFYzqL95w8cOLBU+2Ut6n/sVj1/7MxrazljmD+ZXQAAmiXYBQCgWYJdAACaJdgFAKBZgl0AAJq18asxjGmNXpsNMI8NSwDgSDK7AAA0a+MzuwCM37LrBI+9f1Zn6Od26P5bILMLAECzBLsAADRLsAsAQLPU7AIwekPXLQ7dP6sz9HM7dP8tkNkFAKBZgl0AAJrVRBnD4dMvnHt83sL389rud7sh+mxlrIuYO2MFYDM1EewCsBx1gUCrlDEAANCsJjK7y7x1ude2+91uiD7HNNZVXXcT5m5TxgrAZpLZBQCgWYJdAACaJdgFAKBZgl0AAJol2AUAoFml1jr0GC5xwcGDOw7Gp7CBZczbkOL4Q4fKPg5lVeb+MreOLjBGW1tbi05Z+PtbZhcAgGYJdgEAaNZoNpWY9xbkIqsogZg3nlWVXCwzBzsZ01jXzZjmboix7vfPHQDMIrMLAECzBLsAADRLsAsAQLMEuwAANGs0H1ADYHUWrWW5aJ1e7Te7/Sqvrf1mtz8WZHYBAGiWYBcAgGYJdgEAaFapde526vvqgoMHVzKYeYvb73VR/FW0W7btXoxprOtmTHM3xFiH+PnZq+MPHVq4t/oIzP39uR91cQDH2qKa3yQLf3/L7AIA0CzBLgAAzRLsAgDQrI1fZ3defWAL/S1jTGNdN2OauyHGOqb5AWDcZHYBAGiWYBcAgGYJdgEAaJZgFwCAZgl2AQBolmAXAIBmCXYBAGjWxq+zC7AJdrG//FwHDhw4RiOZbdH4hu5/WWMf/zob+9wOPf6h+98PaxXsWmh+vjHNz5jGum7M3c5WNTf10EouC8AaUMYAAECz1iqzC0fj3Ju8ccdjtz7vlH0bBwCwvgS7jM68IHf6HEEvAGw2ZQwAADRLZpfR2Clbe+5N3jjzsXltAIDNILMLAECzZHYZnVnZWhlcGNa6r+W56v6XNfbxr7Oxz+3Q4x+6/2NBZhcAgGY1kdk9fPqFc4/PW4h+XttVLGA/xFhX0W6IPg9/sPtzVvZ2UUZ30+duVWPdqzGNFYBxk9kFAKBZTWR22Syz1tndzdq7wOoMXbc3dP/LGvv419nY53bo8Q/d/7HQRLC7zFuX+/225xBj3e92q+rz3Jt8cK/D2fi5W0W7ZYxprACMmzIGAACaJdgFAKBZgl0AAJol2AUAoFmCXQAAmiXYZTRufd4pl9o8YvL7eccAgM0l2AUAoFlNrLPLZpnO2E5ndOedCwBslrUKdhfte79X+70Q/aruYxVWNTdjmoO9Mnf7z9wAcLSUMQAA0Ky1yuwCsDct7F8/pLHP336P/6EPfeiOx5797Gfv40hWb+yvjaGtev62trYWniOzCwBAswS7AAA0S7ALAECz1OwCAHNN1+jOq8s9mnNhP8jsAgDQLMEuAADN2vgyhnmL1O/3ZhTrZtEC/nudnzHN6143MVjV3LXCzx0A+2Xjg10AFq9VuWitzGXbL2vo8bfe/pxzzpl7fJlrD31vrbdf1tjHnyhjAACgYYJdAACatfFlDOoDd2Zu9s7czWd+ANgvpdY69BguccHBgysZzCr+Yd3rB5fWzaqCjlY+gDTEfXht7WxVc3P8oUNlJRfeRwcOHFifX+Z7sA51fexseu3co2Gd3WG1/rO1tbW18Pe3MgYAAJol2AUAoFmCXQAAmiXYBQCgWYJdAACaJdgFAKBZgl0AAJq1VptK7HVdziHWJR3TWrGMi9fW3u117uqhYzwQaIy1chkzmV0AAJol2AUAoFmCXQAAmrVWNbsA7M3W1tZKr3/gwIGVXn+RVd/fqq16/sY+P8swt+ttHeZPZhcAgGYJdgEAaJZgFwCAZgl2AQBo1sZ/QG3ehhTzFqhfRbsh+lzVWDfdpjzP+z1WADhaMrsAADRLsAsAQLM2voxhr2+Z7ne7Ifpc97eTz37QC3c8ds/n328fR3KkTXmeW31tcfQWraU59Dq9627V8zfm+V92nVavzeW0MH8yuwAANGvjM7uMz7yM7vQ5Q2d4AYBhyewCANAsmV1GY6ds7dkPeuHMx+a1AY6tMdTtrTPztzrmdjktzJ/MLgAAzRLsMjpnP+iFR9TtznoMAECwCwBAs9TsMjqz6m/V5AIAs6xVsHv49AuHHsKujWms84xxcf9Z5QotlTB4be1dK3MHwLGjjAEAgGYJdgEAaJZgFwCAZgl2AQBolmAXAIBmCXYBAGjWWi09BsDeLLt//dbW1jEayd4sO35YFa/N1dqP310yu4zGPZ9/v0ttHjH5/bxjAMDm2vjM7rxF6Me44cKxtGiB/qHmZzqInQ5y5527X9Z17taFnzsA9ovMLgAAzRLsAgDQLMEuAADN2viaXfWBO1vV3CyqZ22B19V85geA/SKzCwBAszY+swvA8hatdbloLU3tl2vPzoZ+bja9/TqQ2QUAoFmCXQAAmiXYBQCgWWp2AVjasnV72q9/3eOkhz3sYTseO+OMM/ZxJIsN/dxsevt1ILMLAECzBLsAADRrrcoYWllovpX7WIY52Lsxzd1+bxCyqrmph1ZyWQDWwFoFuwDA+pmu0Z1Xl3s058J+EOwyWu958puOeOxmP3vHSx3b/h4A2ExqdgEAaJbMLqMzK6M7fUyGFwBIBLs0Yjq4FeQCAIkyBgAAGiazy+hMZ2vf8+Q3zS1tAAA2l8wuAADNaiKzu2hh+3kL0c9ru4oF7IcY6yraDdHnTu0mM707ZXjN3WrGuldjGutYbG1tDT2EpYx9/IscOHBg6CEcU9Nr6bbMa3O19mN+ZXYBAGiWYBcAgGY1UcawzFuX+/225xBj3e92+93nbj6cZu6OfbtljGmsAIybzC4AAM1qIrPLZrHMGACwWzK7AAA0S2aXJtkmGPbXouWDhl7eaN2t+/ydccYZg/a/jHWf23XXwvzJ7AIA0CyZXUZnO2s7q3ZXRhcAmCSzCwBAs2R2GS1ZXFgfY6jbW2ernr/Wt7ydx2tzOS3M31oFu4v2vd+rVhainzc/q7jHRc/HEH2uwibM3Sb8DADALMoYAABolmAXAIBmCXYBAGiWYBcAgGYJdgEAaJZgFwCAZgl2AQBolmAXAIBmrdWmEkPY9I0aVrXZQCsbHGzC3G3KaxKAzSSzCwBAszY+swvQghb2r5+n9ftb1n7Pz1lnnbXrcx/4wAeucCSrN/bX3tbW1qD9r3r+dnN/MrsAADRr4zO7+10fOEQ94qb0uQqbMHebcI8AbC6ZXQAAmrXxmV0AYO+ma3KPpp4X9oPMLgAAzRLsAgDQLMEuAADNUrMLwMK1Khetlbls+2UNPf5Na380a+eO7d5aa7+ssY8/kdkFAKBhgl0AAJo1mjKGdVuEft3G0wrzOi5j2gDj8OkXHsORADAWpdY69BguccHBgzsORhA0/x/rVczPouDAc7KzIeZuv18fYzNvfo4/dKjs41BW4sCBA+vzy3wP1qGuj71ZtK7u0dT3cuy1/rO1tbW18Pe3MgYAAJol2AUAoFmCXQAAmjWaD6gBAOthXp2uGl3WjcwuAADNEuwCANAswS4AAM1qomZ3mTVN97o26SraLWq7V0Osv7rf8zq28cwzpvWUN+HnBziSulzGRGYXAIBmCXYBAGiWYBcAgGY1UbO7TJ3eXtvud7tljKnPVY113cazTn36+WnD1tbWSq9/4MCBlV5/aIvmb9X3v+79r9LQ99Z6/0Mb8rW1TWYXAIBmCXYBAGiWYBcAgGYJdgEAaJZgFwCAZgl2AQBolmAXAIBmNbHOLgDDGnotUf23u5br0Pem//G/tjY+2D18+oXH/JqtLHw/xH2s4vlYZBX32cprYBE/PwCsO2UMAAA0S7ALAECzNr6MAYDlDV23p//1r5vcq6HvTf/jf23J7AIA0CzBLgAAzRLsAgDQLMEuAADNEuwCANCsJlZjWLSw/TotUj/EWOf1Oa+/Zca61z7XzSbMnZ8fAFomswsAQLMEuwAANKuJMoYxvXU5xFj32ucyYx3TczLPJszdmJ6rMY0VgPUgswsAQLMEuwAANKuJMgYA5lv3/e3XfXytM//tGvq5Xbb/ra2tpccgswsAQLMEuwAANEuwCwBAswS7AAA0S7ALAECzrMbQiMOnX7jvfVrgf2dDPB8AwJFkdgEAaFYTmd1FWTQZSID5Fq1luWitzLG3X9bY+1/l+Nd5bLsx9v7H3v5YkNkFAKBZgl0AAJol2AUAoFlN1OwCsJxl6+bG3n5ZY++bBRIbAAAFuElEQVR/leNf57FtQv9jb38syOwCANAswS4AAM1SxrDPNmWZtHn3OaZ7bOU+5tmU1yQAm0lmFwCAZgl2AQBolmAXAIBmqdndZ5tS/9jKfbZyH/Nswj0CsLmaCHb9Yw0AwCzKGAAAaJZgFwCAZgl2AQBoVhM1uwAMa2tra6XXP3DgwEqvP7Sh72/Vz988q773oed2aH42ZXYBAGiYYBcAgGYJdgEAaNbG1+yOaY3eMY11r1q5x1buY5FNuU8AxktmFwCAZgl2AQBolmAXAIBmbXzNLgCL1+Jc97U0hx7/sv0P3X6dDT03Q8/t0P0vax3GL7MLAECzBLsAADRLsAsAQLPU7AKw9nV/iww9/mX7H7r9Oht6boae26H7X9Y6jL+JYPfw6RfOPT5v4ft5bfe73RB9tjLWRcydsQKwmZQxAADQLMEuAADNaqKMYZm3Lvfadr/bDdHnmMa6qutuwtxtylgB2EwyuwAANEuwCwBAswS7AAA0S7ALAECzBLsAADRLsAsAQLMEuwAANKuJdXYBWM7W1tbc4+uwv/2QFs3Pqm36/C9j6OeO4Y0m2D18+oVDDwEAgJFRxgAAQLMEuwAANEuwCwBAs9aqZvdaz3zm0EMANlA9dGjoIQCwIjK7AAA0a60yu6v02teenCQ59dS3Xer7SdvHxtwnHCuvPumkS/5+t3e+c8CRAMDebUSw+9rXnjw34Jw8L1k+AN1NkHus+4RjZTvInQxwJwPf6WOwDjZ9neCW77/le9uNTb//Y0EZAwAAzdqIzO6pp77tiOzqbrKuy/Q369qzxgHrZl7WdvvYrOwvAKwjmV0AAJq1EZndWSYzrPtVLztEn7AKMrysm02vW2z5/lu+t93Y9Ps/FmR2AQBo1sZmdofIrMrmMkavPumkIzK306szAMC6ktkFAKBZG5vZBebbzXq7k4+r2x03dYFAqzY22PUBNdgdgSwAY6aMAQCAZm1sZnfSThtOtNYn7MasUoXpkgbbBwMwFjK7AAA0a2Myu6vcHnid+oT9IJMLwFjI7AIA0KyNyexum5dtXVXd7BB9wl7NqsuVyQVgrDYu2AXg2Bv7Or1jH/+yWr7/lu9tNzb9/pMNDnZtFwzzyeYC0AI1uwAANEuwCwBAswS7AAA0S7ALAECzBLsAADRLsAsAQLM2dukxOFbOOXSHIx67w8FzBhgJ7N3W1tbc44vW6tR+s9uv8trab3b7Y0FmFwCAZsnswh5NZ3S3s7nnHLrDJcdkeAFgWIJdOEo7BbmT32+fI+gFgGGVWuvQY7hEKWV9BgM7WBTs7vYc1kettQw9hmNg7u/P/aiLAzjWFtX8Jln4+1vNLgAAzRLsAgDQLMEuAADNEuwCANAsqzHAUZpcYmzyz+nHJx8DAIYh2IU92inonTwGAAxLGQMAAM2S2YUlyeICwPqS2QUAoFmCXQAAmiXYBQCgWaXWudup76tSyvoMBtgYtdaFe6uPgN+fwCZa+PtbZhcAgGYJdgEAaJZgFwCAZllnF6ABpQxbdvzZz372Ut9f6UpX2qj+YShnnnnmpb5/8IMfPMg4hrKbz57J7AIA0CzBLgAAzRLsAgDQLOvsAhuvhXV29/v353SN7CLHuoZ26P6B9bCb398yuwAANEuwCwBAswS7AAA0yzq7AA1Y97rjoT8fMnT/wHBkdgEAaJZgFwCAZgl2AQBolmAXAIBmCXYBAGiWYBcAgGYJdgEAaJZgFwCAZgl2AQBolmAXAIBmCXYBAGiWYBcAgGaVWuvQYwAAgJWQ2QUAoFmCXQAAmiXYBQCgWYJdAACaJdgFAKBZgl0AAJol2AUAoFmCXQAAmiXYBQCgWYJdAACaJdgFAKBZgl0AAJol2AUAoFmCXQAAmiXYBQCgWYJdAACaJdgFAKBZgl0AAJol2AUAoFmCXQAAmiXYBQCgWYJdAACaJdgFAKBZgl0AAJol2AUAoFn/H7/x+ck6JhE4AAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(10, 6))\n", "plt.subplot(121)\n", "plt.title(\"원본 관측 (160×210 RGB)\")\n", "plt.imshow(obs)\n", "plt.axis(\"off\")\n", "plt.subplot(122)\n", "plt.title(\"전처리된 관측 (88×80 그레이스케일)\")\n", "plt.imshow(img.reshape(88, 80), interpolation=\"nearest\", cmap=\"gray\")\n", "plt.axis(\"off\")\n", "save_fig(\"preprocessing_plot\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## DQN 만들기" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [], "source": [ "reset_graph()\n", "\n", "input_height = 88\n", "input_width = 80\n", "input_channels = 1\n", "conv_n_maps = [32, 64, 64]\n", "conv_kernel_sizes = [(8,8), (4,4), (3,3)]\n", "conv_strides = [4, 2, 1]\n", "conv_paddings = [\"SAME\"] * 3 \n", "conv_activation = [tf.nn.relu] * 3\n", "n_hidden_in = 64 * 11 * 10 # conv3은 11x10 크기의 64개의 맵을 가집니다\n", "n_hidden = 512\n", "hidden_activation = tf.nn.relu\n", "n_outputs = env.action_space.n # 9개의 행동이 가능합니다\n", "initializer = tf.variance_scaling_initializer()\n", "\n", "def q_network(X_state, name):\n", " prev_layer = X_state / 128.0 # 픽셀 강도를 [-1.0, 1.0] 범위로 스케일 변경합니다.\n", " with tf.variable_scope(name) as scope:\n", " for n_maps, kernel_size, strides, padding, activation in zip(\n", " conv_n_maps, conv_kernel_sizes, conv_strides,\n", " conv_paddings, conv_activation):\n", " prev_layer = tf.layers.conv2d(\n", " prev_layer, filters=n_maps, kernel_size=kernel_size,\n", " strides=strides, padding=padding, activation=activation,\n", " kernel_initializer=initializer)\n", " last_conv_layer_flat = tf.reshape(prev_layer, shape=[-1, n_hidden_in])\n", " hidden = tf.layers.dense(last_conv_layer_flat, n_hidden,\n", " activation=hidden_activation,\n", " kernel_initializer=initializer)\n", " outputs = tf.layers.dense(hidden, n_outputs,\n", " kernel_initializer=initializer)\n", " trainable_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,\n", " scope=scope.name)\n", " trainable_vars_by_name = {var.name[len(scope.name):]: var\n", " for var in trainable_vars}\n", " return outputs, trainable_vars_by_name" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [], "source": [ "X_state = tf.placeholder(tf.float32, shape=[None, input_height, input_width,\n", " input_channels])\n", "online_q_values, online_vars = q_network(X_state, name=\"q_networks/online\")\n", "target_q_values, target_vars = q_network(X_state, name=\"q_networks/target\")\n", "\n", "copy_ops = [target_var.assign(online_vars[var_name])\n", " for var_name, target_var in target_vars.items()]\n", "copy_online_to_target = tf.group(*copy_ops)" ] }, { "cell_type": "code", "execution_count": 62, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "{'/conv2d/bias:0': ,\n", " '/conv2d/kernel:0': ,\n", " '/conv2d_1/bias:0': ,\n", " '/conv2d_1/kernel:0': ,\n", " '/conv2d_2/bias:0': ,\n", " '/conv2d_2/kernel:0': ,\n", " '/dense/bias:0': ,\n", " '/dense/kernel:0': ,\n", " '/dense_1/bias:0': ,\n", " '/dense_1/kernel:0': }" ] }, "execution_count": 62, "metadata": {}, "output_type": "execute_result" } ], "source": [ "online_vars" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [], "source": [ "learning_rate = 0.001\n", "momentum = 0.95\n", "\n", "with tf.variable_scope(\"train\"):\n", " X_action = tf.placeholder(tf.int32, shape=[None])\n", " y = tf.placeholder(tf.float32, shape=[None, 1])\n", " q_value = tf.reduce_sum(online_q_values * tf.one_hot(X_action, n_outputs),\n", " axis=1, keepdims=True)\n", " error = tf.abs(y - q_value)\n", " clipped_error = tf.clip_by_value(error, 0.0, 1.0)\n", " linear_error = 2 * (error - clipped_error)\n", " loss = tf.reduce_mean(tf.square(clipped_error) + linear_error)\n", "\n", " global_step = tf.Variable(0, trainable=False, name='global_step')\n", " optimizer = tf.train.MomentumOptimizer(learning_rate, momentum, use_nesterov=True)\n", " training_op = optimizer.minimize(loss, global_step=global_step)\n", "\n", "init = tf.global_variables_initializer()\n", "saver = tf.train.Saver()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "노트: 처음 책을 쓸 때는 타깃 Q-가치(y)와 예측 Q-가치(q_value) 사이의 제곱 오차를 사용했습니다. 하지만 매우 잡음이 많은 경험 때문에 작은 오차(1.0 이하)에 대해서만 손실에 이차식을 사용하고, 큰 오차에 대해서는 위의 계산식처럼 선형적인 손실(절대 오차의 두 배)을 사용하는 것이 더 낫습니다. 이렇게 하면 큰 오차가 모델 파라미터를 너무 많이 변경하지 못합니다. 또 몇 가지 하이퍼파라미터를 조정했습니다(작은 학습률을 사용하고 논문에 따르면 적응적 경사 하강법 알고리즘이 이따금 나쁜 성능을 낼 수 있으므로 Adam 최적화대신 네스테로프 가속 경사를 사용합니다). 아래에서 몇 가지 다른 하이퍼파라미터도 수정했습니다(재생 메모리 크기 확대, e-그리디 정책을 위한 감쇠 단계 증가, 할인 계수 증가, 온라인 DQN에서 타깃 DQN으로 복사 빈도 축소 등입니다)." ] }, { "cell_type": "code", "execution_count": 64, "metadata": {}, "outputs": [], "source": [ "from collections import deque\n", "\n", "replay_memory_size = 500000\n", "replay_memory = deque([], maxlen=replay_memory_size)\n", "\n", "def sample_memories(batch_size):\n", " indices = np.random.permutation(len(replay_memory))[:batch_size]\n", " cols = [[], [], [], [], []] # 상태, 행동, 보상, 다음 상태, 계속\n", " for idx in indices:\n", " memory = replay_memory[idx]\n", " for col, value in zip(cols, memory):\n", " col.append(value)\n", " cols = [np.array(col) for col in cols]\n", " return cols[0], cols[1], cols[2].reshape(-1, 1), cols[3], cols[4].reshape(-1, 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### ReplayMemory 클래스를 사용한 방법 ==================" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "랜덤 억세스(random access)가 훨씬 빠르기 때문에 deque 대신에 ReplayMemory 클래스를 사용합니다(기여해 준 @NileshPS 님 감사합니다). 또 기본적으로 중복을 허용하여 샘플하면 큰 재생 메모리에서 중복을 허용하지 않고 샘플링하는 것보다 훨씬 빠릅니다." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class ReplayMemory:\n", " def __init__(self, maxlen):\n", " self.maxlen = maxlen\n", " self.buf = np.empty(shape=maxlen, dtype=np.object)\n", " self.index = 0\n", " self.length = 0\n", " \n", " def append(self, data):\n", " self.buf[self.index] = data\n", " self.length = min(self.length + 1, self.maxlen)\n", " self.index = (self.index + 1) % self.maxlen\n", " \n", " def sample(self, batch_size, with_replacement=True):\n", " if with_replacement:\n", " indices = np.random.randint(self.length, size=batch_size) # 더 빠름\n", " else:\n", " indices = np.random.permutation(self.length)[:batch_size]\n", " return self.buf[indices]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "replay_memory_size = 500000\n", "replay_memory = ReplayMemory(replay_memory_size)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def sample_memories(batch_size):\n", " cols = [[], [], [], [], []] # 상태, 행동, 보상, 다음 상태, 계속\n", " for memory in replay_memory.sample(batch_size):\n", " for col, value in zip(cols, memory):\n", " col.append(value)\n", " cols = [np.array(col) for col in cols]\n", " return cols[0], cols[1], cols[2].reshape(-1, 1), cols[3], cols[4].reshape(-1, 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### =============================================" ] }, { "cell_type": "code", "execution_count": 65, "metadata": {}, "outputs": [], "source": [ "eps_min = 0.1\n", "eps_max = 1.0\n", "eps_decay_steps = 2000000\n", "\n", "def epsilon_greedy(q_values, step):\n", " epsilon = max(eps_min, eps_max - (eps_max-eps_min) * step/eps_decay_steps)\n", " if np.random.rand() < epsilon:\n", " return np.random.randint(n_outputs) # 랜덤 행동\n", " else:\n", " return np.argmax(q_values) # 최적 행동" ] }, { "cell_type": "code", "execution_count": 66, "metadata": {}, "outputs": [], "source": [ "n_steps = 4000000 # 전체 훈련 스텝 횟수\n", "training_start = 10000 # 10,000번 게임을 반복한 후에 훈련을 시작합니다\n", "training_interval = 4 # 4번 게임을 반복하고 훈련 스텝을 실행합니다\n", "save_steps = 1000 # 1,000번 훈련 스텝마다 모델을 저장합니다\n", "copy_steps = 10000 # 10,000번 훈련 스텝마다 온라인 DQN을 타깃 DQN으로 복사합니다\n", "discount_rate = 0.99\n", "skip_start = 90 # 게임의 시작 부분은 스킵합니다 (시간 낭비이므로).\n", "batch_size = 50\n", "iteration = 0 # 게임 반복횟수\n", "checkpoint_path = \"./my_dqn.ckpt\"\n", "done = True # 환경을 리셋해야 합니다" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "학습 과정을 트래킹하기 위해 몇 개의 변수가 필요합니다:" ] }, { "cell_type": "code", "execution_count": 67, "metadata": {}, "outputs": [], "source": [ "loss_val = np.infty\n", "game_length = 0\n", "total_max_q = 0\n", "mean_max_q = 0.0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이제 훈련 반복 루프입니다!" ] }, { "cell_type": "code", "execution_count": 68, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from ./my_dqn.ckpt\n", "반복 13269992\t훈련 스텝 3999999/4000000 (100.0)%\t손실 1.899633\t평균 최대-Q 220.667374 " ] } ], "source": [ "with tf.Session() as sess:\n", " if os.path.isfile(checkpoint_path + \".index\"):\n", " saver.restore(sess, checkpoint_path)\n", " else:\n", " init.run()\n", " copy_online_to_target.run()\n", " while True:\n", " step = global_step.eval()\n", " if step >= n_steps:\n", " break\n", " iteration += 1\n", " print(\"\\r반복 {}\\t훈련 스텝 {}/{} ({:.1f})%\\t손실 {:5f}\\t평균 최대-Q {:5f} \".format(\n", " iteration, step, n_steps, step * 100 / n_steps, loss_val, mean_max_q), end=\"\")\n", " if done: # 게임이 종료되면 다시 시작합니다\n", " obs = env.reset()\n", " for skip in range(skip_start): # 게임 시작 부분은 스킵합니다\n", " obs, reward, done, info = env.step(0)\n", " state = preprocess_observation(obs)\n", "\n", " # 온라인 DQN이 해야할 행동을 평가합니다\n", " q_values = online_q_values.eval(feed_dict={X_state: [state]})\n", " action = epsilon_greedy(q_values, step)\n", "\n", " # 온라인 DQN으로 게임을 플레이합니다.\n", " obs, reward, done, info = env.step(action)\n", " next_state = preprocess_observation(obs)\n", "\n", " # 재생 메모리에 기록합니다\n", " replay_memory.append((state, action, reward, next_state, 1.0 - done))\n", " state = next_state\n", "\n", " # 트래킹을 위해 통계값을 계산합니다 (책에는 없습니다)\n", " total_max_q += q_values.max()\n", " game_length += 1\n", " if done:\n", " mean_max_q = total_max_q / game_length\n", " total_max_q = 0.0\n", " game_length = 0\n", "\n", " if iteration < training_start or iteration % training_interval != 0:\n", " continue # 워밍엄 시간이 지난 후에 일정 간격으로 훈련합니다\n", " \n", " # 메모리에서 샘플링하여 타깃 Q-가치를 얻기 위해 타깃 DQN을 사용합니다\n", " X_state_val, X_action_val, rewards, X_next_state_val, continues = (\n", " sample_memories(batch_size))\n", " next_q_values = target_q_values.eval(\n", " feed_dict={X_state: X_next_state_val})\n", " max_next_q_values = np.max(next_q_values, axis=1, keepdims=True)\n", " y_val = rewards + continues * discount_rate * max_next_q_values\n", "\n", " # 온라인 DQN을 훈련시킵니다\n", " _, loss_val = sess.run([training_op, loss], feed_dict={\n", " X_state: X_state_val, X_action: X_action_val, y: y_val})\n", "\n", " # 온라인 DQN을 타깃 DQN으로 일정 간격마다 복사합니다\n", " if step % copy_steps == 0:\n", " copy_online_to_target.run()\n", "\n", " # 일정 간격으로 저장합니다\n", " if step % save_steps == 0:\n", " saver.save(sess, checkpoint_path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "아래 셀에서 에이전트를 테스트하기 위해 언제든지 위의 셀을 중지할 수 있습니다. 그런다음 다시 위의 셀을 실행하면 마지막으로 저장된 파라미터를 로드하여 훈련을 다시 시작할 것입니다." ] }, { "cell_type": "code", "execution_count": 69, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from ./my_dqn.ckpt\n" ] } ], "source": [ "frames = []\n", "n_max_steps = 10000\n", "\n", "with tf.Session() as sess:\n", " saver.restore(sess, checkpoint_path)\n", "\n", " obs = env.reset()\n", " for step in range(n_max_steps):\n", " state = preprocess_observation(obs)\n", "\n", " # 온라인 DQN이 해야할 행동을 평가합니다\n", " q_values = online_q_values.eval(feed_dict={X_state: [state]})\n", " action = np.argmax(q_values)\n", "\n", " # 온라인 DQN이 게임을 플레이합니다\n", " obs, reward, done, info = env.step(action)\n", "\n", " img = env.render(mode=\"rgb_array\")\n", " frames.append(img)\n", "\n", " if done:\n", " break" ] }, { "cell_type": "code", "execution_count": 70, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 70, "metadata": {}, "output_type": "execute_result" } ], "source": [ "video = plot_animation(frames, figsize=(5,6))\n", "HTML(video.to_html5_video()) # HTML5 동영상으로 만들어 줍니다" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 추가 자료" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 브레이크아웃(Breakout)을 위한 전처리" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "다음은 Breakout-v0 아타리 게임을 위한 DQN을 훈련시키기 위해 사용할 수 있는 전처리 함수입니다:" ] }, { "cell_type": "code", "execution_count": 71, "metadata": {}, "outputs": [], "source": [ "def preprocess_observation(obs):\n", " img = obs[34:194:2, ::2] # 자르고 크기를 줄입니다.\n", " return np.mean(img, axis=2).reshape(80, 80) / 255.0" ] }, { "cell_type": "code", "execution_count": 72, "metadata": {}, "outputs": [], "source": [ "env = gym.make(\"Breakout-v0\")\n", "obs = env.reset()\n", "for step in range(10):\n", " obs, _, _, _ = env.step(1)\n", "\n", "img = preprocess_observation(obs)" ] }, { "cell_type": "code", "execution_count": 73, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlkAAAF2CAYAAABd6o05AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAGgtJREFUeJzt3Xu4blVdL/Dvz6AECSoktVAJt4qCFhRKhokdA7yUmlZ67KlO2jnlNXaUZRfJboaIWmpPJzNL8hhdTEUE07RMxQvbS3g7eUHbHhHwglcU4Xf+mHPJy+Jde29wj7322vvzeZ75sN45xxxzzHcB7/f5jTHfVd0dAAB2rpus9wAAAPZEQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQtYuVlX7VNUzqur9VfXeqvrXqjpmPvYLVfWChbaPqqpzlvRxWlU9c43+z6qqixe2y6rqb+ZjF1TVCUvOOX/VOQfuwH3cpKr+vqqO3vG7Z7Wq+rGq+qP1HgcAO5+Qtev9YpIjktylu++U5Mwkf7uN9vepqksWtySnrtW4u3+quw9b2ZL8VpLPLmtbVcdU1fuS3DbJlUm+OB+6tKruup37eHySD3b32+e+9q+q06uqq+q4Vdf5hqr6o6p6d1W9tapes3DsofP+d1XVv1XVHbZz3WX3cYuq+rM5tL6lql5fVXdZOP5tVfXCeWy3XHXuY6rqfVV1UVWdU1W3WNL/aVX1iTmkvmHe7rNwfN+qemJVvaOq3llVb6yqc6vqnvPxF1TVR+fz31JVb6uq+ydJd/9jkjtV1Q/f0PsGYPcmZO16lel9X3nv99lO+/O6+5aLW5IztnmBqgNWqmNJDkly6bJ23b2lu4/o7iOSHJPkOUk+n+RHuvtd2+h/vyRPXDWOX03y3iQfWXLKXyTZmuSo7j42yU/M/WxK8mdJHtjdd03y/CT/sOpaZ1bVTyy8/raqenVVHbLQ7Jgk53f3nbr7bkn+KcnTF44/JcnZS+7jhPk+7tndRyV52zzWZV7S3cd19w8keVySl1TVzedjL0pydJIf7u7v7u57JPnpJFctnP/c+fy7JTklyV8tHPvDJKevcV0ANqjtfcCz8/1pktsneU9VdZJLkjx8G+1Pqqqtq/YdmCmQrOWIJGfN/9yU5HULx36jqk7q7l+fKy1HJTkhUzXr35NckeTxVfXdSV61Rtj60STv6u7LVnZ092lJUlVPXmw4V6ZuleS7kry5qj6dqbr2ySQPyRSOPjA3f2GSM6rqLt39H/O+30/yyqraP8m5SV6Z5Omrrv3KVeP7eBb+3e7ux85jWX0fP5nkrIW+npXksqo6qLuvWHLfK/1tqarPJjlsDmp3TPK93X3VQpvLk1y+Rhe3TfLOhbZvqqqbV9WR3f3uta4LwMYiZO1CVXXnJHdI8pokr03yTUluluTeVfWQrKoCdffzkjxvPvfUJEd096Nu4GV/L8mnq+qg+fV5Sd4w/7wpU7XlD7r7ax/6c4XpXpmqbst8bxZCwnbcM8ndkjytuzdX1clJzq2q2yU5PMkHVxp299VV9ZF5/3/M+z45T6Wdm+S0JJvnKbal5um+pyR55A6M7fAsVLi6+9NVdUWSw7Z1f1X1oCRfSXJRkp9N8rKVgFVVj0vyiCTfnClAbp5Pe/R83iGZ/rt78Kpu35npfRWyAPYQQtaudViS4+efvzPJDyZ5RqaKx4WZAtiNVlW/k6k6dNMkt66qizMFpS8nefHc7MLuvqCqzp/HsHLuWn2+sLtXL8w+IMkndnBY357k37v71UnS3efN4zp+HtvVq9p/Ndefxv7WTNW7j2cKhmuN9eDMYay7/3UHxraj10+SB1fV92SqQl2Q5LjuvnJ+3/ZbadTdf5LkT6rq1zJVElc8t7ufOo/ze5K8oqpOXKhcfT7JQQFgjyFk7ULdfW6mELCyHuj23X1GVb01U4Xjm5O8vKoOzfRBXvN2zUofS6YOH9nd588/PzVTaDswyeeSXNHdi+eeuHDeg7Nja/K+smTfRzJNee6IS3P9hffXZAo3WzMFz0W3nvcnSarqjkn+MckvJHlLkr+rqv1XpicX2t0q01Ti6d39oh0c29Ykt1noY/8kBy9ef8FLuvsX5gXrz84UipLp97S5qm6y+F5vS3e/o6relOSkXFu5unWSD+/guAHYACx8XwdV9bdJ7rzyuruPnZ8E/I359dbuPjTTguhzu/vQbWznL/Tzpe7+TJK/TvID2/rQ7+4vJvnhJO9bY/vn7v58dy8LWeck+aEdvN1zkpy48tRgVR2fKdi8KdOC8fvPoTJV9eOZnnK8cOH8Byb5H939+u7+cqZK3cFz1SrzebfNNAX7uzcgYCXTGrBHLEylPibJGxbXe63W3a/INJX51HnX2ZkC43Oq6oCFpoeu1cc8pXn3JG+eXx+U5MhMU8gA7CFUstbHHTN9dcNfrtp/eZKPDrzuHyR5/8qL7n5JkpesblRVR2Rau7VUd7+7qj5YVfddsuh8ddtPVNUjkvxDVV2VaeryQfPC8ivmNUwvn499Lsl9F8Nhd5++qr+rMj3dt+jpSW6R5Feq6lfmfV/u7nttZ2yvrarnJPnX+fr/L8nDtnXO7AlJLqqqs7v79VV17yRPSnJBVV2dKShemus+4biyJmvl6dLf7O6VtXE/n+SvuvsLO3BtADaI6u71HsNep6rekemJu6uWHH53d580t3tskj9K8ukl7T7Q3Ses0f/rMj01eOWSw0/u7rW+pmDl/CMyfXXEYdtoc6ckf57k3otP1XHDzBW5VyY5qbuX/Z4B2KCELG60qrpXkk9290XrPZaNqqp+MMnl3f2e9R4LADuXkAUAMICF7wAAAwhZAAAD7FZPF85/ZgbYQ3X3Wn9FAGCPo5IFADCAkAUAMICQBQAwgJAFADCAkAUAMICQBQAwwG71FQ57ujPPPHOn9bV58+YNf421+t9V9pT7WMta97e7jhdgT6OSBQAwgJAFADCA6cLdwJ4w9berrrErmE4DYGdQyQIAGEDIAgAYwHQhrHJDpzdNLwKwjEoWAMAAQhYAwACmC9lr3dBpvo32lCQA60slCwBgACELAGAA04W7gV0xDbWnXGNn2mjjBWBjUckCABhAyALYQ1VVrfcYYK/W3bvNlqRtNtueu633/2P29C3JxUm2Jvno/M/PJdl3PvbiJD+70PbUJJfM21tX9bM1yab55z+eXy9ut5uPHZHk4oXzDkty9TyOxe0uSTYl2bpkzN+3pP1VSW6f5OQkr1tyzn1Xtf/tJLUD788PJPmb9f49Dfrd/7dVv6Pj5/2vS3KfJE9NctpC+4fN7T6R5JcX9t90W/+tJvmXJPffwTFVkt9LclqSn01y1sKxk5JsSXJZktcnOXbh2AVJTph/fmaSB6z3+3tjN2uyAPYQ3X3Yys9V9ZtJvre7r1qj7RlJzlhoXz1/qq1q9/gkj19otzXJvtsYxmWL41g4b9Ma43hbpnC22PaSJJ9d1r6qnp/kHkmunHddleTXMwW+/77WoKrqZkn+Isk959d3T/K0TKFi3yQfSvL47v5YVX1jpnB5r0xB4fwkm7v76rX6X+Oa98sUMHq+zpb5Gp+rqoOS/HmSI5N8Q5IXdvfvL+njwCT/d0n3N0myf3cfkCTd/Zokh87n/Pt8vbXGdeckT09yfKb3+XVV9d7uPncHbuv2ST6/Rr+vy7VB+1szBbtl7W6b5IVJ7tfdb6uq+yd5WVVt6u4vrGr+pCQXVtWbu/uyHRjfbsV0IcAepqpul+SXkpxWVU+ZQ8sDF45vrqqt83ZxVX0myVPWa7zzmO5cVTevqpskOSjJJ5e16+6f6+4juvuIJA9P8oEkf5/k57dzicclOXfhg/pFSf60u++W5JhMVZ3T52NPyhRYjkxyVJI7JzllYaw3q6p/qarvWNh3YlX9xcLrb5zH9YTuvvt8jYOS/Nrc5FlJPtndR2aq5j2sqh665H4/2923XL0l+f5Mlcob4+eSPL+7P9zdn0zyjHnfNlXVQ+Z7+O1lU9HdfUJ3H9bdt0vyT5mqVMsck+Rdc8BOd78i073cYUmfX0zy15l+JxvOblXJ8rQXwNenqu6Y5KVJvpDkHt3925k+FF+80qa7z0xy5ty+krwzySt30hAOqaqLV+07Nck7tnPe6UnOSvLmJJd291fnz/E7VdVZmULJf2WqRH1PkntnCkafSnJ0kt+vqrclObu7v7Kk/5/JdYPEx+axVpJvTPIt874k+clM4eiaJNdU1XOS/Fbmyl93f6Gqzkjyqqp6QJK7ZpoWu99C/1dnmor79vn1fkm+OcnH5iD50ExhI939+TmgPTxTMNsRN8/aISZJ9q2q45IcuOTYpiRnL7x+T6YQuqaqOjxT5e9hmd7LZ+W6Fc6bJPnBTJW/AzJNzT4rye2WdPe2JEfN1cS3JPnRJDdL8v41Lv93Sd5QVb88/042jN0qZAFw41TVfkkenSnQPCHJq5KcM3/QnrqNU09Jckl3v7Gq3p/pw+6WC/2+IMmDct0porWmIC/ONPWVqro8yffN+9acLlzi0iQnz9Np+2ZaM/aCJB/JNAV2hyRvSnJ6d1+xcO/HZ1pHdr2AVVX7Z5pOfOfC7h9L8ookT0yyf5JzMk07JsnhST640PaD877Fez23qr6U5DWZ3psTu/uSheNXV9V9kpxbVc9KcnCSZ3f3c6vqVvM1P7TqGj+zI2/Q7OAkH1+4x5MzVY9WxrBPpuD3XWucX2v8fN1GVftmmob9gyS/M9/3a5O8vKrOSfK47v5wppmxn0zylSTfmeTWmcL+t2RaV/U13f1fVfXwTEH/Nknel2md1xeXjaG7PzD/Dg/NtN5wwxCyAPYMD8tU0fn+hWBzQqYKxfXWWs3HH5ppWnGfqjqqu+8479+6qulp3f3M63WwfftW1cFJviNTdWP19W+RKaQk04ft92WqaF2dqUL150k+3d2vrqoHJllZs/Tw+fxl9/SYTO/B4lTaAUmuybXruJJpuvD8JL+T6bPwuZnWKf1SptCxuP7qq1m+vOawJJ/OFJgOybUBZyWcvDTJ07r7eXNo/D9V9UuZqkg9j2l711jLrTMt+l90QXefsPD65fM6qdU+kCl0rrhTrhsqFz0wU/h7UHe/NUm6+0tVdVKSX03yT1V1/Px+/+L8+359kod198uq6veWddrdr81U7VrmbzJVLRd9IdNU5YYiZAHsAbr7L5P8ZVUdXFWHd/eHuvurmdbbXCeQzJWfJyV5cKYPumOSnF9V/3NeH7NNc4A4PMk3Lew7OsnL55fXZKp2nZcpLF2S6YNztUszVaD2zTSd9onu/vJCnycvtH1Frg1k29TdqxdmX5bky5nC3taqunmmJ+5+dF7MfnVVnZ7krZlC1tZMoe/i+fzbzPu+pqoem+THM01bHp7k7Kr6qe6+cG7y3UkO7u7nzWO6Yq5oPS3JczKFqsXKzPWusR1H5rqVsOuoqm/L2ovfn5/kn+eHCD6bqZr5xGUNu/vvs2QKc37f/nDeVq55bKaq4zO7+2Xbu4F5yvVBSw59Z5L/yBz8quqmmSp3H9len7sbC98B9iw/kmsXcF/PvAbpZZlCzfd3939190szrRE6YskpVyf5zXmB/IfmKcULkjw5CyGru9/e3YdmWpfzju6+VXffrruP7e4fybQO5zp68plMgeEfFwPWkrZfncPTKzNNLy3b7rskYGV+avIVmb7mIJmqT5fP97zix3PtmqAXZqrK1LzW6NFZCBrz1NWR8/U+193vzDT9+MCF/j6a5MCqutd8zjckeUiS989PfL44yWPmY/sleVR2fD1W5uu/d+H1NUnuVlUfq6oPZ5o6PH7Zid39niS/kuTVmZ54/OsdfLJwTfMTi89O8qju/t87ck53n9rdm1ZvSS5c1fSHkryxu5c+cbo7U8kC2DtsTvKF7u6qOmn1AuLuflOmtU5Ztf+RSR65rMOqWhbK9s+8oHuVT2VebH8DXJSp8rM4nnuuMZYXZ1pPtpZnZpoO/Kt5vdQDkpxeVU/ItI7s40keMbd96tx+5cP+33LtVOXKE2+/uGpc7830fV0rry+dn8Y7fa4iflOmUPTYuckpSf6sqi7MND35D5mqQIv3dEauGwQX7Zfk2VX1x/M9PTnTe38dVfXoZSd394syTZmuaZ5OPmNbbWbv7+6Tktx9B9ou9n9WkhOSfGbJ4cWvcnh0Vq3r2iiELIA9z/2WrKtKpg/VX90FT2jdco3rp6qet0ZF4q5rnHNFpirU16W731BV766qn+jus7v7LZk+4Je1/VKS/7UTrnlepinTZcc+lal6tq3zT822H1oYaq2pwp3splmyXi9TYHvrXAn8ane/ZPA4hqi+/nfPrZtnPOMZu89ggJ3ulFNO8WdeWDfzd1f99Mo6KXZ/VfVTSV666kGGDUMlC4C9wvz1DgLWBtLdZ633GL4eFr4DAAwgZAEADGC6EODrc4PXki77Ek1gfdzItek79B+xShYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAG/oPRG/evHm9hwB7tTPPPHO9hwCw21LJAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGGCf9R4AwN5my5Yt6z0EYBdQyQIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABjAN74D7GIHHHDAeg8B2AVUsgAABhCyAAAGELIAAAbY0GuyLjj55PUeAuzV3rjeAwDYjalkAQAMIGQBAAwgZAEADCBkAQAMsKEXvgNsRFdeeeV6DwHYBVSyAAAGELIAAAYQsgAABtjQa7Ku2fTZ9R4CAMBSKlkAAAMIWQAAAwhZAAADCFkAAAMIWQAAA2zopwsBNqL99ttvvYcA7AIqWQAAAwhZAAADbOjpwk8d+MX1HgIAwFIqWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADbOinCwE2ov/8z/9c7yEAs02bNg3rWyULAGAAIQsAYAAhCwBggA29JutTR3xlvYcAe7fL13sAALsvlSwAgAGELACAAYQsAIABhCwAgAGELACAATb004UAG9Ehhxyy3kMAdgGVLACAAYQsAIABNvR04Yuuuc16DwH2aieu9wAAdmMqWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADbOinCwE2omOPPXa9hwDMuntY3ypZAAADCFkAAANs6OnCr7z4tPUeAuzdTnzjeo8AYLelkgUAMICQBQAwgJAFADCAkAUAMICQBQAwgJAFADCAkAUAMICQBQAwgJAFADDAhv7G938577j1HgLs1R5w4pnrPQSA3ZZKFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwAD7rPcAANh7bNmy5Ws/H3PMMes4EhhPJQsAYAAhCwBgACELAGAAIQsAYAAhCwBgACELAGAAX+EAwC7jaxvYm6hkAQAMIGQBAAwgZAEADGBN1gZzwcknf+3n4847bx1HAgBsi0oWAMAAQhYAwABCFgDAAEIWAMAAQhYAwACeLtxgPFEIABuDShYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwAD7rPcAADayiy66aL2HwF5qy5Ytw69xzDHHDL/Genv7299+g885+uijd6idShYAwABCFgDAAEIWAMAAQhYAwABCFgDAAJ4uBPg6HHXUUXVDz+nuEUOBnc6/q18flSwAgAGELACAAXar6cJzvuXz6z0E9gIXnHzy8Gscd955w6+xO7jHq151w0445ZQxAwHYDalkAQAMIGQBAAwgZAEADCBkAQAMIGQBAAwgZAEADLBbfYUD7Ap7y9crALC+VLIAAAYQsgAABjBdCNxoN3Tq1Z+aBfYmtTv9he2q2n0GA+x03V3rPQaAXcV0IQDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwADV3es9BgCAPY5KFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAP8f4IPPCvV/SDoAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(10, 6))\n", "plt.subplot(121)\n", "plt.title(\"원본 관측 (160×210 RGB)\")\n", "plt.imshow(obs)\n", "plt.axis(\"off\")\n", "plt.subplot(122)\n", "plt.title(\"전처리된 관측 (80×80 그레이스케일)\")\n", "plt.imshow(img, interpolation=\"nearest\", cmap=\"gray\")\n", "plt.axis(\"off\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "여기서 볼 수 있듯이 하나의 이미지는 볼의 방향과 속도에 대한 정보가 없습니다. 이 정보들은 이 게임에 아주 중요합니다. 이런 이유로 실제로 몇 개의 연속된 관측을 연결하여 환경의 상태를 표현하는 것이 좋습니다. 한 가지 방법은 관측당 하나의 채널을 할당하여 멀티 채널 이미지를 만드는 것입니다. 다른 방법은 `np.max()` 함수를 사용해 최근의 관측을 모두 싱글 채널 이미지로 합치는 것입니다. 여기에서는 이전 이미지를 흐리게하여 DQN이 현재와 이전을 구분할 수 있도록 했습니다." ] }, { "cell_type": "code", "execution_count": 74, "metadata": {}, "outputs": [], "source": [ "from collections import deque\n", "\n", "def combine_observations_multichannel(preprocessed_observations):\n", " return np.array(preprocessed_observations).transpose([1, 2, 0])\n", "\n", "def combine_observations_singlechannel(preprocessed_observations, dim_factor=0.5):\n", " dimmed_observations = [obs * dim_factor**index\n", " for index, obs in enumerate(reversed(preprocessed_observations))]\n", " return np.max(np.array(dimmed_observations), axis=0)\n", "\n", "n_observations_per_state = 3\n", "preprocessed_observations = deque([], maxlen=n_observations_per_state)\n", "\n", "obs = env.reset()\n", "for step in range(10):\n", " obs, _, _, _ = env.step(1)\n", " preprocessed_observations.append(preprocess_observation(obs))" ] }, { "cell_type": "code", "execution_count": 75, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlUAAAEuCAYAAACnC+ctAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADAdJREFUeJzt3W2IpWd9x/HfPy5Zk12p1iS1UDG1Kwjii0yakkJrUg2iFq0PtRWkNIgxBVs1g40tqEiI1GKcRBFhUdt9IaUvfGHVQqGltTSVQM1EIYImRqykbUok1NaHrXH36otzJxxPZ2Z3Zv87Z2b284EDM/fc57qvE8jF91znYWuMEQAAzs1Fy54AAMBBIKoAABqIKgCABqIKAKCBqAIAaCCqAAAaiCoAgAai6oCrqqNVNarqG3O3Mff3T1XVm6eff7OqvrVw+8DcuY9U1ZUbXGP+fu+ajp2oqhur6oaqunvu3FFVT99krn9ZVTdu8VhurapPbnWfqrp7Yf4vrKrbp9sNVXXPmf+rAXtRVX2xql6zcGxxDXjvtFYt3v6nqtY2GdcaRotDy54Au2OMceyJn6vqx5ucdjTJPWOMN2xz7E8n+fQ0dm3nvlX1UJIj068/leRvtjj9iiQPn2HIn0vy6jHGl+eu8dvbmROwZ12W5D+2OmGMcVuS2xaPV9VtSR7f5D7WMFrYqbpAPfHsLcnrznGc587tgD2Q5HtV9ZSzvf8Y4xfGGM8aYzwryV9tcZ1K8utJXlRVhxf+/OHp8fzsjh4EsOdV1fOTPC/JC3c4xLOzQdBYw+hkp+oCNS0AqapPneM430xybBprJckdY4xTZ3qyNy1az0zyM0lWkqyf4VLvSvIPSb6c5GNVddMY4/T0t7ePMU5M4+7wkQB71RQk70/y50n+uKq+NMa4b+GcK5L82RbD/FKSF1TV65LcPsa4J7GG0UtUHXynkzxUVd+YO/atLc5/ZVXN//2OMcZHz/JatyY5Pvf7WpJTSb4+d+xfM1tUTif5QWZb+Q8m2fB9AlV1UWaL0e8k+eUxxner6pok/7jFlvhnq+pHSX48xnj+dOxtSW5O8tBZPhZgD6iqQ0k+lOTnk/xKkmuT/HVV3TrGmH9S+J0kN57lsP+9yXFrGOdEVB1wY4wfZHoWtomPJfnPud8/t933VCVJVb0+yeuTfGXu8Gpm2+3vm5vPlVuM8dEk/75w+FiSlyW5bozx3WmMm6rqjQvznveq+fcjTD6S5AtJbj/DQwH2lluTPCfJ9dN69vdVdV2SN0/BkiSZdn2+M73p+92bjPWZMcY7N/qDNYwOouoCUVVvy+yZzqIrMls4/t+zn6q6NMnTMlsU/muLsV+T5I7MnkH+RVU9doa5/GGSt2/wp2ckeWuSbz5xYIzxQJLrpvu9OslbMnvGeiqz94Od2mTuh5P8dGaLMbB//WmS02OMMYXIbyW5MrNdouNJfpi53ffpZbQTi4PU7FPO1290AWsYXUTVBWKM8ZHMnun8hKo6Mffr95JcO738dzrJ95M8luRr2WCRmu7/B0luSvLSMcbXq+olmT2rO7LR+dNcPpjkgxuMten7u6rqN5Lcldn2/r8keUqSFyf5RJIPzJ36WGafvvnfJP+W5O4kP9psXGBvG2OcSpKq+qMkr03yjsx2ky5N8qokfzLddsQaRidRdYGoqvcl+f3M3new6PPJT36seJMxNjr88SSfGGP8cBrj20netBBri+PcPs1lo92vz2xytxcnOTHG+ML8udOW/a8muXe6/som1wP2t5cnef8Y44vT799P8smqekWSX0vyQJJU1TuSvCcbf/XC5zY4Zg2jja9UuLAcSvLUDW7X7nTAMcbJJxajJr+4yfG/S/K7VXV9VV1aVU+bttJvyOwTNcDB9rdJ3llVV1XVU6vq6VX1hsyC5J8Wzr0os/Vu8Xb94qDWMDrVGOPMZ8EeUFWvTPJ7SZ6b2fsQvpbkzjHGPy91YsB5N32twpsye0/Vs5OcTPLVJB8eY3xpmXM7W9awg09UAQA08PIfAEADUQUA0GDXP/135513bvv1xtXV1fMxFWAH1tbWtn2fW2655SD92xvbXsP80yOwd+zwbU9n9T+xnSoAgAaiCgCggagCAGggqgAAGogqAIAGogoAoIGoAgBoIKoAABqIKgCABqIKAKCBqAIAaCCqAAAaiCoAgAaiCgCggagCAGggqgAAGogqAIAGogoAoIGoAgBoIKoAABqIKgCABqIKAKCBqAIAaCCqAAAaiCoAgAaiCgCggagCAGggqgAAGogqAIAGogoAoIGoAgBoIKoAABqIKgCABqIKAKCBqAIAaCCqAAAaiCoAgAaHlj2Bs7G6urrsKQDs2Pr6+rKnAOwCO1UAAA1EFQBAA1EFANBAVAEANBBVAAANRBUAQANRBQDQQFQBADQQVQAADfbFN6ofPnx42VMA2LGjR48uewrALrBTBQDQQFQBADQQVQAADUQVAEADUQUA0EBUAQA0EFUAAA1EFQBAg33x5Z+PP/74sqcAsGMnT55c9hSAXWCnCgCggagCAGggqgAAGogqAIAGogoAoIGoAgBoIKoAABqIKgCABqIKAKDBvvhG9YsvvnjZUwDYsUsuuWTZUwB2gZ0qAIAGogoAoIGoAgBoIKoAABqIKgCABqIKAKCBqAIAaCCqAAAa7Isv/3z00UeXPQVgctllly17CvvOgw8+uOwpAJNjx46dt7HtVAEANBBVAAANRBUAQANRBQDQQFQBADQQVQAADUQVAEADUQUA0EBUAQA02BffqH7kyJFlTwFgxy6//PJlTwHYBXaqAAAaiCoAgAaiCgCggagCAGggqgAAGogqAIAGogoAoIGoAgBosC++/POuu+5a9hSAydra2rKnsO9cc801y54CMBljnLex7VQBADQQVQAADUQVAEADUQUA0EBUAQA0EFUAAA1EFQBAA1EFANBAVAEANBBVAAANRBUAQANRBQDQQFQBADQQVQAADUQVAEADUQUA0EBUAQA0EFUAAA1EFQBAA1EFANBAVAEANBBVAAANRBUAQANRBQDQQFQBADQQVQAADUQVAEADUQUA0EBUAQA0EFUAAA1EFQBAA1EFANBAVAEANBBVAAANRBUAQANRBQDQQFQBADQQVQAADUQVAEADUQUA0EBUAQA0EFUAAA1EFQBAA1EFANBAVAEANBBVAAANRBUAQANRBQDQQFQBADQQVQAADUQVAEADUQUA0EBUAQA0EFUAAA1EFQBAA1EFANBAVAEANBBVAAANRBUAQANRBQDQQFQBADQQVQAADUQVAEADUQUA0EBUAQA0EFUAAA1EFQBAA1EFANBAVAEANBBVAAANRBUAQANRBQDQQFQBADQQVQAADUQVAECDQ8ueALtnde7ntaXNAmBnrr766id/vvfee5c4E9iYnSoAgAaiCgCggagCAGjgPVUHkTdPAfvYzTff/OTPx48fX+JMYHvsVAEANBBVAAANvPx3YJz5NT+vBAJ71fr6+pM/b/aSn69RYK+zUwUA0EBUAQA08PLfQeR1PmAf84k/9is7VQAADUQVAEADL/8dGF7zA/avlZWVZU8BzpmdKgCABqIKAKCBqAIAaCCqAAAaiCoAgAaiCgCggagCAGggqgAAGogqAIAGogoAoIGoAgBoIKoAABqIKgCABqIKAKCBqAIAaCCqAAAaiCoAgAaiCgCggagCAGggqgAAGogqAIAGogoAoIGoAgBoIKoAABqIKgCABqIKAKCBqAIAaCCqAAAaiCoAgAaiCgCggagCAGggqgAAGogqAIAGogoAoIGoAgBoIKoAABqIKgCABqIKAKCBqAIAaCCqAAAaiCoAgAaiCgCggagCAGggqgAAGogqAIAGogoAoIGoAgBoIKoAABqIKgCABqIKAKCBqAIAaCCqAAAaiCoAgAaiCgCggagCAGggqgAAGogqAIAGogoAoIGoAgBoIKoAABqIKgCABqIKAKCBqAIAaHBoty/4yCOP7PYlIUmyurp63q+xtrZ23q+xbA8//PCyp7BU999//7KnwAVqfX39vF9jZWXlvF9j2e67775t3+eqq646q/PsVAEANBBVAAANRBUAQANRBQDQQFQBADSoMcay5wAAsO/ZqQIAaCCqAAAaiCoAgAaiCgCggagCAGggqgAAGogqAIAGogoAoIGoAgBoIKoAABqIKgCABqIKAKCBqAIAaCCqAAAaiCoAgAaiCgCggagCAGggqgAAGogqAIAGogoAoIGoAgBoIKoAABqIKgCABqIKAKCBqAIAaPB/+oBlBA1K15MAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "img1 = combine_observations_multichannel(preprocessed_observations)\n", "img2 = combine_observations_singlechannel(preprocessed_observations)\n", "\n", "plt.figure(figsize=(10, 6))\n", "plt.subplot(121)\n", "plt.title(\"멀티 채널 상태\")\n", "plt.imshow(img1, interpolation=\"nearest\")\n", "plt.axis(\"off\")\n", "plt.subplot(122)\n", "plt.title(\"싱글 채널 상태\")\n", "plt.imshow(img2, interpolation=\"nearest\", cmap=\"gray\")\n", "plt.axis(\"off\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 연습문제 해답" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. to 7." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "부록 A 참조." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 8. BipedalWalker-v2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "*문제: 정책 그래디언트를 사용해 OpenAI 짐의 ‘BypedalWalker-v2’를 훈련시켜보세요*" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import gym" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "env = gym.make(\"BipedalWalker-v2\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "노트: 만약 `BipedalWalker-v2` 환경을 만들 때 \"`module 'Box2D._Box2D' has no attribute 'RAND_LIMIT'`\"와 같은 이슈가 발생하면 다음과 같이 해보세요:\n", "```\n", "$ pip uninstall Box2D-kengz\n", "$ pip install git+https://github.com/pybox2d/pybox2d\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "obs = env.reset()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "img = env.render(mode=\"rgb_array\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.imshow(img)\n", "plt.axis(\"off\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "obs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이 24개의 숫자에 대한 의미는 [온라인 문서](https://github.com/openai/gym/wiki/BipedalWalker-v2)를 참고하세요." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "env.action_space" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "env.action_space.low" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "env.action_space.high" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이는 각 다리의 엉덩이 관절의 토크와 발목 관절 토크를 제어하는 연속적인 4D 행동 공간입니다(-1에서 1까지). 연속적인 행동 공간을 다루기 위한 한 가지 방법은 이를 불연속적으로 나누는 것입니다. 예를 들어, 가능한 토크 값을 3개의 값 -1.0, 0.0, 1.0으로 제한할 수 있습니다. 이렇게 하면 가능한 행동은 $3^4=81$개가 됩니다." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from itertools import product" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "possible_torques = np.array([-1.0, 0.0, 1.0])\n", "possible_actions = np.array(list(product(possible_torques, possible_torques, possible_torques, possible_torques)))\n", "possible_actions.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tf.reset_default_graph()\n", "\n", "# 1. 네트워크 구조를 정의합니다\n", "n_inputs = env.observation_space.shape[0] # == 24\n", "n_hidden = 10\n", "n_outputs = len(possible_actions) # == 625\n", "initializer = tf.variance_scaling_initializer()\n", "\n", "# 2. 신경망을 만듭니다\n", "X = tf.placeholder(tf.float32, shape=[None, n_inputs])\n", "\n", "hidden = tf.layers.dense(X, n_hidden, activation=tf.nn.selu,\n", " kernel_initializer=initializer)\n", "logits = tf.layers.dense(hidden, n_outputs,\n", " kernel_initializer=initializer)\n", "outputs = tf.nn.softmax(logits)\n", "\n", "# 3. 추정 확률에 기초하여 무작위한 행동을 선택합니다\n", "action_index = tf.squeeze(tf.multinomial(logits, num_samples=1), axis=-1)\n", "\n", "# 4. 훈련\n", "learning_rate = 0.01\n", "\n", "y = tf.one_hot(action_index, depth=len(possible_actions))\n", "cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=logits)\n", "optimizer = tf.train.AdamOptimizer(learning_rate)\n", "grads_and_vars = optimizer.compute_gradients(cross_entropy)\n", "gradients = [grad for grad, variable in grads_and_vars]\n", "gradient_placeholders = []\n", "grads_and_vars_feed = []\n", "for grad, variable in grads_and_vars:\n", " gradient_placeholder = tf.placeholder(tf.float32, shape=grad.get_shape())\n", " gradient_placeholders.append(gradient_placeholder)\n", " grads_and_vars_feed.append((gradient_placeholder, variable))\n", "training_op = optimizer.apply_gradients(grads_and_vars_feed)\n", "\n", "init = tf.global_variables_initializer()\n", "saver = tf.train.Saver()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "아직 훈련되지 않았지만 이 정책 네트워크를 실행해 보죠." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def run_bipedal_walker(model_path=None, n_max_steps = 1000):\n", " env = gym.make(\"BipedalWalker-v2\")\n", " frames = []\n", " with tf.Session() as sess:\n", " if model_path is None:\n", " init.run()\n", " else:\n", " saver.restore(sess, model_path)\n", " obs = env.reset()\n", " for step in range(n_max_steps):\n", " img = env.render(mode=\"rgb_array\")\n", " frames.append(img)\n", " action_index_val = action_index.eval(feed_dict={X: obs.reshape(1, n_inputs)})\n", " action = possible_actions[action_index_val]\n", " obs, reward, done, info = env.step(action[0])\n", " if done:\n", " break\n", " env.close()\n", " return frames" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "frames = run_bipedal_walker()\n", "video = plot_animation(frames)\n", "HTML(video.to_html5_video())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "안되네요, 걷지를 못합니다. 그럼 훈련시켜 보죠!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "n_games_per_update = 10\n", "n_max_steps = 1000\n", "n_iterations = 1000\n", "save_iterations = 10\n", "discount_rate = 0.95\n", "\n", "with tf.Session() as sess:\n", " init.run()\n", " for iteration in range(n_iterations):\n", " print(\"\\rIteration: {}/{}\".format(iteration + 1, n_iterations), end=\"\")\n", " all_rewards = []\n", " all_gradients = []\n", " for game in range(n_games_per_update):\n", " current_rewards = []\n", " current_gradients = []\n", " obs = env.reset()\n", " for step in range(n_max_steps):\n", " action_index_val, gradients_val = sess.run([action_index, gradients],\n", " feed_dict={X: obs.reshape(1, n_inputs)})\n", " action = possible_actions[action_index_val]\n", " obs, reward, done, info = env.step(action[0])\n", " current_rewards.append(reward)\n", " current_gradients.append(gradients_val)\n", " if done:\n", " break\n", " all_rewards.append(current_rewards)\n", " all_gradients.append(current_gradients)\n", "\n", " all_rewards = discount_and_normalize_rewards(all_rewards, discount_rate=discount_rate)\n", " feed_dict = {}\n", " for var_index, gradient_placeholder in enumerate(gradient_placeholders):\n", " mean_gradients = np.mean([reward * all_gradients[game_index][step][var_index]\n", " for game_index, rewards in enumerate(all_rewards)\n", " for step, reward in enumerate(rewards)], axis=0)\n", " feed_dict[gradient_placeholder] = mean_gradients\n", " sess.run(training_op, feed_dict=feed_dict)\n", " if iteration % save_iterations == 0:\n", " saver.save(sess, \"./my_bipedal_walker_pg.ckpt\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "frames = run_bipedal_walker(\"./my_bipedal_walker_pg.ckpt\")\n", "video = plot_animation(frames)\n", "HTML(video.to_html5_video())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "최상의 결과는 아니지만 적어도 직립해서 (느리게) 오른쪽으로 이동합니다. 이 문제에 대한 더 좋은 방법은 액터-크리틱(actor-critic) 알고리즘을 사용하는 것입니다. 이 방법은 행동 공간을 이산화할 필요가 없으므로 훨씬 빠르게 수렴합니다. 이에 대한 더 자세한 내용은 Yash Patel가 쓴 멋진 [블로그 포스트](https://towardsdatascience.com/reinforcement-learning-w-keras-openai-actor-critic-models-f084612cfd69)를 참고하세요." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 9.\n", "**Comming soon**" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" }, "nav_menu": {}, "toc": { "navigate_menu": true, "number_sections": true, "sideBar": true, "threshold": 6, "toc_cell": false, "toc_section_display": "block", "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 1 }