{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPython 3.6.5\n", "IPython 6.4.0\n", "\n", "numpy 1.14.3\n", "sklearn 0.19.1\n", "scipy 1.1.0\n", "matplotlib 2.2.2\n", "tensorflow 1.8.0\n", "gym 0.10.5\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -v -p numpy,sklearn,scipy,matplotlib,tensorflow,gym" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**16장 – 강화 학습**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "_이 노트북은 15장에 있는 모든 샘플 코드와 연습문제 해답을 가지고 있습니다._" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 설정" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "파이썬 2와 3을 모두 지원합니다. 공통 모듈을 임포트하고 맷플롯립 그림이 노트북 안에 포함되도록 설정하고 생성한 그림을 저장하기 위한 함수를 준비합니다:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# 파이썬 2와 파이썬 3 지원\n", "from __future__ import division, print_function, unicode_literals\n", "\n", "# 공통\n", "import numpy as np\n", "import os\n", "import sys\n", "\n", "# 일관된 출력을 위해 유사난수 초기화\n", "def reset_graph(seed=42):\n", " tf.reset_default_graph()\n", " tf.set_random_seed(seed)\n", " np.random.seed(seed)\n", "\n", "# 맷플롯립 설정\n", "from IPython.display import HTML\n", "import matplotlib\n", "import matplotlib.animation as animation\n", "import matplotlib.pyplot as plt\n", "plt.rcParams['axes.labelsize'] = 14\n", "plt.rcParams['xtick.labelsize'] = 12\n", "plt.rcParams['ytick.labelsize'] = 12\n", "\n", "# 한글출력\n", "plt.rcParams['font.family'] = 'NanumBarunGothic'\n", "plt.rcParams['axes.unicode_minus'] = False\n", "\n", "# 그림을 저장할 폴더\n", "PROJECT_ROOT_DIR = \".\"\n", "CHAPTER_ID = \"rl\"\n", "\n", "def save_fig(fig_id, tight_layout=True):\n", " path = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID, fig_id + \".png\")\n", " if tight_layout:\n", " plt.tight_layout()\n", " plt.savefig(path, format='png', dpi=300)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# OpenAI 짐(gym)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이 노트북에서는 강화 학습 알고리즘을 개발하고 비교할 수 있는 훌륭한 도구인 [OpenAI 짐(gym)](https://gym.openai.com/)을 사용합니다. 짐은 *에이전트*가 학습할 수 있는 많은 환경을 제공합니다. `gym`을 임포트해 보죠:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import gym" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "그다음 MsPacman 환경 버전 0을 로드합니다." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "env = gym.make('MsPacman-v0')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`reset()` 메서드를 호출하여 환경을 초기화합니다. 이 메서드는 하나의 관측을 반환합니다:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "obs = env.reset()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "관측은 환경마다 다릅니다. 여기에서는 [width, height, channels] 크기의 3D 넘파이 배열로 저장되어 있는 RGB 이미지입니다(채널은 3개로 빨강, 초록, 파랑입니다). 잠시 후에 보겠지만 다른 환경에서는 다른 오브젝트가 반환될 수 있습니다." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(210, 160, 3)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "obs.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "환경은 `render()` 메서드를 사용하여 화면에 나타낼 수 있고 렌더링 모드를 고를 수 있습니다(렌더링 옵션은 환경마다 다릅니다). 이 경우에는 `mode=\"rgb_array\"`로 지정해서 넘파이 배열로 환경에 대한 이미지를 받겠습니다:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "img = env.render(mode=\"rgb_array\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이미지를 그려보죠:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAVUAAAGoCAYAAAD2AHbpAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADBpJREFUeJzt3b+u29YdB3Cx6EsUrt/A6BCgQJYUWQJ0yNYlk18hd3Mfod7uM2Tyks1DgSxGsxgokKHI0tk2+hjsYNxrWZEoSvxS5w8/n01Xonl4JH3949EhzzCO4w6AjN+VbgBAT4QqQJBQBQgSqgBBQhUgSKgCBAlVgCChChAkVAGCfl+6AfuGYXB5F1ClcRyHOa9TqQIEVVWpvv/++9JNAFhEpQoQVFWlOuWPP/6hdBOq8f5v/zv5nH7iGJ+Zeab6aS6VKkCQUAUIEqoAQUIVIEioAgQJVYAgoQoQJFQBgpqZ/H/OksnN1257bqLwWtsuoZ/KtqlE/y+hny6nUgUIEqoAQUIVIGgYx3ruC/3h7u5kY9z04ZPSY0a0x2dmnql+enJ/7ybVALcmVAGChCpAkFAFCBKqAEFCFSBIqAIECVWAoG5uqLJEYgXFY3qbVL1WP/GRz8s8tfeTShUgqJtK1WV4QA1UqgBBQhUgSKgCBAlVgCChChAkVAGCuplS1ZoWV1OlrNZWU90qlSpAkFAFCBKqAEHGVAspNRZlDKxdJd47n5fLqVQBgrqpVP2PCtRApQoQJFQBgoQqQJBQBQgSqgBBQhUgqJspVUu0Nh3LhQNcyoUDt6NSBQgSqgBBQhUgSKgCBAlVgCChChAkVAGChCpAUDeT/5esNHnttktWmlxr27WOdcm2pfppivdu3rZb66cElSpAkFAFCBKqAEHDOI6l2/Dow93dycZs9eYMwO1Mjcc+ub8f5vwbKlWAIKEKECRUAYKEKkCQUAUIEqoAQUIVIEioAgQ1c0OVczdnmLLWhQNr3exj6X6v1Vp7a9VaP/p+ZKlUAYKEKkCQUAUIEqoAQUIVIEioAgQJVYAgoQoQ1Mzk/zWVWH2x1MToJUqvUtmD1vqwxgn8NfbTPpUqQJBQBQgSqgBBxlR3ZcZoah8XOqbFNtemtT4s1d7W+mmfShUgSKgCBAlVgCChChAkVAGChCpAkFAFCBKqAEFVTf5vecLvLbXWT621t1b6cZ61+mm8n/c6lSpAUFWVKizx9tmbyee//PXrm7SDbVOpAgSpVGneuQr18HUqVtakUgUIUqnStLfP3vym8jxVkT78/dg2kKJSBQgSqjTv7bM3R8dVT/0d1iRUAYK6GVNdsqRta0tUlzrWa7dd61h3u93u/X9P/5o/NW66Zpu8d8u19n3e102osl2nTvGd+lOC03+AoGEcx9JteDQ8fVlPY2YodWrEJ9dWo6ZUra+378f47sUw53UqVYAgoQoQJFQBgoQqQJBQBQgSqjRt/1f8L3/9+qLHsAahChDkiiqad1h9XvoYklSqAEFVVarnrsC4VqkrN9Y6nrWUunlMb/RjWaX7SaUKECRUAYKEKkCQUAUIEqoAQUIVIEioAgQJVYCgqib/l1J69cVWrNVPrfXxksnlPmvztNxPKlWAIKEKEOT0f1f/6UQt9NNy+nCelvtJpQoQJFQBgoQqQJBQBQgSqgBBQhUgSKgCBAlVgCChChBU1RVVS66iKLGCYstXfdAen7fllvTheD/vdSpVgCChChAkVAGChCpAkFAFCBKqAEFCFSBIqAIEVTX5v5RrV248d8FBiW3PTW4utS0ftfbetfgZL02lChAkVAGChCpA0DCOY+k2PBqevry6MS2PwfRuydjaJV4/fzX5/Lc/fBfZz62Oh+uslQXjuxfDnNepVAGC/PpP885VqIevS1WscIxQpWmvn7/6TUieCs+Hvx/bBlKc/gMECVWa9/r5q6NDAKf+DmsSqgBBxlRp3qnxUeOmlKBSBQiqqlItsSLqEq2195xWJ62fGjftbTzV5225W/ShShUgSKgCBAlVgCChChAkVAGChCpAkFClafsT/L/94buLHsMahCpAUFWT/0uxasA8tfbTYfV56eNbqrUPa9NyP6lUAYKEKkCQUAUIMqa6q3+MphZr9VNvNwqZ4rM2T8v9pFIFCBKqAEFCFSBIqAIECVWAIKEKECRUAYKEKkDQMI5j6TY8Gp6+rKcxM5ybtN7yBGZYqrfvx/juxTDndSpVgCCXqdKN//zjX0f//qe//+Wz5x8ewxpUqgBBKlWad6pCPXx+v2JVrbIWlSpAkEqVbh2OpRpT5RaEKs07DMnDEIVbcvoPENRNpbpk9cUSKzcumRhd6liv3XatYz217anKtWSb5mzb23u3RGvf530qVYAgoQoQJFQBgtxQZYHebhjRi3O/+ptSdRu9fT/cUAWggG5+/We7zEelJipVgCCVKpthLJVbUKkCBKlUad7hjVNOPQ+3IFTphvCkBk7/AYKqqlTPTRa+VmuTjM8pdcOIGm88s5YabxTS4n5LKPWZeaBSBQgSqgBBQhUgSKgCBAlVgCChChAkVAGChCpAUFWT/0upcVJ7T5PPe7sDvPeu7H5r/7yoVAGChCpAkFAFCDKmuiszRlNqXGhLx7oW712/+01QqQIECVWAIKEKECRUAYKEKkCQUAUIEqoAQUIVIKiZyf81TgausU290cftavHCgcRKrCpVgCChChAkVAGChCpAkFAFCBKqAEFCFSBIqAIECVWAoGauqDpnyZK21267ZPneLS39W2M/1dimc7b03pX4PqeoVAGChCpA0DCOY+k2PPpwd3eyMW6s8Ump05vSp1U98N7VbaqfntzfD3P+DZUqQJBQBQgSqgBBQhUgSKgCBAlVgCChChAkVAGChCpAUDc3VFkisSztMb1dqVLieNZ6b85Z61h7+0xM2er3SqUKECRUAYKEKkCQUAUIEqoAQUIVIEioAgQJVYCgbib/t7ZcRIsrcpZYdbZGW1oltMX3rnQWqFQBgoQqQJBQBQjqZky1xrGdKaXau2S/127b2ntzTok+XLptS/tcqnSbVaoAQUIVIEioAgQJVYAgoQoQJFQBgoQqQJBQBQjqZvI/H/W2+mhvSr0/3I5KFSCom0q19O2+AHY7lSpAlFAFCBKqAEFCFSBIqAIECVWAoG6mVLWmxVUqr9XbsfZ2PFO2dKwpKlWAIKEKECRUAYKMqRaypbGo3o61t+OZsqVjTVGpAgR1U6n6HxWogUoVIEioAgQJVYAgoQoQJFQBgoQqQFA3U6qWaG06VmvtXaK3Y+3teKZs6Vj3qVQBgoQqQJBQBQgSqgBBQhUgSKgCBAlVgCChChDUzeT/qVUfz01CvnbbJStNrrXtWse6ZNtS/TTFezdv2631U4JKFSBIqAIECVWAoGEcx9JtePTh7u5kY7Z6cwbgdqbGY5/c3w9z/g2VKkCQUAUIEqoAQUIVIEioAgQJVYAgoQoQJFQBgpq5ocq5mzMA1EClChAkVAGChCpAUFU3VBmGoZ7GAOwZx9ENVQBurZlf/xN++unPu91ut/vmm39/9njfw3Pp/U7tc639wlr++cUXnz3+6y+/FGpJfTZx+j8nTA8lQm5/v3P2mdovrOUhTA9DdD9kew1Yp/8ABWzi9P9UhXpJ5Zra79r7hDVNVaEPz52qZrdCpQoQtIlK9ZRjFeQt93nL/cKtbL1iVakCBG26Ui1RJapM6cmxavRwutXWqFQBgjZdqQLXuaRC3drYqkoVIGjTlapf/2GZrVWhc6hUAYI2ce3/g0uuYkpWkKX2C2lTv+wfzk89/Hvr5l77v6nT/1KXiF5yQxVoXS8hei2n/wBBmzr9P3Sr+6me2+ct9gtJvZ7iT3HrP4ACNl2pAsylUgUoQKgCBAlVgCChChAkVAGChCpAkFAFCNrUtf9wKz/ff3X071/d/XzjlnBrJv9D2H6gPoToYcgK1/aY/A9QgEoVQo5VqFOvmXod9VGpAhQgVAGChCpAkFAFCBKqAEF+/Ycw81T75Nd/gAJUqrACl6n2Z26lKlQBZnD6D1CAUAUIEqoAQUIVIEioAgQJVYAgoQoQJFQBgoQqQJBQBQgSqgBBQhUgSKgCBAlVgCChChAkVAGChCpAkFAFCBKqAEFCFSBIqAIECVWAIKEKECRUAYKEKkCQUAUIEqoAQUIVIEioAgQN4ziWbgNAN1SqAEFCFSBIqAIECVWAIKEKECRUAYKEKkCQUAUIEqoAQUIVIEioAgQJVYAgoQoQJFQBgoQqQJBQBQgSqgBBQhUgSKgCBAlVgCChChAkVAGChCpAkFAFCBKqAEH/Bx83TBg0+8FmAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(5,6))\n", "plt.imshow(img)\n", "plt.axis(\"off\")\n", "save_fig(\"MsPacman\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "1980년대로 돌아오신 걸 환영합니다! :)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이 환경에서는 렌더링된 이미지가 관측과 동일합니다(하지만 많은 경우에 그렇지 않습니다):" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(img == obs).all()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "환경을 그리기 위한 유틸리티 함수를 만들겠습니다:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "def plot_environment(env, figsize=(5,6)):\n", " plt.figure(figsize=figsize)\n", " img = env.render(mode=\"rgb_array\")\n", " plt.imshow(img)\n", " plt.axis(\"off\")\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "환경을 어떻게 다루는지 보겠습니다. 에이전트는 \"행동 공간\"(가능한 행동의 모음)에서 하나의 행동을 선택합니다. 이 환경의 액션 공간을 다음과 같습니다:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Discrete(9)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "env.action_space" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`Discrete(9)`는 가능한 행동이 정수 0에서부터 8까지있다는 의미입니다. 이는 조이스틱의 9개의 위치(0=중앙, 1=위, 2=오른쪽, 3=왼쪽, 4=아래, 5=오른쪽위, 6=왼쪽위, 7=오른쪽아래, 8=왼쪽아래)에 해당합니다." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "그다음 환경에게 플레이할 행동을 알려주고 게임의 다음 단계를 진행시킵니다. 왼쪽으로 110번을 진행하고 왼쪽아래로 40번을 진행해 보겠습니다:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "env.reset()\n", "for step in range(110):\n", " env.step(3) #왼쪽\n", "for step in range(40):\n", " env.step(8) #왼쪽아래" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "어디에 있을까요?" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAASYAAAFrCAYAAAB8AiRoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAACtJJREFUeJzt3c2K21gaBmB7mJsYMn0LWQR6E0gWgUDvB3IbXbvOJUzv6joKZt8QyGICtWnIYm6hO8xlaDbjQnFJsizr5z1Hz7NyXDLnSHZeviN/ko9N0xwAkvxl6wkAnBNMQBzBBMQRTEAcwQTEEUxAHMEExBFMQBzBBMQRTECcv249gbbj8ej6GKhY0zTHMdtFBdOfP/+89RSAAFHB1Ofv//rb1lNY1J//+G/n87Xv9x7s9b3t2++xnGMC4ggmII5gAuIIJiCOYALiCCYgjmAC4hTRx9RnqFeir0/k2r6SNcaYYq79GJrTGmNsOfaW+7f0nKa8Zs4xbqViAuIIJiDOMel35b7d3XVOZq/t+7Xv9x7s9b3t2+8X9/ejLuJVMQFxBBMQRzABcQQTEEcwAXEEExBHMAFxir4kZYpbb/nZVlIvypz7Xbu9vq8uSQEYIJiAOIIJiCOYgDiCCYgjmIA4ggmIs7s+pjUk3oKV23lf16NiAuIIJiCOpdwC1ii791baJ9jr+7rFnFRMQBzBBMQRTEAcwQTEEUxAHMEExNldu8Bev45N3O89WPq41/q+qpiAOIIJiCOYgDiCCYgjmIA4ggmII5iAOEX3MU35hdBr7xC45RhDPSq1jLHl2DWMkfp/4FYqJiCOYALiHJum2XoOT77d3XVOpta2e6hV3/Lvxf39cczrVUxAHMEExBFMQBzBBMQRTEAcwQTEEUxAnCIuSRlqie8zV+/TGu34U/avz5w9X3POayupx2Ovn8+xVExAHMEExBFMQBzBBMQRTEAcwQTEEUxAnCL6mOY05+1ftxzjWolz2lriMdnr5/OcigmII5iAOLtbyq1RriaVxCeJc9pa4jHZ6+fznIoJiCOYgDiCCYgjmIA4ggmII5iAOIIJiBPVx1RCf8USUvc7dV5b2evxmHO/m/tx26mYgDiCCYgjmIA4ggmII5iAOIIJiCOYgDhRfUzXmvLzyIm3Lp0yp1rG2HLsWsa41lxzGnrNrVRMQBzBBMQ5Nk2z9RyeHH/4NWcy/7dFGQtjlfb5bP745ThmOxUTEEcwAXEEExBHMAFxBBMQRzABcQQTECfqkpShnoxrrdHDMed857LGpQslcTxut8V+q5iAOIIJiCOYgDiCCYgjmIA4ggmIE9UusIY17hCYaM79TjxWU77S9ln4XtJ+q5iAOIIJiLO7pVxSubqmve73kL0ekxL2W8UExBFMQBzBBMTZ3TkmyvDuw+vO5z8/PK48E7YgmIjSF0jnfxdQdbOUA+IIJmK0q6Xziujzw+N3z12qrChb1FLu2v6KNe6sV0LPR61OQdQVWFsEk8/Cc9cek+Z+3HYqJiBOVMUEJ+cVkaXbvqiYgDiCCYgjmIh06Vs56uYcE1Ha37p1tQds+a0c61ExAXF2VzFde1vRoV6pa18z1/ZrjbGFdiXUVRXNWSnV8D6t8fncgooJiCOYiNI+yd33+PRv6nVsmmbrOTw5/vDrVZOZUsYyXgkl/0lJcy3NnP/Pmj9+OY7ZTsUExBFMQBzBBMQRTEAcwQTEEUxAHMEExIm6JGWNW+VeK3FOQ/TtLMdn4bmljomKCYgjmIA4ggmII5iAOIIJiCOYgDiCCYgT1ce0hr3et2ev+z1kr8ekhP1WMQFxBBMQZ3dLuaRydU1z7ndpl2b08VnIpWIC4ggmII5gAuIIJiCOYALiCCYgjmAC4hT9E+Fr8DPkJCvt8+knwoFiCSYgjmAC4ggmII5gAuIIJiBO0bc9mfJVaQl37+P692nK+1rLGNeaa05Dr7mVigmII5iAODq/Lyits5Z9Ke3zqfMbKJZgAuIIJiCOYALiCCYgjmAC4ggmIE7UJSlz/sJrYg9Hn8TLEOYcY061HJMaPp9LUjEBcQQTEEcwAXEEExBHMAFxBBMQRzABcaL6mNZQQ1/J4aAH51wtx6OWMW6lYgLiCCYgzu6Wcknl6i3W2I+SjlUtx6OWMW6lYgLiCCYgjmAC4ggmII5gAuIIJiCOYALiFNHHtGXfRQk9Hykcq/Wl9j3dejteFRMQRzABcQQTEEcwAXGKOPldmncfXnc+//nhceWZQJkE04z6Aun87wIKhlnKAXFUTDNpV0ufHx6f/bu9zbsPr2+umhJvwTrnnLYce645rTHGUL/QXMdqyhi3UjEt5PPD47Pw6XoOeE4wAXEs5RZwfhL80knxKRIvRZhzTluOXdIYU+aUeGzPqZiAOIIJiCOYFuCkN9xGMAFxnPyeUbtfqV0hnXd8L3EyHGoimGbUDpyu8BFIMI6lHBBndxXTrbf8bDvv7+hayp0/bm+3pjV6UeY8tn227JVKtOTneUu7C6altMNmzGOgn6UcEEcwAXEEExBHMAFxBBMQx7dyCyjpTodDc1pjP+ayxv4lvq9rcAdLgINgAgIJJiCOc0wL2OstWLdUyy1mE4+5W+sCHAQTEEgwAXEEExBHMAFxBBMQR7tAsNLuCFmaNY5vDVySAnAQTEAgwQTEEUxAHMEExBFMQBzBBMTRx7SAxNujTlHSfpQ01yG17MetVExAHMEExLGUW0AtZXdJ+1HSXIck7oc7WAIcBBMQSDABcQQTEEcwAXEEExBHMAFxdtfHpE9kOaXtR2nz7VLDPnRRMQFxiq6Y/vPPfz977uXHt09/Oz0GylJsMHWFUvv5lx/ffvcYKIelHBCnuIrp0vKtvZ1KCcpUXcX08uNbgQSFqy6YgPIVt5S7pO+k+Mm1ty6d8vPIc40x1KNSyxhbjl3DGFt+Podec6uiK6bzZZslHNSh6GAC6nRsmmbrOTz5dnfXOZl2udheqnVVSJf+Diyvb/n34v7+OOb1xVZMfaHjWzkoX7HBBNSruG/lxlZDqiYol4oJiCOYgDiCCYgjmIA4RZz8HmqJB+qjYgLiCCYgTtQlKcfjMWcywOyaphl1SUoR55im+vTpx8PhcDi8f//70+O29+9/L2IM9u23V6+eHv/09euGM1mPpRwQp9ql3KdPPz5VK12VTNvUqmbsGKompjhVSu0qqfTqafdLufOl1diQShuD/eoLntPzv716VWQ4jWEpB8SptmI611XZlDgGnPz09Wvncq8GKiYgzm4qpjUqGFUSS7l0Irw2uwkmKNX5Se6uQKptSWcpB8TZTcXk5Dclq60iuqTaBsvDYXw/0S0hMmYMIcUUfeeQ2n1M58+lG9tgaSkHxKl6KbdGJ7Zub7ZSSpU0RdVLubY1rvx3dwGWUOKSrY+lHFCs3VRMwPZUTECxBBMQRzABcQQTEEcwAXGqbrCkLF/u3zw9fnP35eI2Xa55Xd+2bE/FBMTRx8TmxlQz59sM/b39t67XjanMWMbufyWFcpzC4dIy7Vp9AdQe77SNgMpiKQfEUTFRhPOqqq+6UvnUQTBRhLHnmL7cvxFOFbCUA+KomKhW+xu49klu38rlUzEBcfQxsZmx7QFDfUlD214aS7W0vrF9TIIJWI0bxQHFEkxAHMEExBFMQBzBBMQRTEAcwQTEEUxAHMEExBFMQBzBBMQRTEAcwQTEEUxAHMEExBFMQBzBBMQRTEAcwQTEEUxAHMEExBFMQBzBBMQRTECcqB+8BDgcVExAIMEExBFMQBzBBMQRTEAcwQTEEUxAHMEExBFMQBzBBMQRTEAcwQTEEUxAHMEExBFMQBzBBMQRTEAcwQTEEUxAHMEExBFMQBzBBMQRTEAcwQTEEUxAnP8BDtpNmXLhFT4AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_environment(env)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "사실 `step()` 함수는 여러 개의 중요한 객체를 반환해 줍니다:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "obs, reward, done, info = env.step(0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "앞서 본 것처럼 관측은 보이는 환경을 설명합니다. 여기서는 210x160 RGB 이미지입니다:" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(210, 160, 3)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "obs.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "환경은 마지막 스텝에서 받을 수 있는 보상을 알려 줍니다:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.0" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reward" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "게임이 종료되면 환경은 `done=True`를 반환합니다:" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "False" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "done" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "마지막으로 `info`는 환경의 내부 상태에 관한 추가 정보를 제공하는 딕셔너리입니다. 디버깅에는 유용하지만 에이전트는 학습을 위해서 이 정보를 사용하면 안됩니다(학습이 아니고 속이는 셈이므로)." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'ale.lives': 3}" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "info" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "10번의 스텝마다 랜덤한 방향을 선택하는 식으로 전체 게임(3개의 팩맨)을 플레이하고 각 프레임을 저장해 보겠습니다:" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "frames = []\n", "\n", "n_max_steps = 1000\n", "n_change_steps = 10\n", "\n", "obs = env.reset()\n", "for step in range(n_max_steps):\n", " img = env.render(mode=\"rgb_array\")\n", " frames.append(img)\n", " if step % n_change_steps == 0:\n", " action = env.action_space.sample() # play randomly\n", " obs, reward, done, info = env.step(action)\n", " if done:\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이제 애니메이션으로 한번 보죠:" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "def update_scene(num, frames, patch):\n", " plt.close() # 이전 그래프를 닫지 않으면 두 개의 그래프가 출력되는 matplotlib의 버그로 보입니다.\n", " patch.set_data(frames[num])\n", " return patch,\n", "\n", "def plot_animation(frames, figsize=(5,6), repeat=False, interval=40):\n", " fig = plt.figure(figsize=figsize)\n", " patch = plt.imshow(frames[0])\n", " plt.axis('off')\n", " return animation.FuncAnimation(fig, update_scene, fargs=(frames, patch), \n", " frames=len(frames), repeat=repeat, interval=interval)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "video = plot_animation(frames)\n", "HTML(video.to_html5_video()) # HTML5 동영상으로 만들어 줍니다" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "환경을 더 이상 사용하지 않으면 환경을 종료하여 자원을 반납합니다:" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "env.close()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "첫 번째 에이전트를 학습시키기 위해 간단한 Cart-Pole 환경을 사용하겠습니다." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 간단한 Cart-Pole 환경" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Cart-Pole은 아주 간단한 환경으로 왼쪽이나 오른쪽으로 움직일 수 있는 카트와 카트 위에 수직으로 서 있는 막대로 구성되어 있습니다. 에이전트는 카트를 왼쪽이나 오른쪽으로 움직여서 막대가 넘어지지 않도록 유지시켜야 합니다." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n" ] } ], "source": [ "env = gym.make(\"CartPole-v0\")" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "obs = env.reset()" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([-0.01028308, 0.03910706, 0.03149887, 0.00700671])" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "obs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "관측은 4개의 부동소수로 구성된 1D 넘파이 배열입니다. 각각 카트의 수평 위치, 속도, 막대의 각도(0=수직), 각속도를 나타냅니다. 이 환경을 렌더링하려면 먼저 몇 가지 이슈를 해결해야 합니다." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 렌더링 이슈 해결하기" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "일부 환경(Cart-Pole을 포함하여)은 `rgb_array` 모드를 설정하더라도 별도의 창을 띄우기 위해 디스플레이 접근이 필수적입니다. 일반적으로 이 창을 무시하면 됩니다. 주피터가 헤드리스(headless) 서버로 (즉 스크린이 없이) 실행중이면 예외가 발생합니다. 이를 피하는 한가지 방법은 Xvfb 같은 가짜 X 서버를 설치하는 것입니다. `xvfb-run` 명령을 사용해 주피터를 실행합니다:\n", "\n", " $ xvfb-run -s \"-screen 0 1400x900x24\" jupyter notebook\n", " \n", "주피터가 헤드리스 서버로 실행 중이지만 Xvfb를 설치하기 번거롭다면 Cart-Pole에 대해서는 다음 렌더링 함수를 사용할 수 있습니다:" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "from PIL import Image, ImageDraw\n", "\n", "try:\n", " from pyglet.gl import gl_info\n", " openai_cart_pole_rendering = True # 문제없음, OpenAI 짐의 렌더링 함수를 사용합니다\n", "except Exception:\n", " openai_cart_pole_rendering = False # 가능한 X 서버가 없다면, 자체 렌더링 함수를 사용합니다\n", "\n", "def render_cart_pole(env, obs):\n", " if openai_cart_pole_rendering:\n", " # OpenAI 짐의 렌더링 함수를 사용합니다\n", " return env.render(mode=\"rgb_array\")\n", " else:\n", " # Cart-Pole 환경을 위한 렌더링 (OpenAI 짐이 처리할 수 없는 경우)\n", " img_w = 600\n", " img_h = 400\n", " cart_w = img_w // 12\n", " cart_h = img_h // 15\n", " pole_len = img_h // 3.5\n", " pole_w = img_w // 80 + 1\n", " x_width = 2\n", " max_ang = 0.2\n", " bg_col = (255, 255, 255)\n", " cart_col = 0x000000 # 파랑 초록 빨강\n", " pole_col = 0x669acc # 파랑 초록 빨강\n", "\n", " pos, vel, ang, ang_vel = obs\n", " img = Image.new('RGB', (img_w, img_h), bg_col)\n", " draw = ImageDraw.Draw(img)\n", " cart_x = pos * img_w // x_width + img_w // x_width\n", " cart_y = img_h * 95 // 100\n", " top_pole_x = cart_x + pole_len * np.sin(ang)\n", " top_pole_y = cart_y - cart_h // 2 - pole_len * np.cos(ang)\n", " draw.line((0, cart_y, img_w, cart_y), fill=0)\n", " draw.rectangle((cart_x - cart_w // 2, cart_y - cart_h // 2, cart_x + cart_w // 2, cart_y + cart_h // 2), fill=cart_col) # draw cart\n", " draw.line((cart_x, cart_y - cart_h // 2, top_pole_x, top_pole_y), fill=pole_col, width=pole_w) # draw pole\n", " return np.array(img)\n", "\n", "def plot_cart_pole(env, obs):\n", " img = render_cart_pole(env, obs)\n", " plt.imshow(img)\n", " plt.axis(\"off\")\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXsAAAD/CAYAAAAQaHZxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAABD1JREFUeJzt3cFtGlEUQNFMRBVuw3W4DVyTacN1uI20MVmYhYMtOxKD//jfcyQWsBi9BVyeRl+wrOv6C4C5/R49AAC3J/YAAWIPECD2AAFiDxAg9gABYg8QIPYAAWIPECD2AAGH0QNc8NsNAO8t117AZg8QIPYAAWIPECD2AAFiDxAg9gABYg8QIPYAAWIPECD2AAFiDxAg9gABYg8QIPYAAWIPECD2AAFiDxAg9gABYg8QIPYAAWIPECD2AAFiDxAg9gABYg8QIPYAAWIPECD2AAFiDxAg9gABYg8QIPYAAWIPECD2AAFiDxAg9gABYg8QIPYAAWIPECD2AAFiDxAg9gABYg8QIPYAAWIPECD2AAFiDxAg9gABYg8QIPYAAWIPECD2AAFiDxAg9gABYg8QIPYAAWIPECD2AAFiDxAg9gABYg8QIPYAAWIPECD2AAFiDxAg9gABYg8QIPYAAWIPECD2AAFiDxAg9gABYg8QIPYAAWIPECD2AAFiDxAg9gABYg8QIPYAAWIPECD2AAFiDxAg9gABYg8QIPYAAWIPECD2AAFiDxAg9gABYg8QIPYAAWIPEHAYPQDszcvp8d1r98enAZPAdmz2AAFiDxAg9gABYg8QIPYAAWIPECD2AAFiDxAg9gABYg8QIPYAAWIPECD2AAFiDxAg9gABYg8QIPbwxkd/XAIzEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsYcv3B+fRo8AVxN7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHuAALEHCBB7gACxBwgQe4AAsQcIEHs4ezk9jh4BbkbsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7Jnasiz//bjlNWA0sQcIOIweAPbk+c/xn+cPd6dBk8C2bPZwdhl6mInYwyd8ATALsYdPuI3DLMQezoSdmS3ruo6e4a1dDcPP993HIXf2eWIeV7+Rd3UaxzllfjrvYW5hiyViV7G3FbE1mz28cs8eIEDsAQLEHiBA7AECxB4gQOwBAsQeIGBX5+xha869wyubPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AwGH0ABeW0QMAzMhmDxAg9gABYg8QIPYAAWIPECD2AAFiDxAg9gABYg8QIPYAAWIPECD2AAFiDxAg9gABYg8QIPYAAWIPECD2AAFiDxAg9gABYg8QIPYAAWIPECD2AAFiDxDwF8mDKhPNBUEkAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_cart_pole(env, obs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "행동 공간을 확인해 보겠습니다:" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Discrete(2)" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "env.action_space" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "네 딱 두 개의 행동이 있네요. 왼쪽이나 오른쪽 방향으로 가속합니다. 막대가 넘어지기 전까지 카트를 왼쪽으로 밀어보죠:" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "obs = env.reset()\n", "while True:\n", " obs, reward, done, info = env.step(0)\n", " if done:\n", " break" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaIAAAEYCAYAAAAeWvJ8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAABQ5JREFUeJzt3cFx2lAUQFErQxVpI7SRNqAOlwFtpA2njbRBFlkkY4xtgdAV0TkzbFgwb4Hm8uF/MZxOpycAqHypBwBg3YQIgJQQAZASIgBSQgRASogASAkRACkhAiAlRACkhAiA1KYe4BX3GwJ4HMMUL2JFBEBKiABICREAKSECICVEAKSECICUEAGQEiIAUkIEQEqIAEgJEQApIQIgJUQApIQIgJQQAZASIgBSQgRASogASAkRACkhAiAlRACkhAiAlBABkBIiAFJCBEBKiABICREAKSECICVEAKSECICUEAGQEiIAUkIEQEqIAEgJEQApIQIgJUQApIQIgJQQAZASIgBSQgRASogASAkRACkhAiAlRACkhAiAlBABkBIiAFJCBEBKiABICREAKSECICVEAKSECICUEAGQEiIAUkIEQEqIAEgJEQApIQIgJUQApIQIgJQQAZASIgBSQgRASogASAkRACkhAiAlRACkhAiAlBABkBIiAFJCBEBKiABICREAKSECICVEAKSECICUEAGQEiIAUkIEQEqIAEgJEQApIQIgJUQApIQIgJQQAZASIgBSQgRASogASAkRACkhAiAlRACkhAiAlBABkBIiAFJCBEBKiABICREAKSECICVEAKSECICUEAGQEiIAUkIEQGpTDwBr8/O4f/P5b7vDzJPAMlgRAZASIgBSQgRASogASAkRACkhghnZMQfnhAiAlBDBTC6thmDthAhivpZj7YQIgJQQAZASIgBSQgRASohgBs4PwWVCBEBKiODOnB+C9wkRACkhAiAlRBCxUQH+ECIAUkIEQEqI4I6cH4KPCREAKSECICVEAKSECICUEMGd2KgAnyNEAKSECO7AjU7h84QIgJQQAZASIgBSQgQzsmMOzgkRACkhgok5PwTjCBEAKSECICVEAKSECICUEMGE3NoHxhMimIEdc3CZEMFErIbgOkIEd2Y1BO8TIgBSQgRASogASAkRACkhggnYMQfXEyK4Izvm4GNCBEBKiABICREAKSGCG/lHVriNEAGQEiIAUkIEN/C1HNxOiABICREAKSECICVEAKSECK5kowJMQ4gASAkRACkhAiAlRACkhAiu4B9ZYTpCxOoNwzD6ccl2f5zkdWBNhAiA1KYeAB7Ny2H39OPX7uz55+dtMA08PisiGOmtCAHXEyIAUkIEQEqIYKTvX4/1CPBfGU6nUz3DvxY1DOtwyzbql8Pf34u2+/GBWtj1B2NNcgZBiFi98jzPwq4/GGuSi2dR27cd8GNtvOd5ZFN9kFpUiHw6pGBFBC2bFQBICREAKSECICVEAKSECICUEAGQEiIAUos6RwQFZ3mgZUUEQEqIAEgJEQApIQIgJUQApIQIgJQQAZASIgBSQgRASogASAkRACkhAiAlRACkhAiAlBABkBIiAFJCBEBKiABICREAKSECICVEAKSECICUEAGQEiIAUkIEQEqIAEgJEQApIQIgJUQApIQIgJQQAZASIgBSQgRASogASAkRACkhAiAlRACkhAiAlBABkBIiAFJCBEBKiABICREAKSECICVEAKSECICUEAGQEiIAUkIEQEqIAEgJEQCpTT3AK0M9AADzsiICICVEAKSECICUEAGQEiIAUkIEQEqIAEgJEQApIQIgJUQApIQIgJQQAZASIgBSQgRASogASAkRACkhAiAlRACkhAiAlBABkBIiAFJCBEBKiABICREAKSECIPUbeNJbfH1bP+8AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "img = render_cart_pole(env, obs)\n", "plt.imshow(img)\n", "plt.axis(\"off\")\n", "save_fig(\"cart_pole_plot\")" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(400, 600, 3)" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "img.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "막대가 실제로 넘어지지 않더라도 너무 기울어지면 게임이 끝납니다. 환경을 다시 초기화하고 이번에는 오른쪽으로 밀어보겠습니다:" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "obs = env.reset()\n", "while True:\n", " obs, reward, done, info = env.step(1)\n", " if done:\n", " break" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "scrolled": true }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXsAAAD/CAYAAAAQaHZxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAABKNJREFUeJzt3MFNW0EUQNE4chW0kbRBG1ATbiNthDbShrNByHIM2Njkz/x7joRkFlizsK+e5n2x2e/33wBYt+9LHwCAryf2AAFiDxAg9gABYg8QIPYAAWIPECD2AAFiDxAg9gAB26UPcMT/bgD41+baNzDZAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEiD1AwHbpA8BonnePr69/PDwteBK4HbGHF4eRh7VxjQMQIPbwDtM+ayH2AAFiDxAg9vDCkzesmdjDB9zbswZiDxAg9gABYg8H3NuzVmIPECD2cAZLWmYn9gABYg8QIPZwxJKWNRJ7gACxhzNZ0jIzsQcIEHuAALGHEyxpWRuxBwgQe7iAJS2zEnt4g6sc1kTs4UKme2Yk9gABYg8QIPbwDvf2rIXYAwSIPXyCJS2zEXuAALEHCBB7+IAlLWsg9gABYg+fZEnLTMQeIEDsAQLEHs5gScvsxB4gQOzhCpa0zELsAQLEHiBA7OFMlrTMTOzhSu7tmYHYAwSIPUCA2AMEiD1cwJKWWYk93IAlLaMTe7iQ6Z4ZiT1AgNjDjbjKYWRiDxAg9gABYg+fYEnLbMQeIEDs4YYsaRmV2AMEiD1AgNjDJ1nSMhOxhxM2m81ZP6c87x7P/vu33gNuTewBArZLHwDW4Nefh9fX93e7BU8Cp5ns4UqHoT/1O4xA7OEKPx9N8cxB7OEL/H4y3TMWsYcrHd/R39/tTPwMZ7Pf75c+w6GhDkPX/3wkcrDvIGO6+gM51NM4njmmyOeej9xiIBgq9iYcRmGyZ23c2QMEiD1AgNgDBIg9QIDYAwSIPUCA2AMEDPWcPYzCs++sjckeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBgu/QBjmyWPgDAGpnsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQLEHiBA7AECxB4gQOwBAsQeIEDsAQL+Ah8WVBUmj+TxAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_cart_pole(env, obs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "아까 말했던 것과 같은 상황인 것 같습니다. 어떻게 막대가 똑 바로 서있게 만들 수 있을까요? 이를 위한 *정책*을 만들어야 합니다. 이 정책은 에이전트가 각 스텝에서 행동을 선택하기 위해 사용할 전략입니다. 어떤 행동을 할지 결정하기 위해 지난 행동이나 관측을 사용할 수 있습니다." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 하드 코딩 정책" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "간단한 정책을 하드 코딩해 보겠습니다. 막대가 왼쪽으로 기울어지면 카트를 왼쪽으로 밀고 반대의 경우는 오른쪽으로 밉니다. 작동이 되는지 확인해 보죠:" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "frames = []\n", "\n", "n_max_steps = 1000\n", "n_change_steps = 10\n", "\n", "obs = env.reset()\n", "for step in range(n_max_steps):\n", " img = render_cart_pole(env, obs)\n", " frames.append(img)\n", "\n", " # hard-coded policy\n", " position, velocity, angle, angular_velocity = obs\n", " if angle < 0:\n", " action = 0\n", " else:\n", " action = 1\n", "\n", " obs, reward, done, info = env.step(action)\n", " if done:\n", " break" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "video = plot_animation(frames, figsize=(6,4))\n", "HTML(video.to_html5_video()) # HTML5 동영상으로 만들어 줍니다" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "아니네요, 불안정해서 몇 번 움직이고 막대가 너무 기울어져 게임이 끝났습니다. 더 똑똑한 정책이 필요합니다!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 신경망 정책" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "관측을 입력으로 받고 각 관측에 대해 선택할 행동을 출력하는 신경망을 만들어 보겠습니다. 행동을 선택하기 위해 네트워크는 먼저 각 행동에 대한 확률을 추정하고 그다음 추정된 확률을 기반으로 랜덤하게 행동을 선택합니다. Cart-Pole 환경의 경우에는 두 개의 행동(왼쪽과 오른쪽)이 있으므로 하나의 출력 뉴런만 있으면 됩니다. 행동 0(왼쪽)에 대한 확률 `p`를 출력할 것입니다. 행동 1(오른쪽)에 대한 확률은 `1 - p`가 됩니다." ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "\n", "# 1. 네트워크 구조를 설정합니다\n", "n_inputs = 4 # == env.observation_space.shape[0]\n", "n_hidden = 4 # 간단한 작업이므로 너무 많은 뉴런이 필요하지 않습니다\n", "n_outputs = 1 # 왼쪽으로 가속할 확률을 출력합니다\n", "initializer = tf.variance_scaling_initializer()\n", "\n", "# 2. 네트워크를 만듭니다\n", "X = tf.placeholder(tf.float32, shape=[None, n_inputs])\n", "hidden = tf.layers.dense(X, n_hidden, activation=tf.nn.elu,\n", " kernel_initializer=initializer)\n", "outputs = tf.layers.dense(hidden, n_outputs, activation=tf.nn.sigmoid,\n", " kernel_initializer=initializer)\n", "\n", "# 3. 추정된 확률을 기반으로 랜덤하게 행동을 선택합니다\n", "p_left_and_right = tf.concat(axis=1, values=[outputs, 1 - outputs])\n", "action = tf.multinomial(tf.log(p_left_and_right), num_samples=1)\n", "\n", "init = tf.global_variables_initializer()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이 환경은 각 관측이 환경의 모든 상태를 포함하고 있기 때문에 지난 행동과 관측은 무시해도 괜찮습니다. 숨겨진 상태가 있다면 이 정보를 추측하기 위해 이전 행동과 상태를 고려해야 합니다. 예를 들어, 속도가 없고 카트의 위치만 있다면 현재 속도를 예측하기 위해 현재의 관측뿐만 아니라 이전 관측도 고려해야 합니다. 관측에 잡음이 있을 때도 같은 경우입니다. 현재 상태를 근사하게 추정하기 위해 과거 몇 개의 관측을 사용하는 것이 좋을 것입니다. 이 문제는 아주 간단해서 현재 관측에 잡음이 없고 환경의 모든 상태가 담겨 있습니다." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "정책 네트워크에서 만든 확률을 기반으로 가장 높은 확률을 가진 행동을 고르지 않고 왜 랜덤하게 행동을 선택하는지 궁금할 수 있습니다. 이런 방식이 에이전트가 새 행동을 *탐험*하는 것과 잘 동작하는 행동을 *이용*하는 것 사이에 균형을 맞추게 합니다. 만약 어떤 레스토랑에 처음 방문했다고 가정합시다. 모든 메뉴에 대한 선호도가 동일하므로 랜덤하게 하나를 고릅니다. 이 메뉴가 맛이 좋았다면 다음에 이를 주문할 가능성을 높일 것입니다. 하지만 100% 확률이 되어서는 안됩니다. 그렇지 않으면 다른 메뉴를 전혀 선택하지 않게 되고 더 좋을 수 있는 메뉴를 시도해 보지 못하게 됩니다." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "정책 신경망을 랜덤하게 초기화하고 게임 하나를 플레이해 보겠습니다:" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "n_max_steps = 1000\n", "frames = []\n", "\n", "with tf.Session() as sess:\n", " init.run()\n", " obs = env.reset()\n", " for step in range(n_max_steps):\n", " img = render_cart_pole(env, obs)\n", " frames.append(img)\n", " action_val = action.eval(feed_dict={X: obs.reshape(1, n_inputs)})\n", " obs, reward, done, info = env.step(action_val[0][0])\n", " if done:\n", " break\n", "\n", "env.close()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "랜덤하게 초기화한 정책 네트워크가 얼마나 잘 동작하는지 확인해 보겠습니다:" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "video = plot_animation(frames, figsize=(6,4))\n", "HTML(video.to_html5_video()) # HTML5 동영상으로 만들어 줍니다" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "음.. 별로 좋지 않네요. 신경망이 더 잘 학습되어야 합니다. 먼저 앞서 사용한 기본 정책을 학습할 수 있는지 확인해 보겠습니다. 막대가 왼쪽으로 기울어지면 왼쪽으로 움직이고 오른쪽으로 기울어지면 오른쪽으로 이동해야 합니다. 다음 코드는 같은 신경망이지만 타깃 확률 `y`와 훈련 연산(`cross_entropy`, `optimizer`, `training_op`)을 추가했습니다:" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "\n", "reset_graph()\n", "\n", "n_inputs = 4\n", "n_hidden = 4\n", "n_outputs = 1\n", "\n", "learning_rate = 0.01\n", "\n", "initializer = tf.variance_scaling_initializer()\n", "\n", "X = tf.placeholder(tf.float32, shape=[None, n_inputs])\n", "y = tf.placeholder(tf.float32, shape=[None, n_outputs])\n", "\n", "hidden = tf.layers.dense(X, n_hidden, activation=tf.nn.elu, kernel_initializer=initializer)\n", "logits = tf.layers.dense(hidden, n_outputs)\n", "outputs = tf.nn.sigmoid(logits) # 행동 0(왼쪽)에 대한 확률\n", "p_left_and_right = tf.concat(axis=1, values=[outputs, 1 - outputs])\n", "action = tf.multinomial(tf.log(p_left_and_right), num_samples=1)\n", "\n", "cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(labels=y, logits=logits)\n", "optimizer = tf.train.AdamOptimizer(learning_rate)\n", "training_op = optimizer.minimize(cross_entropy)\n", "\n", "init = tf.global_variables_initializer()\n", "saver = tf.train.Saver()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "동일한 네트워크를 동시에 10개의 다른 환경에서 플레이하고 1,000번 반복동안 훈련시키겠습니다. 완료되면 환경을 리셋합니다." ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n" ] } ], "source": [ "n_environments = 10\n", "n_iterations = 1000\n", "\n", "envs = [gym.make(\"CartPole-v0\") for _ in range(n_environments)]\n", "observations = [env.reset() for env in envs]\n", "\n", "with tf.Session() as sess:\n", " init.run()\n", " for iteration in range(n_iterations):\n", " target_probas = np.array([([1.] if obs[2] < 0 else [0.]) for obs in observations]) # angle<0 이면 proba(left)=1. 이 되어야 하고 그렇지 않으면 proba(left)=0. 이 되어야 합니다\n", " action_val, _ = sess.run([action, training_op], feed_dict={X: np.array(observations), y: target_probas})\n", " for env_index, env in enumerate(envs):\n", " obs, reward, done, info = env.step(action_val[env_index][0])\n", " observations[env_index] = obs if not done else env.reset()\n", " saver.save(sess, \"./my_policy_net_basic.ckpt\")\n", "\n", "for env in envs:\n", " env.close()" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [], "source": [ "def render_policy_net(model_path, action, X, n_max_steps = 1000):\n", " frames = []\n", " env = gym.make(\"CartPole-v0\")\n", " obs = env.reset()\n", " with tf.Session() as sess:\n", " saver.restore(sess, model_path)\n", " for step in range(n_max_steps):\n", " img = render_cart_pole(env, obs)\n", " frames.append(img)\n", " action_val = action.eval(feed_dict={X: obs.reshape(1, n_inputs)})\n", " obs, reward, done, info = env.step(action_val[0][0])\n", " if done:\n", " break\n", " env.close()\n", " return frames " ] }, { "cell_type": "code", "execution_count": 42, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", "INFO:tensorflow:Restoring parameters from ./my_policy_net_basic.ckpt\n" ] }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "frames = render_policy_net(\"./my_policy_net_basic.ckpt\", action, X)\n", "video = plot_animation(frames, figsize=(6,4))\n", "HTML(video.to_html5_video()) # HTML5 동영상으로 만들어 줍니다" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "정책을 잘 학습한 것 같네요. 이제 스스로 더 나은 정책을 학습할 수 있는지 알아 보겠습니다." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 정책 그래디언트" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "신경망을 훈련하기 위해 타깃 확률 `y`를 정의할 필요가 있습니다. 행동이 좋다면 이 확률을 증가시켜야 하고 반대로 나쁘면 이를 감소시켜야 합니다. 하지만 행동이 좋은지 나쁜지 어떻게 알 수 있을까요? 대부분의 행동으로 인한 영향은 뒤늦게 나타나는 것이 문제입니다. 게임에서 이기거나 질 때 어떤 행동이 이런 결과에 영향을 미쳤는지 명확하지 않습니다. 마지막 행동일까요? 아니면 마지막 10개의 행동일까요? 아니면 50번 스텝 앞의 행동일까요? 이를 *신용 할당 문제*라고 합니다.\n", "\n", "*정책 그래디언트* 알고리즘은 먼저 여러번 게임을 플레이하고 성공한 게임에서의 행동을 조금 더 높게 실패한 게임에서는 조금 더 낮게 되도록 하여 이 문제를 해결합니다. 먼저 게임을 진행해 보고 다시 어떻게 한 것인지 살펴 보겠습니다." ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "\n", "reset_graph()\n", "\n", "n_inputs = 4\n", "n_hidden = 4\n", "n_outputs = 1\n", "\n", "learning_rate = 0.01\n", "\n", "initializer = tf.variance_scaling_initializer()\n", "\n", "X = tf.placeholder(tf.float32, shape=[None, n_inputs])\n", "\n", "hidden = tf.layers.dense(X, n_hidden, activation=tf.nn.elu, kernel_initializer=initializer)\n", "logits = tf.layers.dense(hidden, n_outputs)\n", "outputs = tf.nn.sigmoid(logits) # 행동 0(왼쪽)에 대한 확률\n", "p_left_and_right = tf.concat(axis=1, values=[outputs, 1 - outputs])\n", "action = tf.multinomial(tf.log(p_left_and_right), num_samples=1)\n", "\n", "y = 1. - tf.to_float(action)\n", "cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(labels=y, logits=logits)\n", "optimizer = tf.train.AdamOptimizer(learning_rate)\n", "grads_and_vars = optimizer.compute_gradients(cross_entropy)\n", "gradients = [grad for grad, variable in grads_and_vars]\n", "gradient_placeholders = []\n", "grads_and_vars_feed = []\n", "for grad, variable in grads_and_vars:\n", " gradient_placeholder = tf.placeholder(tf.float32, shape=grad.get_shape())\n", " gradient_placeholders.append(gradient_placeholder)\n", " grads_and_vars_feed.append((gradient_placeholder, variable))\n", "training_op = optimizer.apply_gradients(grads_and_vars_feed)\n", "\n", "init = tf.global_variables_initializer()\n", "saver = tf.train.Saver()" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [], "source": [ "def discount_rewards(rewards, discount_rate):\n", " discounted_rewards = np.zeros(len(rewards))\n", " cumulative_rewards = 0\n", " for step in reversed(range(len(rewards))):\n", " cumulative_rewards = rewards[step] + cumulative_rewards * discount_rate\n", " discounted_rewards[step] = cumulative_rewards\n", " return discounted_rewards\n", "\n", "def discount_and_normalize_rewards(all_rewards, discount_rate):\n", " all_discounted_rewards = [discount_rewards(rewards, discount_rate) for rewards in all_rewards]\n", " flat_rewards = np.concatenate(all_discounted_rewards)\n", " reward_mean = flat_rewards.mean()\n", " reward_std = flat_rewards.std()\n", " return [(discounted_rewards - reward_mean)/reward_std for discounted_rewards in all_discounted_rewards]" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([-22., -40., -50.])" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "discount_rewards([10, 0, -50], discount_rate=0.8)" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[array([-0.28435071, -0.86597718, -1.18910299]),\n", " array([1.26665318, 1.0727777 ])]" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "discount_and_normalize_rewards([[10, 0, -50], [10, 20]], discount_rate=0.8)" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", "반복: 249" ] } ], "source": [ "env = gym.make(\"CartPole-v0\")\n", "\n", "n_games_per_update = 10\n", "n_max_steps = 1000\n", "n_iterations = 250\n", "save_iterations = 10\n", "discount_rate = 0.95\n", "\n", "with tf.Session() as sess:\n", " init.run()\n", " for iteration in range(n_iterations):\n", " print(\"\\r반복: {}\".format(iteration), end=\"\")\n", " all_rewards = []\n", " all_gradients = []\n", " for game in range(n_games_per_update):\n", " current_rewards = []\n", " current_gradients = []\n", " obs = env.reset()\n", " for step in range(n_max_steps):\n", " action_val, gradients_val = sess.run([action, gradients], feed_dict={X: obs.reshape(1, n_inputs)})\n", " obs, reward, done, info = env.step(action_val[0][0])\n", " current_rewards.append(reward)\n", " current_gradients.append(gradients_val)\n", " if done:\n", " break\n", " all_rewards.append(current_rewards)\n", " all_gradients.append(current_gradients)\n", "\n", " all_rewards = discount_and_normalize_rewards(all_rewards, discount_rate=discount_rate)\n", " feed_dict = {}\n", " for var_index, gradient_placeholder in enumerate(gradient_placeholders):\n", " mean_gradients = np.mean([reward * all_gradients[game_index][step][var_index]\n", " for game_index, rewards in enumerate(all_rewards)\n", " for step, reward in enumerate(rewards)], axis=0)\n", " feed_dict[gradient_placeholder] = mean_gradients\n", " sess.run(training_op, feed_dict=feed_dict)\n", " if iteration % save_iterations == 0:\n", " saver.save(sess, \"./my_policy_net_pg.ckpt\")" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [], "source": [ "env.close()" ] }, { "cell_type": "code", "execution_count": 49, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", "INFO:tensorflow:Restoring parameters from ./my_policy_net_pg.ckpt\n" ] }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "frames = render_policy_net(\"./my_policy_net_pg.ckpt\", action, X, n_max_steps=1000)\n", "video = plot_animation(frames, figsize=(6,4))\n", "HTML(video.to_html5_video()) # HTML5 동영상으로 만들어 줍니다" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 마르코프 연쇄" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "상태: 0 0 3 \n", "상태: 0 1 2 1 2 1 2 1 2 1 3 \n", "상태: 0 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 3 \n", "상태: 0 3 \n", "상태: 0 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 3 \n", "상태: 0 1 3 \n", "상태: 0 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 ...\n", "상태: 0 0 3 \n", "상태: 0 0 0 1 2 1 2 1 3 \n", "상태: 0 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 3 \n" ] } ], "source": [ "transition_probabilities = [\n", " [0.7, 0.2, 0.0, 0.1], # s0에서 s0, s1, s2, s3으로\n", " [0.0, 0.0, 0.9, 0.1], # s1에서 ...\n", " [0.0, 1.0, 0.0, 0.0], # s2에서 ...\n", " [0.0, 0.0, 0.0, 1.0], # s3에서 ...\n", " ]\n", "\n", "n_max_steps = 50\n", "\n", "def print_sequence(start_state=0):\n", " current_state = start_state\n", " print(\"상태:\", end=\" \")\n", " for step in range(n_max_steps):\n", " print(current_state, end=\" \")\n", " if current_state == 3:\n", " break\n", " current_state = np.random.choice(range(4), p=transition_probabilities[current_state])\n", " else:\n", " print(\"...\", end=\"\")\n", " print()\n", "\n", "for _ in range(10):\n", " print_sequence()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 마르코프 결정 과정" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "policy_fire\n", "상태 (+보상): 0 (10) 0 (10) 0 1 (-50) 2 2 2 (40) 0 (10) 0 (10) 0 (10) ... 전체 보상 = 210\n", "상태 (+보상): 0 1 (-50) 2 (40) 0 (10) 0 (10) 0 1 (-50) 2 2 (40) 0 (10) ... 전체 보상 = 70\n", "상태 (+보상): 0 (10) 0 1 (-50) 2 (40) 0 (10) 0 (10) 0 (10) 0 (10) 0 (10) 0 (10) ... 전체 보상 = 70\n", "상태 (+보상): 0 1 (-50) 2 1 (-50) 2 (40) 0 (10) 0 1 (-50) 2 (40) 0 ... 전체 보상 = -10\n", "상태 (+보상): 0 1 (-50) 2 (40) 0 (10) 0 (10) 0 1 (-50) 2 (40) 0 (10) 0 (10) ... 전체 보상 = 290\n", "요약: 평균=121.1, 표준 편차=129.333766, 최소=-330, 최대=470\n", "\n", "policy_random\n", "상태 (+보상): 0 1 (-50) 2 1 (-50) 2 (40) 0 1 (-50) 2 2 (40) 0 ... 전체 보상 = -60\n", "상태 (+보상): 0 (10) 0 0 0 0 0 (10) 0 0 0 (10) 0 ... 전체 보상 = -30\n", "상태 (+보상): 0 1 1 (-50) 2 (40) 0 0 1 1 1 1 ... 전체 보상 = 10\n", "상태 (+보상): 0 (10) 0 (10) 0 0 0 0 1 (-50) 2 (40) 0 0 ... 전체 보상 = 0\n", "상태 (+보상): 0 0 (10) 0 1 (-50) 2 (40) 0 0 0 0 (10) 0 (10) ... 전체 보상 = 40\n", "요약: 평균=-22.1, 표준 편차=88.152740, 최소=-380, 최대=200\n", "\n", "policy_safe\n", "상태 (+보상): 0 1 1 1 1 1 1 1 1 1 ... 전체 보상 = 0\n", "상태 (+보상): 0 1 1 1 1 1 1 1 1 1 ... 전체 보상 = 0\n", "상태 (+보상): 0 (10) 0 (10) 0 (10) 0 1 1 1 1 1 1 ... 전체 보상 = 30\n", "상태 (+보상): 0 (10) 0 1 1 1 1 1 1 1 1 ... 전체 보상 = 10\n", "상태 (+보상): 0 1 1 1 1 1 1 1 1 1 ... 전체 보상 = 0\n", "요약: 평균=22.3, 표준 편차=26.244312, 최소=0, 최대=170\n", "\n" ] } ], "source": [ "transition_probabilities = [\n", " [[0.7, 0.3, 0.0], [1.0, 0.0, 0.0], [0.8, 0.2, 0.0]], # s0에서, 행동 a0이 선택되면 0.7의 확률로 상태 s0로 가고 0.3의 확률로 상태 s1로 가는 식입니다.\n", " [[0.0, 1.0, 0.0], None, [0.0, 0.0, 1.0]],\n", " [None, [0.8, 0.1, 0.1], None],\n", " ]\n", "\n", "rewards = [\n", " [[+10, 0, 0], [0, 0, 0], [0, 0, 0]],\n", " [[0, 0, 0], [0, 0, 0], [0, 0, -50]],\n", " [[0, 0, 0], [+40, 0, 0], [0, 0, 0]],\n", " ]\n", "\n", "possible_actions = [[0, 1, 2], [0, 2], [1]]\n", "\n", "def policy_fire(state):\n", " return [0, 2, 1][state]\n", "\n", "def policy_random(state):\n", " return np.random.choice(possible_actions[state])\n", "\n", "def policy_safe(state):\n", " return [0, 0, 1][state]\n", "\n", "class MDPEnvironment(object):\n", " def __init__(self, start_state=0):\n", " self.start_state=start_state\n", " self.reset()\n", " def reset(self):\n", " self.total_rewards = 0\n", " self.state = self.start_state\n", " def step(self, action):\n", " next_state = np.random.choice(range(3), p=transition_probabilities[self.state][action])\n", " reward = rewards[self.state][action][next_state]\n", " self.state = next_state\n", " self.total_rewards += reward\n", " return self.state, reward\n", "\n", "def run_episode(policy, n_steps, start_state=0, display=True):\n", " env = MDPEnvironment()\n", " if display:\n", " print(\"상태 (+보상):\", end=\" \")\n", " for step in range(n_steps):\n", " if display:\n", " if step == 10:\n", " print(\"...\", end=\" \")\n", " elif step < 10:\n", " print(env.state, end=\" \")\n", " action = policy(env.state)\n", " state, reward = env.step(action)\n", " if display and step < 10:\n", " if reward:\n", " print(\"({})\".format(reward), end=\" \")\n", " if display:\n", " print(\"전체 보상 =\", env.total_rewards)\n", " return env.total_rewards\n", "\n", "for policy in (policy_fire, policy_random, policy_safe):\n", " all_totals = []\n", " print(policy.__name__)\n", " for episode in range(1000):\n", " all_totals.append(run_episode(policy, n_steps=100, display=(episode<5)))\n", " print(\"요약: 평균={:.1f}, 표준 편차={:1f}, 최소={}, 최대={}\".format(np.mean(all_totals), np.std(all_totals), np.min(all_totals), np.max(all_totals)))\n", " print()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Q-러닝" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Q-러닝은 에이전트가 플레이하는 것(가령, 랜덤하게)을 보고 점진적으로 Q-가치 추정을 향상시킵니다. 정확한 (또는 충분히 이에 가까운) Q-가치가 추정되면 최적의 정책은 가장 높은 Q-가치(즉, 그리디 정책)를 가진 행동을 선택하는 것이 됩니다." ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [], "source": [ "n_states = 3\n", "n_actions = 3\n", "n_steps = 20000\n", "alpha = 0.01\n", "gamma = 0.99\n", "exploration_policy = policy_random\n", "q_values = np.full((n_states, n_actions), -np.inf)\n", "for state, actions in enumerate(possible_actions):\n", " q_values[state][actions]=0\n", "\n", "env = MDPEnvironment()\n", "for step in range(n_steps):\n", " action = exploration_policy(env.state)\n", " state = env.state\n", " next_state, reward = env.step(action)\n", " next_value = np.max(q_values[next_state]) # 그리디한 정책\n", " q_values[state, action] = (1-alpha)*q_values[state, action] + alpha*(reward + gamma * next_value)" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "def optimal_policy(state):\n", " return np.argmax(q_values[state])" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[39.13508139, 38.88079412, 35.23025716],\n", " [18.9117071 , -inf, 20.54567816],\n", " [ -inf, 72.53192111, -inf]])" ] }, "execution_count": 54, "metadata": {}, "output_type": "execute_result" } ], "source": [ "q_values" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "상태 (+보상): 0 (10) 0 (10) 0 1 (-50) 2 (40) 0 (10) 0 1 (-50) 2 (40) 0 (10) ... 전체 보상 = 230\n", "상태 (+보상): 0 (10) 0 (10) 0 (10) 0 1 (-50) 2 2 1 (-50) 2 (40) 0 (10) ... 전체 보상 = 90\n", "상태 (+보상): 0 1 (-50) 2 (40) 0 (10) 0 (10) 0 (10) 0 (10) 0 (10) 0 (10) 0 (10) ... 전체 보상 = 170\n", "상태 (+보상): 0 1 (-50) 2 (40) 0 (10) 0 (10) 0 (10) 0 (10) 0 (10) 0 (10) 0 (10) ... 전체 보상 = 220\n", "상태 (+보상): 0 1 (-50) 2 (40) 0 (10) 0 1 (-50) 2 (40) 0 (10) 0 (10) 0 (10) ... 전체 보상 = -50\n", "요약: 평균=125.6, 표준 편차=127.363464, 최소=-290, 최대=500\n", "\n" ] } ], "source": [ "all_totals = []\n", "for episode in range(1000):\n", " all_totals.append(run_episode(optimal_policy, n_steps=100, display=(episode<5)))\n", "print(\"요약: 평균={:.1f}, 표준 편차={:1f}, 최소={}, 최대={}\".format(np.mean(all_totals), np.std(all_totals), np.min(all_totals), np.max(all_totals)))\n", "print()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# DQN 알고리즘으로 미스팩맨 게임 학습하기" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 미스팩맨 환경 만들기" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(210, 160, 3)" ] }, "execution_count": 56, "metadata": {}, "output_type": "execute_result" } ], "source": [ "env = gym.make(\"MsPacman-v0\")\n", "obs = env.reset()\n", "obs.shape" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Discrete(9)" ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ "env.action_space" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 전처리" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이미지 전처리는 선택 사항이지만 훈련 속도를 크게 높여 줍니다." ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [], "source": [ "mspacman_color = 210 + 164 + 74\n", "\n", "def preprocess_observation(obs):\n", " img = obs[1:176:2, ::2] # 자르고 크기를 줄입니다.\n", " img = img.sum(axis=2) # 흑백 스케일로 변환합니다.\n", " img[img==mspacman_color] = 0 # 대비를 높입니다.\n", " img = (img // 3 - 128).astype(np.int8) # -128~127 사이로 정규화합니다.\n", " return img.reshape(88, 80, 1)\n", "\n", "img = preprocess_observation(obs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "노트 `preprocess_observation()` 함수가 책에 있는 것과 조금 다릅니다. 64비트 부동소수를 -1.0~1.0 사이로 나타내지 않고 부호있는 바이트(-128~127 사이)로 표현합니다. 이렇게 하는 이유는 재생 메모리가 약 8배나 적게 소모되기 때문입니다(52GB에서 6.5GB로). 정밀도를 감소시켜도 눈에 띄이게 훈련에 미치는 영향은 없습니다." ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAArsAAAGoCAYAAABGyS0qAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xm8pFdZJ/DfAaOyCCPIAAY0AkLYJRhkJ4yEALKIcYGBZhhxUEChYVBcQNERRWRpVIwyiIEbEVBkhLAJDEhEIpuAYXNEA3YMhE1WxQBn/njfGyrVdatu3+q6b72nvt/P5366b73vec95T9W9/fRTT51Taq0BAIAWXWboAQAAwKoIdgEAaJZgFwCAZgl2AQBolmAXAIBmCXYBAGiWYBcAgGYJdkeslPJ1pZRnlFI+WEp5fynlL0spJ/XHfqKUcubEuT9WSjl7xjWeWEo5tMP1zyqlnD/x9fFSyh/1x84tpZwyo81rptpcaRf3cZlSyp+WUm6x+7tnWinlB0opvzH0OABgnQh2x+1hSU5MctNa6w2TPD3Ji+acf5dSykcnv5I8dqeTa60PrLWesP2V5AlJPjvr3FLKSaWUDyT59iT/nuSL/aGLSik3W3Afj0zyoVrr3/bXunwp5SmllFpKufVUP5ctpfxGKeW9pZS3lVJeP3HsB/vH31NKeVMp5foL+p11H1cvpfx+/5+Ht5ZSziml3HTi+FVKKVv92K4x1fYRpZQPlFLOK6WcXUq5+ozrP7GU8rH+Pwtv7r/uMnH8uFLK40op7yqlvLuU8tellFeWUu7QHz+zlPKRvv1bSylvL6V8X5LUWv8syQ1LKace7X0DQKsEu+NW0j2H28/j1y04/9W11mtMfiV56twOSrnidrY4ydWSXDTrvFrrO2utJ9ZaT0xyUpJnJfl8knvVWt8z5/qXS/K4qXH8TJL3J/nwjCZ/kORwkpvUWk9O8sP9da6X5PeT3KfWerMkz03ykqm+nl5K+eGJ769SSnldKeVqE6edlOQ1tdYb1lpvleT/JHnaxPFfSfLiGfdxSn8fd6i13iTJ2/uxzvLSWuuta623S/JTSV5aSvmW/tgLktwiyam11pvXWm+b5EFJLp5o/7t9+1sleXSS500c+/UkT9mhXwDYOIuCI9bbGUm+M8n7Sik1yUeT3H/O+aeVUg5PPXaldIHhTk5Mclb/5/WSvHHi2C+UUk6rtf5cn3m8SZJT0mV3/yrJZ5I8spRy8yR/sUPQe+8k76m1fnz7gVrrE5OklPJLkyf2mdprJvmOJH9TSvl0umzzJ5Ocni5I/Yf+9K0kTy2l3LTW+nf9Y09K8qpSyuWTvDLJq5I8barvV02N78JM/JzUWn+yH8v0ffxIkrMmrvXMJB8vpVy51vqZGfe9fb13llI+m+SEPmC+QZJb1lovnjjnE0k+scMlvj3JuyfOfUsp5VtKKTeutb53p34BYFMIdkeqlHKjJNdP8vokb0jyDUmukOTOpZTTM5UVrbU+J8lz+raPTXJirfXHjrLbX03y6VLKlfvvX53kzf3fr5cu+/hrtdZLgq8+43qndFnoWW6ZiWBtgTskuVWS36y1PqaUcrckryylXDfJdZJ8aPvEWutXSikf7h//u/6xT/Zv8b8yyROTPKZ/63+mvgzhV5I8ZBdju04mMr611k+XUj6T5IR591dK+f4k/5HkvCQPTvKy7UC3lPJTSR6Q5JvSBfKP6Zs9vG93tXQ/w/eduuy7082rYBeAjSfYHa8Tkty+//vxSe6Y5BnpMoDvSBcI71kp5ZfTZUu/Mcm1SynnpwtYv5Tkhf1p76i1nltKeU0/hu22O11zq9Y6/QGqKyb52C6H9Z+T/FWt9XVJUmt9dT+u2/dj+8rU+V/OkaU635wum31hugB9p7FeNX1QXGv9y12Mbbf9J8l9SynflS4re26SW9da/72ft8ttn1Rr/e0kv11K+dl0mfVtv1trfXI/zu9K8opSyl0nMrmfT3LlAACC3bGqtb4yXTC2XS/6nbXWp5ZS3pYu4/dNSV5eSrlWuoCq9F9f3b7GjJKGh9RaX9P//cnpgucrJflcks/UWifb3nWi3X2zu/rv/5jx2IfTlWLsxkU58gNyX00XZB5O9x+ASdfuH0+SlFJukOTPkvxEkrcm+ZNSyuW3yyYmzrtmuhKHp9RaX7DLsR1O8m0T17h8kqtO9j/hpbXWn+g/WPY76YLTpHueHlNKuczkXM9Ta31XKeUtSU7L1zK5107yT7scNwA0zQfURq6U8qIkN9r+vtZ6cr9ywi/03x+utV4r3QeXXllrvdacr9dMXOffaq3/muT5SW43L/iqtX4xyalJPrDD12trrZ+vtc4Kds9O8l92ebtnJ7nr9ioLpZTbpwsw35Lug13f1wf3KaX8ULpVId4x0f4+Sf57rfWcWuuX0mWur9pncdO3+/Z0pSH/6ygC3aSrEX7ARInHI5K8ebIeeFqt9RXpSiye3D/04nSB+7NKKVecOPVaO12jL7X4niR/039/5SQ3TlfaAgAbT2Z3/G6QbsmxP5x6/BNJPrLCfn8tyQe3v6m1vjTJS6dPKqWcmK62d6Za63tLKR8qpdx9xofDps/9WCnlAUleUkq5OF1Jxff3HwD7TF/j+vL+2OeS3H0ySK+1PmXqehenWw1h0tOSXD3JT5dSfrp/7Eu11jstGNsbSinPSvKXff//kuR+89r0HpXkvFLKi2ut55RS7pzk55OcW0r5SrqA/aJcekWI7Zrd7dU4Hl9r3a6d/h9Jnldr/cIu+gaA5pVa69BjYAmllHelW6Hg4hmH31trPa0/7yeT/EaST8847x9qrafscP03pltl4d9nHP6lWutOy2tttz8x3ZJnJ8w554ZJ/neSO0+uQsDR6TPUr0pyWq111vMMABtHsMtaKKXcKckna63nDT2WsSql3DHJJ2qt7xt6LACwLgS7AAA0ywfUAGhW2WktRGBz1FrX5itJ9eXLl6/9/hr6d5+vY/rvyPnplvz7SP/n55Ic1x97YZIHT5z72HQ7T340ydumrnM4yfX6v/9W//3k13X7YycmOX+i3QnpVlU5f+rrpunW9j48Y8zfPeP8i9Mty3i3JG+c0ebuU+f/Yvp3axfMz+2S/NHQz9OKnvsbTD1H9+8fPzPJj6VbdvLMifPvnG4zok+m22b9shPHPprkhB36eW6SR+xyTKXv+8x0O4z+1cSxk5Ock+5DyO9M93mL7WOXvFaT/M8kDx96fsf8ZTUGAJpRJz4MW0p5fKa2354696lJnjpxfql9dDF13iOTPHLivMNJjpszjI/XGR/K7XeUnDWOt2dqnfBSykdz5Lri28eem+S2+doHhy9O8nPpAu//utOgSilXSPIH6XajTCnle5L8ZrrNg45L8o9JHllrvaBf4vGZ6dZtL+lW+Hl0PcrPBCzo4+vT/UfiTn0fr0m3s+X0Bj0ppXwwszfLuXqSq9VaP1Fr/WD6pRpLKWel21V0p3F9S7qA8vuSvC/Jy5L8eJLf3cVtXT/JzM2G+n5vn24N+CskeUm6QHb6vCskeXmSH621vrKUcnK6DYJOrrV+eOr0ZyT5m1LKG2qt79/F+JiyVsHu4Uc9aughANCAfhvxg0m+t5TyK0kemi5YenV//DFJtrfg/nKS/5Tkt5M8Yf9H2+m3gb8oyafSjfWTs86rtf7oRJtbpMvqvitd5nKen0q33vr2+t8vSLd04R/35R6H0mU4H5DkjCSvq7X+et/PY9KtmnO7/vvLJHlFksfWfvfGfiy/XGu990Sf8/r4+XTB6Y3TlVW+KsmjM/EfkIl7vsH0Y6WUr0u3WdHnFtz3LPdL8hf9fzRSSvm1flxzg91Syi2T3Dzd8pQvqrVeaqWiWusDJ879xSSX3eFS10/y+dptEJVa69v61ZVOSrfZ0uQ1v1pKeXq6JT+nt4dnF9TsAtCUfrfEVyT5QpLb1lp/sdZ6jSR/vn1OrfXptd9QJ8l3pCt7mLvW91G4Winl/KmvH9xFu6ckuUu6rcQvqrV+uX/8hqWUs0opJ5dSrlFK+aFSypNKKX+dLqP7qXQB45NKKQ/sM6az/LckfzLx/QX9WEuSr08X8F8wcewqpZTLllIum25HyO1jqd0a5k9It+75LfoM7h+n39Bol338SJLfqrV+tb/XZyW5/y7madtV0wWMX9rphFLKrdJlp6ddL8lklvR9Sa47r7M+G/zsJD+d7vX1gj7gnjzntqWUU0op90hy18zI6vY+mORypZR7ls5t0gXRb9vh/JcnuUsp5ZvnjZHZ1iqzO8+1XnLNoYcwuMOnX7jjMfOD18d88+aHNpRSLpfk4elqcR+V5C+SnF1KuXX/2E4eneSjtda/7t8uv0KSa0xc98wk35+vbe2dzF7bPLXW89Nn80opn0jy3f1jO5YxzHBRkrv1OyIel65+9Mx0Gb9vTJcVfEu6Lc0/M3Hvt09XZ3zEbpX9FuYnJnn3xMM/kC5oe1ySy6fbpfLn+mMPTfcW/MfSZb7fk6msYq317f1ulS9JV/9+776UYNK8Pq6TrmZ224f6x3brqkku+cHu13U/L1/bpv3lSZ6Y5LsyY9OjdKUTs/5+6ZO6QP0+6coxXlhr/b0+s/28JOeUUn681vqe/vT7pLvPKyS5VbpSkCumu+9L1Fq/2G8Z/5vpsuj/nOR+tdZZW8yn1vr5UsqH09V+v2mnsTKbzC4ArbhfklskuU2t9cW12/L8lHRv8R9Ri5skfcb1YJKblFJuUmu9QZ/t/ejUqU+sl95e/UNHXm2m40opVy2l3DQzMoyllKuXUs4rpZyX5I7p3uZ/f5JXJvm/SY5P8ula6+uS3KZ//P7pthl/80Tbt6Wr7XxE/9g3TXV1xXR1pJNvu78gXZ3st6UL7v8jX9ut8bf6OTi+/zo3yR/NuL/j0+1meXGSb51xfF4fJd2H+bZ9OUcXl1w73YfzJh2utZ7Qf72o1nqPdP/pmfYP6YL/bTfMpQPvSbdOF6A/otb6hKTLbNdaD6TbvfRFpZRr948/Ll3m93rp6o9PSPKzsy5aa31XrfXUWuu1a623rbW+YeLwy9MF7pM+n9l1yywwmswuAMxTa/3DJH/YB5fXqbX+Y//2+DOSZHIVsj4T+vPpspW3S1cr+ZpSykNrra9Y1Fcp5bh0WchvmHjsFumClKQLLC9OVyP8qXSB46xg8aJ0Gdnjklwuyccm35Yvpdxt4txXJHn9orElXSZw6qGPpwtKvzXJ4f4t+buky8Z+JclXSilPSRc0H0z3Qbdbbo+lr2n9YinlKrXWT/WP3TddvfBp6eKJs0spP1Nr3a6LXtTH4XRB8Pn9GL8tX8vK7saN033gbaZSypXSlU3M8sIkT+hrcN+fLph97qwTa61vSfI9Oxx7drrShu0+r5/k+enqgX9n0Q2UUh6d5BEzDl09XY312yceu3aSf1p0TY4kswtAa+6Vrv51pv5t6ZelCy5vU2v951rrnyf5wVw627ftK0ke39fe/mNf6nBukl/KRLBba/3bPit8xyTvqrVes9Z63VrrybXWe+XSgct2m9pnoG+c5M/m1Z/WWr/cB7GvSvKBHb7uPiPQTb/KxCuSfG//0KfTrbAwWUv8Q+lqSZPk7/vvJ499PMm/JpfM4alJ7lpr/Zda60fSBb33mGizqI+tJA/ra1Yvk64E5U93uv8ZbpxL193WJNcspVzQv+X/6iT3ntWw1vqJdAH9VpL/l6684/ePou8jlFKu3l/vV2utv7ibNrXWZ9Rarzf9lYn68v7aN0r3n6f3LjPGTSWzC8CmeEySL9RaaynltP5DVpfoM3hvmW5Ua31IkofMumBfJzrt8ukyxdM+leTpRznm89LVdU6O5w47jOWFmbPcVroSiacleV6t9SullHsmeUop5VHp6owvTLdKQtKVhDyjlPLOdMH+F5Lca3vO+uD54VPjujATS7Ttoo8n92N6R//9m5I8aeqeDqbLAs/yjUlO61eKeH3/PH3D9El9zfURaq2vT3KjHa693fbW6bLAi3yh1nrj7JABnnP9X01XH33RjMOTHyZ8eJJn9vPOURLsAtCie5RuPdxpL0jyM9OB7gpcY4f+U0p5Tq111hq6N9uhzWfSZWWXUmt9cynlvaWUH+5rmt+arqZ51rl/n24N2mX7nNfHv6Vb23Ze+0PpAuJB1FrPzdQayCvw9elqqqfdJsmf96URN0/3QUr2QLALQFNqrWemW71gmWtca5fnfSBTwVCt9bwc5b+vtdY3pgt69qzWer9dnPawJA9app+xqbU+eOLb39tlm2ssPmvX/T8nyXP6b28/dezxSR6/4BInJvmRusPmKCwm2AWADdEvS/achSeyNmqtLxt6DGPnA2oAADSriczuosXi5y2ov9eF+FfRblHbvRpirPs9r4uM6Xne780hxvQ8D/HzA8C4NRHsAjB70wSAxu24+902ZQwAADSrrNOSbRccPLjjYLw9uf9vbzMuXh/zzZuf4w8dWpgZGIH1+WUOsH8W/v5WxgCwAQ4cODD0ENba1tbW3OPmj73y2lrOovnbDWUMAAA0S7ALAECzBLsAADRLsAsAQLMEuwAANEuwCwBAswS7AAA0a+PX2Z230PxetbKA/yrmho7XyM5amZuxWXYt0LG3X9bQ4x+6/SqvPfb2yxp6/EO3PxZkdgEAaFYTmd1F2SWZIgCAzSSzCwBAs0qtdegxXOKCgwd3HMy87Owymd0x1RzOG+sq+lSzuzr7/XwN8ZrcqyHGevyhQ2Ulne6vub/M96MubszWoa6QNnltLWfR/CVZ+PtbZhcAgGYJdgEAaJZgFwCAZgl2AQBoVhNLj43JEMukDfHBJfZu3T6I6DUCwJjJ7AIA0CzBLgAAzRLsAgDQLDW7+2yI+kc1l+Oy38+X1wcALWsi2PWPNQAAsyhjAACgWU1kdgFYzqL95w8cOLBU+2Ut6n/sVj1/7MxrazljmD+ZXQAAmiXYBQCgWYJdAACaJdgFAKBZgl0AAJq18asxjGmNXpsNMI8NSwDgSDK7AAA0a+MzuwCM37LrBI+9f1Zn6Od26P5bILMLAECzBLsAADRLsAsAQLPU7AIwekPXLQ7dP6sz9HM7dP8tkNkFAKBZgl0AAJrVRBnD4dMvnHt83sL389rud7sh+mxlrIuYO2MFYDM1EewCsBx1gUCrlDEAANCsJjK7y7x1ude2+91uiD7HNNZVXXcT5m5TxgrAZpLZBQCgWYJdAACaJdgFAKBZgl0AAJol2AUAoFml1jr0GC5xwcGDOw7Gp7CBZczbkOL4Q4fKPg5lVeb+MreOLjBGW1tbi05Z+PtbZhcAgGYJdgEAaNZoNpWY9xbkIqsogZg3nlWVXCwzBzsZ01jXzZjmboix7vfPHQDMIrMLAECzBLsAADRLsAsAQLMEuwAANGs0H1ADYHUWrWW5aJ1e7Te7/Sqvrf1mtz8WZHYBAGiWYBcAgGYJdgEAaFapde526vvqgoMHVzKYeYvb73VR/FW0W7btXoxprOtmTHM3xFiH+PnZq+MPHVq4t/oIzP39uR91cQDH2qKa3yQLf3/L7AIA0CzBLgAAzRLsAgDQrI1fZ3defWAL/S1jTGNdN2OauyHGOqb5AWDcZHYBAGiWYBcAgGYJdgEAaJZgFwCAZgl2AQBolmAXAIBmCXYBAGjWxq+zC7AJdrG//FwHDhw4RiOZbdH4hu5/WWMf/zob+9wOPf6h+98PaxXsWmh+vjHNz5jGum7M3c5WNTf10EouC8AaUMYAAECz1iqzC0fj3Ju8ccdjtz7vlH0bBwCwvgS7jM68IHf6HEEvAGw2ZQwAADRLZpfR2Clbe+5N3jjzsXltAIDNILMLAECzZHYZnVnZWhlcGNa6r+W56v6XNfbxr7Oxz+3Q4x+6/2NBZhcAgGY1kdk9fPqFc4/PW4h+XttVLGA/xFhX0W6IPg9/sPtzVvZ2UUZ30+duVWPdqzGNFYBxk9kFAKBZTWR22Syz1tndzdq7wOoMXbc3dP/LGvv419nY53bo8Q/d/7HQRLC7zFuX+/225xBj3e92q+rz3Jt8cK/D2fi5W0W7ZYxprACMmzIGAACaJdgFAKBZgl0AAJol2AUAoFmCXQAAmiXYZTRufd4pl9o8YvL7eccAgM0l2AUAoFlNrLPLZpnO2E5ndOedCwBslrUKdhfte79X+70Q/aruYxVWNTdjmoO9Mnf7z9wAcLSUMQAA0Ky1yuwCsDct7F8/pLHP336P/6EPfeiOx5797Gfv40hWb+yvjaGtev62trYWniOzCwBAswS7AAA0S7ALAECz1OwCAHNN1+jOq8s9mnNhP8jsAgDQLMEuAADN2vgyhnmL1O/3ZhTrZtEC/nudnzHN6143MVjV3LXCzx0A+2Xjg10AFq9VuWitzGXbL2vo8bfe/pxzzpl7fJlrD31vrbdf1tjHnyhjAACgYYJdAACatfFlDOoDd2Zu9s7czWd+ANgvpdY69BguccHBgysZzCr+Yd3rB5fWzaqCjlY+gDTEfXht7WxVc3P8oUNlJRfeRwcOHFifX+Z7sA51fexseu3co2Gd3WG1/rO1tbW18Pe3MgYAAJol2AUAoFmCXQAAmiXYBQCgWYJdAACaJdgFAKBZgl0AAJq1VptK7HVdziHWJR3TWrGMi9fW3u117uqhYzwQaIy1chkzmV0AAJol2AUAoFmCXQAAmrVWNbsA7M3W1tZKr3/gwIGVXn+RVd/fqq16/sY+P8swt+ttHeZPZhcAgGYJdgEAaJZgFwCAZgl2AQBo1sZ/QG3ehhTzFqhfRbsh+lzVWDfdpjzP+z1WADhaMrsAADRLsAsAQLM2voxhr2+Z7ne7Ifpc97eTz37QC3c8ds/n328fR3KkTXmeW31tcfQWraU59Dq9627V8zfm+V92nVavzeW0MH8yuwAANGvjM7uMz7yM7vQ5Q2d4AYBhyewCANAsmV1GY6ds7dkPeuHMx+a1AY6tMdTtrTPztzrmdjktzJ/MLgAAzRLsMjpnP+iFR9TtznoMAECwCwBAs9TsMjqz6m/V5AIAs6xVsHv49AuHHsKujWms84xxcf9Z5QotlTB4be1dK3MHwLGjjAEAgGYJdgEAaJZgFwCAZgl2AQBolmAXAIBmCXYBAGjWWi09BsDeLLt//dbW1jEayd4sO35YFa/N1dqP310yu4zGPZ9/v0ttHjH5/bxjAMDm2vjM7rxF6Me44cKxtGiB/qHmZzqInQ5y5527X9Z17taFnzsA9ovMLgAAzRLsAgDQLMEuAADN2viaXfWBO1vV3CyqZ22B19V85geA/SKzCwBAszY+swvA8hatdbloLU3tl2vPzoZ+bja9/TqQ2QUAoFmCXQAAmiXYBQCgWWp2AVjasnV72q9/3eOkhz3sYTseO+OMM/ZxJIsN/dxsevt1ILMLAECzBLsAADRrrcoYWllovpX7WIY52Lsxzd1+bxCyqrmph1ZyWQDWwFoFuwDA+pmu0Z1Xl3s058J+EOwyWu958puOeOxmP3vHSx3b/h4A2ExqdgEAaJbMLqMzK6M7fUyGFwBIBLs0Yjq4FeQCAIkyBgAAGiazy+hMZ2vf8+Q3zS1tAAA2l8wuAADNaiKzu2hh+3kL0c9ru4oF7IcY6yraDdHnTu0mM707ZXjN3WrGuldjGutYbG1tDT2EpYx9/IscOHBg6CEcU9Nr6bbMa3O19mN+ZXYBAGiWYBcAgGY1UcawzFuX+/225xBj3e92+93nbj6cZu6OfbtljGmsAIybzC4AAM1qIrPLZrHMGACwWzK7AAA0S2aXJtkmGPbXouWDhl7eaN2t+/ydccYZg/a/jHWf23XXwvzJ7AIA0CyZXUZnO2s7q3ZXRhcAmCSzCwBAs2R2GS1ZXFgfY6jbW2ernr/Wt7ydx2tzOS3M31oFu4v2vd+rVhainzc/q7jHRc/HEH2uwibM3Sb8DADALMoYAABolmAXAIBmCXYBAGiWYBcAgGYJdgEAaJZgFwCAZgl2AQBolmAXAIBmrdWmEkPY9I0aVrXZQCsbHGzC3G3KaxKAzSSzCwBAszY+swvQghb2r5+n9ftb1n7Pz1lnnbXrcx/4wAeucCSrN/bX3tbW1qD9r3r+dnN/MrsAADRr4zO7+10fOEQ94qb0uQqbMHebcI8AbC6ZXQAAmrXxmV0AYO+ma3KPpp4X9oPMLgAAzRLsAgDQLMEuAADNUrMLwMK1Khetlbls+2UNPf5Na380a+eO7d5aa7+ssY8/kdkFAKBhgl0AAJo1mjKGdVuEft3G0wrzOi5j2gDj8OkXHsORADAWpdY69BguccHBgzsORhA0/x/rVczPouDAc7KzIeZuv18fYzNvfo4/dKjs41BW4sCBA+vzy3wP1qGuj71ZtK7u0dT3cuy1/rO1tbW18Pe3MgYAAJol2AUAoFmCXQAAmjWaD6gBAOthXp2uGl3WjcwuAADNEuwCANAswS4AAM1qomZ3mTVN97o26SraLWq7V0Osv7rf8zq28cwzpvWUN+HnBziSulzGRGYXAIBmCXYBAGiWYBcAgGY1UbO7TJ3eXtvud7tljKnPVY113cazTn36+WnD1tbWSq9/4MCBlV5/aIvmb9X3v+79r9LQ99Z6/0Mb8rW1TWYXAIBmCXYBAGiWYBcAgGYJdgEAaJZgFwCAZgl2AQBolmAXAIBmNbHOLgDDGnotUf23u5br0Pem//G/tjY+2D18+oXH/JqtLHw/xH2s4vlYZBX32cprYBE/PwCsO2UMAAA0S7ALAECzNr6MAYDlDV23p//1r5vcq6HvTf/jf23J7AIA0CzBLgAAzRLsAgDQLMEuAADNEuwCANCsJlZjWLSw/TotUj/EWOf1Oa+/Zca61z7XzSbMnZ8fAFomswsAQLMEuwAANKuJMoYxvXU5xFj32ucyYx3TczLPJszdmJ6rMY0VgPUgswsAQLMEuwAANKuJMgYA5lv3/e3XfXytM//tGvq5Xbb/ra2tpccgswsAQLMEuwAANEuwCwBAswS7AAA0S7ALAECzrMbQiMOnX7jvfVrgf2dDPB8AwJFkdgEAaFYTmd1FWTQZSID5Fq1luWitzLG3X9bY+1/l+Nd5bLsx9v7H3v5YkNkFAKBZgl0AAJol2AUAoFlN1OwCsJxl6+bG3n5ZY++bBRIbAAAFuElEQVR/leNf57FtQv9jb38syOwCANAswS4AAM1SxrDPNmWZtHn3OaZ7bOU+5tmU1yQAm0lmFwCAZgl2AQBolmAXAIBmqdndZ5tS/9jKfbZyH/Nswj0CsLmaCHb9Yw0AwCzKGAAAaJZgFwCAZgl2AQBoVhM1uwAMa2tra6XXP3DgwEqvP7Sh72/Vz988q773oed2aH42ZXYBAGiYYBcAgGYJdgEAaNbG1+yOaY3eMY11r1q5x1buY5FNuU8AxktmFwCAZgl2AQBolmAXAIBmbXzNLgCL1+Jc97U0hx7/sv0P3X6dDT03Q8/t0P0vax3GL7MLAECzBLsAADRLsAsAQLPU7AKw9nV/iww9/mX7H7r9Oht6boae26H7X9Y6jL+JYPfw6RfOPT5v4ft5bfe73RB9tjLWRcydsQKwmZQxAADQLMEuAADNaqKMYZm3Lvfadr/bDdHnmMa6qutuwtxtylgB2EwyuwAANEuwCwBAswS7AAA0S7ALAECzBLsAADRLsAsAQLMEuwAANKuJdXYBWM7W1tbc4+uwv/2QFs3Pqm36/C9j6OeO4Y0m2D18+oVDDwEAgJFRxgAAQLMEuwAANEuwCwBAs9aqZvdaz3zm0EMANlA9dGjoIQCwIjK7AAA0a60yu6v02teenCQ59dS3Xer7SdvHxtwnHCuvPumkS/5+t3e+c8CRAMDebUSw+9rXnjw34Jw8L1k+AN1NkHus+4RjZTvInQxwJwPf6WOwDjZ9neCW77/le9uNTb//Y0EZAwAAzdqIzO6pp77tiOzqbrKuy/Q369qzxgHrZl7WdvvYrOwvAKwjmV0AAJq1EZndWSYzrPtVLztEn7AKMrysm02vW2z5/lu+t93Y9Ps/FmR2AQBo1sZmdofIrMrmMkavPumkIzK306szAMC6ktkFAKBZG5vZBebbzXq7k4+r2x03dYFAqzY22PUBNdgdgSwAY6aMAQCAZm1sZnfSThtOtNYn7MasUoXpkgbbBwMwFjK7AAA0a2Myu6vcHnid+oT9IJMLwFjI7AIA0KyNyexum5dtXVXd7BB9wl7NqsuVyQVgrDYu2AXg2Bv7Or1jH/+yWr7/lu9tNzb9/pMNDnZtFwzzyeYC0AI1uwAANEuwCwBAswS7AAA0S7ALAECzBLsAADRLsAsAQLM2dukxOFbOOXSHIx67w8FzBhgJ7N3W1tbc44vW6tR+s9uv8trab3b7Y0FmFwCAZsnswh5NZ3S3s7nnHLrDJcdkeAFgWIJdOEo7BbmT32+fI+gFgGGVWuvQY7hEKWV9BgM7WBTs7vYc1kettQw9hmNg7u/P/aiLAzjWFtX8Jln4+1vNLgAAzRLsAgDQLMEuAADNEuwCANAsqzHAUZpcYmzyz+nHJx8DAIYh2IU92inonTwGAAxLGQMAAM2S2YUlyeICwPqS2QUAoFmCXQAAmiXYBQCgWaXWudup76tSyvoMBtgYtdaFe6uPgN+fwCZa+PtbZhcAgGYJdgEAaJZgFwCAZllnF6ABpQxbdvzZz372Ut9f6UpX2qj+YShnnnnmpb5/8IMfPMg4hrKbz57J7AIA0CzBLgAAzRLsAgDQLOvsAhuvhXV29/v353SN7CLHuoZ26P6B9bCb398yuwAANEuwCwBAswS7AAA0yzq7AA1Y97rjoT8fMnT/wHBkdgEAaJZgFwCAZgl2AQBolmAXAIBmCXYBAGiWYBcAgGYJdgEAaJZgFwCAZgl2AQBolmAXAIBmCXYBAGiWYBcAgGaVWuvQYwAAgJWQ2QUAoFmCXQAAmiXYBQCgWYJdAACaJdgFAKBZgl0AAJol2AUAoFmCXQAAmiXYBQCgWYJdAACaJdgFAKBZgl0AAJol2AUAoFmCXQAAmiXYBQCgWYJdAACaJdgFAKBZgl0AAJol2AUAoFmCXQAAmiXYBQCgWYJdAACaJdgFAKBZgl0AAJol2AUAoFn/H7/x+ck6JhE4AAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(10, 6))\n", "plt.subplot(121)\n", "plt.title(\"원본 관측 (160×210 RGB)\")\n", "plt.imshow(obs)\n", "plt.axis(\"off\")\n", "plt.subplot(122)\n", "plt.title(\"전처리된 관측 (88×80 그레이스케일)\")\n", "plt.imshow(img.reshape(88, 80), interpolation=\"nearest\", cmap=\"gray\")\n", "plt.axis(\"off\")\n", "save_fig(\"preprocessing_plot\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## DQN 만들기" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [], "source": [ "reset_graph()\n", "\n", "input_height = 88\n", "input_width = 80\n", "input_channels = 1\n", "conv_n_maps = [32, 64, 64]\n", "conv_kernel_sizes = [(8,8), (4,4), (3,3)]\n", "conv_strides = [4, 2, 1]\n", "conv_paddings = [\"SAME\"] * 3 \n", "conv_activation = [tf.nn.relu] * 3\n", "n_hidden_in = 64 * 11 * 10 # conv3은 11x10 크기의 64개의 맵을 가집니다\n", "n_hidden = 512\n", "hidden_activation = tf.nn.relu\n", "n_outputs = env.action_space.n # 9개의 행동이 가능합니다\n", "initializer = tf.variance_scaling_initializer()\n", "\n", "def q_network(X_state, name):\n", " prev_layer = X_state / 128.0 # 픽셀 강도를 [-1.0, 1.0] 범위로 스케일 변경합니다.\n", " with tf.variable_scope(name) as scope:\n", " for n_maps, kernel_size, strides, padding, activation in zip(\n", " conv_n_maps, conv_kernel_sizes, conv_strides,\n", " conv_paddings, conv_activation):\n", " prev_layer = tf.layers.conv2d(\n", " prev_layer, filters=n_maps, kernel_size=kernel_size,\n", " strides=strides, padding=padding, activation=activation,\n", " kernel_initializer=initializer)\n", " last_conv_layer_flat = tf.reshape(prev_layer, shape=[-1, n_hidden_in])\n", " hidden = tf.layers.dense(last_conv_layer_flat, n_hidden,\n", " activation=hidden_activation,\n", " kernel_initializer=initializer)\n", " outputs = tf.layers.dense(hidden, n_outputs,\n", " kernel_initializer=initializer)\n", " trainable_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,\n", " scope=scope.name)\n", " trainable_vars_by_name = {var.name[len(scope.name):]: var\n", " for var in trainable_vars}\n", " return outputs, trainable_vars_by_name" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [], "source": [ "X_state = tf.placeholder(tf.float32, shape=[None, input_height, input_width,\n", " input_channels])\n", "online_q_values, online_vars = q_network(X_state, name=\"q_networks/online\")\n", "target_q_values, target_vars = q_network(X_state, name=\"q_networks/target\")\n", "\n", "copy_ops = [target_var.assign(online_vars[var_name])\n", " for var_name, target_var in target_vars.items()]\n", "copy_online_to_target = tf.group(*copy_ops)" ] }, { "cell_type": "code", "execution_count": 62, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "{'/conv2d/kernel:0': ,\n", " '/conv2d/bias:0': ,\n", " '/conv2d_1/kernel:0': ,\n", " '/conv2d_1/bias:0': ,\n", " '/conv2d_2/kernel:0': ,\n", " '/conv2d_2/bias:0': ,\n", " '/dense/kernel:0': ,\n", " '/dense/bias:0': ,\n", " '/dense_1/kernel:0': ,\n", " '/dense_1/bias:0': }" ] }, "execution_count": 62, "metadata": {}, "output_type": "execute_result" } ], "source": [ "online_vars" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [], "source": [ "learning_rate = 0.001\n", "momentum = 0.95\n", "\n", "with tf.variable_scope(\"train\"):\n", " X_action = tf.placeholder(tf.int32, shape=[None])\n", " y = tf.placeholder(tf.float32, shape=[None, 1])\n", " q_value = tf.reduce_sum(online_q_values * tf.one_hot(X_action, n_outputs),\n", " axis=1, keepdims=True)\n", " error = tf.abs(y - q_value)\n", " clipped_error = tf.clip_by_value(error, 0.0, 1.0)\n", " linear_error = 2 * (error - clipped_error)\n", " loss = tf.reduce_mean(tf.square(clipped_error) + linear_error)\n", "\n", " global_step = tf.Variable(0, trainable=False, name='global_step')\n", " optimizer = tf.train.MomentumOptimizer(learning_rate, momentum, use_nesterov=True)\n", " training_op = optimizer.minimize(loss, global_step=global_step)\n", "\n", "init = tf.global_variables_initializer()\n", "saver = tf.train.Saver()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "노트: 처음 책을 쓸 때는 타깃 Q-가치(y)와 예측 Q-가치(q_value) 사이의 제곱 오차를 사용했습니다. 하지만 매우 잡음이 많은 경험 때문에 작은 오차(1.0 이하)에 대해서만 손실에 이차식을 사용하고, 큰 오차에 대해서는 위의 계산식처럼 선형적인 손실(절대 오차의 두 배)을 사용하는 것이 더 낫습니다. 이렇게 하면 큰 오차가 모델 파라미터를 너무 많이 변경하지 못합니다. 또 몇 가지 하이퍼파라미터를 조정했습니다(작은 학습률을 사용하고 논문에 따르면 적응적 경사 하강법 알고리즘이 이따금 나쁜 성능을 낼 수 있으므로 Adam 최적화대신 네스테로프 가속 경사를 사용합니다). 아래에서 몇 가지 다른 하이퍼파라미터도 수정했습니다(재생 메모리 크기 확대, e-그리디 정책을 위한 감쇠 단계 증가, 할인 계수 증가, 온라인 DQN에서 타깃 DQN으로 복사 빈도 축소 등입니다)." ] }, { "cell_type": "code", "execution_count": 64, "metadata": {}, "outputs": [], "source": [ "from collections import deque\n", "\n", "replay_memory_size = 500000\n", "replay_memory = deque([], maxlen=replay_memory_size)\n", "\n", "def sample_memories(batch_size):\n", " indices = np.random.permutation(len(replay_memory))[:batch_size]\n", " cols = [[], [], [], [], []] # 상태, 행동, 보상, 다음 상태, 계속\n", " for idx in indices:\n", " memory = replay_memory[idx]\n", " for col, value in zip(cols, memory):\n", " col.append(value)\n", " cols = [np.array(col) for col in cols]\n", " return cols[0], cols[1], cols[2].reshape(-1, 1), cols[3], cols[4].reshape(-1, 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### ReplayMemory 클래스를 사용한 방법 ==================" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "랜덤 억세스(random access)가 훨씬 빠르기 때문에 deque 대신에 ReplayMemory 클래스를 사용합니다(기여해 준 @NileshPS 님 감사합니다). 또 기본적으로 중복을 허용하여 샘플하면 큰 재생 메모리에서 중복을 허용하지 않고 샘플링하는 것보다 훨씬 빠릅니다." ] }, { "cell_type": "code", "execution_count": 65, "metadata": {}, "outputs": [], "source": [ "class ReplayMemory:\n", " def __init__(self, maxlen):\n", " self.maxlen = maxlen\n", " self.buf = np.empty(shape=maxlen, dtype=np.object)\n", " self.index = 0\n", " self.length = 0\n", " \n", " def append(self, data):\n", " self.buf[self.index] = data\n", " self.length = min(self.length + 1, self.maxlen)\n", " self.index = (self.index + 1) % self.maxlen\n", " \n", " def sample(self, batch_size, with_replacement=True):\n", " if with_replacement:\n", " indices = np.random.randint(self.length, size=batch_size) # 더 빠름\n", " else:\n", " indices = np.random.permutation(self.length)[:batch_size]\n", " return self.buf[indices]" ] }, { "cell_type": "code", "execution_count": 66, "metadata": {}, "outputs": [], "source": [ "replay_memory_size = 500000\n", "replay_memory = ReplayMemory(replay_memory_size)" ] }, { "cell_type": "code", "execution_count": 67, "metadata": {}, "outputs": [], "source": [ "def sample_memories(batch_size):\n", " cols = [[], [], [], [], []] # 상태, 행동, 보상, 다음 상태, 계속\n", " for memory in replay_memory.sample(batch_size):\n", " for col, value in zip(cols, memory):\n", " col.append(value)\n", " cols = [np.array(col) for col in cols]\n", " return cols[0], cols[1], cols[2].reshape(-1, 1), cols[3], cols[4].reshape(-1, 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### =============================================" ] }, { "cell_type": "code", "execution_count": 68, "metadata": {}, "outputs": [], "source": [ "eps_min = 0.1\n", "eps_max = 1.0\n", "eps_decay_steps = 2000000\n", "\n", "def epsilon_greedy(q_values, step):\n", " epsilon = max(eps_min, eps_max - (eps_max-eps_min) * step/eps_decay_steps)\n", " if np.random.rand() < epsilon:\n", " return np.random.randint(n_outputs) # 랜덤 행동\n", " else:\n", " return np.argmax(q_values) # 최적 행동" ] }, { "cell_type": "code", "execution_count": 69, "metadata": {}, "outputs": [], "source": [ "n_steps = 4000000 # 전체 훈련 스텝 횟수\n", "training_start = 10000 # 10,000번 게임을 반복한 후에 훈련을 시작합니다\n", "training_interval = 4 # 4번 게임을 반복하고 훈련 스텝을 실행합니다\n", "save_steps = 1000 # 1,000번 훈련 스텝마다 모델을 저장합니다\n", "copy_steps = 10000 # 10,000번 훈련 스텝마다 온라인 DQN을 타깃 DQN으로 복사합니다\n", "discount_rate = 0.99\n", "skip_start = 90 # 게임의 시작 부분은 스킵합니다 (시간 낭비이므로).\n", "batch_size = 50\n", "iteration = 0 # 게임 반복횟수\n", "checkpoint_path = \"./my_dqn.ckpt\"\n", "done = True # 환경을 리셋해야 합니다" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "학습 과정을 트래킹하기 위해 몇 개의 변수가 필요합니다:" ] }, { "cell_type": "code", "execution_count": 70, "metadata": {}, "outputs": [], "source": [ "loss_val = np.infty\n", "game_length = 0\n", "total_max_q = 0\n", "mean_max_q = 0.0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이제 훈련 반복 루프입니다!" ] }, { "cell_type": "code", "execution_count": 71, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from ./my_dqn.ckpt\n", "반복 13992\t훈련 스텝 3999999/4000000 (100.0)%\t손실 0.534047\t평균 최대-Q 221.066981 " ] } ], "source": [ "with tf.Session() as sess:\n", " if os.path.isfile(checkpoint_path + \".index\"):\n", " saver.restore(sess, checkpoint_path)\n", " else:\n", " init.run()\n", " copy_online_to_target.run()\n", " while True:\n", " step = global_step.eval()\n", " if step >= n_steps:\n", " break\n", " iteration += 1\n", " print(\"\\r반복 {}\\t훈련 스텝 {}/{} ({:.1f})%\\t손실 {:5f}\\t평균 최대-Q {:5f} \".format(\n", " iteration, step, n_steps, step * 100 / n_steps, loss_val, mean_max_q), end=\"\")\n", " if done: # 게임이 종료되면 다시 시작합니다\n", " obs = env.reset()\n", " for skip in range(skip_start): # 게임 시작 부분은 스킵합니다\n", " obs, reward, done, info = env.step(0)\n", " state = preprocess_observation(obs)\n", "\n", " # 온라인 DQN이 해야할 행동을 평가합니다\n", " q_values = online_q_values.eval(feed_dict={X_state: [state]})\n", " action = epsilon_greedy(q_values, step)\n", "\n", " # 온라인 DQN으로 게임을 플레이합니다.\n", " obs, reward, done, info = env.step(action)\n", " next_state = preprocess_observation(obs)\n", "\n", " # 재생 메모리에 기록합니다\n", " replay_memory.append((state, action, reward, next_state, 1.0 - done))\n", " state = next_state\n", "\n", " # 트래킹을 위해 통계값을 계산합니다 (책에는 없습니다)\n", " total_max_q += q_values.max()\n", " game_length += 1\n", " if done:\n", " mean_max_q = total_max_q / game_length\n", " total_max_q = 0.0\n", " game_length = 0\n", "\n", " if iteration < training_start or iteration % training_interval != 0:\n", " continue # 워밍엄 시간이 지난 후에 일정 간격으로 훈련합니다\n", " \n", " # 메모리에서 샘플링하여 타깃 Q-가치를 얻기 위해 타깃 DQN을 사용합니다\n", " X_state_val, X_action_val, rewards, X_next_state_val, continues = (\n", " sample_memories(batch_size))\n", " next_q_values = target_q_values.eval(\n", " feed_dict={X_state: X_next_state_val})\n", " max_next_q_values = np.max(next_q_values, axis=1, keepdims=True)\n", " y_val = rewards + continues * discount_rate * max_next_q_values\n", "\n", " # 온라인 DQN을 훈련시킵니다\n", " _, loss_val = sess.run([training_op, loss], feed_dict={\n", " X_state: X_state_val, X_action: X_action_val, y: y_val})\n", "\n", " # 온라인 DQN을 타깃 DQN으로 일정 간격마다 복사합니다\n", " if step % copy_steps == 0:\n", " copy_online_to_target.run()\n", "\n", " # 일정 간격으로 저장합니다\n", " if step % save_steps == 0:\n", " saver.save(sess, checkpoint_path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "아래 셀에서 에이전트를 테스트하기 위해 언제든지 위의 셀을 중지할 수 있습니다. 그런다음 다시 위의 셀을 실행하면 마지막으로 저장된 파라미터를 로드하여 훈련을 다시 시작할 것입니다." ] }, { "cell_type": "code", "execution_count": 72, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Restoring parameters from ./my_dqn.ckpt\n" ] } ], "source": [ "frames = []\n", "n_max_steps = 10000\n", "\n", "with tf.Session() as sess:\n", " saver.restore(sess, checkpoint_path)\n", "\n", " obs = env.reset()\n", " for step in range(n_max_steps):\n", " state = preprocess_observation(obs)\n", "\n", " # 온라인 DQN이 해야할 행동을 평가합니다\n", " q_values = online_q_values.eval(feed_dict={X_state: [state]})\n", " action = np.argmax(q_values)\n", "\n", " # 온라인 DQN이 게임을 플레이합니다\n", " obs, reward, done, info = env.step(action)\n", "\n", " img = env.render(mode=\"rgb_array\")\n", " frames.append(img)\n", "\n", " if done:\n", " break" ] }, { "cell_type": "code", "execution_count": 73, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 73, "metadata": {}, "output_type": "execute_result" } ], "source": [ "video = plot_animation(frames, figsize=(5,6))\n", "HTML(video.to_html5_video()) # HTML5 동영상으로 만들어 줍니다" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 추가 자료" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 브레이크아웃(Breakout)을 위한 전처리" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "다음은 Breakout-v0 아타리 게임을 위한 DQN을 훈련시키기 위해 사용할 수 있는 전처리 함수입니다:" ] }, { "cell_type": "code", "execution_count": 74, "metadata": {}, "outputs": [], "source": [ "def preprocess_observation(obs):\n", " img = obs[34:194:2, ::2] # 자르고 크기를 줄입니다.\n", " return np.mean(img, axis=2).reshape(80, 80) / 255.0" ] }, { "cell_type": "code", "execution_count": 75, "metadata": {}, "outputs": [], "source": [ "env = gym.make(\"Breakout-v0\")\n", "obs = env.reset()\n", "for step in range(10):\n", " obs, _, _, _ = env.step(1)\n", "\n", "img = preprocess_observation(obs)" ] }, { "cell_type": "code", "execution_count": 76, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlkAAAF2CAYAAABd6o05AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAGgtJREFUeJzt3Xu4blVdL/Dvz6AECSoktVAJUVHQgkLJMLFjgJdS00qPPdVJO6e8xo6y7CLZzRBRS+3pZGZJHqOLqYhgmpapeGF7CW8nL2h4RMALXlGE3/ljziUvi3ftvcE99lpr78/neebDeuccc8wx3wW83+c3xnxXdXcAANi5brLeAwAA2B0JWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADCFkAAAMIWbtYVe1VVc+oqvdX1Xur6l+r6uj52C9U1QsW2j6qqs5e0sepVfXMNfo/s6ouWtguq6q/mY+dX1XHLznnvFXn7L8D93GTqvr7qjpqx++e1arqx6rqj9Z7HADsfELWrveLSQ5PcpfuvlOSM5L87Tba36eqLlnckpyyVuPu/qnuPmRlS/JbST67rG1VHV1V70ty2yRXJvnifOjSqrrrdu7j8Uk+2N1vn/vat6pOq6quqmNXXecbquqPqurdVfXWqnrNwrGHzvvfVVX/VlV32M51l93HLarqz+bQ+paqen1V3WXh+LdV1Qvnsd1y1bmPqar3VdWFVXV2Vd1iSf+nVtUn5pD6hnm7z8LxvavqiVX1jqp6Z1W9sarOqap7zsdfUFUfnc9/S1W9rarunyTd/Y9J7lRVP3xD7xuAjU3I2vUq0/u+8t7vtZ3253b3LRe3JKdv8wJV+61Ux5IclOTSZe26e2t3H97dhyc5Oslzknw+yY9097u20f8+SZ64ahy/muS9ST6y5JS/SHJxkiO7+5gkPzH3c1iSP0vywO6+a5LnJ/mHVdc6o6p+YuH1t1XVq6vqoIVmRyc5r7vv1N13S/JPSZ6+cPwpSc5ach/Hz/dxz+4+Msnb5rEu85LuPra7fyDJ45K8pKpuPh97UZKjkvxwd393d98jyU8nuWrh/OfO598tyclJ/mrh2B8mOW2N6wKwSW3vA56d70+T3D7Je6qqk1yS5OHbaH9iVV28at/+mQLJWg5Pcub8z8OSvG7h2G9U1Ynd/etzpeXIJMdnqmb9e5Irkjy+qr47yavWCFs/muRd3X3Zyo7uPjVJqurJiw3nytStknxXkjdX1aczVdc+meQhmcLRB+bmL0xyelXdpbv/Y973+0leWVX7JjknySuTPH3VtV+5anwfz8K/29392Hksq+/jJ5OcudDXs5JcVlUHdPcVS+57pb+tVfXZJIfMQe2OSb63u69aaHN5ksvX6OK2Sd650PZNVXXzqjqiu9+91nUB2FyErF2oqu6c5A5JXpPktUm+KcnNkty7qh6SVVWg7n5ekufN556S5PDuftQNvOzvJfl0VR0wvz43yRvmnw/LVG35g+7+2of+XGG6V6aq2zLfm4WQsB33THK3JE/r7i1VdVKSc6rqdkkOTfLBlYbdfXVVfWTe/x/zvk/OU2nnJDk1yZZ5im2pebrvKUkeuQNjOzQLFa7u/nRVXZHkkG3dX1U9KMlXklyY5GeTvGwlYFXV45I8Isk3ZwqQW+bTHj2fd1Cm/+4evKrbd2Z6X4UsgN2EkLVrHZLkuPnn70zyg0mekanicUGmAHajVdXvZKoO3TTJravqokxB6ctJXjw3u6C7z6+q8+YxrJy7Vp8v7O7VC7P3S/KJHRzWtyf59+5+dZJ097nzuI6bx3b1qvZfzfWnsb81U/Xu45mC4VpjPTBzGOvuf92Bse3o9ZPkwVX1PZmqUOcnOba7r5zft31WGnX3nyT5k6r6tUyVxBXP7e6nzuP8niSvqKoTFipXn09yQADYbQhZu1B3n5MpBKysB7p9d59eVW/NVOH45iQvr6qDM32Q17xds9LHkqnDR3b3efPPT80U2vZP8rkkV3T34rknLJz34OzYmryvLNn3kUxTnjvi0lx/4f01mcLNxZmC56Jbz/uTJFV1xyT/mOQXkrwlyd9V1b4r05ML7W6VaSrxtO5+0Q6O7eIkt1noY98kBy5ef8FLuvsX5gXrz84UipLp97Slqm6y+F5vS3e/o6relOTEXFu5unWSD+/guAHYBCx8XwdV9bdJ7rzyuruPmZ8E/I359cXdfXCmBdHndPfB29jOW+jnS939mSR/neQHtvWh391fTPLDSd63xvbP3f357l4Wss5O8kM7eLtnJzlh5anBqjouU7B5U6YF4/efQ2Wq6sczPeV4wcL5D0zyP7r79d395UyVugPnqlXm826baQr2d29AwEqmNWCPWJhKfUySNyyu91qtu1+RaSrzqfOuszIFxudU1X4LTQ9eq495SvPuSd48vz4gyRGZppAB2E2oZK2PO2b66oa/XLX/8iQfHXjdP0jy/pUX3f2SJC9Z3aiqDs+0dmup7n53VX2wqu67ZNH56rafqKpHJPmHqroq09Tlg+aF5VfMa5hePh/7XJL7LobD7j5tVX9XZXq6b9HTk9wiya9U1a/M+77c3ffaztheW1XPSfKv8/X/X5KHbeuc2ROSXFhVZ3X366vq3kmelOT8qro6U1C8NNd9wnFlTdbK06W/2d0ra+N+PslfdfcXduDaAGwS1d3rPYY9TlW9I9MTd1ctOfzu7j5xbvfYJH+U5NNL2n2gu49fo//XZXpq8Molh5/c3Wt9TcHK+Ydn+uqIQ7bR5k5J/jzJvRefquOGmStyr0xyYncv+z0DsEkJWdxoVXWvJJ/s7gvXeyybVVX9YJLLu/s96z0WAHYuIQsAYAAL3wEABhCyAAAG2FBPF85/ZgbYTXX3Wn9FAGC3o5IFADCAkAUAMICQBQAwgJAFADCAkAUAMICQBQAwwIb6Cofd3RlnnLHT+tqyZcumv8Za/e8qu8t9rGWt+9uo4wXY3ahkAQAMIGQBAAxgunAD2B2m/nbVNXYF02kA7AwqWQAAAwhZAAADmC6EVW7o9KbpRQCWUckCABhAyAIAGMB0IXusGzrNt9mekgRgfalkAQAMIGQBAAxgunAD2BXTULvLNXamzTZeADYXlSwAgAGELIDdVFXVeo8B9mjdvWG2JG2z2Xbfbb3/H7O7b0kuSnJxko/O//xckr3nYy9O8rMLbU9Jcsm8vXVVPxcnOWz++Y/n14vb7eZjhye5aOG8Q5JcPY9jcbtLksOSXLxkzN+3pP1VSW6f5KQkr1tyzn1Xtf/tJLUD788PJPmb9f49Dfrd/7dVv6Pj5v2vS3KfJE9NcupC+4fN7T6R5JcX9t90W/+tJvmXJPffwTFVkt9LcmqSn01y5sKxE5NsTXJZktcnOWbh2PlJjp9/fmaSB6z3+3tjN2uyAHYT3X3Iys9V9ZtJvre7r1qj7elJTl9oXz1/qq1q9/gkj19od3GSvbcxjMsWx7Fw3mFrjONtmcLZYttLknx2Wfuqen6SeyS5ct51VZJfzxT4/vtag6qqmyX5iyT3nF/fPcnTMoWKvZN8KMnju/tjVfWNmcLlvTIFhfOSbOnuq9fqf41r3i9TwOj5Olvna3yuqg5I8udJjkjyDUle2N2/v6SP/ZP83yXd3yTJvt29X5J092uSHDyf8+/z9dYa152TPD3JcZne59dV1Xu7+5wduK3bJ/n8Gv2+LtcG7W/NFOyWtbttkhcmuV93v62q7p/kZVV1WHd/YVXzJyW5oKre3N2X7cD4NhTThQC7maq6XZJfSnJqVT1lDi0PXDi+paounreLquozSZ6yXuOdx3Tnqrp5Vd0kyQFJPrmsXXf/XHcf3t2HJ3l4kg8k+fskP7+dSzwuyTkLH9QvSvKn3X23JEdnquqcNh97UqbAckSSI5PcOcnJC2O9WVX9S1V9x8K+E6rqLxZef+M8rid0993naxyQ5NfmJs9K8snuPiJTNe9hVfXQJff72e6+5eotyfdnqlTeGD+X5Pnd/eHu/mSSZ8z7tqmqHjLfw28vm4ru7uO7+5Duvl2Sf8pUpVrm6CTvmgN2uvsVme7lDkv6/GKSv870O9l0NlQly9NeAF+fqrpjkpcm+UKSe3T3b2f6UHzxSpvuPiPJGXP7SvLOJK/cSUM4qKouWrXvlCTv2M55pyU5M8mbk1za3V+dP8fvVFVnZgol/5WpEvU9Se6dKRh9KslRSX6/qt6W5Kzu/sqS/n8m1w0SH5vHWkm+Mcm3zPuS5CczhaNrklxTVc9J8luZK3/d/YWqOj3Jq6rqAUnummla7H4L/V+daSru2+fX+yT55iQfm4PkQzOFjXT35+eA9vBMwWxH3Dxrh5gk2buqjk2y/5JjhyU5a+H1ezKF0DVV1aGZKn8Py/RePivXrXDeJMkPZqr87ZdpavZZSW63pLu3JTlyria+JcmPJrlZkvevcfm/S/KGqvrl+XeyaWyokAXAjVNV+yR5dKZA84Qkr0py9vxBe8o2Tj05ySXd/caqen+mD7tbLvT7giQPynWniNaagrwo09RXquryJN8371tzunCJS5OcNE+n7Z1pzdgLknwk0xTYHZK8Kclp3X3Fwr0fl2kd2fUCVlXtm2k68Z0Lu38sySuSPDHJvknOzjTtmCSHJvngQtsPzvsW7/WcqvpSktdkem9O6O5LFo5fXVX3SXJOVT0ryYFJnt3dz62qW83X/NCqa/zMjrxBswOTfHzhHk/KVD1aGcNemYLfd61xfq3x83UbVe2daRr2D5L8znzfr03y8qo6O8njuvvDmWbGfjLJV5J8Z5JbZwr735JpXdXXdPd/VdXDMwX92yR5X6Z1Xl9cNobu/sD8Ozw403rDTUPIAtg9PCxTRef7F4LN8ZkqFNdbazUff2imacW9qurI7r7jvP/iVU1P7e5nXq+D7du7qg5M8h2Zqhurr3+LTCElmT5svy9TRevqTBWqP0/y6e5+dVU9MMnKmqWHz+cvu6fHZHoPFqfS9ktyTa5dx5VM04XnJfmdTJ+Fz820TumXMoWOxfVXX83y5TWHJPl0psB0UK4NOCvh5KVJntbdz5tD4/+pql/KVEXqeUzbu8Zabp1p0f+i87v7+IXXL5/XSa32gUyhc8Wdct1QueiBmcLfg7r7rUnS3V+qqhOT/GqSf6qq4+b3+xfn3/frkzysu19WVb+3rNPufm2matcyf5OparnoC5mmKjcVIQtgN9Ddf5nkL6vqwKo6tLs/1N1fzbTe5jqBZK78PCnJgzN90B2d5Lyq+p/z+phtmgPEoUm+aWHfUUlePr+8JlO169xMYemSTB+cq12aqQK1d6bptE9095cX+jxpoe0rcm0g26buXr0w+7IkX84U9i6uqptneuLuR+fF7FdX1WlJ3popZF2cKfRdNJ9/m3nf11TVY5P8eKZpy0OTnFVVP9XdF8xNvjvJgd39vHlMV8wVracleU6mULVYmbneNbbjiFy3EnYdVfVtWXvx+/OT/PP8EMFnM1Uzn7isYXf/fZZMYc7v2x/O28o1j8lUdXxmd79sezcwT7k+aMmh70zyH5mDX1XdNFPl7iPb63OjsfAdYPfyI7l2Aff1zGuQXpYp1Hx/d/9Xd7800xqhw5eccnWS35wXyH9onlI8P8mTsxCyuvvt3X1wpnU57+juW3X37br7mO7+kUzrcK6jJ5/JFBj+cTFgLWn71Tk8vTLT9NKy7b5LAlbmpyZfkelrDpKp+nT5fM8rfjzXrgl6YaaqTM1rjR6dhaAxT10dMV/vc939zkzTjw9c6O+jSfavqnvN53xDkockef/8xOeLkzxmPrZPkkdlx9djZb7+exdeX5PkblX1sar6cKapw+OWndjd70nyK0lenemJx7/ewScL1zQ/sfjsJI/q7v+9I+d09yndfdjqLckFq5r+UJI3dvfSJ043MpUsgD3DliRf6O6uqhNXLyDu7jdlWuuUVfsfmeSRyzqsqmWhbN/MC7pX+VTmxfY3wIWZKj+L47nnGmN5cab1ZGt5ZqbpwL+a10s9IMlpVfWETOvIPp7kEXPbp87tVz7s/y3XTlWuPPH2i6vG9d5M39e18vrS+Wm80+Yq4jdlCkWPnZucnOTPquqCTNOT/5CpCrR4T6fnukFw0T5Jnl1Vfzzf05MzvffXUVWPXnZyd78o05Tpmubp5NO31Wb2/u4+Mcndd6DtYv9nJjk+yWeWHF78KodHZ9W6rs1CyALY/dxvybqqZPpQ/dVd8ITWLde4fqrqeWtUJO66xjlXZKpCfV26+w1V9e6q+onuPqu735LpA35Z2y8l+V874ZrnZpoyXXbsU5mqZ9s6/5Rs+6GFodaaKtzJbpol6/UyBba3zpXAr3b3SwaPY4jq63/33Lp5xjOesXEGA+x0J598sj/zwrqZv7vqp1fWSbHxVdVPJXnpqgcZNg2VLAD2CPPXOwhYm0h3n7neY/h6WPgOADCAkAUAMIDpQoCvzw1eS7rsSzSB9XEj16bv0H/EKlkAAAMIWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADbOo/EL1ly5b1HgLs0c4444z1HgLAhqWSBQAwgJAFADCAkAUAMICQBQAwgJAFADCAkAUAMICQBQAwgJAFADCAkAUAMICQBQAwgJAFADCAkAUAMICQBQAwgJAFADCAkAUAMICQBQAwgJAFADCAkAUAMICQBQAwgJAFADCAkAUAMICQBQAwgJAFADCAkAUAMMBe6z0AgD3N1q1b13sIwC6gkgUAMICQBQAwgJAFADCAkAUAMICQBQAwgJAFADCAkAUAMICQBQAwgJAFADCAb3wH2MX222+/9R4CsAuoZAEADCBkAQAMIGQBAAywqddknX/SSes9BNijvXG9BwCwgalkAQAMIGQBAAwgZAEADCBkAQAMsKkXvgNsRldeeeV6DwHYBVSyAAAGELIAAAYQsgAABtjUa7KuOeyz6z0EAIClVLIAAAYQsgAABhCyAAAGELIAAAYQsgAABtjUTxcCbEb77LPPeg8B2AVUsgAABhCyAAAG2NTThZ/a/4vrPQQAgKVUsgAABhCyAAAGELIAAAYQsgAABhCyAAAG2NRPFwJsRv/5n/+53kMAZocddtiwvlWyAAAGELIAAAYQsgAABtjUa7I+dfhX1nsIsGe7fL0HALBxqWQBAAwgZAEADCBkAQAMIGQBAAwgZAEADLCpny4E2IwOOuig9R4CsAuoZAEADCBkAQAMsKmnC190zW3WewiwRzthvQcAsIGpZAEADCBkAQAMIGQBAAwgZAEADCBkAQAMsKmfLgTYjI455pj1HgIw6+5hfatkAQAMIGQBAAywqacLv/LiU9d7CLBnO+GN6z0CgA1LJQsAYAAhCwBgACELAGAAIQsAYAAhCwBgACELAGAAIQsAYAAhCwBgACELAGCATf2N7/9y7rHrPQTYoz3ghDPWewgAG5ZKFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwAB7rfcAAIBdb+vWrV/7+eijj17Hkey+VLIAAAYQsgAABhCyAAAGELIAAAYQsgAABhCyAAAG8BUOALAH8rUN46lkAQAMIGQBAAwgZAEADCBkAQAMIGQBAAzg6UK4Ec4/6aSv/Xzsueeu40gA2KhUsgAABhCyAAAGELIAAAYQsgAABhCyAAAG8HQh3AieKARge1SyAAAGELIAAAYQsgAABhCyAAAGELIAAAYQsgAABhCyAAAGELIAAAYQsgAABhCyAAAGELIAAAYQsgAABhCyAAAGELIAAAYQsgAABhCyAAAGELIAAAYQsgAABhCyAAAGELIAAAYQsgAABhCyAAAGELIAAAYQsgAABhCyAAAGELIAAAYQsgAABhCyAAAGELIAAAYQsgAABhCyAAAGELIAAAYQsgAABhCyAAAGELIAAAYQsgAABhCyAAAGELIAAAYQsgAABhCyAAAGELIAAAbYa70HALCZXXjhhes9BPZQW7duHX6No48+evg11tvb3/72G3zOUUcdtUPtVLIAAAYQsgAABhCyAAAGELIAAAYQsgAABvB0IcDX4cgjj6wbek53jxgK7HT+Xf36qGQBAAwgZAEADLChpgvP/pbPr/cQ2AOcf9JJw69x7LnnDr/GRnCPV73qhp1w8sljBgKwAalkAQAMIGQBAAwgZAEADCBkAQAMIGQBAAwgZAEADLChvsIBdoU95esVAFhfKlkAAAMIWQAAA5guBG60Gzr16k/NAnuS2kh/YbuqNs5ggJ2uu2u9xwCwq5guBAAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGKC6e73HAACw21HJAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAY4P8DoCjPChP499IAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(10, 6))\n", "plt.subplot(121)\n", "plt.title(\"원본 관측 (160×210 RGB)\")\n", "plt.imshow(obs)\n", "plt.axis(\"off\")\n", "plt.subplot(122)\n", "plt.title(\"전처리된 관측 (80×80 그레이스케일)\")\n", "plt.imshow(img, interpolation=\"nearest\", cmap=\"gray\")\n", "plt.axis(\"off\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "여기서 볼 수 있듯이 하나의 이미지는 볼의 방향과 속도에 대한 정보가 없습니다. 이 정보들은 이 게임에 아주 중요합니다. 이런 이유로 실제로 몇 개의 연속된 관측을 연결하여 환경의 상태를 표현하는 것이 좋습니다. 한 가지 방법은 관측당 하나의 채널을 할당하여 멀티 채널 이미지를 만드는 것입니다. 다른 방법은 `np.max()` 함수를 사용해 최근의 관측을 모두 싱글 채널 이미지로 합치는 것입니다. 여기에서는 이전 이미지를 흐리게하여 DQN이 현재와 이전을 구분할 수 있도록 했습니다." ] }, { "cell_type": "code", "execution_count": 77, "metadata": {}, "outputs": [], "source": [ "from collections import deque\n", "\n", "def combine_observations_multichannel(preprocessed_observations):\n", " return np.array(preprocessed_observations).transpose([1, 2, 0])\n", "\n", "def combine_observations_singlechannel(preprocessed_observations, dim_factor=0.5):\n", " dimmed_observations = [obs * dim_factor**index\n", " for index, obs in enumerate(reversed(preprocessed_observations))]\n", " return np.max(np.array(dimmed_observations), axis=0)\n", "\n", "n_observations_per_state = 3\n", "preprocessed_observations = deque([], maxlen=n_observations_per_state)\n", "\n", "obs = env.reset()\n", "for step in range(10):\n", " obs, _, _, _ = env.step(1)\n", " preprocessed_observations.append(preprocess_observation(obs))" ] }, { "cell_type": "code", "execution_count": 78, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlUAAAEuCAYAAACnC+ctAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADAZJREFUeJzt3W+IbPddx/HPN73kNrm3WGsSK1iM9QoF6YNsjERQE2sQ/9DaWqsFEUNpjKC2zVJTBZUSUqw0btJSChdbvQ+K+KAPamuhoGjFWAJtNhUitElTqkSNpASr/XNteu/PB3MSpuPu3rub7+7s7H29YGD2zJnf+U0gh/f8ZubcGmMEAIDn5rJlTwAA4CgQVQAADUQVAEADUQUA0EBUAQA0EFUAAA1EFQBAA1F1xFXVyaoaVfX5uduYe/yDVfXG6f4vVNUXF27vnNv3iaq6dotjzD/vbdO2M1V1a1XdUlX3z+07quqF28z1L6rq1h1ey51V9YGdnlNV9y/M/+VVdfd0u6WqHrjwfzXgMKqqT1bVaxa2LZ4D/mA6Vy3e/qeqNrYZ1zmMFseWPQEOxhjj1DP3q+qb2+x2MskDY4zX73LsDyX50DR27ea5VfVYkhPTn9+W5OM77H5NkscvMOR3J3n1GOMzc8f4pd3MCTi0rkryHzvtMMa4K8ldi9ur6q4kT2/zHOcwWlipukQ98+4tyWuf4zgvnVsBeyTJV6rqeRf7/DHG940xXjzGeHGSv9zhOJXkZ5P8WFUdX3j43dPr+a49vQjg0KuqlyX5/iQv3+MQL8kWQeMcRicrVZeo6QSQqvrgcxznC0lOTWOtJblnjHHuQm/2ppPWdyT5ziRrSTYvcKi3Jfm7JJ9J8r6qum2McX567M1jjDPTuHt8JcBhNQXJO5L8WZLfrapPjzEeWtjnmiR/usMwP5TkB6rqtUnuHmM8kDiH0UtUHX3nkzxWVZ+f2/bFHfZ/ZVXNP37PGOO9F3msO5Ocnvt7I8m5JJ+b2/YvmZ1Uzif5WmZL+Y8m2fJ7AlV1WWYno19J8sNjjC9X1Q1J/n6HJfGPVNU3knxzjPGyadubktye5LGLfC3AIVBVx5L8cZLvTfIjSW5M8rGqunOMMf+m8EtJbr3IYf97m+3OYTwnouqIG2N8LdO7sG28L8l/zv390d1+pypJqup1SV6X5J/mNq9nttz+9rn5XLvDGO9N8u8Lm08l+akkN40xvjyNcVtV/fLCvOe9av77CJP3JPlEkrsv8FKAw+XOJN+T5ObpfPa3VXVTkjdOwZIkmVZ9vjR96fv3thnrw2OMt271gHMYHUTVJaKq3pTZO51F12R24vh/736q6sokL8jspPBfO4z9miT3ZPYO8s+r6qkLzOW3k7x5i4e+PclvJPnCMxvGGI8kuWl63quT/Fpm71jPZfZ9sHPbzP14khdldjIGVtcfJTk/xhhTiPxikmszWyU6neTrmVt9nz5GO7M4SM1+5XzzVgdwDqOLqLpEjDHek9k7nW9RVWfm/vxKkhunj//OJ/lqkqeSfDZbnKSm5/9WktuS/OQY43NV9ROZvas7sdX+01zeleRdW4y17fe7qurnktyX2fL+p5I8L8krkrw/yTvndn0qs1/f/G+Sf0tyf5JvbDcucLiNMc4lSVX9TpKfT/KWzFaTrkzyqiR/ON32xDmMTqLqElFVb0/ym5l972DRXyXf+rPibcbYavOfJHn/GOPr0xj/muQNC7G2OM7d01y2Wv368DZPe0WSM2OMT8zvOy3Z/2iSB6fjr21zPGC1/XSSd4wxPjn9/dUkH6iqn0ny40keSZKqekuS38/Wl1746BbbnMNo45IKl5ZjSZ6/xe3GvQ44xjj7zMmoyQ9us/1vkvxqVd1cVVdW1QumpfRbMvtFDXC0/XWSt1bVdVX1/Kp6YVW9PrMg+YeFfS/L7Hy3eLt5cVDnMDrVGOPCe8EhUFWvTPLrSV6a2fcQPpvk3jHGPy51YsC+my6r8IbMvlP1kiRnk/xzknePMT69zLldLOewo09UAQA08PEfAEADUQUA0ODAf/1377337vrzxvX19f2YCrAHGxsbu37OHXfccZT+7Y1dn8P80yNweOzxa08X9T+xlSoAgAaiCgCggagCAGggqgAAGogqAIAGogoAoIGoAgBoIKoAABqIKgCABqIKAKCBqAIAaCCqAAAaiCoAgAaiCgCggagCAGggqgAAGogqAIAGogoAoIGoAgBoIKoAABqIKgCABqIKAKCBqAIAaCCqAAAaiCoAgAaiCgCggagCAGggqgAAGogqAIAGogoAoIGoAgBoIKoAABqIKgCABqIKAKCBqAIAaCCqAAAaiCoAgAbHlj2Bi7G+vr7sKQDs2ebm5rKnABwAK1UAAA1EFQBAA1EFANBAVAEANBBVAAANRBUAQANRBQDQQFQBADQQVQAADVbiiurHjx9f9hQA9uzkyZPLngJwAKxUAQA0EFUAAA1EFQBAA1EFANBAVAEANBBVAAANRBUAQANRBQDQYCUu/vn0008vewoAe3b27NllTwE4AFaqAAAaiCoAgAaiCgCggagCAGggqgAAGogqAIAGogoAoIGoAgBoIKoAABqsxBXVL7/88mVPAWDPrrjiimVPATgAVqoAABqIKgCABqIKAKCBqAIAaCCqAAAaiCoAgAaiCgCggagCAGiwEhf/fPLJJ5c9BWBy1VVXLXsKK+fRRx9d9hSAyalTp/ZtbCtVAAANRBUAQANRBQDQQFQBADQQVQAADUQVAEADUQUA0EBUAQA0EFUAAA1W4orqJ06cWPYUAPbs6quvXvYUgANgpQoAoIGoAgBoIKoAABqIKgCABqIKAKCBqAIAaCCqAAAaiCoAgAYrcfHP++67b9lTACYbGxvLnsLKueGGG5Y9BWAyxti3sa1UAQA0EFUAAA1EFQBAA1EFANBAVAEANBBVAAANRBUAQANRBQDQQFQBADQQVQAADUQVAEADUQUA0EBUAQA0EFUAAA1EFQBAA1EFANBAVAEANBBVAAANRBUAQANRBQDQQFQBADQQVQAADUQVAEADUQUA0EBUAQA0EFUAAA1EFQBAA1EFANBAVAEANBBVAAANRBUAQANRBQDQQFQBADQQVQAADUQVAEADUQUA0EBUAQA0EFUAAA1EFQBAA1EFANBAVAEANBBVAAANRBUAQANRBQDQQFQBADQQVQAADUQVAEADUQUA0EBUAQA0EFUAAA1EFQBAA1EFANBAVAEANBBVAAANRBUAQANRBQDQQFQBADQQVQAADUQVAEADUQUA0EBUAQA0EFUAAA1EFQBAA1EFANBAVAEANBBVAAANRBUAQANRBQDQQFQBADQQVQAADUQVAEADUQUA0EBUAQA0EFUAAA1EFQBAg2PLngD7a33u/sbSZgGwN9dff/2z9x988MElzgQuzEoVAEADUQUA0EBUAQA0EFUAAA1EFQBAA1EFANDAJRWOornrKGy4jgKwYm6//fZn758+fXqJM4HdsVIFANBAVAEANPDx30pzvXRgdW1ubj57f21tbYkzgR5WqgAAGogqAIAGPv47inwSCKwwv/hjVVmpAgBoIKoAABr4+G+l+ZwPWF1+8cdRY6UKAKCBqAIAaCCqAAAaiCoAgAaiCgCggagCAGggqgAAGogqAIAGogoAoIGoAgBoIKoAABqIKgCABqIKAKCBqAIAaCCqAAAaiCoAgAaiCgCggagCAGggqgAAGogqAIAGogoAoIGoAgBoIKoAABqIKgCABqIKAKCBqAIAaCCqAAAaiCoAgAaiCgCggagCAGggqgAAGogqAIAGogoAoIGoAgBoIKoAABqIKgCABqIKAKCBqAIAaCCqAAAaiCoAgAaiCgCggagCAGggqgAAGogqAIAGogoAoIGoAgBoIKoAABqIKgCABqIKAKCBqAIAaCCqAAAaiCoAgAaiCgCggagCAGggqgAAGogqAIAGogoAoIGoAgBoIKoAABqIKgCABscO+oBPPPHEQR8SkiTr6+v7foyNjY19P8ayPf7448uewlI9/PDDy54Cl6jNzc19P8ba2tq+H2PZHnrooV0/57rrrruo/axUAQA0EFUAAA1EFQBAA1EFANBAVAEANKgxxrLnAACw8qxUAQA0EFUAAA1EFQBAA1EFANBAVAEANBBVAAANRBUAQANRBQDQQFQBADQQVQAADUQVAEADUQUA0EBUAQA0EFUAAA1EFQBAA1EFANBAVAEANBBVAAANRBUAQANRBQDQQFQBADQQVQAADUQVAEADUQUA0EBUAQA0+D97y2FomDBkvwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "img1 = combine_observations_multichannel(preprocessed_observations)\n", "img2 = combine_observations_singlechannel(preprocessed_observations)\n", "\n", "plt.figure(figsize=(10, 6))\n", "plt.subplot(121)\n", "plt.title(\"멀티 채널 상태\")\n", "plt.imshow(img1, interpolation=\"nearest\")\n", "plt.axis(\"off\")\n", "plt.subplot(122)\n", "plt.title(\"싱글 채널 상태\")\n", "plt.imshow(img2, interpolation=\"nearest\", cmap=\"gray\")\n", "plt.axis(\"off\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 연습문제 해답" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. to 7." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "부록 A 참조." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 8. BipedalWalker-v2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "*문제: 정책 그래디언트를 사용해 OpenAI 짐의 ‘BypedalWalker-v2’를 훈련시켜보세요*" ] }, { "cell_type": "code", "execution_count": 79, "metadata": {}, "outputs": [], "source": [ "import gym" ] }, { "cell_type": "code", "execution_count": 81, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n" ] } ], "source": [ "env = gym.make(\"BipedalWalker-v2\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "노트: 만약 `BipedalWalker-v2` 환경을 만들 때 \"`module 'Box2D._Box2D' has no attribute 'RAND_LIMIT'`\"와 같은 이슈가 발생하면 다음과 같이 해보세요:\n", "```\n", "$ pip uninstall Box2D-kengz\n", "$ pip install git+https://github.com/pybox2d/pybox2d\n", "```" ] }, { "cell_type": "code", "execution_count": 82, "metadata": {}, "outputs": [], "source": [ "obs = env.reset()" ] }, { "cell_type": "code", "execution_count": 83, "metadata": {}, "outputs": [], "source": [ "img = env.render(mode=\"rgb_array\")" ] }, { "cell_type": "code", "execution_count": 84, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXsAAAD/CAYAAAAQaHZxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAB2FJREFUeJzt3e112kgAhlGxJ02sUgZuw22ENrYNp420EcqItgz2R2IvxsYI62tm3nt/OT7AGTujR2MhpN3pdOoAaNtfWw8AgOWJPUAAsQcIIPYAAcQeIIDYAwQQe4AAYg8QQOwBAog9QIAvWw/g3DB0rt0AvNH3W49gc7upL1BU7AHOifx8xB4oisAvQ+yBIoj8ssR+Idcm7jCsOw4omcCvR+xncM+EfX6s6JNM5Ncn9hNMmbCiTyqh34bY32GJSXr+msJPqwR+e2J/xRaTU/hpjciXQ+z/KG1SCj81K217Ijj2NU1G4acGNW1TiSJi39IkFH5K09L21bKmL4TW921PxJZ/NupgDtajqZV94sS7/Jmt9pkqcTtKUG3sTcj3OX+/feY+n1FN7E3w+1jxt8G8Zy7Fxt4kn5cVf/nMeZa0O52Kul9IUYNJIP7bEHbuNPnmJWKP4K9I5PkksWdewj8vcWcmYs9yhP8+ws6CxJ51CP9b4s6KxJ71JYZf2NmY2LOtFsMv7BRI7ClDC9EXeQom9pSllugLO5URe8pWSvzFncqJPXVYM/rCToPEnjrNGX9xJ4DYU7d7oy/shBJ72nAr+iJPOLEHCDA59k3fgxaA38QeIIDYAwQQe4AAYg8QQOwBAog9QACxBwgg9gABxB4ggNgDBBB7gABiDxBA7AECiD1AALEHCCD2AAHEHiCA2AMEEHuAAGIPEEDsAQKIPUAAsQcIIPYAAcQeIMCXrQcA/DYMtx/T98uPgzbtTqfT1mM4V9RgYG5jgj6G6MfZTX4BsYd5zRX0WwQ/ithDCdYK/CXBjyH2sJatgj6G6DdP7MnxHNslw1Zy0McQ/WaJPRnei/Bnw1Z70McQ/eaIPRk+CvR7YUsI+hii3wyxp33CPZ3oV29y7H2ClqIJ/Tz8HrGyp1gCNT8r/PKM/OT05JW9yyVQJKFfxhpnNFHm/BV7CCT69ysx4PcQe4pT+0ZVk+Top80zsacoaRtgKVqJvvlzndizGhti+YahnuCbT/cRe16xAbHVKt/cW5bYV8YGwVrmWOWbr+UQ+wWY4LTifC73vblds8jYm7BwP9tN3YqMvUkFMK+iYi/yAMtwITSAAGIPEEDsAQKIPUAAsQcIIPYAAcQeIIDYAwQQe4AAYg8QQOwBAog9QACxBwgg9gABxB4ggNgDBBB7gABiDxBA7AECiD1AALEHCCD2AAHEHiCA2AMEEHuAAGIPEEDsAQKIPUAAsQcIIPYAAcQeIIDYAwQQe4AAYg8QQOwBAog9QACxBwgg9gABxB4gwJetB5DqcPhn1OOensY9DuAju9PptPUYXgxDV85gPmlsxPf9t1GPOw7fX74WfsjU991u6muI/cwOh39Gh/we59HvOuGHJHPE3mGcSlzuQM7/ghB+4Baxr9R5/C8PHYk/cEnsG3Bt1S/6wDOnXjZoifcMgLqJPUAAsQcIIPYAAbxBu4DLc+IBtuZDVYU4PDx0+37/8u/jcOyefv7ccERAKeb4UJXDOAABxL5Q+37fHR4eth4G0AixBwgg9gABxB4gQLWnXn79+vbN6V+/Yk/mAfhQtbF/9uv/sxXf3QF0nZ0AQPWxP3ce/nOXOwHxB9I4Zl+I/X7fHYfj1sMAGtXUyv7rlVZayQPpqo/9eeBFHeB91cde4AFuc8weIED1K/vWeJMWWIKVfWH2/f7lUscuhgbMRewBAoh9wRzSAebimH3B3KkKmIuVPUAAsS/E8Xh8dQ/aZ/2Vi7sB3EPsAQKIPUAAsQcIIPYFOj/lcnDtH2AGYg8QQOwBAog9QACxr4Bz7YGpxL4A3w+HrYcANE7sAQK4EFoFnH5Zh0P/+Pobty5a+vbqGK88DT8mjQfOiX1Bzs+vd8XLeRz6x4+jeyO49zy3/z5yUCNee9h33aF7vP6Ad8Zt58BHdqdTOavGx+Pj68FM2RBvPX/iqmutgNz13FvPD/yZ7w5wxYZvf75Y4f/ZjmVeh3+v7Nj//L5/dD8mn6VRVOwP3x/LGQw07GXH8J6FdsxPf9tBdN3F4b4rv6/+4vtP36bH3mEcCPTpv3gm3DztsH9c7i/As+etsVMZE+wXUw/3zUTsgVX0x27SzmLscw/7dw6JzLyTuSvYhdxdVOyBplweAum6bpWdTOmcZw8QQOwBAog9QACxBwgg9gABxB4ggNgDBBB7gABiDxBA7AECiD1AALEHCCD2AAHEHiCA2AMEEHuAAGIPEEDsAQKIPUAAsQcIIPYAAcQeIIDYAwQQe4AAYg8QQOwBAog9QACxBwgg9gABxB4ggNgDBBB7gABiDxBA7AECiD1AALEHCCD2AAHEHiCA2AMEEHuAAGIPEEDsAQKIPUAAsQcIIPYAAcQeIIDYAwQQe4AAYg8QQOwBAuxOp9PWYwBgYVb2AAHEHiCA2AMEEHuAAGIPEEDsAQKIPUAAsQcIIPYAAcQeIIDYAwQQe4AAYg8QQOwBAog9QACxBwgg9gABxB4ggNgDBBB7gABiDxBA7AECiD1AALEHCCD2AAH+A6D4pAvtVSLqAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.imshow(img)\n", "plt.axis(\"off\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 85, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 2.74731708e-03, -2.37687770e-05, 1.84897050e-03, -1.59998417e-02,\n", " 9.18241963e-02, -2.44001253e-03, 8.60346854e-01, 3.20565856e-03,\n", " 1.00000000e+00, 3.22292708e-02, -2.43984465e-03, 8.53896812e-01,\n", " 1.72559964e-03, 1.00000000e+00, 4.40814108e-01, 4.45820212e-01,\n", " 4.61422890e-01, 4.89550292e-01, 5.34102917e-01, 6.02461159e-01,\n", " 7.09149063e-01, 8.85932028e-01, 1.00000000e+00, 1.00000000e+00])" ] }, "execution_count": 85, "metadata": {}, "output_type": "execute_result" } ], "source": [ "obs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이 24개의 숫자에 대한 의미는 [온라인 문서](https://github.com/openai/gym/wiki/BipedalWalker-v2)를 참고하세요." ] }, { "cell_type": "code", "execution_count": 86, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Box(4,)" ] }, "execution_count": 86, "metadata": {}, "output_type": "execute_result" } ], "source": [ "env.action_space" ] }, { "cell_type": "code", "execution_count": 87, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([-1., -1., -1., -1.], dtype=float32)" ] }, "execution_count": 87, "metadata": {}, "output_type": "execute_result" } ], "source": [ "env.action_space.low" ] }, { "cell_type": "code", "execution_count": 88, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1., 1., 1., 1.], dtype=float32)" ] }, "execution_count": 88, "metadata": {}, "output_type": "execute_result" } ], "source": [ "env.action_space.high" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이는 각 다리의 엉덩이 관절의 토크와 발목 관절 토크를 제어하는 연속적인 4D 행동 공간입니다(-1에서 1까지). 연속적인 행동 공간을 다루기 위한 한 가지 방법은 이를 불연속적으로 나누는 것입니다. 예를 들어, 가능한 토크 값을 3개의 값 -1.0, 0.0, 1.0으로 제한할 수 있습니다. 이렇게 하면 가능한 행동은 $3^4=81$개가 됩니다." ] }, { "cell_type": "code", "execution_count": 89, "metadata": {}, "outputs": [], "source": [ "from itertools import product" ] }, { "cell_type": "code", "execution_count": 90, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(81, 4)" ] }, "execution_count": 90, "metadata": {}, "output_type": "execute_result" } ], "source": [ "possible_torques = np.array([-1.0, 0.0, 1.0])\n", "possible_actions = np.array(list(product(possible_torques, possible_torques, possible_torques, possible_torques)))\n", "possible_actions.shape" ] }, { "cell_type": "code", "execution_count": 91, "metadata": {}, "outputs": [], "source": [ "tf.reset_default_graph()\n", "\n", "# 1. 네트워크 구조를 정의합니다\n", "n_inputs = env.observation_space.shape[0] # == 24\n", "n_hidden = 10\n", "n_outputs = len(possible_actions) # == 625\n", "initializer = tf.variance_scaling_initializer()\n", "\n", "# 2. 신경망을 만듭니다\n", "X = tf.placeholder(tf.float32, shape=[None, n_inputs])\n", "\n", "hidden = tf.layers.dense(X, n_hidden, activation=tf.nn.selu,\n", " kernel_initializer=initializer)\n", "logits = tf.layers.dense(hidden, n_outputs,\n", " kernel_initializer=initializer)\n", "outputs = tf.nn.softmax(logits)\n", "\n", "# 3. 추정 확률에 기초하여 무작위한 행동을 선택합니다\n", "action_index = tf.squeeze(tf.multinomial(logits, num_samples=1), axis=-1)\n", "\n", "# 4. 훈련\n", "learning_rate = 0.01\n", "\n", "y = tf.one_hot(action_index, depth=len(possible_actions))\n", "cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=logits)\n", "optimizer = tf.train.AdamOptimizer(learning_rate)\n", "grads_and_vars = optimizer.compute_gradients(cross_entropy)\n", "gradients = [grad for grad, variable in grads_and_vars]\n", "gradient_placeholders = []\n", "grads_and_vars_feed = []\n", "for grad, variable in grads_and_vars:\n", " gradient_placeholder = tf.placeholder(tf.float32, shape=grad.get_shape())\n", " gradient_placeholders.append(gradient_placeholder)\n", " grads_and_vars_feed.append((gradient_placeholder, variable))\n", "training_op = optimizer.apply_gradients(grads_and_vars_feed)\n", "\n", "init = tf.global_variables_initializer()\n", "saver = tf.train.Saver()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "아직 훈련되지 않았지만 이 정책 네트워크를 실행해 보죠." ] }, { "cell_type": "code", "execution_count": 92, "metadata": {}, "outputs": [], "source": [ "def run_bipedal_walker(model_path=None, n_max_steps = 1000):\n", " env = gym.make(\"BipedalWalker-v2\")\n", " frames = []\n", " with tf.Session() as sess:\n", " if model_path is None:\n", " init.run()\n", " else:\n", " saver.restore(sess, model_path)\n", " obs = env.reset()\n", " for step in range(n_max_steps):\n", " img = env.render(mode=\"rgb_array\")\n", " frames.append(img)\n", " action_index_val = action_index.eval(feed_dict={X: obs.reshape(1, n_inputs)})\n", " action = possible_actions[action_index_val]\n", " obs, reward, done, info = env.step(action[0])\n", " if done:\n", " break\n", " env.close()\n", " return frames" ] }, { "cell_type": "code", "execution_count": 93, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n" ] }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 93, "metadata": {}, "output_type": "execute_result" } ], "source": [ "frames = run_bipedal_walker()\n", "video = plot_animation(frames)\n", "HTML(video.to_html5_video())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "안되네요, 걷지를 못합니다. 그럼 훈련시켜 보죠!" ] }, { "cell_type": "code", "execution_count": 94, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iteration: 1000/1000" ] } ], "source": [ "n_games_per_update = 10\n", "n_max_steps = 1000\n", "n_iterations = 1000\n", "save_iterations = 10\n", "discount_rate = 0.95\n", "\n", "with tf.Session() as sess:\n", " init.run()\n", " for iteration in range(n_iterations):\n", " print(\"\\rIteration: {}/{}\".format(iteration + 1, n_iterations), end=\"\")\n", " all_rewards = []\n", " all_gradients = []\n", " for game in range(n_games_per_update):\n", " current_rewards = []\n", " current_gradients = []\n", " obs = env.reset()\n", " for step in range(n_max_steps):\n", " action_index_val, gradients_val = sess.run([action_index, gradients],\n", " feed_dict={X: obs.reshape(1, n_inputs)})\n", " action = possible_actions[action_index_val]\n", " obs, reward, done, info = env.step(action[0])\n", " current_rewards.append(reward)\n", " current_gradients.append(gradients_val)\n", " if done:\n", " break\n", " all_rewards.append(current_rewards)\n", " all_gradients.append(current_gradients)\n", "\n", " all_rewards = discount_and_normalize_rewards(all_rewards, discount_rate=discount_rate)\n", " feed_dict = {}\n", " for var_index, gradient_placeholder in enumerate(gradient_placeholders):\n", " mean_gradients = np.mean([reward * all_gradients[game_index][step][var_index]\n", " for game_index, rewards in enumerate(all_rewards)\n", " for step, reward in enumerate(rewards)], axis=0)\n", " feed_dict[gradient_placeholder] = mean_gradients\n", " sess.run(training_op, feed_dict=feed_dict)\n", " if iteration % save_iterations == 0:\n", " saver.save(sess, \"./my_bipedal_walker_pg.ckpt\")" ] }, { "cell_type": "code", "execution_count": 95, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", "INFO:tensorflow:Restoring parameters from ./my_bipedal_walker_pg.ckpt\n" ] }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 95, "metadata": {}, "output_type": "execute_result" } ], "source": [ "frames = run_bipedal_walker(\"./my_bipedal_walker_pg.ckpt\")\n", "video = plot_animation(frames)\n", "HTML(video.to_html5_video())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "최상의 결과는 아니지만 적어도 직립해서 (느리게) 오른쪽으로 이동합니다(훨씬 더 오랜 학습이 필요할 것 같습니다 :). 이 문제에 대한 더 좋은 방법은 액터-크리틱(actor-critic) 알고리즘을 사용하는 것입니다. 이 방법은 행동 공간을 이산화할 필요가 없으므로 훨씬 빠르게 수렴합니다. 이에 대한 더 자세한 내용은 Yash Patel가 쓴 멋진 [블로그 포스트](https://towardsdatascience.com/reinforcement-learning-w-keras-openai-actor-critic-models-f084612cfd69)를 참고하세요." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 9.\n", "**Comming soon**" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" }, "nav_menu": {}, "toc": { "navigate_menu": true, "number_sections": true, "sideBar": true, "threshold": 6, "toc_cell": false, "toc_section_display": "block", "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 1 }