march alex's blog
hello,I am march alex
posts - 52,comments - 7,trackbacks - 0

实现时遇到一个问题,就是fi的值的设定问题,因为我们采用随机梯度下降法,一方面这节省了时间,但是如果fi值亘古不变的话可能会面临跳来跳去一直找不到答案的问题,所以我这里设定他得知在每一轮之后都会按比例减小(fi *= 0.5;),大家可以按自己的喜好自由设定。
import java.util.Scanner;


public class Perceptron {
    private static int N = 3;
    private static int n = 2;
    private static double[][] X = null;
    private static double[] Y = null;
    private static double[] W = null;
    private static double B = 0;
    private static double fi = 0.5;
    
    private static boolean check(int id) {
        double ans = B;
        for(int i=0;i<n;i++)
            ans += X[id][i] * W[i];
        if(ans * Y[id] > 0) return true;
        return false;
    }
    
    private static void debug() {
        System.out.print("debug: W");
        for(int i=0;i<n;i++) System.out.print(W[i] + " ");
        System.out.println("/ B : " + B);
    }
    
    public static void solve() {
        Scanner in = new Scanner(System.in);
        System.out.print("input N:"); N = in.nextInt();
        System.out.print("input n:"); n = in.nextInt();
        
        X = new double[N][n];
        Y = new double[N];
        W = new double[n];
        
        System.out.println("input N * n datas X[i][j]:");
        for(int i=0;i<N;i++)
            for(int j=0;j<n;j++)
                X[i][j] = in.nextDouble();
        System.out.println("input N datas Y[i]");
        for(int i=0;i<N;i++) 
            Y[i] = in.nextDouble();
        
        
        for(int i=0;i<n;i++) W[i] = 0;
        B = 0;
        
        boolean ok = true;
        while(ok == true) {
            ok = false;
            //这里在原来算法的基础上不断地将fi缩小,以避免跳来跳去一直达不到要求的点的效果。
            for(int i=0;i<N;i++) {
                //System.out.println("here " + i);
                while(check(i) == false) {
                    ok = true;
                    for(int j=0;j<n;j++)
                        W[j] += fi * Y[i] * X[i][j];
                    B += fi * Y[i];
                    //debug();
                }
            }
            fi *= 0.5;
        }
    }
    
    public static void main(String[] args) {
        solve();
        System.out.print("W = [");
        for(int i=0;i<n-1;i++) System.out.print(W[i] + ", ");
        System.out.println(W[n-1] + "]");
        System.out.println("B = " + B);
    }
}

posted on 2015-03-20 11:08 marchalex 阅读(630) 评论(0)  编辑  收藏 所属分类: java小程序

只有注册用户登录后才能发表评论。


网站导航: