algo

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub dnx04/algo

:heavy_check_mark: tests/Sum_of_Exponential_times_Polynomial_limit.test.cpp

Depends on

Code

#define PROBLEM "https://judge.yosupo.jp/problem/sum_of_exponential_times_polynomial_limit"

#include "../misc/macros.h"
#include "../math/SumPowerPoly.h"

// calculate pws(i) = i^d for 0 <= i < n using sieve
vector<Fp> getMonomials(int n, int d) {
  vector<Fp> pws(n);
  vector<int> primes, lpf(n);
  pws[1] = 1, pws[0] = (d == 0 ? 1 : 0);
  for (int i = 2; i < n; ++i) {
    if (lpf[i] == 0) lpf[i] = i, primes.eb(i), pws[i] = Fp(i).pow(d);
    for (auto p : primes) {
      if (p > lpf[i] || i * p >= n) break;
      lpf[i * p] = p;
      pws[i * p] = pws[i] * pws[p];
    }
  }
  return pws;
}

void solve() {
  Fp r;
  int d;
  cin >> r >> d;
  prepareFac(d + 2);
  cout << sumPolyLimit(r, getMonomials(d + 1, d));
}

int main() {
  solve();
}
#line 1 "tests/Sum_of_Exponential_times_Polynomial_limit.test.cpp"
#define PROBLEM "https://judge.yosupo.jp/problem/sum_of_exponential_times_polynomial_limit"

#line 1 "misc/macros.h"
// #pragma GCC optimize("Ofast,unroll-loops")       // unroll long, simple loops
// #pragma GCC target("avx2,fma")                   // vectorizing code
// #pragma GCC target("lzcnt,popcnt,abm,bmi,bmi2")  // for fast bitset operation

#include <bits/extc++.h>
#include <tr2/dynamic_bitset>

using namespace std;
using namespace __gnu_pbds;  // ordered_set, gp_hash_table
// using namespace __gnu_cxx; // rope

// for templates to work
#define all(x) (x).begin(), (x).end()
#define sz(x) (int) (x).size()
#define pb push_back
#define eb emplace_back
using i32 = int32_t;
using u32 = uint32_t;
using i64 = int64_t;
using u64 = uint64_t;
using i128 = __int128_t;
using u128 = __uint128_t;
using ld = long double;
using pii = pair<i32, i32>;
using vi = vector<i32>;

// fast map
const int RANDOM = chrono::high_resolution_clock::now().time_since_epoch().count();
struct chash {  // customize hash function for gp_hash_table
  int operator()(int x) const { return x ^ RANDOM; }
};
gp_hash_table<int, int, chash> table;

/* ordered set
    find_by_order(k): returns an iterator to the k-th element (0-based)
    order_of_key(k): returns the number of elements in the set that are strictly less than k
*/
template <class T>
using ordered_set = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;

/*  rope
    rope <int> cur = v.substr(l, r - l + 1);
    v.erase(l, r - l + 1);
    v.insert(v.mutable_begin(), cur);
*/
#line 2 "math/ModInt.h"

template <int mod>
struct modint {
  using M = modint;
  static_assert(mod > 0 && mod <= 2147483647);
  static constexpr int modulo = mod;
  static constexpr u32 r1 = []() {
    u32 r1 = mod;
    for (int i = 0; i < 5; ++i) r1 *= 2 - mod * r1;
    return -r1;
  }();
  static constexpr u32 r2 = -u64(mod) % mod;
  static u32 reduce(u64 x) {
    u32 y = u32(x) * r1, r = (x + u64(y) * mod) >> 32;
    return r >= mod ? r - mod : r;
  }
  u32 x;
  modint() : x(0) {}
  modint(i64 x) : x(reduce(u64(x % mod + mod) * r2)) {}
  M& operator+=(const M& a) {
    if ((x += a.x) >= mod) x -= mod;
    return *this;
  }
  M& operator-=(const M& a) {
    if ((x += mod - a.x) >= mod) x -= mod;
    return *this;
  }
  M& operator*=(const M& a) {
    x = reduce(u64(x) * a.x);
    return *this;
  }
  M& operator/=(const M& a) { return *this *= a.inv(); }
  M operator-() const { return M(0) - *this; }
  M operator+(const M& a) const { return M(*this) += a; }
  M operator-(const M& a) const { return M(*this) -= a; }
  M operator*(const M& a) const { return M(*this) *= a; }
  M operator/(const M& a) const { return M(*this) /= a; }
  bool operator==(const M& a) const { return x == a.x; }
  bool operator!=(const M& a) const { return x != a.x; }
  M pow(u64 k) const {
    M res(1), b = *this;
    while (k) {
      if (k & 1) res *= b;
      b *= b, k >>= 1;
    }
    return res;
  }
  M inv() const { return pow(mod - 2); }
  friend ostream& operator<<(ostream& os, const M& a) {
    return os << reduce(a.x);
  }
  friend istream& operator>>(istream& is, M& a) {
    i64 v;
    is >> v;
    a = M(v);
    return is;
  }
};

u64 modmul(u64 x, u64 y, u64 m) { return u128(x) * y % m; }
u64 modpow(u64 x, u64 k, u64 m) {
  u64 res = 1;
  while (k) {
    if (k & 1) res = modmul(res, x, m);
    x = modmul(x, x, m);
    k >>= 1;
  }
  return res;
}
#line 2 "math/SumPowerPoly.h"

using Fp = modint<998244353>;

vector<Fp> fac, invFac;
void prepareFac(int n) {
  fac.resize(n + 1);
  invFac.resize(n + 1);
  fac[0] = 1;
  for (int i = 1; i <= n; ++i) fac[i] = fac[i - 1] * i;
  invFac[n] = fac[n].inv();
  for (int i = n; i >= 1; --i) invFac[i - 1] = invFac[i] * i;
}

// Lagrange interpolation [0,...,n-1] in O(n)
Fp interpolate(const vector<Fp>& y, i64 n) {
  int k = sz(y) - 1;
  if (n <= k) return y[n];
  vector<Fp> pre(k + 1), suf(k + 1);
  pre[0] = suf[k] = 1;
  for (int i = 0; i < k; ++i) pre[i + 1] = pre[i] * (n - i);
  for (int i = k; i > 0; --i) suf[i - 1] = suf[i] * (n - i);
  Fp ans = 0;
  for (int i = 0; i <= k; ++i) {
    Fp val = pre[i] * suf[i] * y[i] * invFac[i] * invFac[k - i];
    if ((k - i) & 1) ans -= val;
    else ans += val;
  }
  return ans;
}

// C = sum_{i=0->inf} r^i * fs[i] (r != 1)
Fp sumPolyLimit(Fp r, const vector<Fp>& fs) {
  int d = fs.size() - 1;
  if (r.x == 0) return fs[0];
  vector<Fp> rr(d + 1);
  rr[0] = 1;
  for (int i = 1; i <= d; ++i) rr[i] = rr[i - 1] * r;
  Fp ans = 0, S = 0;
  for (int i = 0; i <= d; ++i) {
    S += rr[i] * fs[i];
    Fp term = invFac[d - i] * invFac[i + 1] * rr[d - i] * S;
    if ((d - i) & 1) ans -= term;
    else ans += term;
  }
  return ans * fac[d + 1] / (Fp(1) - r).pow(d + 1);
}

// Sum_{i=0->n-1} r^i * fs[i]
Fp sumPoly(Fp r, const vector<Fp>& fs, u64 n) {
  if (n == 0) return 0;
  if (r == 0) return fs[0];
  int d = sz(fs) - 1;
  if (r == 1) {
    vector<Fp> S(d + 2);
    S[0] = 0;
    for (int i = 0; i <= d; ++i) S[i + 1] = S[i] + fs[i];
    return interpolate(S, n);
  }
  Fp C = sumPolyLimit(r, fs), S_curr = 0, rp = 1, rip = 1, ri = r.inv();
  vector<Fp> g(d + 1);
  for (int k = 0; k <= d; ++k) {
    g[k] = (S_curr - C) * rip;
    S_curr += rp * fs[k], rp *= r, rip *= ri;
  }
  return C + r.pow(n) * interpolate(g, n);
}
#line 5 "tests/Sum_of_Exponential_times_Polynomial_limit.test.cpp"

// calculate pws(i) = i^d for 0 <= i < n using sieve
vector<Fp> getMonomials(int n, int d) {
  vector<Fp> pws(n);
  vector<int> primes, lpf(n);
  pws[1] = 1, pws[0] = (d == 0 ? 1 : 0);
  for (int i = 2; i < n; ++i) {
    if (lpf[i] == 0) lpf[i] = i, primes.eb(i), pws[i] = Fp(i).pow(d);
    for (auto p : primes) {
      if (p > lpf[i] || i * p >= n) break;
      lpf[i * p] = p;
      pws[i * p] = pws[i] * pws[p];
    }
  }
  return pws;
}

void solve() {
  Fp r;
  int d;
  cin >> r >> d;
  prepareFac(d + 2);
  cout << sumPolyLimit(r, getMonomials(d + 1, d));
}

int main() {
  solve();
}
Back to top page