※ 김성훈 교수님의 [모두를 위한 딥러닝] 강의 정리
- https://www.youtube.com/watch?reload=9&v=BS6O0zOGX4E&feature=youtu.be&list=PLlMkM4tgfjnLSOjrEJN31gZATbcj_MpUm&fbclid=IwAR07UnOxQEOxSKkH6bQ8PzYj2vDop_J0Pbzkg3IVQeQ_zTKcXdNOwaSf_k0
- 참고자료 : Andrew Ng's ML class
1) https://class.coursera.org/ml-003/lecture
2) http://www.holehouse.org/mlcass/ (note)
1. (Linear) Hypothesis and cost function
* Hypothesis : H(x) = Wx + b
* Cost function(W,b) = ( H(x) - y ) ^ 2 // How fit the line to our (training) data
* Goal = Minimize cost
2. How to minimize cost
* 학습 : W,b 값을 조정하여 cost 값을 최소화 하는 과정
(1) 그래프 생성
import tensorflow as tf |
|
# X and Y data |
x_train = [1, 2, 3] |
y_train = [1, 2, 3] |
|
# Try to find values for W and b to compute y_data = x_data * W + b |
# We know that W should be 1 and b should be 0 |
# But let TensorFlow figure it out |
W = tf.Variable(tf.random_normal([1]), name="weight") // Variable은 다른 프로그래밍 언어의 변수와는 달리, Tensorflow가 트레이닝을 위해 사용하는 변수임 |
b = tf.Variable(tf.random_normal([1]), name="bias") |
|
# Our hypothesis XW+b |
hypothesis = x_train * W + b |
|
# cost/loss function |
cost = tf.reduce_mean(tf.square(hypothesis - y_train)) |
* GradientDescent : 학습
# Minimize |
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) |
train = optimizer.minimize(cost) |
(2) 세션 실행 : 데이터 입력 및 그래프 실행
# Launch the graph in a session. |
sess = tf.Session() |
# Initializes global variables in the graph. |
sess.run(tf.global_variables_initializer()) |
|
# Fit the line |
for step in range(2001): |
sess.run(train) |
if step % 20 == 0: |
print(step, sess.run(cost), sess.run(W), sess.run(b)) |
(3) 그래프 업데이트 및 결과값 반환 : 학습에 의해 cost를 최소화하는 W, b 값 추론
...
(0, 3.5240757, array([2.1286771], dtype=float32), array([-0.8523567], dtype=float32))
(20, 0.19749945, array([1.533928], dtype=float32), array([-1.0505961], dtype=float32))
(40, 0.15214379, array([1.4572546], dtype=float32), array([-1.0239124], dtype=float32))
(60, 0.1379325, array([1.4308538], dtype=float32), array([-0.9779527], dtype=float32))
(80, 0.12527025, array([1.4101374], dtype=float32), array([-0.93219817], dtype=float32))
(100, 0.11377233, array([1.3908179], dtype=float32), array([-0.8884077], dtype=float32))
(120, 0.10332986, array([1.3724468], dtype=float32), array([-0.8466577], dtype=float32))
(140, 0.093845844, array([1.3549428], dtype=float32), array([-0.80686814], dtype=float32))
(160, 0.08523229, array([1.3382617], dtype=float32), array([-0.7689483], dtype=float32))
(180, 0.07740932, array([1.3223647], dtype=float32), array([-0.73281056], dtype=float32))
(200, 0.07030439, array([1.3072149], dtype=float32), array([-0.6983712], dtype=float32))
(220, 0.06385162, array([1.2927768], dtype=float32), array([-0.6655505], dtype=float32))
(240, 0.05799109, array([1.2790174], dtype=float32), array([-0.63427216], dtype=float32))
(260, 0.05266844, array([1.2659047], dtype=float32), array([-0.6044637], dtype=float32))
(280, 0.047834318, array([1.2534081], dtype=float32), array([-0.57605624], dtype=float32))
(300, 0.043443877, array([1.2414987], dtype=float32), array([-0.5489836], dtype=float32))
(320, 0.0394564, array([1.2301493], dtype=float32), array([-0.5231833], dtype=float32))
(340, 0.035834935, array([1.2193329], dtype=float32), array([-0.49859545], dtype=float32))
(360, 0.032545824, array([1.2090251], dtype=float32), array([-0.47516325], dtype=float32))
(380, 0.029558638, array([1.1992016], dtype=float32), array([-0.45283225], dtype=float32))
(400, 0.026845641, array([1.18984], dtype=float32), array([-0.4315508], dtype=float32))
(420, 0.024381675, array([1.1809182], dtype=float32), array([-0.41126958], dtype=float32))
(440, 0.02214382, array([1.1724157], dtype=float32), array([-0.39194146], dtype=float32))
(460, 0.020111356, array([1.1643128], dtype=float32), array([-0.37352163], dtype=float32))
(480, 0.018265454, array([1.1565907], dtype=float32), array([-0.35596743], dtype=float32))
(500, 0.016588978, array([1.1492316], dtype=float32), array([-0.33923826], dtype=float32))
(520, 0.015066384, array([1.1422179], dtype=float32), array([-0.3232953], dtype=float32))
(540, 0.01368351, array([1.1355343], dtype=float32), array([-0.30810148], dtype=float32))
(560, 0.012427575, array([1.1291647], dtype=float32), array([-0.29362184], dtype=float32))
(580, 0.011286932, array([1.1230947], dtype=float32), array([-0.2798227], dtype=float32))
(600, 0.010250964, array([1.1173096], dtype=float32), array([-0.26667204], dtype=float32))
(620, 0.009310094, array([1.1117964], dtype=float32), array([-0.25413945], dtype=float32))
(640, 0.008455581, array([1.1065423], dtype=float32), array([-0.24219586], dtype=float32))
(660, 0.0076795053, array([1.1015354], dtype=float32), array([-0.23081362], dtype=float32))
(680, 0.006974643, array([1.0967635], dtype=float32), array([-0.21996623], dtype=float32))
(700, 0.0063344706, array([1.0922159], dtype=float32), array([-0.20962858], dtype=float32))
(720, 0.0057530706, array([1.0878822], dtype=float32), array([-0.19977672], dtype=float32))
(740, 0.0052250377, array([1.0837522], dtype=float32), array([-0.19038804], dtype=float32))
(760, 0.004745458, array([1.0798159], dtype=float32), array([-0.18144041], dtype=float32))
(780, 0.004309906, array([1.076065], dtype=float32), array([-0.17291337], dtype=float32))
(800, 0.003914324, array([1.0724902], dtype=float32), array([-0.16478711], dtype=float32))
(820, 0.0035550483, array([1.0690835], dtype=float32), array([-0.1570428], dtype=float32))
(840, 0.0032287557, array([1.0658368], dtype=float32), array([-0.14966238], dtype=float32))
(860, 0.0029324207, array([1.0627428], dtype=float32), array([-0.14262886], dtype=float32))
(880, 0.0026632652, array([1.059794], dtype=float32), array([-0.13592596], dtype=float32))
(900, 0.0024188235, array([1.056984], dtype=float32), array([-0.12953788], dtype=float32))
(920, 0.0021968128, array([1.0543059], dtype=float32), array([-0.12345006], dtype=float32))
(940, 0.001995178, array([1.0517538], dtype=float32), array([-0.11764836], dtype=float32))
(960, 0.0018120449, array([1.0493214], dtype=float32), array([-0.11211928], dtype=float32))
(980, 0.0016457299, array([1.0470035], dtype=float32), array([-0.10685005], dtype=float32))
(1000, 0.0014946823, array([1.0447946], dtype=float32), array([-0.10182849], dtype=float32))
(1020, 0.0013574976, array([1.0426894], dtype=float32), array([-0.09704296], dtype=float32))
(1040, 0.001232898, array([1.0406833], dtype=float32), array([-0.09248237], dtype=float32))
(1060, 0.0011197334, array([1.038771], dtype=float32), array([-0.08813594], dtype=float32))
(1080, 0.0010169626, array([1.0369489], dtype=float32), array([-0.08399385], dtype=float32))
(1100, 0.0009236224, array([1.0352125], dtype=float32), array([-0.08004645], dtype=float32))
(1120, 0.0008388485, array([1.0335577], dtype=float32), array([-0.07628451], dtype=float32))
(1140, 0.0007618535, array([1.0319806], dtype=float32), array([-0.07269943], dtype=float32))
(1160, 0.0006919258, array([1.0304775], dtype=float32), array([-0.06928282], dtype=float32))
(1180, 0.00062842044, array([1.0290452], dtype=float32), array([-0.06602671], dtype=float32))
(1200, 0.0005707396, array([1.0276802], dtype=float32), array([-0.06292368], dtype=float32))
(1220, 0.00051835255, array([1.0263793], dtype=float32), array([-0.05996648], dtype=float32))
(1240, 0.00047077626, array([1.0251396], dtype=float32), array([-0.05714824], dtype=float32))
(1260, 0.00042756708, array([1.0239582], dtype=float32), array([-0.0544625], dtype=float32))
(1280, 0.00038832307, array([1.0228322], dtype=float32), array([-0.05190301], dtype=float32))
(1300, 0.00035268333, array([1.0217593], dtype=float32), array([-0.04946378], dtype=float32))
(1320, 0.0003203152, array([1.0207369], dtype=float32), array([-0.04713925], dtype=float32))
(1340, 0.0002909189, array([1.0197623], dtype=float32), array([-0.0449241], dtype=float32))
(1360, 0.00026421514, array([1.0188333], dtype=float32), array([-0.04281275], dtype=float32))
(1380, 0.0002399599, array([1.0179482], dtype=float32), array([-0.04080062], dtype=float32))
(1400, 0.00021793543, array([1.0171047], dtype=float32), array([-0.03888312], dtype=float32))
(1420, 0.00019793434, array([1.0163009], dtype=float32), array([-0.03705578], dtype=float32))
(1440, 0.00017976768, array([1.0155348], dtype=float32), array([-0.03531429], dtype=float32))
(1460, 0.00016326748, array([1.0148047], dtype=float32), array([-0.03365463], dtype=float32))
(1480, 0.00014828023, array([1.0141089], dtype=float32), array([-0.03207294], dtype=float32))
(1500, 0.00013467176, array([1.0134459], dtype=float32), array([-0.03056567], dtype=float32))
(1520, 0.00012231102, array([1.0128139], dtype=float32), array([-0.02912918], dtype=float32))
(1540, 0.0001110848, array([1.0122118], dtype=float32), array([-0.0277602], dtype=float32))
(1560, 0.000100889745, array([1.0116379], dtype=float32), array([-0.02645557], dtype=float32))
(1580, 9.162913e-05, array([1.011091], dtype=float32), array([-0.02521228], dtype=float32))
(1600, 8.322027e-05, array([1.0105698], dtype=float32), array([-0.02402747], dtype=float32))
(1620, 7.5580865e-05, array([1.0100728], dtype=float32), array([-0.02289824], dtype=float32))
(1640, 6.8643785e-05, array([1.0095996], dtype=float32), array([-0.02182201], dtype=float32))
(1660, 6.234206e-05, array([1.0091484], dtype=float32), array([-0.02079643], dtype=float32))
(1680, 5.662038e-05, array([1.0087185], dtype=float32), array([-0.01981908], dtype=float32))
(1700, 5.142322e-05, array([1.0083088], dtype=float32), array([-0.01888768], dtype=float32))
(1720, 4.6704197e-05, array([1.0079182], dtype=float32), array([-0.01800001], dtype=float32))
(1740, 4.2417145e-05, array([1.0075461], dtype=float32), array([-0.01715406], dtype=float32))
(1760, 3.852436e-05, array([1.0071915], dtype=float32), array([-0.01634789], dtype=float32))
(1780, 3.4988276e-05, array([1.0068535], dtype=float32), array([-0.01557961], dtype=float32))
(1800, 3.1776715e-05, array([1.0065314], dtype=float32), array([-0.01484741], dtype=float32))
(1820, 2.8859866e-05, array([1.0062244], dtype=float32), array([-0.0141496], dtype=float32))
(1840, 2.621177e-05, array([1.005932], dtype=float32), array([-0.01348464], dtype=float32))
(1860, 2.380544e-05, array([1.0056531], dtype=float32), array([-0.01285094], dtype=float32))
(1880, 2.1620841e-05, array([1.0053875], dtype=float32), array([-0.012247], dtype=float32))
(1900, 1.9636196e-05, array([1.0051342], dtype=float32), array([-0.01167146], dtype=float32))
(1920, 1.7834054e-05, array([1.004893], dtype=float32), array([-0.01112291], dtype=float32))
(1940, 1.6197106e-05, array([1.0046631], dtype=float32), array([-0.01060018], dtype=float32))
(1960, 1.4711059e-05, array([1.004444], dtype=float32), array([-0.01010205], dtype=float32))
(1980, 1.3360998e-05, array([1.0042351], dtype=float32), array([-0.00962736], dtype=float32))
(2000, 1.21343355e-05, array([1.0040361], dtype=float32), array([-0.00917497], dtype=float32))
3. How to minimize cost (placeholder 이용) // 에러..
import tensorflow as tf |
W = tf.Variable(tf.random_normal([1]), name='weight') |
b = tf.Variable(tf.random_normal([1]), name='bias') |
X= tf.placeholder(tf.float32, shape=[None]) |
Y= tf.placeholder(tf.float32, shape=[None]) |
|
# Our hypothesis XW+b |
hypothesis = X * W + b |
|
# cost/loss function |
cost = tf.reduce_mean(tf.square(hypothesis - Y) |
|
# Minimize |
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) |
train = optimizer.minimize(cost) |
|
# Launch the graph in a session. |
sess = tf.Session() |
# Initializes global variables in the graph. |
sess.run(tf.global_variables_initializer()) |
|
# Fit the line with new training data |
for step in range(2001): |
cost_val, W_val, b_val, _ = sess.run([cost, W, b, train], feed_dict={X: [1, 2, 3, 4, 5], Y: [2.1, 3.1, 4.1, 5.1, 6.1]) |
if step % 20 == 0: |
print(step, cost_val, W_val, b_val) |
(0, 1.2035878, array([1.0696986], dtype=float32), array([0.01276637], dtype=float32))
(20, 0.16904518, array([1.2650416], dtype=float32), array([0.13934135], dtype=float32))
(40, 0.14761032, array([1.2485868], dtype=float32), array([0.20250577], dtype=float32))
(60, 0.1289092, array([1.2323107], dtype=float32), array([0.26128453], dtype=float32))
(80, 0.112577364, array([1.2170966], dtype=float32), array([0.3162127], dtype=float32))
(100, 0.09831471, array([1.2028787], dtype=float32), array([0.36754355], dtype=float32))
(120, 0.08585897, array([1.189592], dtype=float32), array([0.41551268], dtype=float32))
(140, 0.07498121, array([1.1771754], dtype=float32), array([0.46034035], dtype=float32))
(160, 0.0654817, array([1.165572], dtype=float32), array([0.5022322], dtype=float32))
(180, 0.05718561, array([1.1547288], dtype=float32), array([0.54138047], dtype=float32))
(200, 0.049940635, array([1.1445953], dtype=float32), array([0.5779649], dtype=float32))
(220, 0.043613486, array([1.1351256], dtype=float32), array([0.6121535], dtype=float32))
(240, 0.038087945, array([1.1262761], dtype=float32), array([0.64410305], dtype=float32))
(260, 0.033262506, array([1.1180062], dtype=float32), array([0.6739601], dtype=float32))
(280, 0.029048424, array([1.1102779], dtype=float32), array([0.7018617], dtype=float32))
(300, 0.025368208, array([1.1030556], dtype=float32), array([0.7279361], dtype=float32))
(320, 0.022154227, array([1.0963064], dtype=float32), array([0.7523028], dtype=float32))
(340, 0.019347461, array([1.0899993], dtype=float32), array([0.7750737], dtype=float32))
(360, 0.016896311, array([1.0841053], dtype=float32), array([0.7963533], dtype=float32))
(380, 0.014755693, array([1.0785972], dtype=float32), array([0.8162392], dtype=float32))
(400, 0.012886246, array([1.0734499], dtype=float32), array([0.83482295], dtype=float32))
(420, 0.011253643, array([1.0686395], dtype=float32), array([0.85218966], dtype=float32))
(440, 0.009827888, array([1.0641443], dtype=float32), array([0.868419], dtype=float32))
(460, 0.008582776, array([1.0599433], dtype=float32), array([0.88358533], dtype=float32))
(480, 0.0074953884, array([1.0560175], dtype=float32), array([0.89775866], dtype=float32))
(500, 0.006545782, array([1.0523489], dtype=float32), array([0.9110037], dtype=float32))
(520, 0.005716468, array([1.0489205], dtype=float32), array([0.9233812], dtype=float32))
(540, 0.0049922303, array([1.0457168], dtype=float32), array([0.93494815], dtype=float32))
(560, 0.004359761, array([1.0427227], dtype=float32), array([0.94575745], dtype=float32))
(580, 0.0038074062, array([1.0399247], dtype=float32), array([0.95585895], dtype=float32))
(600, 0.0033250246, array([1.0373099], dtype=float32), array([0.96529907], dtype=float32))
(620, 0.0029037776, array([1.0348666], dtype=float32), array([0.9741207], dtype=float32))
(640, 0.0025359015, array([1.0325832], dtype=float32), array([0.9823645], dtype=float32))
(660, 0.002214623, array([1.0304493], dtype=float32), array([0.99006844], dtype=float32))
(680, 0.0019340345, array([1.028455], dtype=float32), array([0.99726814], dtype=float32))
(700, 0.00168901, array([1.0265915], dtype=float32), array([1.0039961], dtype=float32))
(720, 0.0014750187, array([1.02485], dtype=float32), array([1.0102835], dtype=float32))
(740, 0.0012881459, array([1.0232226], dtype=float32), array([1.0161589], dtype=float32))
(760, 0.0011249502, array([1.0217017], dtype=float32), array([1.0216497], dtype=float32))
(780, 0.0009824366, array([1.0202806], dtype=float32), array([1.026781], dtype=float32))
(800, 0.00085795636, array([1.0189523], dtype=float32), array([1.0315762], dtype=float32))
(820, 0.00074926845, array([1.017711], dtype=float32), array([1.0360574], dtype=float32))
(840, 0.0006543383, array([1.0165511], dtype=float32), array([1.0402449], dtype=float32))
(860, 0.00057143776, array([1.0154672], dtype=float32), array([1.0441583], dtype=float32))
(880, 0.00049904286, array([1.0144542], dtype=float32), array([1.0478154], dtype=float32))
(900, 0.0004358191, array([1.0135076], dtype=float32), array([1.0512332], dtype=float32))
(920, 0.00038059853, array([1.0126231], dtype=float32), array([1.0544269], dtype=float32))
(940, 0.00033238466, array([1.0117964], dtype=float32), array([1.0574113], dtype=float32))
(960, 0.0002902703, array([1.0110238], dtype=float32), array([1.0602009], dtype=float32))
(980, 0.00025349384, array([1.0103018], dtype=float32), array([1.0628073], dtype=float32))
(1000, 0.00022137808, array([1.009627], dtype=float32), array([1.0652432], dtype=float32))
(1020, 0.00019332914, array([1.0089965], dtype=float32), array([1.0675194], dtype=float32))
(1040, 0.00016882908, array([1.0084072], dtype=float32), array([1.0696473], dtype=float32))
(1060, 0.00014743926, array([1.0078566], dtype=float32), array([1.0716351], dtype=float32))
(1080, 0.00012875989, array([1.007342], dtype=float32), array([1.0734928], dtype=float32))
(1100, 0.00011244613, array([1.0068612], dtype=float32), array([1.0752288], dtype=float32))
(1120, 9.8200355e-05, array([1.0064118], dtype=float32), array([1.0768511], dtype=float32))
(1140, 8.5755724e-05, array([1.0059919], dtype=float32), array([1.0783674], dtype=float32))
(1160, 7.489431e-05, array([1.0055996], dtype=float32), array([1.0797837], dtype=float32))
(1180, 6.5406595e-05, array([1.0052328], dtype=float32), array([1.0811077], dtype=float32))
(1200, 5.7120622e-05, array([1.0048901], dtype=float32), array([1.0823449], dtype=float32))
(1220, 4.9882394e-05, array([1.0045699], dtype=float32), array([1.0835012], dtype=float32))
(1240, 4.3564207e-05, array([1.0042707], dtype=float32), array([1.0845816], dtype=float32))
(1260, 3.804614e-05, array([1.003991], dtype=float32), array([1.0855912], dtype=float32))
(1280, 3.3225275e-05, array([1.0037296], dtype=float32), array([1.0865349], dtype=float32))
(1300, 2.901571e-05, array([1.0034853], dtype=float32), array([1.0874166], dtype=float32))
(1320, 2.5340463e-05, array([1.003257], dtype=float32), array([1.0882409], dtype=float32))
(1340, 2.2129901e-05, array([1.0030438], dtype=float32), array([1.089011], dtype=float32))
(1360, 1.9328054e-05, array([1.0028446], dtype=float32), array([1.0897301], dtype=float32))
(1380, 1.6878726e-05, array([1.0026582], dtype=float32), array([1.0904027], dtype=float32))
(1400, 1.4740454e-05, array([1.0024842], dtype=float32), array([1.0910312], dtype=float32))
(1420, 1.2873619e-05, array([1.0023215], dtype=float32), array([1.0916185], dtype=float32))
(1440, 1.1241735e-05, array([1.0021695], dtype=float32), array([1.0921675], dtype=float32))
(1460, 9.818069e-06, array([1.0020274], dtype=float32), array([1.0926803], dtype=float32))
(1480, 8.574677e-06, array([1.0018947], dtype=float32), array([1.0931597], dtype=float32))
(1500, 7.4886166e-06, array([1.0017706], dtype=float32), array([1.0936075], dtype=float32))
(1520, 6.539272e-06, array([1.0016547], dtype=float32), array([1.0940262], dtype=float32))
(1540, 5.711003e-06, array([1.0015464], dtype=float32), array([1.0944173], dtype=float32))
(1560, 4.9874334e-06, array([1.001445], dtype=float32), array([1.094783], dtype=float32))
(1580, 4.3559958e-06, array([1.0013504], dtype=float32), array([1.0951246], dtype=float32))
(1600, 3.804345e-06, array([1.0012621], dtype=float32), array([1.0954438], dtype=float32))
(1620, 3.322312e-06, array([1.0011792], dtype=float32), array([1.0957422], dtype=float32))
(1640, 2.9007756e-06, array([1.0011021], dtype=float32), array([1.0960212], dtype=float32))
(1660, 2.5334934e-06, array([1.00103], dtype=float32), array([1.0962818], dtype=float32))
(1680, 2.2123513e-06, array([1.0009624], dtype=float32), array([1.0965253], dtype=float32))
(1700, 1.9319202e-06, array([1.0008993], dtype=float32), array([1.096753], dtype=float32))
(1720, 1.6872369e-06, array([1.0008405], dtype=float32), array([1.0969656], dtype=float32))
(1740, 1.4738443e-06, array([1.0007855], dtype=float32), array([1.0971642], dtype=float32))
(1760, 1.2871467e-06, array([1.0007341], dtype=float32), array([1.0973498], dtype=float32))
(1780, 1.12424e-06, array([1.0006859], dtype=float32), array([1.0975232], dtype=float32))
(1800, 9.815564e-07, array([1.0006411], dtype=float32), array([1.0976855], dtype=float32))
(1820, 8.573661e-07, array([1.0005993], dtype=float32), array([1.0978369], dtype=float32))
(1840, 7.4871434e-07, array([1.00056], dtype=float32), array([1.0979784], dtype=float32))
(1860, 6.5427787e-07, array([1.0005234], dtype=float32), array([1.0981107], dtype=float32))
(1880, 5.712507e-07, array([1.0004891], dtype=float32), array([1.0982342], dtype=float32))
(1900, 4.989224e-07, array([1.0004572], dtype=float32), array([1.0983498], dtype=float32))
(1920, 4.358085e-07, array([1.0004272], dtype=float32), array([1.0984578], dtype=float32))
(1940, 3.8070743e-07, array([1.0003992], dtype=float32), array([1.0985587], dtype=float32))
(1960, 3.3239553e-07, array([1.000373], dtype=float32), array([1.098653], dtype=float32))
(1980, 2.9042917e-07, array([1.0003488], dtype=float32), array([1.0987409], dtype=float32))
(2000, 2.5372992e-07, array([1.000326], dtype=float32), array([1.0988233], dtype=float32))