学习Andrew N.g的机器学习课程之后的简单实现.
课程地址:https://class.coursera.org/ml-007
不大会编辑公式,所以略去具体的推导,有疑惑的同学去看看Andrew 的课程吧,顺带一句,Andrew的课程实在是很赞。
如果还有疑问,feel free to contact me via emails or QQ.
LinearRegression.java
import java.io.BufferedReader;import java.io.File;import java.io.FileReader;import java.io.IOException;public class LinearRegression { /* * 训练数据示例: * x0 x1 x2 y 1.0 1.0 2.0 7.2 1.0 2.0 1.0 4.9 1.0 3.0 0.0 2.6 1.0 4.0 1.0 6.3 1.0 5.0 -1.0 1.0 1.0 6.0 0.0 4.7 1.0 7.0 -2.0 -0.6 注意!!!!x1,x2,y三列是用户实际输入的数据,x0是为了推导出来的公式统一,特地补上的一列。 x0,x1,x2是“特征”,y是结果 h(x) = theta0 * x0 + theta1* x1 + theta2 * x2 theta0,theta1,theta2 是想要训练出来的参数 此程序采用“梯度下降法” * */ private double [][] trainData;//训练数据,一行一个数据,每一行最后一个数据为 y private int row;//训练数据 行数 private int column;//训练数据 列数 private double [] theta;//参数theta private double alpha;//训练步长 private int iteration;//迭代次数 public LinearRegression(String fileName) { int rowoffile=getRowNumber(fileName);//获取输入训练数据文本的 行数 int columnoffile = getColumnNumber(fileName);//获取输入训练数据文本的 列数 trainData = new double[rowoffile][columnoffile+1];//这里需要注意,为什么要+1,因为为了使得公式整齐,我们加了一个特征x0,x0恒等于1 this.row=rowoffile; this.column=columnoffile+1; this.alpha = 0.001;//步长默认为0.001 this.iteration=100000;//迭代次数默认为 100000 theta = new double [column-1];//h(x)=theta0 * x0 + theta1* x1 + theta2 * x2 + ....... initialize_theta(); loadTrainDataFromFile(fileName,rowoffile,columnoffile); } public LinearRegression(String fileName,double alpha,int iteration) { int rowoffile=getRowNumber(fileName);//获取输入训练数据文本的 行数 int columnoffile = getColumnNumber(fileName);//获取输入训练数据文本的 列数 trainData = new double[rowoffile][columnoffile+1];//这里需要注意,为什么要+1,因为为了使得公式整齐,我们加了一个特征x0,x0恒等于1 this.row=rowoffile; this.column=columnoffile+1; this.alpha = alpha; this.iteration=iteration; theta = new double [column-1];//h(x)=theta0 * x0 + theta1* x1 + theta2 * x2 + ....... initialize_theta(); loadTrainDataFromFile(fileName,rowoffile,columnoffile); } private int getRowNumber(String fileName) { int count =0; File file = new File(fileName); BufferedReader reader = null; try { reader = new BufferedReader(new FileReader(file)); while ( reader.readLine() != null) count++; reader.close(); } catch (IOException e) { e.printStackTrace(); } finally { if (reader != null) { try { reader.close(); } catch (IOException e1) { } } } return count; } private int getColumnNumber(String fileName) { int count =0; File file = new File(fileName); BufferedReader reader = null; try { reader = new BufferedReader(new FileReader(file)); String tempString = reader.readLine(); if(tempString!=null) count = tempString.split(" ").length; reader.close(); } catch (IOException e) { e.printStackTrace(); } finally { if (reader != null) { try { reader.close(); } catch (IOException e1) { } } } return count; } private void initialize_theta()//将theta各个参数全部初始化为1.0 { for(int i=0;i0 ) { //对每个theta i 求 偏导数 double [] partial_derivative = compute_partial_derivative();//偏导数 //更新每个theta for(int i =0; i< theta.length;i++) theta[i]-= alpha * partial_derivative[i]; } } private double [] compute_partial_derivative() { double [] partial_derivative = new double[theta.length]; for(int j =0;j
TestLinearRegression.java
public class TestLinearRegression { public static void main(String[] args) { // TODO Auto-generated method stub LinearRegression m = new LinearRegression("trainData",0.001,1000000); m.printTrainData(); m.trainTheta(); m.printTheta(); }}
trainData文件中是训练数据,默认最后一列是y,比如:
1.0 2.0 7.2
2.0 1.0 4.9 3.0 0.0 2.6 4.0 1.0 6.3 5.0 -1.0 1.0 6.0 0.0 4.7 7.0 -2.0 -0.6前两列是“feature”,最后一列,也就是第三列是y
Email: wuzimian2006@163.com
QQ: 726590906