UOJ Logo mcfxmcfx的博客

博客

快速乘法取模的一种奇怪实现

2021-08-20 17:02:51 By mcfxmcfx

常见的取模优化有 Montgomery multiplicationBarrett reduction 等。当模数固定时,编译器会帮我们进行 Barrett reduction 的优化,而当模数不固定时,这两种都需要手动实现。而手动实现的速度总会比编译器慢(比如模数啥的必须占一个寄存器位置)。

那么能不能在程序中内置编译器呢?即使能,也多半超过了码长限制。但是我们可以白嫖编译器的成果,让编译器对着某个模数优化,我们之后再改掉代码里的模数。

实现如下:

const int P=1000000007;

#ifdef _WIN32
#include<windows.h>
bool set_rwx(void*addr)
{
    char tmp[8];
    return VirtualProtect(addr, 0x1000, PAGE_EXECUTE_READWRITE, (PDWORD)tmp);
}
#else
#include<sys/mman.h>
bool set_rwx(void*addr)
{
    return !mprotect(addr, 0x1000, PROT_READ | PROT_WRITE | PROT_EXEC);
}
#endif

namespace mod_jit
{
    struct pointer{
        union{uint64_t*ul;uint32_t*ui;uint16_t*us;uint64_t v;};
        pointer(){}
        template<typename T>pointer(T x){ul=(uint64_t*)x;}
    };

    pointer range_l,range_r;

    bool add_segment(uint64_t addr,uint64_t elfsz,uint64_t memsz)
    {
        if(elfsz+0x10000<memsz)return 1; // skip global arrays
        uint64_t tmp=range_l.v+addr+elfsz;
        if(range_r.v<tmp)range_r.v=tmp;
        return 0;
    }

#ifdef _WIN32
    void get_range()
    {
        pointer t=get_range;
        t.v&=~0xfffull;
        while(*t.ui!=0x905a4d)t.v-=0x1000;
        range_l=t;
        t.v+=*(t.ui+15);
        int num=*(t.us+3);
        t.v+=0x18+*(t.us+10);
        for(;num--;t.v+=0x28){
            if(*(t.ui+9)>>25&1)continue;
            if(add_segment(*(t.ui+3),*(t.ui+4),*(t.ui+2)))break;
        }
    }
#else
    void get_range()
    {
        pointer t=get_range;
        t.v&=~0xfffull;
        while(*t.ui!=0x464c457f)t.v-=0x1000;
        int num=*(t.us+28);
        range_l=t;
        t.v+=*(t.ul+4);
        for(;num--;t.v+=0x38){
            if(add_segment(*(t.ul+2),*(t.ul+4),*(t.ul+5)))break;
        }
    }
#endif

    void get_change(int P,std::function<void(uint32_t)>add_int,std::function<void(uint64_t)>add_ll)
    {
        assert(P>(1<<29)&&P<(1<<30));
        auto tint=[&](uint32_t x){add_int(x);add_int(-x);};
        auto tll=[&](uint64_t x){add_ll(x);add_ll(-x);};
        for(int i=-3;i<=3;i++)tint(P+i);
        tint(((((1ull<<30)-P)<<32)+P-1)/P);
        for(int i=59;i<=61;i++)tint(((1ull<<i)+P-1)/P);
        for(int i=91;i<=93;i++)tll(((__uint128_t(1)<<i)+P-1)/P);
    }

    int vc,tc,pl;
    std::vector<std::pair<int,pointer>>p;

    void record_int(uint32_t x)
    {
        pointer a=range_l,b=range_r;
        b.v-=3;
        while(a.v!=b.v)
        {
            if(*a.ui==x)p.push_back(std::make_pair(vc,a));
            a.v++;
        }
        vc++;
    }

    void record_ll(uint64_t x)
    {
        pointer a=range_l,b=range_r;
        b.v-=7;
        while(a.v!=b.v)
        {
            if(*a.ul==x)p.push_back(std::make_pair(vc,a));
            a.v++;
        }
        vc++;
    }

    void change_int(uint32_t x)
    {
        for(;pl<p.size()&&p[pl].first==tc;pl++)
            *p[pl].second.ui=x;
        tc++;
    }

    void change_ll(uint64_t x)
    {
        for(;pl<p.size()&&p[pl].first==tc;pl++)
            *p[pl].second.ul=x;
        tc++;
    }

    uint64_t anti_optimize_zero()
    {
        uint64_t res=0;
        for(int i=0;i<0x100;i++){
            res=res+55555^114514;
            res^=res<<9;
        }
        return res^11608655508041216768ull;
    }

    struct init{init()
    {
        get_range();
        get_change(P+anti_optimize_zero(),record_int,record_ll);
        std::set<uint64_t>map_area;
        for(auto&a:p)map_area.insert(a.second.v&~0xfffull);
        for(uint64_t x:map_area)assert(set_rwx(*(void**)&x));
    }}init_;

    void change_mod(int P)
    {
        tc=pl=0;
        get_change(P,change_int,change_ll);
    }
}

用法:粘上这坨,然后用这个固定模数 P 写代码。要改模数的时候调用 mod_jit::change_mod

速度参考:https://loj.ac/s/1233497 (原始提交:https://loj.ac/s/1233405 https://loj.ac/s/1170461

这坨代码的主要流程是,先找到程序基地址,然后读取每个段,得到整个程序的地址区间,接下来在程序中找到编译器可能优化出来的常数,把对应内存设为可修改的,每次 change_mod 时就修改他们。

然而这份代码也只是个概念验证,想粘个板子直接用,目前还不太现实。

有一个主要问题,这个找常数的过程找到的不一定真的是常数,他有可能是指令的一部分恰好和常数相同(虽然概率非常小)。为了解决这个问题,就需要把程序按指令划分开,虽然能做到(比如参考这个),但是成本会高得多。

另一个问题是,编译器优化出的常数,可能性非常多,很难枚举完。在 get_change 函数里面可以看到许多带 P 的表达式,他们勉强能覆盖大多数情况,但是显然不能覆盖所有情况。另外这里还要求 $2^{29}< P < 2^{30}$,这是因为编译器根据模数不同,生成的移位指令也不同,而想要写代码找出这些移位指令是比较困难的。(另一种解决方法是,对于每个 $2^k\sim 2^{k+1}$ 之间的 $P$ 都生成一份代码,这倒也许还不错)

评论

run
老师你好为什么我零分啊为什么谢谢老师我就是问问谢谢

发表评论

可以用@mike来提到mike这个用户,mike会被高亮显示。如果你真的想打“@”这个字符,请用“@@”。