SGD

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks; namespace ConsoleApp4
{
class Program
{
static void Main(string[] args)
{
List<float[]> inputs_x = new List<float[]>();
inputs_x.Add( new float[] { 0.9f, 0.6f});
inputs_x.Add(new float[] { 2f, 2.5f } );
inputs_x.Add(new float[] { 2.6f, 2.3f });
inputs_x.Add(new float[] { 2.7f, 1.9f }); List<float> inputs_y = new List<float>();
inputs_y.Add( 2.5f);
inputs_y.Add( 2.5f);
inputs_y.Add( 3.5f);
inputs_y.Add( 4.2f); float[] weights = new float[3];
for (var i= 0;i < weights.Length;i++)
weights[i] = (float)new Random().NextDouble(); int epoch = 30000;
float epsilon =0.00001f;
float lr = 0.01f; float lastCost=0; for (var epoch_i = 0; epoch_i <= epoch; epoch_i++)
{
//随机获取input
var batch = GetRandomBatch(inputs_x, inputs_y, 2); float[] weights_in_poch = new float[weights.Length]; foreach (var x_y in batch)
{
var x1 = x_y.Item1.First();
var x2 = x_y.Item1.Skip(1).Take(1).First();
var target_y = x_y.Item2; float diffWithTargetY = target_y - fun(x1, x2, weights[1], weights[2], weights[0]); weights_in_poch[0] += diffWithTargetY * dy_b(x1, x2);
weights_in_poch[1] += diffWithTargetY * dy_theta1(x1, x2);
weights_in_poch[2] += diffWithTargetY * dy_theta2(x1, x2);
} for(var i=0;i<weights.Length;i++)
weights[i] += lr * weights_in_poch[i]; float totalErrorCost = 0f;
foreach (var x_y in batch)
{
var x1 = x_y.Item1.First();
var x2 = x_y.Item1.Skip(1).Take(1).First();
var target_y = x_y.Item2; float diffWithTargetY = target_y - fun(x1, x2, weights[1], weights[2], weights[0]);
totalErrorCost += (float)System.Math.Pow(diffWithTargetY, 2)/2;
} float cost = totalErrorCost / batch.Count; if (System.Math.Abs(cost - lastCost) <= epsilon)
{
Console.WriteLine(string.Format("EPOCH {0}", epoch_i));
Console.WriteLine(string.Format("LAST MSE {0}", lastCost));
Console.WriteLine(string.Format("MSE {0}", cost));
break;
} lastCost = cost; if (epoch_i % 100 == 0|| epoch_i==epoch)
{
Console.WriteLine(string.Format("MSE {0}", cost));
}
} print(weights[1], weights[2], weights[0]); Console.ReadLine();
} private static List<Tuple<float[], float>> GetRandomBatch(List<float[]> inputs_x, List<float> inputs_y, int maxCount)
{
List<Tuple<float[], float>> lst = new List<Tuple<float[], float>>(); System.Random rnd = new Random((int)DateTime.Now.Ticks); int count = 0;
while (count<maxCount)
{
int rndIndex = rnd.Next(inputs_x.Count);
var item=Tuple.Create<float[], float>(inputs_x[rndIndex], inputs_y[rndIndex]);
lst.Add(item);
count++;
} return lst;
} private static void print(float theta1, float theta2, float b)
{
Console.WriteLine(string.Format("y={0}*x1+{1}*x2+{2}", theta1, theta2, b));
}
private static float fun(float x1, float x2, float theta1, float theta2, float b)
{
return theta1 * x1 + theta2 * x2 + b;
}
private static float dy_theta1(float x1, float x2)
{
return x1;
} private static float dy_theta2(float x1, float x2)
{
return x2;
} private static float dy_b(float x1, float x2)
{
return 1;
}
}
}

  

上一篇:WPF:验证登录后关闭登录窗口,显示主窗口的解决方法


下一篇:scala中隐式转换之隐式值和隐式视图