Monday, 26 December 2011

Matrix multiplication in Scala

First take with idiomatic scala(77ms):
def singleThreadedMultiplication2(m1:Seq[Array[Double]], m2:Array[Array[Double]] ) ={
val res = Array.fill(m1.length, m2(0).length)(0.0)
for(row <- 0 until m1.length;
col <- 0 until m2(0).length;
i <- 0 until m1(0).length){
res(row)(col) += m1(row)(i) * m2(i)(col)
}
res
}
view raw gistfile1.scala hosted with ❤ by GitHub


Second take - replacing fors with whiles(63ms):
@inline def singleThreadedMultiplication1(m1:Seq[Array[Double]], m2:Array[Array[Double]] ) ={
val res = Array.ofDim[Double](m1.length, m2(0).length)
var col, i = 0
var row = 0
// while statements are much faster than for statements
while(row < m1.length){ col = 0
while(col < m2(0).length){ i = 0
while(i < m1(0).length){
res(row)(col) += m1(row)(i) * m2(i)(col)
i+=1
}
col += 1
}
row += 1
}
res
}
view raw gistfile1.scala hosted with ❤ by GitHub


Third take - sum in a variable(55ms):
@inline def singleThreadedMultiplicationFAST(m1:Seq[Array[Double]], m2:Array[Array[Double]] ) ={
val res = Array.ofDim[Double](m1.length, m2(0).length)
val M1_COLS = m1(0).length
val M1_ROWS = m1.length
val M2_COLS = m2(0).length
var col, i = 0
var sum = 0.0
var row = 0
while(row < M1_ROWS){ col = 0
while(col < M2_COLS){ i = 0; sum = 0
while(i<M1_COLS){
// calculating the sum in a variable seems to be faster than updating the array all the time
sum += m1(row)(i) * m2(i)(col)
i+=1
}
res(row)(col) = sum
col += 1
}; row += 1
}
res
}
view raw gistfile1.scala hosted with ❤ by GitHub
Multi-threaded with parallel collections(15ms):
override def multiply(m1: Array[Array[Double]], m2: Array[Array[Double]]) : Array[Array[Double]] = {
val res = Array.ofDim[Double](m1.length, m2(0).length)
val M1_COLS = m1(0).length
val M1_ROWS = m1.length
val M2_COLS = m2(0).length
@inline def singleThreadedMultiplicationFAST(start_row:Int, end_row:Int) {
var col, i = 0
var sum = 0.0
var row = start_row
// while statements are much faster than for statements
while(row < end_row){ col = 0
while(col < M2_COLS){ i = 0; sum = 0
while(i<M1_COLS){
sum += m1(row)(i) * m2(i)(col)
i+=1
}
res(row)(col) = sum
col += 1
}; row += 1
}
}
(0 until M1_ROWS).par.foreach( i =>
singleThreadedMultiplicationFAST(i, i+1)
)
res
}
view raw gistfile1.scala hosted with ❤ by GitHub
Multi-threaded with parallel collections idiomatic scala(32ms):
def multiThreadedIdiomatic(m1:Seq[Array[Double]], m2:Array[Array[Double]] ) ={
val res = Array.fill(m1.length, m2(0).length)(0.0)
for(row <- (0 until m1.length).par;
col <- (0 until m2(0).length).par;
i <- 0 until m1(0).length){
res(row)(col) += m1(row)(i) * m2(i)(col)
}
res
}
view raw gistfile1.scala hosted with ❤ by GitHub

No comments:

Post a Comment