Skip to content

Commit

Permalink
Merge pull request #1196 from heeba-khan/heeba
Browse files Browse the repository at this point in the history
Added Strassen's algo in cpp,java and python
  • Loading branch information
OtacilioN committed Oct 18, 2023
2 parents 9fa86ed + ae32b6f commit dffaaf2
Show file tree
Hide file tree
Showing 4 changed files with 384 additions and 0 deletions.
182 changes: 182 additions & 0 deletions Strassen's Algorithm/Strassens.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
#include <bits/stdc++.h>
using namespace std;

#define ROW_1 4
#define COL_1 4

#define ROW_2 4
#define COL_2 4

void print(string display, vector<vector<int> > matrix,
int start_row, int start_column, int end_row,
int end_column)
{
cout << endl << display << " =>" << endl;
for (int i = start_row; i <= end_row; i++) {
for (int j = start_column; j <= end_column; j++) {
cout << setw(10);
cout << matrix[i][j];
}
cout << endl;
}
cout << endl;
return;
}

vector<vector<int> >
add_matrix(vector<vector<int> > matrix_A,
vector<vector<int> > matrix_B, int split_index,
int multiplier = 1)
{
for (auto i = 0; i < split_index; i++)
for (auto j = 0; j < split_index; j++)
matrix_A[i][j]
= matrix_A[i][j]
+ (multiplier * matrix_B[i][j]);
return matrix_A;
}

vector<vector<int> >
multiply_matrix(vector<vector<int> > matrix_A,
vector<vector<int> > matrix_B)
{
int col_1 = matrix_A[0].size();
int row_1 = matrix_A.size();
int col_2 = matrix_B[0].size();
int row_2 = matrix_B.size();

if (col_1 != row_2) {
cout << "\nError: The number of columns in Matrix "
"A must be equal to the number of rows in "
"Matrix B\n";
return {};
}

vector<int> result_matrix_row(col_2, 0);
vector<vector<int> > result_matrix(row_1,
result_matrix_row);

if (col_1 == 1)
result_matrix[0][0]
= matrix_A[0][0] * matrix_B[0][0];
else {
int split_index = col_1 / 2;

vector<int> row_vector(split_index, 0);

vector<vector<int> > a00(split_index, row_vector);
vector<vector<int> > a01(split_index, row_vector);
vector<vector<int> > a10(split_index, row_vector);
vector<vector<int> > a11(split_index, row_vector);
vector<vector<int> > b00(split_index, row_vector);
vector<vector<int> > b01(split_index, row_vector);
vector<vector<int> > b10(split_index, row_vector);
vector<vector<int> > b11(split_index, row_vector);

for (auto i = 0; i < split_index; i++)
for (auto j = 0; j < split_index; j++) {
a00[i][j] = matrix_A[i][j];
a01[i][j] = matrix_A[i][j + split_index];
a10[i][j] = matrix_A[split_index + i][j];
a11[i][j] = matrix_A[i + split_index]
[j + split_index];
b00[i][j] = matrix_B[i][j];
b01[i][j] = matrix_B[i][j + split_index];
b10[i][j] = matrix_B[split_index + i][j];
b11[i][j] = matrix_B[i + split_index]
[j + split_index];
}

vector<vector<int> > p(multiply_matrix(
a00, add_matrix(b01, b11, split_index, -1)));
vector<vector<int> > q(multiply_matrix(
add_matrix(a00, a01, split_index), b11));
vector<vector<int> > r(multiply_matrix(
add_matrix(a10, a11, split_index), b00));
vector<vector<int> > s(multiply_matrix(
a11, add_matrix(b10, b00, split_index, -1)));
vector<vector<int> > t(multiply_matrix(
add_matrix(a00, a11, split_index),
add_matrix(b00, b11, split_index)));
vector<vector<int> > u(multiply_matrix(
add_matrix(a01, a11, split_index, -1),
add_matrix(b10, b11, split_index)));
vector<vector<int> > v(multiply_matrix(
add_matrix(a00, a10, split_index, -1),
add_matrix(b00, b01, split_index)));

vector<vector<int> > result_matrix_00(add_matrix(
add_matrix(add_matrix(t, s, split_index), u,
split_index),
q, split_index, -1));
vector<vector<int> > result_matrix_01(
add_matrix(p, q, split_index));
vector<vector<int> > result_matrix_10(
add_matrix(r, s, split_index));
vector<vector<int> > result_matrix_11(add_matrix(
add_matrix(add_matrix(t, p, split_index), r,
split_index, -1),
v, split_index, -1));

for (auto i = 0; i < split_index; i++)
for (auto j = 0; j < split_index; j++) {
result_matrix[i][j]
= result_matrix_00[i][j];
result_matrix[i][j + split_index]
= result_matrix_01[i][j];
result_matrix[split_index + i][j]
= result_matrix_10[i][j];
result_matrix[i + split_index]
[j + split_index]
= result_matrix_11[i][j];
}

a00.clear();
a01.clear();
a10.clear();
a11.clear();
b00.clear();
b01.clear();
b10.clear();
b11.clear();
p.clear();
q.clear();
r.clear();
s.clear();
t.clear();
u.clear();
v.clear();
result_matrix_00.clear();
result_matrix_01.clear();
result_matrix_10.clear();
result_matrix_11.clear();
}
return result_matrix;
}

int main()
{
vector<vector<int> > matrix_A = { { 1, 1, 1, 1 },
{ 2, 2, 2, 2 },
{ 3, 3, 3, 3 },
{ 2, 2, 2, 2 } };

print("Array A", matrix_A, 0, 0, ROW_1 - 1, COL_1 - 1);

vector<vector<int> > matrix_B = { { 1, 1, 1, 1 },
{ 2, 2, 2, 2 },
{ 3, 3, 3, 3 },
{ 2, 2, 2, 2 } };

print("Array B", matrix_B, 0, 0, ROW_2 - 1, COL_2 - 1);

vector<vector<int> > result_matrix(
multiply_matrix(matrix_A, matrix_B));

print("Result Array", result_matrix, 0, 0, ROW_1 - 1,
COL_2 - 1);
}

// Time Complexity: T(N) = 7T(N/2) + O(N^2) => O(N^Log7)
// which is approximately O(N^2.8074)
// Code Contributed By: Heeba Khan
145 changes: 145 additions & 0 deletions Strassen's Algorithm/Strassens.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@

import java.util.Scanner;


public class Strassen
{

public int[][] multiply(int[][] A, int[][] B)
{
int n = A.length;
int[][] R = new int[n][n];

if (n == 1)
R[0][0] = A[0][0] * B[0][0];
else
{
int[][] A11 = new int[n/2][n/2];
int[][] A12 = new int[n/2][n/2];
int[][] A21 = new int[n/2][n/2];
int[][] A22 = new int[n/2][n/2];
int[][] B11 = new int[n/2][n/2];
int[][] B12 = new int[n/2][n/2];
int[][] B21 = new int[n/2][n/2];
int[][] B22 = new int[n/2][n/2];


split(A, A11, 0 , 0);
split(A, A12, 0 , n/2);
split(A, A21, n/2, 0);
split(A, A22, n/2, n/2);

split(B, B11, 0 , 0);
split(B, B12, 0 , n/2);
split(B, B21, n/2, 0);
split(B, B22, n/2, n/2);



int [][] M1 = multiply(add(A11, A22), add(B11, B22));
int [][] M2 = multiply(add(A21, A22), B11);
int [][] M3 = multiply(A11, sub(B12, B22));
int [][] M4 = multiply(A22, sub(B21, B11));
int [][] M5 = multiply(add(A11, A12), B22);
int [][] M6 = multiply(sub(A21, A11), add(B11, B12));
int [][] M7 = multiply(sub(A12, A22), add(B21, B22));


int [][] C11 = add(sub(add(M1, M4), M5), M7);
int [][] C12 = add(M3, M5);
int [][] C21 = add(M2, M4);
int [][] C22 = add(sub(add(M1, M3), M2), M6);


join(C11, R, 0 , 0);
join(C12, R, 0 , n/2);
join(C21, R, n/2, 0);
join(C22, R, n/2, n/2);
}

return R;
}

public int[][] sub(int[][] A, int[][] B)
{
int n = A.length;
int[][] C = new int[n][n];
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
C[i][j] = A[i][j] - B[i][j];
return C;
}

public int[][] add(int[][] A, int[][] B)
{
int n = A.length;
int[][] C = new int[n][n];
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
C[i][j] = A[i][j] + B[i][j];
return C;
}

public void split(int[][] P, int[][] C, int iB, int jB)
{
for(int i1 = 0, i2 = iB; i1 < C.length; i1++, i2++)
for(int j1 = 0, j2 = jB; j1 < C.length; j1++, j2++)
C[i1][j1] = P[i2][j2];
}

public void join(int[][] C, int[][] P, int iB, int jB)
{
for(int i1 = 0, i2 = iB; i1 < C.length; i1++, i2++)
for(int j1 = 0, j2 = jB; j1 < C.length; j1++, j2++)
P[i2][j2] = C[i1][j1];
}

public static void main (String[] args)
{
Scanner scan = new Scanner(System.in);
System.out.println("Strassen Multiplication Algorithm Test\n");

Strassen s = new Strassen();


int N = 4;


int[][] A = { { 1, 1, 1, 1 },
{ 2, 2, 2, 2 },
{ 3, 3, 3, 3 },
{ 2, 2, 2, 2 } };

int[][] B = { { 1, 1, 1, 1 },
{ 2, 2, 2, 2 },
{ 3, 3, 3, 3 },
{ 2, 2, 2, 2 } };
System.out.println("\nArray A =>");

for (int i = 0; i < N; i++)
{
for (int j = 0; j < N; j++)
System.out.print(A[i][j] +" ");
System.out.println();
}

System.out.println("\nArray B =>");
for (int i = 0; i < N; i++)
{
for (int j = 0; j < N; j++)
System.out.print(B[i][j] +" ");
System.out.println();
}

int[][] C = s.multiply(A, B);

System.out.println("\nProduct of matrices A and B : ");
for (int i = 0; i < N; i++)
{
for (int j = 0; j < N; j++)
System.out.print(C[i][j] +" ");
System.out.println();
}

}
}
37 changes: 37 additions & 0 deletions Strassen's Algorithm/Strassens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import numpy as np

def split(matrix):

row, col = matrix.shape
row2, col2 = row//2, col//2
return matrix[:row2, :col2], matrix[:row2, col2:], matrix[row2:, :col2], matrix[row2:, col2:]

def strassen(x, y):

if len(x) == 1:
return x * y


a, b, c, d = split(x)
e, f, g, h = split(y)


p1 = strassen(a, f - h)
p2 = strassen(a + b, h)
p3 = strassen(c + d, e)
p4 = strassen(d, g - e)
p5 = strassen(a + d, e + h)
p6 = strassen(b - d, g + h)
p7 = strassen(a - c, e + f)


c11 = p5 + p4 - p2 + p6
c12 = p1 + p2
c21 = p3 + p4
c22 = p1 + p5 - p3 - p7


c = np.vstack((np.hstack((c11, c12)), np.hstack((c21, c22))))

return c

20 changes: 20 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"name": "funnyalgorithms",
"version": "1.0.0",
"description": "<!-- <img align=\"center\" height=80% width=80% src=\"https://hacktoberfest.digitalocean.com/assets/HF-full-logo-b05d5eb32b3f3ecc9b2240526104cf4da3187b8b61963dd9042fdc2536e4a76c.svg\" alt=\"hacktoberfest-2020\"> -->",
"main": "index.js",
"scripts": {
"test": "echo \"Error: no test specified\" && exit 1"
},
"repository": {
"type": "git",
"url": "git+https://github.com/heeba-khan/FunnyAlgorithms.git"
},
"keywords": [],
"author": "",
"license": "ISC",
"bugs": {
"url": "https://github.com/heeba-khan/FunnyAlgorithms/issues"
},
"homepage": "https://github.com/heeba-khan/FunnyAlgorithms#readme"
}

0 comments on commit dffaaf2

Please sign in to comment.