diff --git a/02_end_to_end_machine_learning_project.ipynb b/02_end_to_end_machine_learning_project.ipynb index a8386f5..41069f8 100644 --- a/02_end_to_end_machine_learning_project.ipynb +++ b/02_end_to_end_machine_learning_project.ipynb @@ -593,16 +593,35 @@ "execution_count": 14, "metadata": {}, "outputs": [], + "source": [ + "from zlib import crc32\n", + "\n", + "def test_set_check(identifier, test_ratio):\n", + " return crc32(np.int64(identifier)) & 0xffffffff < test_ratio * 2**32\n", + "\n", + "def split_train_test_by_id(data, test_ratio, id_column):\n", + " ids = data[id_column]\n", + " in_test_set = ids.apply(lambda id_: test_set_check(id_, test_ratio))\n", + " return data.loc[~in_test_set], data.loc[in_test_set]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "위의 `test_set_check()` 함수는 파이썬 2와 파이썬 3에서 모두 작동되고 다음의 hashlib를 사용한 구현보다 훨씬 빠릅니다." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], "source": [ "import hashlib\n", "\n", - "def test_set_check(identifier, test_ratio, hash):\n", - " return hash(np.int64(identifier)).digest()[-1] < 256 * test_ratio\n", - "\n", - "def split_train_test_by_id(data, test_ratio, id_column, hash=hashlib.md5):\n", - " ids = data[id_column]\n", - " in_test_set = ids.apply(lambda id_: test_set_check(id_, test_ratio, hash))\n", - " return data.loc[~in_test_set], data.loc[in_test_set]" + "def test_set_check(identifier, test_ratio, hash=hashlib.md5):\n", + " return bytearray(hash(np.int64(identifier)).digest())[-1] < 256 * test_ratio" ] }, {