当前位置 博文首页 > 文章内容

    MATLAB 神经网络分类

    作者: 栏目:未分类 时间:2020-09-10 15:01:13

    本站于2023年9月4日。收到“大连君*****咨询有限公司”通知
    说我们IIS7站长博客,有一篇博文用了他们的图片。
    要求我们给他们一张图片6000元。要不然法院告我们

    为避免不必要的麻烦,IIS7站长博客,全站内容图片下架、并积极应诉
    博文内容全部不再显示,请需要相关资讯的站长朋友到必应搜索。谢谢!

    另祝:版权碰瓷诈骗团伙,早日弃暗投明。

    相关新闻:借版权之名、行诈骗之实,周某因犯诈骗罪被判处有期徒刑十一年六个月

    叹!百花齐放的时代,渐行渐远!



    注:这里的练习鉴于当时理解不完全,可能会有些错误,关于神经网络的实践可以参考我的这篇博文

    这里的代码只是简单的练习,不涉及代码优化,也不涉及神经网络优化,所以我用了最能体现原理的方式来写的代码。

    激活函数用的是h = 1/(1+exp(-y)),其中y=sum([X Y].*w)。

    代价函数用的是E = 1/2*(t-h)^2,其中t为目标值,t为1代表是该类,t为0代表不是该类。

    权值更新采用BP算法。

    网络1形式如下,没有隐含层,1个偏置量,输入直接连接输出:

    分类结果:

    代码如下:

      1 clear all;
      2 close all;
      3 clc;
      4 
      5 n=5;
      6 randn('seed',1);
      7 mu1=[0 0];
      8 S1=[0.5 0;
      9     0 0.5];
     10 P1=mvnrnd(mu1,S1,n);
     11 
     12 mu2=[0 6];
     13 S2=[0.5 0;
     14     0 0.5];
     15 P2=mvnrnd(mu2,S2,n);
     16 
     17 mu3=[6 3];
     18 S3=[0.5 0;
     19     0 0.5];
     20 P3=mvnrnd(mu3,S3,n);
     21 
     22 
     23 P=[P1;P2;P3];
     24 meanP=mean(P);
     25 
     26 P=[P(:,1)-meanP(1) P(:,2)-meanP(2)];
     27 
     28 sigma = 5;
     29 
     30 X=P(:,1);
     31 Y=P(:,2);
     32 B=rand(3*n,1);
     33 
     34 w1 = rand(3*n,1);
     35 w2 = rand(3*n,1);
     36 w3 = rand(3*n,1);
     37 
     38 w4 = rand(3*n,1);
     39 w5 = rand(3*n,1);
     40 w6 = rand(3*n,1);
     41 
     42 
     43 for i=1:3*n
     44     i
     45     while 1
     46         
     47         y1 = X(i)*w1(i) + Y(i)*w4(i) + B(i);       
     48         y2 = X(i)*w2(i) + Y(i)*w5(i) + B(i);        
     49         y3 = X(i)*w3(i) + Y(i)*w6(i) + B(i);     
     50         
     51         h1 = 1/(1+exp(-y1));
     52         h2 = 1/(1+exp(-y2));       
     53         h3 = 1/(1+exp(-y3));      
     54         
     55         e1  = 1/2*(1 - h1)^2;
     56         e2  = 1/2*(1 - h2)^2;       
     57         e3  = 1/2*(1 - h3)^2;
     58  
     59         if i<=n && e1<=0.0000001
     60             break;
     61         elseif i>n && i<=2*n && e2<0.0000001
     62             break;
     63         elseif i>2*n && e3<0.0000001
     64             break;
     65         end
     66         
     67         
     68         if i<=n
     69             w1(i) = w1(i)-sigma*(h1-1)*h1*(1-h1)*X(i);
     70             w2(i) = w2(i)-sigma*(h2-0)*h2*(1-h2)*X(i);
     71             w3(i) = w3(i)-sigma*(h3-0)*h3*(1-h3)*X(i);    
     72             
     73             w4(i) = w4(i)-sigma*(h1-1)*h1*(1-h1)*Y(i);
     74             w5(i) = w5(i)-sigma*(h2-0)*h2*(1-h2)*Y(i);
     75             w6(i) = w6(i)-sigma*(h3-0)*h3*(1-h3)*Y(i);                   
     76             
     77             B(i) =B(i)- sigma*((h1-1)*h1*(1-h1)+(h2-0)*h2*(1-h2)+(h3-0)*h3*(1-h3));
     78         elseif i>n && i<=2*n
     79             w1(i) = w1(i)-sigma*(h1-0)*h1*(1-h1)*X(i);
     80             w2(i) = w2(i)-sigma*(h2-1)*h2*(1-h2)*X(i);
     81             w3(i) = w3(i)-sigma*(h3-0)*h3*(1-h3)*X(i);    
     82             
     83             w4(i) = w4(i)-sigma*(h1-0)*h1*(1-h1)*Y(i);
     84             w5(i) = w5(i)-sigma*(h2-1)*h2*(1-h2)*Y(i);
     85             w6(i) = w6(i)-sigma*(h3-0)*h3*(1-h3)*Y(i);                   
     86             
     87             B(i) =B(i)- sigma*((h1-0)*h1*(1-h1)+(h2-1)*h2*(1-h2)+(h3-0)*h3*(1-h3));         
     88         else
     89             w1(i) = w1(i)-sigma*(h1-0)*h1*(1-h1)*X(i);
     90             w2(i) = w2(i)-sigma*(h2-0)*h2*(1-h2)*X(i);
     91             w3(i) = w3(i)-sigma*(h3-1)*h3*(1-h3)*X(i);    
     92             
     93             w4(i) = w4(i)-sigma*(h1-0)*h1*(1-h1)*Y(i);
     94             w5(i) = w5(i)-sigma*(h2-0)*h2*(1-h2)*Y(i);
     95             w6(i) = w6(i)-sigma*(h3-1)*h3*(1-h3)*Y(i);                   
     96             
     97             B(i) =B(i)- sigma*((h1-0)*h1*(1-h1)+(h2-0)*h2*(1-h2)+(h3-1)*h3*(1-h3));                   
     98         end
     99          
    100 
    101     end
    102 end
    103 
    104 plot(P(:,1),P(:,2),'o');
    105 hold on;
    106 
    107 flag = 0;
    108 M=[];
    109 for x=-8:0.3:8
    110     for y=-8:0.3:8
    111 
    112         H=[]; 
    113         for i=1:3*n
    114             y1 = x*w1(i)+y*w4(i) +B(i);
    115             y2 = x*w2(i)+y*w5(i) +B(i);
    116             y3 = x*w3(i)+y*w6(i) +B(i);
    117             h1=1/(1+exp(-y1));
    118             h2=1/(1+exp(-y2));
    119             h3=1/(1+exp(-y3));
    120             
    121             H=[H;h1 h2 h3];
    122         end
    123   %      H1 = mean(H(1:n,1));
    124   %      H2 = mean(H(n:2*n,2));
    125   %      H3 = mean(H(2*n:3*n,3));
    126         
    127         meanH = mean(H);
    128         H1 = meanH(1);
    129         H2 = meanH(2);
    130         H3= meanH(3);
    131         if H1>H2 && H1>H3
    132             plot(x,y,'g.')
    133         elseif H2 > H1 && H2 > H3
    134             plot(x,y,'r.')
    135         elseif H3 > H1 && H3 > H2
    136             plot(x,y,'b.')
    137         end
    138         
    139     end
    140 end

    网络2形式如下,有1个隐含层,2个偏置量:

     

    分类结果:

    代码如下:

      1 clear all;
      2 close all;
      3 clc;
      4 
      5 n=5;
      6 randn('seed',1);
      7 mu1=[0 0];
      8 S1=[0.5 0;
      9     0 0.5];
     10 P1=mvnrnd(mu1,S1,n);
     11 
     12 mu2=[0 6];
     13 S2=[0.5 0;
     14     0 0.5];
     15 P2=mvnrnd(mu2,S2,n);
     16 
     17 mu3=[6 3];
     18 S3=[0.5 0;
     19     0 0.5];
     20 P3=mvnrnd(mu3,S3,n);
     21 
     22 
     23 P=[P1;P2;P3];
     24 meanP=mean(P);
     25 
     26 P=[P(:,1)-meanP(1) P(:,2)-meanP(2)];
     27 
     28 sigma = 5;
     29 
     30 X=P(:,1);
     31 Y=P(:,2);
     32 
     33 B1=rand(3*n,1);
     34 B2=rand(3*n,1);
     35 
     36 w1 = rand(3*n,1);
     37 w2 = rand(3*n,1);
     38 
     39 w3 = rand(3*n,1);
     40 w4 = rand(3*n,1);
     41 w5 = rand(3*n,1);
     42 
     43 for i=1:3*n
     44     i
     45     while 1
     46         
     47         y0 = X(i)*w1(i) + Y(i)*w2(i) + B1(i);  
     48         h0 = 1/(1+exp(-y0));  
     49               
     50         y1 = h0*w3(i) + B2(i);        
     51         y2 = h0*w4(i) + B2(i);     
     52         y3 = h0*w5(i) + B2(i);
     53         
     54         h1 = 1/(1+exp(-y1));       
     55         h2 = 1/(1+exp(-y2));      
     56         h3 = 1/(1+exp(-y3));
     57         
     58         e1  = 1/2*(1 - h1)^2;
     59         e2  = 1/2*(1 - h2)^2;       
     60         e3  = 1/2*(1 - h3)^2;
     61  
     62         if i<=n && e1<=0.0000001
     63             break;
     64         elseif i>n && i<=2*n && e2<0.0000001
     65             break;
     66         elseif i>2*n && e3<0.0000001
     67             break;
     68         end
     69                
     70         %e1
     71         if i<=n
     72             
     73             w1(i) = w1(i)- sigma*((h1-1)*h1*(1-h1)*w3(i)*h0*(1-h0)*X(i) + (h2-0)*h2*(1-h2)*w4(i)*h0*(1-h0)*X(i) + (h3-0)*h3*(1-h3)*w5(i)*h0*(1-h0)*X(i));      
     74             w2(i) = w2(i)- sigma*((h1-1)*h1*(1-h1)*w3(i)*h0*(1-h0)*Y(i) + (h2-0)*h2*(1-h2)*w4(i)*h0*(1-h0)*Y(i) + (h3-0)*h3*(1-h3)*w5(i)*h0*(1-h0)*Y(i));           
     75             B1(i) = B1(i)- sigma*((h1-1)*h1*(1-h1)*w3(i)*h0*(1-h0)      + (h2-0)*h2*(1-h2)*w4(i)*h0*(1-h0)      + (h3-0)*h3*(1-h3)*w5(i)*h0*(1-h0));
     76             
     77             w3(i) = w3(i)-sigma*(h1-1)*h1*(1-h1)*h0;              
     78             w4(i) = w4(i)-sigma*(h2-0)*h2*(1-h2)*h0;
     79             w5(i) = w5(i)-sigma*(h3-0)*h3*(1-h3)*h0;
     80             B2(i) =B2(i)- sigma*((h1-1)*h1*(1-h1)+(h2-0)*h2*(1-h2)+(h3-0)*h3*(1-h3));   
     81                           
     82         elseif i>n && i<=2*n
     83             w1(i) = w1(i)-sigma*((h1-0)*h1*(1-h1)*w3(i)*h0*(1-h0)*X(i) + (h2-1)*h2*(1-h2)*w4(i)*h0*(1-h0)*X(i) + (h3-0)*h3*(1-h3)*w5(i)*h0*(1-h0)*X(i));      
     84             w2(i) = w2(i)-sigma*((h1-0)*h1*(1-h1)*w3(i)*h0*(1-h0)*Y(i) + (h2-1)*h2*(1-h2)*w4(i)*h0*(1-h0)*Y(i) + (h3-0)*h3*(1-h3)*w5(i)*h0*(1-h0)*Y(i));           
     85             B1(i) =B1(i)- sigma*((h1-0)*h1*(1-h1)*w3(i)*h0*(1-h0)      + (h2-1)*h2*(1-h2)*w4(i)*h0*(1-h0)      + (h3-0)*h3*(1-h3)*w5(i)*h0*(1-h0));
     86             
     87             w3(i) = w3(i)-sigma*(h1-0)*h1*(1-h1)*h0;              
     88             w4(i) = w4(i)-sigma*(h2-1)*h2*(1-h2)*h0;
     89             w5(i) = w5(i)-sigma*(h3-0)*h3*(1-h3)*h0;
     90             B2(i) =B2(i)- sigma*((h1-0)*h1*(1-h1)+(h2-1)*h2*(1-h2)+(h3-0)*h3*(1-h3));   
     91                      
     92         else
     93             w1(i) = w1(i)-sigma*((h1-0)*h1*(1-h1)*w3(i)*h0*(1-h0)*X(i) + (h2-0)*h2*(1-h2)*w4(i)*h0*(1-h0)*X(i) + (h3-1)*h3*(1-h3)*w5(i)*h0*(1-h0)*X(i));      
     94             w2(i) = w2(i)-sigma*((h1-0)*h1*(1-h1)*w3(i)*h0*(1-h0)*Y(i) + (h2-0)*h2*(1-h2)*w4(i)*h0*(1-h0)*Y(i) + (h3-1)*h3*(1-h3)*w5(i)*h0*(1-h0)*Y(i));           
     95             B1(i) =B1(i)- sigma*((h1-0)*h1*(1-h1)*w3(i)*h0*(1-h0)      + (h2-0)*h2*(1-h2)*w4(i)*h0*(1-h0)      + (h3-1)*h3*(1-h3)*w5(i)*h0*(1-h0));
     96           
     97             w3(i) = w3(i)-sigma*(h1-0)*h1*(1-h1)*h0;              
     98             w4(i) = w4(i)-sigma*(h2-0)*h2*(1-h2)*h0;
     99             w5(i) = w5(i)-sigma*(h3-1)*h3*(1-h3)*h0;
    100             B2(i) =B2(i)- sigma*((h1-0)*h1*(1-h1)+(h2-0)*h2*(1-h2)+(h3-1)*h3*(1-h3));   
    101                              
    102         end
    103          
    104 
    105     end
    106 end
    107 
    108 
    109 plot(P(:,1),P(:,2),'o');
    110 hold on;
    111 
    112 flag = 0;
    113 M=[];
    114 for x=-8:0.3:8
    115     for y=-8:0.3:8
    116   
    117        H=[]; 
    118         for i=1:3*n
    119             y0 = x*w1(i)+y*w2(i) +B1(i);
    120             h0=1/(1+exp(-y0));     
    121             
    122             y1 = h0*w3(i) + B2(i);
    123             y2 = h0*w4(i) + B2(i);
    124             y3 = h0*w5(i) + B2(i);
    125 
    126             h1 =1/(1+exp(-y1));
    127             h2 =1/(1+exp(-y2));
    128             h3 =1/(1+exp(-y3));
    129             
    130             H=[H;h1 h2 h3];
    131         end
    132 
    133         meanH = mean(H);
    134        H1 = meanH(1);
    135         H2 = meanH(2);
    136        H3= meanH(3);
    137         if H1>H2 && H1>H3
    138             plot(x,y,'g.')
    139         elseif H2 > H1 && H2 > H3
    140             plot(x,y,'r.')
    141         elseif H3 > H1 && H3 > H2
    142             plot(x,y,'b.')
    143         end
    144         
    145     end
    146 end

    网络3形式如下,有2个隐含层,2个偏置量:

     

     

    分类结果:

    代码如下:

      1 clear all;
      2 close all;
      3 clc;
      4 
      5 n=5;
      6 randn('seed',1);
      7 mu1=[0 0];
      8 S1=[0.5 0;
      9     0 0.5];
     10 P1=mvnrnd(mu1,S1,n);
     11 
     12 mu2=[0 6];
     13 S2=[0.5 0;
     14     0 0.5];
     15 P2=mvnrnd(mu2,S2,n);
     16 
     17 mu3=[6 3];
     18 S3=[0.5 0;
     19     0 0.5];
     20 P3=mvnrnd(mu3,S3,n);
     21 
     22 
     23 P=[P1;P2;P3];
     24 meanP=mean(P);
     25 
     26 P=[P(:,1)-meanP(1) P(:,2)-meanP(2)];
     27 
     28 sigma = 20;
     29 
     30 X=P(:,1);
     31 Y=P(:,2);
     32 
     33 B1=rand(3*n,1);
     34 B2=rand(3*n,1);
     35 
     36 w1 = rand(3*n,1);
     37 w2 = rand(3*n,1);
     38 
     39 w3 = rand(3*n,1);
     40 w4 = rand(3*n,1);
     41 
     42 w5 = rand(3*n,1);
     43 w6 = rand(3*n,1);
     44 w7 = rand(3*n,1);
     45 
     46 w8 = rand(3*n,1);
     47 w9 = rand(3*n,1);
     48 w10 = rand(3*n,1);
     49 
     50 for i=1:3*n
     51     i
     52     while 1
     53         
     54         y1 = X(i)*w1(i) + Y(i)*w3(i) + B1(i);
     55         y2 = X(i)*w2(i) + Y(i)*w4(i) + B1(i);
     56         
     57         h1 = 1/(1+exp(-y1));  
     58         h2 = 1/(1+exp(-y2));        
     59         
     60         dh1 = h1*(1-h1);
     61         dh2 = h2*(1-h2);
     62         
     63         y3 = h1*w5(i) + h2*w8(i)+ B2(i);        
     64         y4 = h1*w6(i) + h2*w9(i)+ B2(i);      
     65         y5 = h1*w7(i) + h2*w10(i)+ B2(i);    
     66         
     67         h3 = 1/(1+exp(-y3));       
     68         h4 = 1/(1+exp(-y4));      
     69         h5 = 1/(1+exp(-y5));
     70         
     71         dh3 = h3*(1-h3);
     72         dh4 = h4*(1-h4);
     73         dh5 = h5*(1-h5);
     74         
     75         e1  = 1/2*(1 - h3)^2;
     76         e2  = 1/2*(1 - h4)^2;       
     77         e3  = 1/2*(1 - h5)^2;
     78  
     79         if i<=n && e1<=0.0000001
     80             break;
     81         elseif i>n && i<=2*n && e2<0.0000001
     82             break;
     83         elseif i>2*n && e3<0.0000001
     84             break;
     85         end
     86                
     87         %e1
     88         if i<=n
     89             
     90             w1(i) = w1(i) -sigma * ((h3-1)*dh3*w5(i)+(h4-0)*dh4*w6(i)+(h5-0)*dh5*w7(i))  * dh1*X(i);
     91             w2(i) = w2(i) -sigma * ((h3-1)*dh3*w8(i)+(h4-0)*dh4*w9(i)+(h5-0)*dh5*w10(i)) * dh2*X(i);          
     92             
     93             w3(i) = w3(i) -sigma * ((h3-1)*dh3*w5(i)+(h4-0)*dh4*w6(i)+(h5-0)*dh5*w7(i))  * dh1*Y(i);
     94             w4(i) = w4(i) -sigma * ((h3-1)*dh3*w8(i)+(h4-0)*dh4*w9(i)+(h5-0)*dh5*w10(i)) * dh2*Y(i);       
     95                      
     96             B1(i) = B1(i)- sigma*(((h3-1)*dh3*w5(i)+(h4-0)*dh4*w6(i)+(h5-0)*dh5*w7(i))*dh1+((h3-1)*dh3*w8(i)+(h4-0)*dh4*w9(i)+(h5-0)*dh5*w10(i))*dh2);
     97             
     98             w5(i) = w5(i)-sigma*(h3-1)*dh3*h1;              
     99             w6(i) = w6(i)-sigma*(h4-0)*dh4*h1;
    100             w7(i) = w7(i)-sigma*(h5-0)*dh5*h1;
    101             
    102             w8(i) = w8(i)-sigma*(h3-1)*dh3*h2;              
    103             w9(i) = w9(i)-sigma*(h4-0)*dh4*h2;
    104             w10(i) = w10(i)-sigma*(h5-0)*dh5*h2;         
    105             
    106             B2(i) =B2(i)- sigma*((h3-1)*dh3+(h4-0)*dh4+(h5-0)*dh5);   
    107                           
    108         elseif i>n && i<=2*n
    109             w1(i) = w1(i) -sigma * ((h3-0)*dh3*w5(i)+(h4-1)*dh4*w6(i)+(h5-0)*dh5*w7(i))  * dh1*X(i);
    110             w2(i) = w2(i) -sigma * ((h3-0)*dh3*w8(i)+(h4-1)*dh4*w9(i)+(h5-0)*dh5*w10(i)) * dh2*X(i);          
    111             
    112             w3(i) = w3(i) -sigma * ((h3-0)*dh3*w5(i)+(h4-1)*dh4*w6(i)+(h5-0)*dh5*w7(i))  * dh1*Y(i);
    113             w4(i) = w4(i) -sigma * ((h3-0)*dh3*w8(i)+(h4-1)*dh4*w9(i)+(h5-0)*dh5*w10(i)) * dh2*Y(i);       
    114                      
    115             B1(i) = B1(i)- sigma*(((h3-0)*dh3*w5(i)+(h4-1)*dh4*w6(i)+(h5-0)*dh5*w7(i))*dh1+((h3-0)*dh3*w8(i)+(h4-1)*dh4*w9(i)+(h5-0)*dh5*w10(i))*dh2);
    116             
    117             w5(i) = w5(i)-sigma*(h3-0)*dh3*h1;              
    118             w6(i) = w6(i)-sigma*(h4-1)*dh4*h1;
    119             w7(i) = w7(i)-sigma*(h5-0)*dh5*h1;
    120             
    121             w8(i) = w8(i)-sigma*(h3-0)*dh3*h2;              
    122             w9(i) = w9(i)-sigma*(h4-1)*dh4*h2;
    123             w10(i) = w10(i)-sigma*(h5-0)*dh5*h2;         
    124             
    125             B2(i) =B2(i)- sigma*((h3-0)*dh3+(h4-1)*dh4+(h5-0)*dh5);   
    126                      
    127         else
    128             w1(i) = w1(i) -sigma * ((h3-0)*dh3*w5(i)+(h4-0)*dh4*w6(i)+(h5-1)*dh5*w7(i))  * dh1*X(i);
    129             w2(i) = w2(i) -sigma * ((h3-0)*dh3*w8(i)+(h4-0)*dh4*w9(i)+(h5-1)*dh5*w10(i)) * dh2*X(i);          
    130             
    131             w3(i) = w3(i) -sigma * ((h3-0)*dh3*w5(i)+(h4-0)*dh4*w6(i)+(h5-1)*dh5*w7(i))  * dh1*Y(i);
    132             w4(i) = w4(i) -sigma * ((h3-0)*dh3*w8(i)+(h4-0)*dh4*w9(i)+(h5-1)*dh5*w10(i)) * dh2*Y(i);       
    133                      
    134             B1(i) = B1(i)- sigma*(((h3-0)*dh3*w5(i)+(h4-0)*dh4*w6(i)+(h5-1)*dh5*w7(i))*dh1+((h3-0)*dh3*w8(i)+(h4-0)*dh4*w9(i)+(h5-1)*dh5*w10(i))*dh2);
    135             
    136             w5(i) = w5(i)-sigma*(h3-0)*dh3*h1;              
    137             w6(i) = w6(i)-sigma*(h4-0)*dh4*h1;
    138             w7(i) = w7(i)-sigma*(h5-1)*dh5*h1;
    139             
    140             w8(i) = w8(i)-sigma*(h3-0)*dh3*h2;              
    141             w9(i) = w9(i)-sigma*(h4-0)*dh4*h2;
    142