矩阵乘法的多线程实现:
/** *//**
* @Title: MultiThreadMatrix.java
* @Package matrix
* @Description: 多线程计算矩阵乘法
* @author Aloong
* @date 2010-10-28 下午09:45:56
* @version V1.0
*/
package matrix;
import java.util.Date;
public class MultiThreadMatrix
{
static int[][] matrix1;
static int[][] matrix2;
static int[][] matrix3;
static int m,n,k;
static int index;
static int threadCount;
static long startTime;
public static void main(String[] args) throws InterruptedException
{
//矩阵a高度m=100宽度k=80,矩阵b高度k=80宽度n=50 ==> 矩阵c高度m=100宽度n=50
m = 1024;
n = 1024;
k = 1024;
matrix1 = new int[m][k];
matrix2 = new int[k][n];
matrix3 = new int[m][n];
//随机初始化矩阵a,b
fillRandom(matrix1);
fillRandom(matrix2);
startTime = new Date().getTime();
//输出a,b
// printMatrix(matrix1);
// printMatrix(matrix2);
//创建线程,数量 <= 4
for(int i=0; i<4; i++)
{
if(index < m)
{
Thread t = new Thread(new MyThread());
t.start();
}else
{
break;
}
}
//等待结束后输出
while(threadCount!=0)
{
Thread.sleep(20);
}
// printMatrix(matrix3);
long finishTime = new Date().getTime();
System.out.println("计算完成,用时"+(finishTime-startTime)+"毫秒");
}
static void printMatrix(int[][] x)
{
for (int i=0; i<x.length; i++)
{
for(int j=0; j<x[i].length; j++)
{
System.out.print(x[i][j]+" ");
}
System.out.println("");
}
System.out.println("");
}
static void fillRandom(int[][] x)
{
for (int i=0; i<x.length; i++)
{
for(int j=0; j<x[i].length; j++)
{
//每个元素设置为0到99的随机自然数
x[i][j] = (int) (Math.random() * 100);
}
}
}
synchronized static int getTask()
{
if(index < m)
{
return index++;
}
return -1;
}
}
class MyThread implements Runnable
{
int task;
@Override
public void run()
{
MultiThreadMatrix.threadCount++;
while( (task = MultiThreadMatrix.getTask()) != -1 )
{
System.out.println("进程: "+Thread.currentThread().getName()+"\t开始计算第 "+(task+1)+"行");
for(int i=0; i<MultiThreadMatrix.n; i++)
{
for(int j=0; j<MultiThreadMatrix.k; j++)
{
MultiThreadMatrix.matrix3[task][i] += MultiThreadMatrix.matrix1[task][j] * MultiThreadMatrix.matrix2[j][i];
}
}
}
MultiThreadMatrix.threadCount--;
}
}
单线程:
/** *//**
* @Title: SingleThreadMatrix.java
* @Package matrix
* @Description: 单线程计算矩阵乘法
* @author Aloong
* @date 2010-10-28 下午11:33:18
* @version V1.0
*/
package matrix;
import java.util.Date;
public class SingleThreadMatrix
{
static int[][] matrix1;
static int[][] matrix2;
static int[][] matrix3;
static int m,n,k;
static long startTime;
public static void main(String[] args)
{
m = 1024;
n = 1024;
k = 1024;
matrix1 = new int[m][k];
matrix2 = new int[k][n];
matrix3 = new int[m][n];
fillRandom(matrix1);
fillRandom(matrix2);
startTime = new Date().getTime();
//输出a,b
// printMatrix(matrix1);
// printMatrix(matrix2);
for(int task=0; task<m; task++)
{
System.out.println("进程: "+Thread.currentThread().getName()+"\t开始计算第 "+(task+1)+"行");
for(int i=0; i<n; i++)
{
for(int j=0; j<k; j++)
{
matrix3[task][i] += matrix1[task][j] * matrix2[j][i];
}
}
}
// printMatrix(matrix3);
long finishTime = new Date().getTime();
System.out.println("计算完成,用时"+(finishTime-startTime)+"毫秒");
}
static void fillRandom(int[][] x)
{
for (int i=0; i<x.length; i++)
{
for(int j=0; j<x[i].length; j++)
{
//每个元素设置为0到99的随机自然数
x[i][j] = (int) (Math.random() * 100);
}
}
}
}
修改m,n,k的值可以修改相乘矩阵的阶数.
结果对比,计算1024阶矩阵的时候多线程用时约4.8秒,单线程用时16秒,
单线程占用内存21M,多线程占用16M.
本机是4核CPU,单线程的时候只有25%的CPU占用,使用4个子线程可以达到接近100%的CPU使用率.
另外请教一个问题,是矩阵乘法的Strassen算法
下面这个是来自网上的一段代码,在我自己的电脑上,只要超过12阶就会内存溢出
不解是什么原因,设置jvm的内存不管多大也会崩溃在12阶
请高手帮忙解答....
package matrix;
import java.io.*;
import java.util.*;
class Matrix //定义矩阵结构
{
public int[][] m = new int[32][32];
}
public class StrassenMatrix2
{
public int IfIsEven(int n)//判断输入矩阵阶数是否为2^k
{
int a = 0, temp = n;
while (temp % 2 == 0)
{
if (temp % 2 == 0)
temp /= 2;
else
a = 1;
}
if (temp == 1)
a = 0;
return a;
}
public void Divide(Matrix d, Matrix d11, Matrix d12, Matrix d21, Matrix d22, int n)//分解矩阵
{
int i, j;
for (i = 1; i <= n; i++)
for (j = 1; j <= n; j++)
{
d11.m[i][j] = d.m[i][j];
d12.m[i][j] = d.m[i][j + n];
d21.m[i][j] = d.m[i + n][j];
d22.m[i][j] = d.m[i + n][j + n];
}
}
public Matrix Merge(Matrix a11, Matrix a12, Matrix a21, Matrix a22, int n)//合并矩阵
{
int i, j;
Matrix a = new Matrix();
for (i = 1; i <= n; i++)
for (j = 1; j <= n; j++)
{
a.m[i][j] = a11.m[i][j];
a.m[i][j + n] = a12.m[i][j];
a.m[i + n][j] = a21.m[i][j];
a.m[i + n][j + n] = a22.m[i][j];
}
return a;
}
public Matrix TwoMatrixMultiply(Matrix x, Matrix y) //阶数为2的矩阵乘法
{
int m1, m2, m3, m4, m5, m6, m7;
Matrix z = new Matrix();
m1 = (y.m[1][2] - y.m[2][2]) * x.m[1][1];
m2 = y.m[2][2] * (x.m[1][1] + x.m[1][2]);
m3 = (x.m[2][1] + x.m[2][2]) * y.m[1][1];
m4 = x.m[2][2] * (y.m[2][1] - y.m[1][1]);
m5 = (x.m[1][1] + x.m[2][2]) * (y.m[1][1] + y.m[2][2]);
m6 = (x.m[1][2] - x.m[2][2]) * (y.m[2][1] + y.m[2][2]);
m7 = (x.m[1][1] - x.m[2][1]) * (y.m[1][1] + y.m[1][2]);
z.m[1][1] = m5 + m4 - m2 + m6;
z.m[1][2] = m1 + m2;
z.m[2][1] = m3 + m4;
z.m[2][2] = m5 + m1 - m3 - m7;
return z;
}
public Matrix MatrixPlus(Matrix f, Matrix g, int n) //矩阵加法
{
int i, j;
Matrix h = new Matrix();
for (i = 1; i <= n; i++)
for (j = 1; j <= n; j++)
h.m[i][j] = f.m[i][j] + g.m[i][j];
return h;
}
public Matrix MatrixMinus(Matrix f, Matrix g, int n) //矩阵减法方法
{
int i, j;
Matrix h = new Matrix();
for (i = 1; i <= n; i++)
for (j = 1; j <= n; j++)
h.m[i][j] = f.m[i][j] - g.m[i][j];
return h;
}
public Matrix MatrixMultiply(Matrix a, Matrix b, int n) //矩阵乘法方法
{
int k;
Matrix a11, a12, a21, a22;
a11 = new Matrix();
a12 = new Matrix();
a21 = new Matrix();
a22 = new Matrix();
Matrix b11, b12, b21, b22;
b11 = new Matrix();
b12 = new Matrix();
b21 = new Matrix();
b22 = new Matrix();
Matrix c11, c12, c21, c22, c;
c11 = new Matrix();
c12 = new Matrix();
c21 = new Matrix();
c22 = new Matrix();
c = new Matrix();
Matrix m1, m2, m3, m4, m5, m6, m7;
k = n;
if (k == 2)
{
c = TwoMatrixMultiply(a, b);
return c;
} else
{
k = n / 2;
Divide(a, a11, a12, a21, a22, k); //拆分A、B、C矩阵
Divide(b, b11, b12, b21, b22, k);
Divide(c, c11, c12, c21, c22, k);
m1 = MatrixMultiply(a11, MatrixMinus(b12, b22, k), k);
m2 = MatrixMultiply(MatrixPlus(a11, a12, k), b22, k);
m3 = MatrixMultiply(MatrixPlus(a21, a22, k), b11, k);
m4 = MatrixMultiply(a22, MatrixMinus(b21, b11, k), k);
m5 = MatrixMultiply(MatrixPlus(a11, a22, k),
MatrixPlus(b11, b22, k), k);
m6 = MatrixMultiply(MatrixMinus(a12, a22, k),
MatrixPlus(b21, b22, k), k);
m7 = MatrixMultiply(MatrixMinus(a11, a21, k),
MatrixPlus(b11, b12, k), k);
c11 = MatrixPlus(MatrixMinus(MatrixPlus(m5, m4, k), m2, k), m6, k);
c12 = MatrixPlus(m1, m2, k);
c21 = MatrixPlus(m3, m4, k);
c22 = MatrixMinus(MatrixMinus(MatrixPlus(m5, m1, k), m3, k), m7, k);
c = Merge(c11, c12, c21, c22, k); //合并C矩阵
return c;
}
}
public Matrix GetMatrix(Matrix X, int n)
{
int i, j;
X = new Matrix();
for (i = 1; i <= n; i++)
for (j = 1; j <= n; j++)
X.m[i][j] = (int) (Math.random() * 10);
for (i = 1; i <= n; i++)
{
for (j = 1; j <= n; j++)
System.out.print(X.m[i][j] + " ");
System.out.println();
}
return X;
}
public Matrix UsualMatrixMultiply(Matrix A, Matrix B, Matrix C, int n)
{
int i, j, t, k;
for (i = 1; i <= n; i++)
for (j = 1; j <= n; j++)
{
for (k = 1, t = 0; k <= n; k++)
t += A.m[i][k] * B.m[k][j];
C.m[i][j] = t;
}
return C;
}
public static void main(String[] args) throws IOException
{
StrassenMatrix2 instance = new StrassenMatrix2();
int i, j, n;
// Matrix A, B, C, D;
Matrix A, B, C;
A = new Matrix();
B = new Matrix();
C = new Matrix();
// D = new matrix();
Scanner in = new Scanner(System.in);
System.out.print("输入矩阵的阶数: ");
n = in.nextInt();
if (instance.IfIsEven(n) == 0)
{
System.out.println("矩阵A:");
A = instance.GetMatrix(A, n);
System.out.println("矩阵B:");
B = instance.GetMatrix(B, n);
if (n == 1)
C.m[1][1] = A.m[1][1] * B.m[1][1]; //矩阵阶数为1时的特殊处理
else
{
long startTime = new Date().getTime();
C = instance.MatrixMultiply(A, B, n);
long finishTime = new Date().getTime();
System.out.println("计算完成,用时"+(finishTime-startTime)+"毫秒");
}
System.out.println("Strassen矩阵C为:");
for (i = 1; i <= n; i++)
{
for (j = 1; j <= n; j++)
System.out.print(C.m[i][j] + " ");
System.out.println();
}
/**//* D = instance.UsualMatrixMultiply(A, B, D, n);
System.out.println("普通乘法矩阵D为:");
for (i = 1; i <= n; i++)
{
for (j = 1; j <= n; j++)
System.out.print(D.m[i][j] + " ");
System.out.println();
}*/
} else
System.out.println("输入的阶数不是2的N次方");
}
}
posted on 2010-10-29 16:23
ApolloDeng 阅读(4110)
评论(0) 编辑 收藏 所属分类:
提问 、
分享 、
Java