import java.util.*;

public class Matriisit {
    // olettaa että matriisit voi kertoa
    public static int[][] perus(int[][] a, int[][] b) {
        int[][] c = new int[a.length][b[0].length];
        for (int i = 0; i < a.length; i++) {
            for (int j = 0; j < b[0].length; j++) {
                for (int k = 0; k < a[0].length; k++) {
                    c[i][j] += a[i][k]*b[k][j];
                }
            }
        }
        return c;
    }

    // olettaa että kumpikin matriisi on 2^kx2^k
    public static int[][] strassen(int[][] a, int[][] b) {
       int n = a.length, p = n/2;
       if (n == 1) {
           return new int[][] {{a[0][0]*b[0][0]}};
       }
       int[][] m1a = new int[p][p], m1b = new int[p][p];
       int[][] m2a = new int[p][p], m2b = new int[p][p];
       int[][] m3a = new int[p][p], m3b = new int[p][p];
       int[][] m4a = new int[p][p], m4b = new int[p][p];
       int[][] m5a = new int[p][p], m5b = new int[p][p];
       int[][] m6a = new int[p][p], m6b = new int[p][p];
       int[][] m7a = new int[p][p], m7b = new int[p][p];
       for (int i = 0; i < p; i++) {
           for (int j = 0; j < p; j++) {
               m1a[i][j] = a[i][j]+a[p+i][p+j];
               m1b[i][j] = b[i][j]+b[p+i][p+j];
               m2a[i][j] = a[p+i][j]+a[p+i][p+j];
               m2b[i][j] = b[i][j];
               m3a[i][j] = a[i][j];
               m3b[i][j] = b[i][p+j]-b[p+i][p+j];
               m4a[i][j] = a[p+i][p+j];
               m4b[i][j] = b[p+i][j]-b[i][j];
               m5a[i][j] = a[i][j]+a[i][p+j];
               m5b[i][j] = b[p+i][p+j];
               m6a[i][j] = a[p+i][j]-a[i][j];
               m6b[i][j] = b[i][j]+b[i][p+j];
               m7a[i][j] = a[i][p+j]-a[p+i][p+j];
               m7b[i][j] = b[p+i][j]+b[p+i][p+j];
           }
       }
       int[][] m1 = strassen(m1a, m1b);
       int[][] m2 = strassen(m2a, m2b);
       int[][] m3 = strassen(m3a, m3b);
       int[][] m4 = strassen(m4a, m4b);
       int[][] m5 = strassen(m5a, m5b);
       int[][] m6 = strassen(m6a, m6b);
       int[][] m7 = strassen(m7a, m7b);
       int[][] c = new int[n][n];
       for (int i = 0; i < p; i++) {
           for (int j = 0; j < p; j++) {
               c[i][j] = m1[i][j]+m4[i][j]-m5[i][j]+m7[i][j];
               c[i][p+j] = m3[i][j]+m5[i][j];
               c[p+i][j] = m2[i][j]+m4[i][j];
               c[p+i][p+j] = m1[i][j]-m2[i][j]+m3[i][j]+m6[i][j];
           }
       }
       return c;
    }

    public static void main(String[] args) {
        int n = 1024;
        System.out.println("Matriisin koko " + n + "x" + n);
        int[][] a = new int[n][n];
        int[][] b = new int[n][n];
        int[][] c = new int[n][n];
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                a[i][j] = (int)(Math.random()*10);
                b[i][j] = (int)(Math.random()*10);
            }
        }
        long alku = System.currentTimeMillis();
        c = perus(a, b);
        long loppu = System.currentTimeMillis();
        System.out.println("Perusalgoritmi: " + (loppu-alku) + " ms");
        alku = System.currentTimeMillis();
        c = strassen(a, b);
        loppu = System.currentTimeMillis();
        System.out.println("Strassenin algoritmi: " + (loppu-alku) + " ms");
    }
}
