|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "metadata": {}, |
| 6 | + "source": [ |
| 7 | + "# Lab 09 XOR - Logistic Regression - Eager Excuetion\n", |
| 8 | + "* XOR 문제를 Logistic Regression을 활용해 풀어보도록 하겠습니다.\n", |
| 9 | + "\n", |
| 10 | + "### 기본 Library 선언 및 Tensorflow 버전 확인" |
| 11 | + ] |
| 12 | + }, |
| 13 | + { |
| 14 | + "cell_type": "code", |
| 15 | + "execution_count": 1, |
| 16 | + "metadata": { |
| 17 | + "scrolled": true |
| 18 | + }, |
| 19 | + "outputs": [ |
| 20 | + { |
| 21 | + "name": "stdout", |
| 22 | + "output_type": "stream", |
| 23 | + "text": [ |
| 24 | + "1.12.0\n" |
| 25 | + ] |
| 26 | + } |
| 27 | + ], |
| 28 | + "source": [ |
| 29 | + "import numpy as np\n", |
| 30 | + "import matplotlib.pyplot as plt\n", |
| 31 | + "%matplotlib inline\n", |
| 32 | + "import tensorflow as tf\n", |
| 33 | + "import tensorflow.contrib.eager as tfe\n", |
| 34 | + "\n", |
| 35 | + "tf.enable_eager_execution()\n", |
| 36 | + "tf.set_random_seed(777) # for reproducibility\n", |
| 37 | + "\n", |
| 38 | + "print(tf.__version__)" |
| 39 | + ] |
| 40 | + }, |
| 41 | + { |
| 42 | + "cell_type": "markdown", |
| 43 | + "metadata": {}, |
| 44 | + "source": [ |
| 45 | + "### 강의에 설명할 Data입니다\n", |
| 46 | + "* x_data가 2차원 배열이기에 2차원 공간에 표현하여 x1과 x2를 기준으로 y_data 0과 1로 구분하는 예제입니다\n", |
| 47 | + "* 붉은색과 푸른색으로 0과 1을 표시해 보도록 하겠습니다." |
| 48 | + ] |
| 49 | + }, |
| 50 | + { |
| 51 | + "cell_type": "code", |
| 52 | + "execution_count": 2, |
| 53 | + "metadata": { |
| 54 | + "scrolled": true |
| 55 | + }, |
| 56 | + "outputs": [ |
| 57 | + { |
| 58 | + "data": { |
| 59 | + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAEKCAYAAAD9xUlFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAEFNJREFUeJzt3W2MXGd5h/HrH5sQCgGqeJFQbOM0dVTcgEq6hFSoJTRp5eSDrQqEEolXpViChlYFoaalBWr3C0WlEpJbcAUJhEIwVCIrMHUlCIoEOPVGKRFOGro1Ads4igkhQkph83L3w4wfJht7d7zZM+NdXz/J8syZRzP38dpz+czZnUlVIUkSwFnjHkCSdPowCpKkxihIkhqjIElqjIIkqTEKkqTGKEiSGqMgSWqMgiSpWT3uAU7VmjVrasOGDeMeQ5KWlTvvvPPHVTWx0LplF4UNGzYwPT097jEkaVlJ8oNh1vnykSSpMQqSpMYoSJIaoyBJaoyCJKkxCpKkxihIkhqjIElqOotCkk8meTDJd09ye5J8NMlMkruTXNLVLHMdPQoXXggPPDCqR5SkRRjDk1WXRwo3AZvnuf0qYGP/1zbgnzuc5Sl27ID77+/9LkmnrTE8WXUWhaq6HfjJPEu2Ap+unn3AC5O8uKt5jjt6FG68EZ58sve7RwuSTktjerIa5zmF84FDA9cP97c9TZJtSaaTTB87duwZPeiOHb0/Y4AnnvBoQdJpakxPVsviRHNV7aqqyaqanJhY8E3+Tup4eGdne9dnZz1akHQaGuOT1TijcARYN3B9bX9bZwbDe5xHC5JOO2N8shpnFKaAN/e/C+ky4JGqOtrpA079MrzHzc7Crbd2+aiSdIrG+GTV2ecpJPkccDmwJslh4APAswCq6mPAHuBqYAZ4FHhbV7Mcd/hw148gSUtgjE9WnUWhqq5d4PYC/qSrx5cknbplcaJZkjQaRkGS1BgFSVJjFCRJjVGQJDVGQZLUGAVJUmMUJEmNUZAkNUZBktQYBUlSYxQkSY1RkCQ1RkGS1BgFSVJjFCRJjVGQJDVGQZLUGAVJUmMUJEmNUZAkNUZBktQYBUlSYxQkSY1RkCQ1RkGS1BgFSVJjFCRJjVGQJDWdRiHJ5iT3JZlJcsMJbl+f5LYkdyW5O8nVXc4jSZpfZ1FIsgrYCVwFbAKuTbJpzrK/BnZX1SuAa4B/6moeSdLCujxSuBSYqaqDVTUL3AJsnbOmgOf3L78A+FGH80iSFrC6w/s+Hzg0cP0w8Ko5az4I/EeSdwHPBa7scB5J0gLGfaL5WuCmqloLXA3cnORpMyXZlmQ6yfSxY8dGPqQknSm6jMIRYN3A9bX9bYOuA3YDVNW3gXOANXPvqKp2VdVkVU1OTEx0NK4kqcso7Ac2Jrkgydn0TiRPzVnzQ+AKgCQvpRcFDwUkaUw6i0JVPQ5cD+wF7qX3XUYHkmxPsqW/7D3A25N8B/gc8Naqqq5mkiTNr8sTzVTVHmDPnG3vH7h8D/DqLmeQJA1v3CeaJUmnEaMgSWqMgiSpMQqSpMYoSJIaoyBJaoyCJKkxCpKkxihIkhqjIElqjIIkqTEKkqTGKEiSGqMgSWqMgiSpMQqSpMYoSJIaoyBJaoyCJKkxCpKkxihIkhqjIElqjIIkqTEKkqTGKEiSGqMgSWqMgiSpMQqSpMYoSJKaTqOQZHOS+5LMJLnhJGvekOSeJAeSfLbLeSRJ81vd1R0nWQXsBP4AOAzsTzJVVfcMrNkI/CXw6qp6OMmLuppHkrSwLo8ULgVmqupgVc0CtwBb56x5O7Czqh4GqKoHO5xHkrSALqNwPnBo4Prh/rZBFwEXJflmkn1JNnc4jyRpAZ29fHQKj78RuBxYC9ye5GVV9dPBRUm2AdsA1q9fP+oZJemM0eWRwhFg3cD1tf1tgw4DU1X1WFV9H/gevUg8RVXtqqrJqpqcmJjobGBJOtN1GYX9wMYkFyQ5G7gGmJqz5kv0jhJIsobey0kHO5xJkjSPzqJQVY8D1wN7gXuB3VV1IMn2JFv6y/YCDyW5B7gNeG9VPdTVTJKk+aWqxj3DKZmcnKzp6elxjyFJy0qSO6tqcqF1/kSzJKkxCpKkxihIkhqjIElqjIIkqTEKkqTGKEiSGqMgSWqMgiSpMQqSpMYoSJKaeaOQ5PlJLjzB9pd3N5IkaVxOGoUkbwD+G/i3JAeSvHLg5pu6HkySNHrzHSn8FfDbVfVbwNuAm5P8Uf+2dD6ZJGnk5vs4zlVVdRSgqv4zyWuBLydZByyv99uWJA1lviOFnw2eT+gH4nJgK/CbHc8lSRqD+aLwDuCsJJuOb6iqnwGbgT/uejBJ0uidNApV9Z2q+h9gd5K/SM9zgI8A7xzZhJKkkRnm5xReBawDvgXsB34EvLrLoSRJ4zFMFB4D/g94DnAO8P2qerLTqSRJYzFMFPbTi8Irgd8Frk3yhU6nkiSNxXzfknrcdVU13b98FNia5E0dziRJGpMFjxQGgjC47eZuxpEkjZNviCdJaoyCJKkxCpKkxihIkhqjIElqjIIkqTEKkqSm0ygk2ZzkviQzSW6YZ93rklSSyS7nkSTNr7MoJFkF7ASuAjbRe3uMTSdYdy7wZ8AdXc0iSRpOl0cKlwIzVXWwqmaBW+h9QM9cO4APAT/vcBZJ0hC6jML5wKGB64f725oklwDrquor891Rkm1JppNMHzt2bOknlSQBYzzRnOQseh/Y856F1lbVrqqarKrJiYmJ7oeTpDNUl1E4Qu/DeY5b29923LnAxcA3ktwPXAZMebJZksanyyjsBzYmuSDJ2cA1wNTxG6vqkapaU1UbqmoDsA/YcqJ3ZZUkjUZnUaiqx4Hrgb3AvcDuqjqQZHuSLV09riRp8Yb5kJ1Fq6o9wJ45295/krWXdzmLJGlh/kSzJKkxCpKkxihIkhqjIElqjIIkqTEKkqTGKEiSGqMgSWqMgiSpMQqSpMYoSJIaoyBJaoyCJKkxCpKkxihIkhqjIElqjIIkqTEKkqTGKEiSGqMgSWqMgiSpMQqSpMYoSJIaoyBJaoyCJKkxCpKkxihIkhqjIElqjIIkqek0Ckk2J7kvyUySG05w+7uT3JPk7iRfS/KSLueRJM2vsygkWQXsBK4CNgHXJtk0Z9ldwGRVvRz4IvD3Xc0jSVpYl0cKlwIzVXWwqmaBW4Ctgwuq6raqerR/dR+wtsN5JEkL6DIK5wOHBq4f7m87meuAr3Y4jyRpAavHPQBAkjcCk8BrTnL7NmAbwPr160c4mSSdWbo8UjgCrBu4vra/7SmSXAm8D9hSVb840R1V1a6qmqyqyYmJiU6GlSR1G4X9wMYkFyQ5G7gGmBpckOQVwMfpBeHBDmeRJA2hsyhU1ePA9cBe4F5gd1UdSLI9yZb+sg8DzwO+kOS/kkyd5O4kSSPQ6TmFqtoD7Jmz7f0Dl6/s8vElSafGn2iWJDVGQZLUGAVJUmMUJEmNUZAkNUZBktQYBUlSYxQkSY1RkCQ1RkGS1BgFSVJjFCRJjVGQJDVGQZLUGAVJUmMUJEmNUZAkNUZBktQYBUlSYxQkSY1RkCQ1RkGS1BgFSVJjFCRJjVGQJDVGQZLUGAVJUmMUJEmNUZAkNZ1GIcnmJPclmUlywwluf3aSz/dvvyPJhi7nkSTNr7MoJFkF7ASuAjYB1ybZNGfZdcDDVfXrwD8CH+pqnqc4ehQuvBAeeGAkDydJizGOp6oujxQuBWaq6mBVzQK3AFvnrNkKfKp/+YvAFUnS4Uw9O3bA/ff3fpek09Q4nqq6jML5wKGB64f72064pqoeBx4Bzutwpl56b7wRnnyy97tHC5JOQ+N6qloWJ5qTbEsynWT62LFjz+zOduzo/SkDPPGERwuSTkvjeqrqMgpHgHUD19f2t51wTZLVwAuAh+beUVXtqqrJqpqcmJhY/ETH0zs727s+O+vRgqTTzjifqrqMwn5gY5ILkpwNXANMzVkzBbylf/n1wNerqjqbaDC9x3m0IOk0M86nqs6i0D9HcD2wF7gX2F1VB5JsT7Klv+wTwHlJZoB3A0/7ttUlNTX1y/QeNzsLt97a6cNK0qkY51NVuvyPeRcmJydrenp63GNI0rKS5M6qmlxo3bI40SxJGg2jIElqjIIkqTEKkqTGKEiSGqMgSWqMgiSpMQqSpGbZ/fBakmPAD5bgrtYAP16C+1ku3N+V60zaV3B/F+slVbXgm8ctuygslSTTw/x030rh/q5cZ9K+gvvbNV8+kiQ1RkGS1JzJUdg17gFGzP1duc6kfQX3t1Nn7DkFSdLTnclHCpKkOVZ8FJJsTnJfkpkkT/sQnyTPTvL5/u13JNkw+imXzhD7++4k9yS5O8nXkrxkHHMuhYX2dWDd65JUkmX9HSvD7G+SN/S/vgeSfHbUMy6lIf4ur09yW5K7+n+frx7HnEshySeTPJjkuye5PUk+2v+zuDvJJZ0NU1Ur9hewCvhf4NeAs4HvAJvmrHkn8LH+5WuAz4977o7397XAr/Qvv2O57u8w+9pfdy5wO7APmBz33B1/bTcCdwG/2r/+onHP3fH+7gLe0b+8Cbh/3HM/g/39PeAS4Lsnuf1q4KtAgMuAO7qaZaUfKVwKzFTVwaqaBW4Bts5ZsxX4VP/yF4ErkmSEMy6lBfe3qm6rqkf7V/cBa0c841IZ5msLsAP4EPDzUQ7XgWH29+3Azqp6GKCqHhzxjEtpmP0t4Pn9yy8AfjTC+ZZUVd0O/GSeJVuBT1fPPuCFSV7cxSwrPQrnA4cGrh/ubzvhmup9rvQjwHkjmW7pDbO/g66j97+P5WjBfe0fYq+rqq+McrCODPO1vQi4KMk3k+xLsnlk0y29Yfb3g8AbkxwG9gDvGs1oY3Gq/7YXbXUXd6rTX5I3ApPAa8Y9SxeSnAV8BHjrmEcZpdX0XkK6nN4R4O1JXlZVPx3rVN25Fripqv4hye8ANye5uKqeHPdgy9lKP1I4AqwbuL62v+2Ea5KspncY+tBIplt6w+wvSa4E3gdsqapfjGi2pbbQvp4LXAx8I8n99F6HnVrGJ5uH+doeBqaq6rGq+j7wPXqRWI6G2d/rgN0AVfVt4Bx67xO0Eg31b3sprPQo7Ac2Jrkgydn0TiRPzVkzBbylf/n1wNerf2ZnGVpwf5O8Avg4vSAs59ec593XqnqkqtZU1Yaq2kDv/MmWqpoez7jP2DB/l79E7yiBJGvovZx0cJRDLqFh9veHwBUASV5KLwrHRjrl6EwBb+5/F9JlwCNVdbSLB1rRLx9V1eNJrgf20vtuhk9W1YEk24HpqpoCPkHvsHOG3omea8Y38TMz5P5+GHge8IX++fQfVtWWsQ29SEPu64ox5P7uBf4wyT3AE8B7q2pZHvUOub/vAf4lyZ/TO+n81uX6H7okn6MX9DX9cyQfAJ4FUFUfo3fO5GpgBngUeFtnsyzTP0NJUgdW+stHkqRTYBQkSY1RkCQ1RkGS1BgFSVJjFKQllOTfk/w0yZfHPYu0GEZBWlofBt407iGkxTIK0iIkeWX/fe3PSfLc/ucXXFxVXwN+Nu75pMVa0T/RLHWlqvYnmQL+DngO8JmqOuEHpEjLiVGQFm87vffo+Tnwp2OeRVoSvnwkLd559N5H6lx6b8YmLXtGQVq8jwN/A/wrvU93k5Y9Xz6SFiHJm4HHquqzSVYB30ry+8DfAr8BPK//bpfXVdXecc4qnQrfJVWS1PjykSSpMQqSpMYoSJIaoyBJaoyCJKkxCpKkxihIkhqjIElq/h+CL7kPwpocagAAAABJRU5ErkJggg==\n", |
| 60 | + "text/plain": [ |
| 61 | + "<Figure size 432x288 with 1 Axes>" |
| 62 | + ] |
| 63 | + }, |
| 64 | + "metadata": { |
| 65 | + "needs_background": "light" |
| 66 | + }, |
| 67 | + "output_type": "display_data" |
| 68 | + } |
| 69 | + ], |
| 70 | + "source": [ |
| 71 | + "x_data = [[0, 0],\n", |
| 72 | + " [0, 1],\n", |
| 73 | + " [1, 0],\n", |
| 74 | + " [1, 1]]\n", |
| 75 | + "y_data = [[0],\n", |
| 76 | + " [1],\n", |
| 77 | + " [1],\n", |
| 78 | + " [0]]\n", |
| 79 | + "\n", |
| 80 | + "plt.scatter(x_data[0][0],x_data[0][1], c='red' , marker='^')\n", |
| 81 | + "plt.scatter(x_data[3][0],x_data[3][1], c='red' , marker='^')\n", |
| 82 | + "plt.scatter(x_data[1][0],x_data[1][1], c='blue' , marker='^')\n", |
| 83 | + "plt.scatter(x_data[2][0],x_data[2][1], c='blue' , marker='^')\n", |
| 84 | + "\n", |
| 85 | + "plt.xlabel(\"x1\")\n", |
| 86 | + "plt.ylabel(\"x2\")\n", |
| 87 | + "plt.show()" |
| 88 | + ] |
| 89 | + }, |
| 90 | + { |
| 91 | + "cell_type": "markdown", |
| 92 | + "metadata": {}, |
| 93 | + "source": [ |
| 94 | + "## Tensorflow Eager\n", |
| 95 | + "### 위 Data를 기준으로 XOR처리를 위한 모델을 만들도록 하겠습니다\n", |
| 96 | + "* Tensorflow data API를 통해 학습시킬 값들을 담는다 (Batch Size는 한번에 학습시킬 Size로 정한다)\n", |
| 97 | + "* preprocess function으로 features,labels는 실재 학습에 쓰일 Data 연산을 위해 Type를 맞춰준다" |
| 98 | + ] |
| 99 | + }, |
| 100 | + { |
| 101 | + "cell_type": "code", |
| 102 | + "execution_count": 3, |
| 103 | + "metadata": {}, |
| 104 | + "outputs": [], |
| 105 | + "source": [ |
| 106 | + "dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)).batch(len(x_data))\n", |
| 107 | + "\n", |
| 108 | + "def preprocess_data(features, labels):\n", |
| 109 | + " features = tf.cast(features, tf.float32)\n", |
| 110 | + " labels = tf.cast(labels, tf.float32)\n", |
| 111 | + " return features, labels" |
| 112 | + ] |
| 113 | + }, |
| 114 | + { |
| 115 | + "cell_type": "markdown", |
| 116 | + "metadata": {}, |
| 117 | + "source": [ |
| 118 | + "## 1) Logistic Regression으로 XOR모델을 만들어 보겠습니다\n", |
| 119 | + "### W와 b은 학습을 통해 생성되는 모델에 쓰이는 Wegith와 Bias (초기값을 variable : 0이나 Random값으로 가능 tf.random_normal([2, 1]) )" |
| 120 | + ] |
| 121 | + }, |
| 122 | + { |
| 123 | + "cell_type": "code", |
| 124 | + "execution_count": 4, |
| 125 | + "metadata": {}, |
| 126 | + "outputs": [ |
| 127 | + { |
| 128 | + "name": "stdout", |
| 129 | + "output_type": "stream", |
| 130 | + "text": [ |
| 131 | + "W = [[0.]\n", |
| 132 | + " [0.]], B = [0.]\n" |
| 133 | + ] |
| 134 | + } |
| 135 | + ], |
| 136 | + "source": [ |
| 137 | + "W = tf.Variable(tf.zeros([2,1]), name='weight')\n", |
| 138 | + "b = tf.Variable(tf.zeros([1]), name='bias')\n", |
| 139 | + "print(\"W = {}, B = {}\".format(W.numpy(), b.numpy()))" |
| 140 | + ] |
| 141 | + }, |
| 142 | + { |
| 143 | + "cell_type": "markdown", |
| 144 | + "metadata": {}, |
| 145 | + "source": [ |
| 146 | + "### Sigmoid 함수를 가설로 선언합니다\n", |
| 147 | + "* Sigmoid는 아래 그래프와 같이 0과 1의 값만을 리턴합니다 tf.sigmoid(tf.matmul(X, W) + b)와 같습니다\n", |
| 148 | + "\n", |
| 149 | + "$$\n", |
| 150 | + "\\begin{align}\n", |
| 151 | + "sigmoid(x) & = \\frac{1}{1+e^{-x}} \\\\\\\\\\\n", |
| 152 | + "\\end{align}\n", |
| 153 | + "$$" |
| 154 | + ] |
| 155 | + }, |
| 156 | + { |
| 157 | + "cell_type": "code", |
| 158 | + "execution_count": 5, |
| 159 | + "metadata": {}, |
| 160 | + "outputs": [], |
| 161 | + "source": [ |
| 162 | + "def logistic_regression(features):\n", |
| 163 | + " hypothesis = tf.div(1., 1. + tf.exp(tf.matmul(features, W) + b))\n", |
| 164 | + " return hypothesis" |
| 165 | + ] |
| 166 | + }, |
| 167 | + { |
| 168 | + "cell_type": "markdown", |
| 169 | + "metadata": {}, |
| 170 | + "source": [ |
| 171 | + "### 가설을 검증할 Cost 함수를 정의합니다\n", |
| 172 | + "$$\n", |
| 173 | + "\\begin{align}\n", |
| 174 | + "cost(h(x),y) & = −log(h(x)) & if & y=1 \\\\\\\\\\\n", |
| 175 | + "cost(h(x),y) & = -log(1−h(x)) & if & y=0\n", |
| 176 | + "\\end{align}\n", |
| 177 | + "$$" |
| 178 | + ] |
| 179 | + }, |
| 180 | + { |
| 181 | + "cell_type": "markdown", |
| 182 | + "metadata": {}, |
| 183 | + "source": [ |
| 184 | + "* 위 두수식을 합치면 아래과 같습니다\n", |
| 185 | + "$$\n", |
| 186 | + "\\begin{align}\n", |
| 187 | + "cost(h(x),y) & = −y log(h(x))−(1−y)log(1−h(x))\n", |
| 188 | + "\\end{align}\n", |
| 189 | + "$$" |
| 190 | + ] |
| 191 | + }, |
| 192 | + { |
| 193 | + "cell_type": "code", |
| 194 | + "execution_count": 6, |
| 195 | + "metadata": {}, |
| 196 | + "outputs": [], |
| 197 | + "source": [ |
| 198 | + "def loss_fn(hypothesis, features, labels):\n", |
| 199 | + " cost = -tf.reduce_mean(labels * tf.log(logistic_regression(features)) + (1 - labels) * tf.log(1 - hypothesis))\n", |
| 200 | + " return cost\n", |
| 201 | + "\n", |
| 202 | + "optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)" |
| 203 | + ] |
| 204 | + }, |
| 205 | + { |
| 206 | + "cell_type": "markdown", |
| 207 | + "metadata": {}, |
| 208 | + "source": [ |
| 209 | + "### 추론한 값은 0.5를 기준(Sigmoid 그래프 참조)로 0과 1의 값을 리턴합니다.\n", |
| 210 | + "* Sigmoid 함수를 통해 예측값이 0.5보다 크면 1을 반환하고 0.5보다 작으면 0으로 반환합니다." |
| 211 | + ] |
| 212 | + }, |
| 213 | + { |
| 214 | + "cell_type": "code", |
| 215 | + "execution_count": 7, |
| 216 | + "metadata": {}, |
| 217 | + "outputs": [], |
| 218 | + "source": [ |
| 219 | + "def accuracy_fn(hypothesis, labels):\n", |
| 220 | + " predicted = tf.cast(hypothesis > 0.5, dtype=tf.float32)\n", |
| 221 | + " accuracy = tf.reduce_mean(tf.cast(tf.equal(predicted, labels), dtype=tf.float32))\n", |
| 222 | + " return accuracy" |
| 223 | + ] |
| 224 | + }, |
| 225 | + { |
| 226 | + "cell_type": "markdown", |
| 227 | + "metadata": {}, |
| 228 | + "source": [ |
| 229 | + "### GradientTape를 통해 경사값을 계산합니다." |
| 230 | + ] |
| 231 | + }, |
| 232 | + { |
| 233 | + "cell_type": "code", |
| 234 | + "execution_count": 8, |
| 235 | + "metadata": {}, |
| 236 | + "outputs": [], |
| 237 | + "source": [ |
| 238 | + "def grad(hypothesis, features, labels):\n", |
| 239 | + " with tf.GradientTape() as tape:\n", |
| 240 | + " loss_value = loss_fn(logistic_regression(features),features,labels)\n", |
| 241 | + " return tape.gradient(loss_value, [W,b])" |
| 242 | + ] |
| 243 | + }, |
| 244 | + { |
| 245 | + "cell_type": "markdown", |
| 246 | + "metadata": {}, |
| 247 | + "source": [ |
| 248 | + "### Tensorflow를 통한 실행을 위해 Session를 선언합니다.\n", |
| 249 | + "* 위의 Data를 Cost함수를 통해 학습시킨 후 모델을 생성합니다. " |
| 250 | + ] |
| 251 | + }, |
| 252 | + { |
| 253 | + "cell_type": "code", |
| 254 | + "execution_count": 9, |
| 255 | + "metadata": { |
| 256 | + "scrolled": true |
| 257 | + }, |
| 258 | + "outputs": [ |
| 259 | + { |
| 260 | + "name": "stdout", |
| 261 | + "output_type": "stream", |
| 262 | + "text": [ |
| 263 | + "Iter: 0, Loss: 0.6931\n", |
| 264 | + "Iter: 100, Loss: 0.6931\n", |
| 265 | + "Iter: 200, Loss: 0.6931\n", |
| 266 | + "Iter: 300, Loss: 0.6931\n", |
| 267 | + "Iter: 400, Loss: 0.6931\n", |
| 268 | + "Iter: 500, Loss: 0.6931\n", |
| 269 | + "Iter: 600, Loss: 0.6931\n", |
| 270 | + "Iter: 700, Loss: 0.6931\n", |
| 271 | + "Iter: 800, Loss: 0.6931\n", |
| 272 | + "Iter: 900, Loss: 0.6931\n", |
| 273 | + "Iter: 1000, Loss: 0.6931\n", |
| 274 | + "W = [[0.]\n", |
| 275 | + " [0.]], B = [0.]\n", |
| 276 | + "Testset Accuracy: 0.5000\n" |
| 277 | + ] |
| 278 | + } |
| 279 | + ], |
| 280 | + "source": [ |
| 281 | + "EPOCHS = 1001\n", |
| 282 | + "\n", |
| 283 | + "for step in range(EPOCHS):\n", |
| 284 | + " for features, labels in tfe.Iterator(dataset):\n", |
| 285 | + " features, labels = preprocess_data(features, labels)\n", |
| 286 | + " grads = grad(logistic_regression(features), features, labels)\n", |
| 287 | + " optimizer.apply_gradients(grads_and_vars=zip(grads,[W,b]))\n", |
| 288 | + " if step % 100 == 0:\n", |
| 289 | + " print(\"Iter: {}, Loss: {:.4f}\".format(step, loss_fn(logistic_regression(features),features,labels)))\n", |
| 290 | + "print(\"W = {}, B = {}\".format(W.numpy(), b.numpy()))\n", |
| 291 | + "x_data, y_data = preprocess_data(x_data, y_data)\n", |
| 292 | + "test_acc = accuracy_fn(logistic_regression(x_data),y_data)\n", |
| 293 | + "print(\"Testset Accuracy: {:.4f}\".format(test_acc))" |
| 294 | + ] |
| 295 | + }, |
| 296 | + { |
| 297 | + "cell_type": "code", |
| 298 | + "execution_count": null, |
| 299 | + "metadata": {}, |
| 300 | + "outputs": [], |
| 301 | + "source": [] |
| 302 | + } |
| 303 | + ], |
| 304 | + "metadata": { |
| 305 | + "kernelspec": { |
| 306 | + "display_name": "Python 3", |
| 307 | + "language": "python", |
| 308 | + "name": "python3" |
| 309 | + }, |
| 310 | + "language_info": { |
| 311 | + "codemirror_mode": { |
| 312 | + "name": "ipython", |
| 313 | + "version": 3 |
| 314 | + }, |
| 315 | + "file_extension": ".py", |
| 316 | + "mimetype": "text/x-python", |
| 317 | + "name": "python", |
| 318 | + "nbconvert_exporter": "python", |
| 319 | + "pygments_lexer": "ipython3", |
| 320 | + "version": "3.6.5" |
| 321 | + } |
| 322 | + }, |
| 323 | + "nbformat": 4, |
| 324 | + "nbformat_minor": 2 |
| 325 | +} |
0 commit comments