博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
多元线性回归----Java简单实现
阅读量:4959 次
发布时间:2019-06-12

本文共 4823 字,大约阅读时间需要 16 分钟。

学习Andrew N.g的机器学习课程之后的简单实现.

课程地址:https://class.coursera.org/ml-007

 不大会编辑公式,所以略去具体的推导,有疑惑的同学去看看Andrew 的课程吧,顺带一句,Andrew的课程实在是很赞。

如果还有疑问,feel free to contact me via emails or QQ.

 

LinearRegression.java

import java.io.BufferedReader;import java.io.File;import java.io.FileReader;import java.io.IOException;public class LinearRegression {    /*     * 训练数据示例:     *   x0        x1        x2        y         1.0       1.0       2.0       7.2         1.0       2.0       1.0       4.9         1.0       3.0       0.0       2.6         1.0       4.0       1.0       6.3         1.0       5.0      -1.0       1.0         1.0       6.0       0.0       4.7         1.0       7.0      -2.0      -0.6         注意!!!!x1,x2,y三列是用户实际输入的数据,x0是为了推导出来的公式统一,特地补上的一列。        x0,x1,x2是“特征”,y是结果                h(x) = theta0 * x0 + theta1* x1 + theta2 * x2                theta0,theta1,theta2 是想要训练出来的参数                 此程序采用“梯度下降法”             *      */    private double [][] trainData;//训练数据,一行一个数据,每一行最后一个数据为 y    private int row;//训练数据  行数    private int column;//训练数据 列数        private double [] theta;//参数theta        private double alpha;//训练步长    private int iteration;//迭代次数        public LinearRegression(String fileName)    {           int rowoffile=getRowNumber(fileName);//获取输入训练数据文本的   行数        int columnoffile = getColumnNumber(fileName);//获取输入训练数据文本的   列数                trainData = new double[rowoffile][columnoffile+1];//这里需要注意,为什么要+1,因为为了使得公式整齐,我们加了一个特征x0,x0恒等于1        this.row=rowoffile;        this.column=columnoffile+1;                this.alpha = 0.001;//步长默认为0.001        this.iteration=100000;//迭代次数默认为 100000                theta = new double [column-1];//h(x)=theta0 * x0 + theta1* x1 + theta2 * x2 + .......        initialize_theta();                loadTrainDataFromFile(fileName,rowoffile,columnoffile);    }    public LinearRegression(String fileName,double alpha,int iteration)    {           int rowoffile=getRowNumber(fileName);//获取输入训练数据文本的   行数        int columnoffile = getColumnNumber(fileName);//获取输入训练数据文本的   列数                trainData = new double[rowoffile][columnoffile+1];//这里需要注意,为什么要+1,因为为了使得公式整齐,我们加了一个特征x0,x0恒等于1        this.row=rowoffile;        this.column=columnoffile+1;                this.alpha = alpha;        this.iteration=iteration;                theta = new double [column-1];//h(x)=theta0 * x0 + theta1* x1 + theta2 * x2 + .......        initialize_theta();                loadTrainDataFromFile(fileName,rowoffile,columnoffile);    }            private int getRowNumber(String fileName)    {        int count =0;        File file = new File(fileName);        BufferedReader reader = null;        try {            reader = new BufferedReader(new FileReader(file));            while ( reader.readLine() != null)                 count++;            reader.close();        } catch (IOException e) {            e.printStackTrace();        } finally {            if (reader != null) {                try {                    reader.close();                } catch (IOException e1) {                }            }        }        return count;            }        private int getColumnNumber(String fileName)    {        int count =0;        File file = new File(fileName);        BufferedReader reader = null;        try {            reader = new BufferedReader(new FileReader(file));            String tempString = reader.readLine();            if(tempString!=null)                count = tempString.split(" ").length;            reader.close();        } catch (IOException e) {            e.printStackTrace();        } finally {            if (reader != null) {                try {                    reader.close();                } catch (IOException e1) {                }            }        }        return count;    }        private void initialize_theta()//将theta各个参数全部初始化为1.0    {        for(int i=0;i
0 ) { //对每个theta i 求 偏导数 double [] partial_derivative = compute_partial_derivative();//偏导数 //更新每个theta for(int i =0; i< theta.length;i++) theta[i]-= alpha * partial_derivative[i]; } } private double [] compute_partial_derivative() { double [] partial_derivative = new double[theta.length]; for(int j =0;j

TestLinearRegression.java

public class TestLinearRegression {    public static void main(String[] args) {        // TODO Auto-generated method stub         LinearRegression m = new LinearRegression("trainData",0.001,1000000);         m.printTrainData();         m.trainTheta();         m.printTheta();    }}

trainData文件中是训练数据,默认最后一列是y,比如:

             1.0       2.0       7.2

             2.0       1.0       4.9
             3.0       0.0       2.6
             4.0       1.0       6.3
             5.0      -1.0       1.0
            6.0       0.0       4.7
            7.0      -2.0      -0.6

前两列是“feature”,最后一列,也就是第三列是y

 

Email: wuzimian2006@163.com

 QQ:    726590906    

转载于:https://www.cnblogs.com/wzm-xu/p/4062266.html

你可能感兴趣的文章
N进制到M进制的转换问题
查看>>
Android------三种监听OnTouchListener、OnLongClickListener同时实现即其中返回值true或者false的含义...
查看>>
MATLAB实现多元线性回归预测
查看>>
Mac xcode 配置OpenGL
查看>>
利用sed把一行的文本文件改成每句一行
查看>>
使用Asyncio的Coroutine来实现一个有限状态机
查看>>
Android应用开发:核心技术解析与最佳实践pdf
查看>>
python——爬虫
查看>>
2.2 标识符
查看>>
孤荷凌寒自学python第五十八天成功使用python来连接上远端MongoDb数据库
查看>>
求一个字符串中最长回文子串的长度(承接上一个题目)
查看>>
简单权限管理系统原理浅析
查看>>
springIOC第一个课堂案例的实现
查看>>
求输入成绩的平均分
查看>>
php PDO (转载)
查看>>
wordpress自动截取文章摘要代码
查看>>
[置顶] 一名优秀的程序设计师是如何管理知识的?
查看>>
scanf和gets
查看>>
highcharts 图表实例
查看>>
ubuntu下如何查看用户登录及系统授权相关信息
查看>>