当前位置 博文首页 > longwind09,多容寡欲,千里江河:Scala-Spark digamma stackove

    longwind09,多容寡欲,千里江河:Scala-Spark digamma stackove

    作者:[db:作者] 时间:2021-08-27 10:09

    Scala-Spark digamma stackoverflow问题

    这两天在用spark做点击率的贝叶斯平滑,参考雅虎的论文进行了一番尝试。

    先上代码:

    # click_count, show_count # this method takes time
    def do_smooth(data_list):
        import scipy.special as sp
        a, b, i = 1.0, 1.0, 0
        da, db = a, b
        while i < 1000 and (da > 1.0E-10 or db > 1.0E-10):
            x1, y1, x2 = 0.0, 0.0, 0.0
            for lineList in data_list:
                x1 += sp.digamma((lineList[0]) + a) - sp.digamma(a)
                y1 += sp.digamma((lineList[1]) + a + b) - sp.digamma(a + b)
                x2 += sp.digamma((lineList[1]) - (lineList[0]) + b) - sp.digamma(b)
            na, nb = a, b
            a *= (x1 / y1)
            b *= (x2 / y1)
            da, db = abs(a - na), abs(b - nb)
            i += 1
        print i, a, b
        return a, b

    这是我之前用的python代码,改成scala也相当容易,digamma函数非常耗时,而且还要迭代1000次。最要命的是digamma在scala里面默认的实现会出现栈溢出!!!

    var a, b, da, db: Double = 1.0
    var index = 0
    while (index < 1000 && (da > 1.0E-9 || db > 1.0E-9)) {
        var x1,x2,y1 = 0.0
        traindata.foreach(p => {
            x1 += MBlas.digamma(p(2) + a) - MBlas.digamma(a)
            y1 += MBlas.digamma(p(1) + a + b) - MBlas.digamma(a + b)
            x2 += MBlas.digamma(p(1) - p(2) + b) - MBlas.digamma(b)
            val na = a
            val nb = b
            a *= (x1 / y1)
            b *= (x2 / y1)
            da = Math.abs(a - na)
            db = Math.abs(b - nb)
        })
    }

    digamma 函数是个递归函数,问题就处在递归上了。

       public static double digamma(double x) {
            if (x > 0 && x <= S_LIMIT) {
                return -GAMMA - 1 / x;
            }
            if (x >= C_LIMIT) {
                double inv = 1 / (x * x);
                return FastMath.log(x) - 0.5 / x - inv * ((1.0 / 12) + inv * (1.0 / 120 - inv / 252));
            }
            return digamma(x + 1) - 1 / x;
        }

    既然知道问题所在,是不是就可以重写递归为非递归呢?在Stack Overflow上找到了一个答案

      val GAMMA = 0.577215664901532860606512090082
            val GAMMA_MINX = 1.e-12
            val DIGAMMA_MINNEGX = -1250
            val C_LIMIT = 49
            val S_LIMIT = 1e-5
            var value = 0.0
            var x = input
            while (true) {
                if (x >= 0 && x < GAMMA_MINX) x = GAMMA_MINX
                if (x < DIGAMMA_MINNEGX) {
                    x = DIGAMMA_MINNEGX + GAMMA_MINX
                } else {
                    if (x > 0 && x <= S_LIMIT) return value + -GAMMA - 1 / x
                    if (x >= C_LIMIT) {
                        val inv = 1 / (x * x)
                        return value + Math.log(x) - 0.5 / x - inv * ((1.0 / 12) + inv * (1.0 / 120 - inv / 252))
                    }
                    value = value - 1.0 / x
                    x += 1
                }
            }

    经测试,没看出什么问题,可以用了。
    不过,上面的代码并没有解决慢的问题,当需要计算CTR的对象比较多时(几百万),仍然比较耗时。所以我决定用两个替代方法:

    1. 抽样,抽取能在可接受时间内出结果的样本数,得到α和β;
    2. 直接使用平均值作为α和β

    参考:
    1. 雅虎专家的论文,如上
    2. Stack Overflow 网友代码,如上

    cs