C++编译期多项式exp

Posted by wyj on March 13, 2020

前言

强行创造一点实际意义:也许对于某些无法打表的题,这个可以用作一种通用的不占运行时时间的预处理方法。多项式exp只是一个例子,用来描述C++模板元编程到底有多强大。

这里实现的是$O(n^2)$的exp,要FFT的话我猜没有10K代码写不完。并且前几天我刚刚在校内训练时$O(n^2)$多项式快速幂卡过$T=10,n=10^4$。下面一段代码描述了正常的C++写法:

g[0]=1;
for(int i=1;i<=n;i++){
	for(int j=0;j<i;j++)(g[i]+=g[j]*f[i-j]%mod*(i-j))%=mod;
	(g[i]*=inv[i])%=mod;
}

编译期线性求逆

这个最简单了,并且也可以结合后面提到的“导出成数组”方法。然而我程序里面没有别的地方需要用到逆元数组,所以就没有导出逆元数组了。

using ll=long long;
const ll mod=998244353;

template<int x> struct Inv{enum:ll{val=-Inv<mod%x>::val*(mod/x)%mod};};
template<> struct Inv<1>{enum:ll{val=1};};

这里的enum:ll是C++11的带类型枚举类型,可以看做是const static的等价替代品。

求Exp数组的单项值

把它叫做expTerm类好了。expTerm<x>::val就是结果。可以把它理解成python中的generator,因为这是在调用时才动态计算的。我们要先构造一个辅助类sum<x,y>,用来递归计算卷积的单项值。这两个类是互相调用的关系,所以需要声明提前。

程序中间涉及到的f数组声明必须是constexpr而非const

template<int x> struct expTerm;

template<int x,int y> struct sum{
	enum:ll{val=(f[y+1]*(y+1)%mod*expTerm<x-y>::val%mod+sum<x,y-1>::val)%mod};
};
template<int x> struct sum<x,-1>{enum:ll{val=0};};

template<int x> struct expTerm{
	enum:ll{val=sum<x-1,x-1>::val*Inv<x>::val%mod};
};
template<> struct expTerm<0>{enum:ll{val=1};};

我一开始的逻辑是sum<x,y>调用sum<x,y+1>,然而这个东西貌似只有clang++编译的过,g++死活不认。所以改成了sum<x,y>调用sum<x,y-1>。g++的报错是这样的:

2.cpp:5:24: error: template argument ‘(y + 1)’ involves template parameter(s)

我可以理解这句话的意思,但是不能把它翻译成中文,因为翻译之后就变成了”错误:模板参数涉及到了模板参数”,狗屁不通。这是理解argumentparameter的意义差别的绝妙例子。

单项值组合成数组

之前的内容看似已经完全够用了。但是如果你要真的像打表程序一样调用的话,是不能在运行时计算expTerm<x>::val的(模板参数必须为常数),必须预先存进数组里。正常人不会写$10000$条赋值语句,所以这个也必须使用模板元编程搞定。

C++11为我们提供了一个强力工具:变长模板参数。我之前一直不是很懂这个东西。最近学习js的过程中使用js作为跳板,终于理解了那个...的意思(js和C++11的...都是从C语言的va_args发展而来的,js的语法更加好懂一些)。

于是先声明一个arrayHolder类,作用大概和js的(...x)=>[...x]类似:把参数列表变成一个数组。这里也必须是constexprconst会CE。

template<ll... args> struct arrayHolder{
	static constexpr ll data[sizeof...(args)]={args...};
};

上面的是C++17语法,C++11的要更反人类一点,就不写了。sizeof...相当于js的x.length,用来得到参数个数(而不是占用空间)。

声明一个ga类用来把expTerm的结果依次放到模板参数列表里。ga<n,F>相当于python中提取出生成器$F$的前$n+1$项。递归边界接上arrayHolder就可以完成单项值->模板参数->数组的转变。递归返回时需要拷贝数组?其实拷贝arrayHolder的类型就可以了,因为数组存放在类型信息(而不是实例)之内。

template<int n,template<int> class F,ll... args> struct ga_impl{
	using result=typename ga_impl<n-1,F,F<n>::val,args...>::result;
};
template<template<int> class F,ll... args> struct ga_impl<0,F,args...>{
	using result=arrayHolder<F<0>::val,args...>;
};

template<int n,template<int> class F> struct ga{
	using result=typename ga_impl<n,F>::result;
};

这里为什么要typename呢?因为编译器在递归到当前层时无法”预判”ga_impl::result到底是static变量还是一个类型,需要typename提示编译器这是一个类型(我才不会说其实是clang++给我提示)。

使用方法

constexpr ll f[5010]={0,1,2,3};

using g=ga<100,expTerm>::result;
for(int i=0;i<=100;i++)
	printf("%lld\n",(g::data[i]+mod)%mod);

这个100就是长度。不要开太大了,否则你会体验比Fork炸弹更加惨烈的死机(不要问我怎么知道的)。