相关矩阵的求解的一些优化心得 - TBY's Blog

相关矩阵的求解的一些优化心得

TBY posted @ 2013年1月07日 16:00 in Haskell with tags Haskell , 1464 阅读

问题如下,有 [tex]n \times m [/tex]Matrix, where [tex] n > 10^4 [/tex] , [tex] m[/tex] ~ [tex]100[/tex],以Row-major order 存储于内存中,需要求 相关矩阵 [tex]M_{cor}[/tex]。由于相关矩阵是对称的,所以只需要求上半个矩阵,共 [tex]p = \frac{(n-1)n}{2}[/tex]次相关系数求解。求解计算本身很简单无外乎:

  for(int i=0;i<r;i++)
    for(int j=i+1;j<r;j++)
    {
      double r = correlation(row(i),row(j));
    }

 

现在想并发该计算,考虑cache locality,最理想的并发方式无疑是顺序产生如下[(i,j)序列:
(0,1),(0,2),(0,3) ... (1,2),(1,3) ... (2,3),(2,4) ...

然后再以如下方式并发(以4核为例):

 

// thread 1 :
for (int idx = 0; idx < (n-1)*n/2; idx=idx+4) {
   (i,j) = calc(idx);
   double r1 = correlation(row(i),row(j));
}

// thread 2 :
for (int idx = 1; idx < (n-1)*n/2; idx=idx+4) {
   (i,j) = calc(idx);
   double r2 = correlation(row(i),row(j));
}

// thread 3 :
for (int idx = 2; idx < (n-1)*n/2; idx=idx+4) {
   (i,j) = calc(idx);
   double r3 = correlation(row(i),row(j));
}

// thread 4 :
for (int idx = 3; idx < (n-1)*n/2; idx=idx+4) {
   (i,j) = calc(p);
   double r4 = correlation(row(i),row(j));
}

好了,所以重点在于如何以并不高的代价解出那个calc函数。想了一下,直接要 idx -> (i,j) 不太容易,先反过来如何从(i,j) -> idx,草稿纸推了下(我小学数学老师要哭了,推了蛮长时间),idx满足以下等式:

[tex] idx = \frac{n\times(n-1)-(n-i-1)\times(n-i)}{2}+j-i-1=\frac{2\times i\times n-i^2-3i+2j-2}{2} [/tex]
这个关系蛮容易验证的:
#include <stdio.h>
int main(int argc, char *argv[])
{
  int n = 10;
  int i,j;
  int idx = 0; 
  for(i=0;i<n;i++)
    for(j=i+1;j<n;j++)
    {
      int p = (n*(n-1) - (n-1-i)*(n-i))/2 + j - i - 1;
      int q = (2 * i * n - i * i - 3 * i + 2 * j - 2) / 2;
      printf("%d  %d  %d\n",p,q,idx);
      ++idx;
    }
  return 0;
}
有了上述关系式子以后,可以发现,事实上,原先要求解的calc可以用一个整形规划来描述:
[tex]idx =  \frac{2\times i\times n - i^2 -3i +2j-2}{2}[/tex],subject to :
[tex]0 \le i < n, i < j < n, 0 \le idx < \frac{(n-1)n}{2}[/tex],其中i,j,n,idx皆为整数。
 
这里当然不可能用通用的整形规划相关的库(比如glpk),去解这种小儿科问题。看看有无其他办法。下面代码直接都上Haskell了,C之类太繁琐,写一小段代码都得撸半天。由关系式可以得到一个判别式:
testIdx :: Int -> Int -> Int -> Bool
testIdx i j idx = (2*i*n-i*i-3*i+2*j-2) == idx * 2
所以可以穷举所有i、j的可能值来求得idx:
f1 :: Int -> Int -> (Int,Int)
f1 n idx = head [(i,j) | i <- [0..n-2], 
                         j <- [i+1..n-1],
                         testIdx i j idx]
当然,这是个很笨的办法,当 n > 1000时已经很慢了。进一步优化需要利用那些约束,由于(i,j)都属于上三角阵的元素,所以很自然的想法就是去 缓存 那些 边界条件 (其实求解整形规划问题也就是这路子),在这里就是对角线上那一排(i,j)对应的idx值。这个比较容易,因为那些点还有个隐含关系:[tex]j = i + 1[/tex]。
f2 :: Int -> Int -> (Int,Int)
f2 n idx = -- quot开销比div更小, Haskell中的div比quot多额外的检查项,这里都是正整数运算无此必要
  let i = getI idx
      j = getJ i idx
  in (i,j)
  where
    getJ :: Int -> Int -> Int
    getJ i p = (2*p+i*i+3*i-2*i*n+2) `quot` 2
    getI :: Int -> Int
    getI ix = length (takeWhile (<= ix) ps0) - 1
    ps0 :: [Int] -- 上对角元素对应的idx值
    ps0 = map (\i -> (2*i*n-i*i-i) `quot` 2) [0..n-2]
上面f2在n=100时,比f1快了4.5倍,但还是不够高效,因为每次求解还是要重新计算ps0,所以很自然的想法就是再把ps0拉出来缓存。

import qualified Data.Vector.Unboxed as UV
getVec :: Int -> UV.Vector Int
getVec n = UV.generate n (\i -> (2*i*n-i*i-i) `quot` 2)  

           
f3 :: UV.Vector Int -> Int -> (Int,Int)
f3 vec idx =
  let i = getI
      j = getJ i
  in (i,j)
  where
    n = UV.length vec
    getJ i = (2*idx+i*i+3*i-2*i*n+2) `quot` 2
    getI = go 0
      where
        go acc | acc == n - 1 = acc
               | otherwise = if UV.unsafeIndex vec  (acc+1) > idx
                             then acc
                             else go (acc+1)
略微修改上述f1,f2,f3以便同时测试程序正确性,得到如下benchmark.hs文件
import Criterion.Main
import qualified Data.Vector.Unboxed as UV

f1 :: Int -> Bool
f1 n =
  let l = (n * (n-1)) `quot` 2
      ls = [0..l-1]
  in and $ zipWith (==) ls $ map (\p -> head [getP i j | i <- [0..n-2],j <- [i+1..n-1],testP i j p]) ls
  where
    testP :: Int -> Int -> Int -> Bool
    testP i j p = (2*i*n-i*i-3*i+2*j-2) == p * 2
    getP :: Int -> Int -> Int
    getP i j = (2*i*n-i*i-3*i+2*j-2) `quot` 2

f2 :: Int -> Bool
f2 n =
  let l = (n * (n-1)) `quot` 2
      ls = [0..l-1]
  in and $
     zipWith (\a b -> a == b) ls $
     map (\p ->
           let i = (length $ takeWhile (<= p) ps0) - 1
               j = getJ i p
           in getP i j) ls
  where
    getP :: Int -> Int -> Int
    getP i j = (2*i*n-i*i-3*i+2*j-2) `quot` 2
    getJ :: Int -> Int -> Int
    getJ i p = (2*p+i*i+3*i-2*i*n+2) `quot` 2
    ps0 :: [Int]
    ps0 = map (\i -> (2*i*n-i*i-i) `quot` 2) [0..n-2]


f3 :: UV.Vector Int -> Int -> Bool
f3 vec n =
  let l = (n * (n-1)) `quot` 2
      ls = [0..l-1]
  in and $
     zipWith (==) ls $
     map (\p ->
           let i = getIdx p
               j = getJ i p
           in getP i j) ls
  where
    getP i j = (2*i*n-i*i-3*i+2*j-2) `quot` 2
    getJ i p = (2*p+i*i+3*i-2*i*n+2) `quot` 2
    getIdx p = go 0
      where
          go acc | acc == n - 1 = acc
                 | otherwise = if UV.unsafeIndex vec (acc+1) > p
                               then acc
                               else go (acc+1)

main = do 
   let vec = getVec 100 
   print $ f1 100 
   print $ f2 100 
   print $ f3 vec 100 
   defaultMain 
   [bench "f1 100" $ nf f1 100 
   ,bench "f2 100" $ nf f2 100 
   ,bench "f3 100" $ nf (f3 vec) 100 
   ,bench "f3 1000" $ nf (f3 (getVec 1000)) 1000]
结果:
$ ghc --make -O2 -fllvm -optl-O3 benchmark.hs
$ ./benchmark
True
True
True
warming up
estimating clock resolution...
mean is 1.296372 us (640001 iterations)
found 3334 outliers among 639999 samples (0.5%)
  3075 (0.5%) high severe
estimating cost of a clock call...
mean is 35.00575 ns (12 iterations)

benchmarking f1 100
mean: 13.92723 ms, lb 13.88566 ms, ub 13.96976 ms, ci 0.950
std dev: 215.6054 us, lb 190.5129 us, ub 246.8802 us, ci 0.950
variance introduced by outliers: 8.472%
variance is slightly inflated by outliers

benchmarking f2 100
mean: 2.810448 ms, lb 2.801007 ms, ub 2.821439 ms, ci 0.950
std dev: 52.27388 us, lb 43.98545 us, ub 65.98971 us, ci 0.950
found 3 outliers among 100 samples (3.0%)
  3 (3.0%) high mild
variance introduced by outliers: 11.350%
variance is moderately inflated by outliers

benchmarking f3 100
mean: 173.9339 us, lb 173.4058 us, ub 174.5147 us, ci 0.950
std dev: 2.829666 us, lb 2.432605 us, ub 3.579531 us, ci 0.950
found 2 outliers among 100 samples (2.0%)
  2 (2.0%) high mild
variance introduced by outliers: 9.410%
variance is slightly inflated by outliers

benchmarking f3 1000
mean: 17.29405 ms, lb 17.26062 ms, ub 17.33314 ms, ci 0.950
std dev: 185.0009 us, lb 157.9185 us, ub 221.9862 us, ci 0.950
 
后记:实际运算的时候,其实并不需要像上面那样给定任意idx求其(i,j),因为i总是从0起始迭代的,所以只需要用关系式求j,并判定下是否需要更新i即可。运算量非常小。
192.168. 0.1 说:
2022年8月10日 02:29

Internet is something we all use in our daily life to experience different things from Entertainment in our personal life to getting extensive work done in the office life as well, but in order to do this you might turn to your Internet connection via the Wi-Fi which is only accessible if you have a router that allows to give you Wireless connection which is WAN in simple terms. 192.168. 0.1 And there are a lot of routers all across the world that use different Internet Protocol Gateways to access the Admin panel and some of these routers use the IP Address 192.168.0.1.

UP Board 12th Questi 说:
2022年8月18日 17:50

UP Board 12th Question Paper 2023, The Board of Higher Education Uttar Pradesh UPMSP, conducts the final Matriculation examination every year. UP Board 12th Question Paper 2023 This year too, Uttar-Pradesh Higher Secondary School Certificate public Examination 2023 will be performed during March and April 2023.


登录 *


loading captcha image...
(输入验证码)
or Ctrl+Enter
Host by is-Programmer.com | Power by Chito 1.3.3 beta | © 2007 LinuxGem | Design by Matthew "Agent Spork" McGee