相关矩阵的求解的一些优化心得 - TBY's Blog

相关矩阵的求解的一些优化心得

TBY posted @ 2013年1月07日 16:00 in Haskell with tags Haskell , 776 阅读

问题如下,有 [tex]n \times m [/tex]Matrix, where [tex] n > 10^4 [/tex] , [tex] m[/tex] ~ [tex]100[/tex],以Row-major order 存储于内存中,需要求 相关矩阵 [tex]M_{cor}[/tex]。由于相关矩阵是对称的,所以只需要求上半个矩阵,共 [tex]p = \frac{(n-1)n}{2}[/tex]次相关系数求解。求解计算本身很简单无外乎:

  for(int i=0;i<r;i++)
    for(int j=i+1;j<r;j++)
    {
      double r = correlation(row(i),row(j));
    }

 

现在想并发该计算,考虑cache locality,最理想的并发方式无疑是顺序产生如下[(i,j)序列:
(0,1),(0,2),(0,3) ... (1,2),(1,3) ... (2,3),(2,4) ...

然后再以如下方式并发(以4核为例):

 

// thread 1 :
for (int idx = 0; idx < (n-1)*n/2; idx=idx+4) {
   (i,j) = calc(idx);
   double r1 = correlation(row(i),row(j));
}

// thread 2 :
for (int idx = 1; idx < (n-1)*n/2; idx=idx+4) {
   (i,j) = calc(idx);
   double r2 = correlation(row(i),row(j));
}

// thread 3 :
for (int idx = 2; idx < (n-1)*n/2; idx=idx+4) {
   (i,j) = calc(idx);
   double r3 = correlation(row(i),row(j));
}

// thread 4 :
for (int idx = 3; idx < (n-1)*n/2; idx=idx+4) {
   (i,j) = calc(p);
   double r4 = correlation(row(i),row(j));
}

好了,所以重点在于如何以并不高的代价解出那个calc函数。想了一下,直接要 idx -> (i,j) 不太容易,先反过来如何从(i,j) -> idx,草稿纸推了下(我小学数学老师要哭了,推了蛮长时间),idx满足以下等式:

[tex] idx = \frac{n\times(n-1)-(n-i-1)\times(n-i)}{2}+j-i-1=\frac{2\times i\times n-i^2-3i+2j-2}{2} [/tex]
这个关系蛮容易验证的:
#include <stdio.h>
int main(int argc, char *argv[])
{
  int n = 10;
  int i,j;
  int idx = 0; 
  for(i=0;i<n;i++)
    for(j=i+1;j<n;j++)
    {
      int p = (n*(n-1) - (n-1-i)*(n-i))/2 + j - i - 1;
      int q = (2 * i * n - i * i - 3 * i + 2 * j - 2) / 2;
      printf("%d  %d  %d\n",p,q,idx);
      ++idx;
    }
  return 0;
}
有了上述关系式子以后,可以发现,事实上,原先要求解的calc可以用一个整形规划来描述:
[tex]idx =  \frac{2\times i\times n - i^2 -3i +2j-2}{2}[/tex],subject to :
[tex]0 \le i < n, i < j < n, 0 \le idx < \frac{(n-1)n}{2}[/tex],其中i,j,n,idx皆为整数。
 
这里当然不可能用通用的整形规划相关的库(比如glpk),去解这种小儿科问题。看看有无其他办法。下面代码直接都上Haskell了,C之类太繁琐,写一小段代码都得撸半天。由关系式可以得到一个判别式:
testIdx :: Int -> Int -> Int -> Bool
testIdx i j idx = (2*i*n-i*i-3*i+2*j-2) == idx * 2
所以可以穷举所有i、j的可能值来求得idx:
f1 :: Int -> Int -> (Int,Int)
f1 n idx = head [(i,j) | i <- [0..n-2], 
                         j <- [i+1..n-1],
                         testIdx i j idx]
当然,这是个很笨的办法,当 n > 1000时已经很慢了。进一步优化需要利用那些约束,由于(i,j)都属于上三角阵的元素,所以很自然的想法就是去 缓存 那些 边界条件 (其实求解整形规划问题也就是这路子),在这里就是对角线上那一排(i,j)对应的idx值。这个比较容易,因为那些点还有个隐含关系:[tex]j = i + 1[/tex]。
f2 :: Int -> Int -> (Int,Int)
f2 n idx = -- quot开销比div更小, Haskell中的div比quot多额外的检查项,这里都是正整数运算无此必要
  let i = getI idx
      j = getJ i idx
  in (i,j)
  where
    getJ :: Int -> Int -> Int
    getJ i p = (2*p+i*i+3*i-2*i*n+2) `quot` 2
    getI :: Int -> Int
    getI ix = length (takeWhile (<= ix) ps0) - 1
    ps0 :: [Int] -- 上对角元素对应的idx值
    ps0 = map (\i -> (2*i*n-i*i-i) `quot` 2) [0..n-2]
上面f2在n=100时,比f1快了4.5倍,但还是不够高效,因为每次求解还是要重新计算ps0,所以很自然的想法就是再把ps0拉出来缓存。

import qualified Data.Vector.Unboxed as UV
getVec :: Int -> UV.Vector Int
getVec n = UV.generate n (\i -> (2*i*n-i*i-i) `quot` 2)  

           
f3 :: UV.Vector Int -> Int -> (Int,Int)
f3 vec idx =
  let i = getI
      j = getJ i
  in (i,j)
  where
    n = UV.length vec
    getJ i = (2*idx+i*i+3*i-2*i*n+2) `quot` 2
    getI = go 0
      where
        go acc | acc == n - 1 = acc
               | otherwise = if UV.unsafeIndex vec  (acc+1) > idx
                             then acc
                             else go (acc+1)
略微修改上述f1,f2,f3以便同时测试程序正确性,得到如下benchmark.hs文件
import Criterion.Main
import qualified Data.Vector.Unboxed as UV

f1 :: Int -> Bool
f1 n =
  let l = (n * (n-1)) `quot` 2
      ls = [0..l-1]
  in and $ zipWith (==) ls $ map (\p -> head [getP i j | i <- [0..n-2],j <- [i+1..n-1],testP i j p]) ls
  where
    testP :: Int -> Int -> Int -> Bool
    testP i j p = (2*i*n-i*i-3*i+2*j-2) == p * 2
    getP :: Int -> Int -> Int
    getP i j = (2*i*n-i*i-3*i+2*j-2) `quot` 2

f2 :: Int -> Bool
f2 n =
  let l = (n * (n-1)) `quot` 2
      ls = [0..l-1]
  in and $
     zipWith (\a b -> a == b) ls $
     map (\p ->
           let i = (length $ takeWhile (<= p) ps0) - 1
               j = getJ i p
           in getP i j) ls
  where
    getP :: Int -> Int -> Int
    getP i j = (2*i*n-i*i-3*i+2*j-2) `quot` 2
    getJ :: Int -> Int -> Int
    getJ i p = (2*p+i*i+3*i-2*i*n+2) `quot` 2
    ps0 :: [Int]
    ps0 = map (\i -> (2*i*n-i*i-i) `quot` 2) [0..n-2]


f3 :: UV.Vector Int -> Int -> Bool
f3 vec n =
  let l = (n * (n-1)) `quot` 2
      ls = [0..l-1]
  in and $
     zipWith (==) ls $
     map (\p ->
           let i = getIdx p
               j = getJ i p
           in getP i j) ls
  where
    getP i j = (2*i*n-i*i-3*i+2*j-2) `quot` 2
    getJ i p = (2*p+i*i+3*i-2*i*n+2) `quot` 2
    getIdx p = go 0
      where
          go acc | acc == n - 1 = acc
                 | otherwise = if UV.unsafeIndex vec (acc+1) > p
                               then acc
                               else go (acc+1)

main = do 
   let vec = getVec 100 
   print $ f1 100 
   print $ f2 100 
   print $ f3 vec 100 
   defaultMain 
   [bench "f1 100" $ nf f1 100 
   ,bench "f2 100" $ nf f2 100 
   ,bench "f3 100" $ nf (f3 vec) 100 
   ,bench "f3 1000" $ nf (f3 (getVec 1000)) 1000]
结果:
$ ghc --make -O2 -fllvm -optl-O3 benchmark.hs
$ ./benchmark
True
True
True
warming up
estimating clock resolution...
mean is 1.296372 us (640001 iterations)
found 3334 outliers among 639999 samples (0.5%)
  3075 (0.5%) high severe
estimating cost of a clock call...
mean is 35.00575 ns (12 iterations)

benchmarking f1 100
mean: 13.92723 ms, lb 13.88566 ms, ub 13.96976 ms, ci 0.950
std dev: 215.6054 us, lb 190.5129 us, ub 246.8802 us, ci 0.950
variance introduced by outliers: 8.472%
variance is slightly inflated by outliers

benchmarking f2 100
mean: 2.810448 ms, lb 2.801007 ms, ub 2.821439 ms, ci 0.950
std dev: 52.27388 us, lb 43.98545 us, ub 65.98971 us, ci 0.950
found 3 outliers among 100 samples (3.0%)
  3 (3.0%) high mild
variance introduced by outliers: 11.350%
variance is moderately inflated by outliers

benchmarking f3 100
mean: 173.9339 us, lb 173.4058 us, ub 174.5147 us, ci 0.950
std dev: 2.829666 us, lb 2.432605 us, ub 3.579531 us, ci 0.950
found 2 outliers among 100 samples (2.0%)
  2 (2.0%) high mild
variance introduced by outliers: 9.410%
variance is slightly inflated by outliers

benchmarking f3 1000
mean: 17.29405 ms, lb 17.26062 ms, ub 17.33314 ms, ci 0.950
std dev: 185.0009 us, lb 157.9185 us, ub 221.9862 us, ci 0.950
 
后记:实际运算的时候,其实并不需要像上面那样给定任意idx求其(i,j),因为i总是从0起始迭代的,所以只需要用关系式求j,并判定下是否需要更新i即可。运算量非常小。

登录 *


loading captcha image...
(输入验证码)
or Ctrl+Enter
Host by is-Programmer.com | Power by Chito 1.3.3 beta | © 2007 LinuxGem | Design by Matthew "Agent Spork" McGee