algo

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub dnx04/algo

:heavy_check_mark: tests/Assignment_Problem.test.cpp

Depends on

Code

#define PROBLEM "https://judge.yosupo.jp/problem/assignment"

#include "../misc/macros.h"
#include "../graph/MinAssignment.h"

void solve() {
  int n;
  cin >> n;
  vector<vector<i64>> W(n, vector<i64>(n));
  for (int i = 0; i < n; ++i) {
    for (int j = 0; j < n; ++j) {
      cin >> W[i][j];
    }
  }
  auto [ret, L] = MinAssignment(W);
  cout << ret << '\n';
  for (int i = 0; i < n; ++i) cout << L[i] << ' ';
}


int main() {
  cin.tie(0)->sync_with_stdio(0);
  cin.exceptions(cin.failbit);
  int tc = 1;
  // cin >> tc;
  for (int i = 1; i <= tc; ++i) {
    solve();
  }
}
#line 1 "tests/Assignment_Problem.test.cpp"
#define PROBLEM "https://judge.yosupo.jp/problem/assignment"

#line 1 "misc/macros.h"
// #pragma GCC optimize("Ofast,unroll-loops")       // unroll long, simple loops
// #pragma GCC target("avx2,fma")                   // vectorizing code
// #pragma GCC target("lzcnt,popcnt,abm,bmi,bmi2")  // for fast bitset operation

#include <bits/extc++.h>
#include <tr2/dynamic_bitset>

using namespace std;
using namespace __gnu_pbds;  // ordered_set, gp_hash_table
// using namespace __gnu_cxx; // rope

// for templates to work
#define all(x) (x).begin(), (x).end()
#define sz(x) (int) (x).size()
#define pb push_back
#define eb emplace_back
using i32 = int32_t;
using u32 = uint32_t;
using i64 = int64_t;
using u64 = uint64_t;
using i128 = __int128_t;
using u128 = __uint128_t;
using ld = long double;
using pii = pair<i32, i32>;
using vi = vector<i32>;

// fast map
const int RANDOM = chrono::high_resolution_clock::now().time_since_epoch().count();
struct chash {  // customize hash function for gp_hash_table
  int operator()(int x) const { return x ^ RANDOM; }
};
gp_hash_table<int, int, chash> table;

/* ordered set
    find_by_order(k): returns an iterator to the k-th element (0-based)
    order_of_key(k): returns the number of elements in the set that are strictly less than k
*/
template <class T>
using ordered_set = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;

/*  rope
    rope <int> cur = v.substr(l, r - l + 1);
    v.erase(l, r - l + 1);
    v.insert(v.mutable_begin(), cur);
*/
#line 1 "graph/MinAssignment.h"
pair<i64, vector<int>> MinAssignment(const vector<vector<i64>>& W) {
  int n = W.size(), m = W[0].size();  // assert(n <= m);
  vector<i64> v(m), dist(m);           // v: potential
  vector<int> L(n, -1), R(m, -1);     // matching pairs
  vector<int> idx(m), prev(m);
  iota(idx.begin(), idx.end(), 0);

  i64 w, h;
  int j, l, s, t;
  auto reduce = [&]() {
    if (s == t) {
      l = s;
      w = dist[idx[t++]];
      for (int k = t; k < m; ++k) {
        j = idx[k];
        h = dist[j];
        if (h > w) continue;
        if (h < w) t = s, w = h;
        idx[k] = idx[t];
        idx[t++] = j;
      }
      for (int k = s; k < t; ++k) {
        j = idx[k];
        if (R[j] < 0) return 1;
      }
    }
    int q = idx[s++], p = R[q];
    for (int k = t; k < m; ++k) {
      j = idx[k];
      h = W[p][j] - W[p][q] + v[q] - v[j] + w;
      if (h < dist[j]) {
        dist[j] = h;
        prev[j] = p;
        if (h == w) {
          if (R[j] < 0) return 1;
          idx[k] = idx[t];
          idx[t++] = j;
        }
      }
    }
    return 0;
  };
  for (int i = 0; i < n; ++i) {
    for (int k = 0; k < m; ++k) dist[k] = W[i][k] - v[k], prev[k] = i;
    s = t = 0;
    while (!reduce());
    for (int k = 0; k < l; ++k) v[idx[k]] += dist[idx[k]] - w;
    for (int k = -1; k != i;) R[j] = k = prev[j], swap(j, L[k]);
  }
  i64 ret = 0;
  for (int i = 0; i < n; ++i) ret += W[i][L[i]];  // (i, L[i]) is a solution
  return {ret, L};
}
#line 5 "tests/Assignment_Problem.test.cpp"

void solve() {
  int n;
  cin >> n;
  vector<vector<i64>> W(n, vector<i64>(n));
  for (int i = 0; i < n; ++i) {
    for (int j = 0; j < n; ++j) {
      cin >> W[i][j];
    }
  }
  auto [ret, L] = MinAssignment(W);
  cout << ret << '\n';
  for (int i = 0; i < n; ++i) cout << L[i] << ' ';
}


int main() {
  cin.tie(0)->sync_with_stdio(0);
  cin.exceptions(cin.failbit);
  int tc = 1;
  // cin >> tc;
  for (int i = 1; i <= tc; ++i) {
    solve();
  }
}
Back to top page