michaele
Published on 2025-04-26 / 17 Visits
2
3

模拟赛T4 题解

思路

这题赛时想到了用质因数来求每个子树的因数个数,但是没时间写了,而且也没有那么熟悉STL,所以最后只拿了a_{i}\leq 3的10pts

下面说下思路,题解写的太恶心,所以我结合自己的思路,再捋一遍

首先我们要预处理出每个数分解质因数以后的结果,也就是每个数分别有哪些质因数、以及每个质因数的指数是多少

void init() {  
    for(int i = 2; i <= 1e6; i++) {  
        if(p[i].size())  
            continue;  
            
        for(int j = i; j <= 1e6; j += i) {  
            int x = j;  
            int t = 0;  
            while(x % i == 0) {  
                x /= i;
                t++;  
            }
	        p[j].push_back({i, t});  
	    }  
    }  
}

然后对这棵树进行dfs,因为是乘积,所以可以转化为质因子相加

启发式合并(把小的加到大的上面)当前节点与其子节点,而后算一下每个质因子对答案的贡献,看是否\geq k即可

d[x]x节点的因子个数,p_{i}表示i的指数

\displaystyle \begin{align} d[x]=\prod_{y\in x的质因数}p_{y} + 1 \end{align}

code

typedef long long ll;  
const int N = 2e5 + 10, M = 1e6 + 10;  

int n;  
ll k;  
int h[N], ver[N << 1], ne[N << 1], tot;  
vector <pair<int, int> > p[M];  
map <int, int> mp[N];  
int a[N];  
ll mul, ans;  
  
void add(int x, int y) {  
    ver[++tot] = y;  
    ne[tot] = h[x];  
    h[x] = tot;  
}  
  
void init() {  
    for(int i = 2; i <= 1e6; i++) {  
        if(p[i].size())  
            continue;  
        for(int j = i; j <= 1e6; j += i) {  
            int x = j;  
            int t = 0;  
            while(x % i == 0) {  
                x /= i;  
                t++;  
            }  
            p[j].push_back({i, t});  
                    }  
    }  
}  
void dfs(int x, int fa) {  
        //将x节点自己的权值的贡献加入mp  
    for(auto o : p[a[x]]) {  
        mp[x][o.first] += o.second;  
    }  
        //分别加儿子的贡献  
    for(int i = h[x]; i; i = ne[i]) {  
        int y = ver[i];  
        if (y == fa) continue;  
        dfs(y, x);  
        if(mp[x].size() < mp[y].size())  
            swap(mp[x], mp[y]);  
        for (auto o : mp[y]) {  
            mp[x][o.first] += o.second;  
        }  
    }  
        //统计因数个数  
    mul = 1;  
    for(auto o : mp[x]) {  
        mul *= o.second + 1;  
        if(mul >= k)  
            break;  
    }  
    ans += mul >= k ? 1 : 0;    
}  
int main () {  
    init();  
    scanf("%d%lld", &n, &k);  
        for (int i = 1; i <= n; i++)  
        scanf("%d", &a[i]);  
    for(int i = 1; i < n; i++) {  
        int u, v;  
        scanf("%d%d", &u, &v);  
        add(u, v);  
        add(v, u);  
    }  
    dfs(1, 0);  
    printf("%lld\n", ans);  
        return 0;  
}

Comment