题解
Step1:公式分析
先把坐标从小到大排序,记第j头牛为(v_j,x_j)。
\sum_{i \in A}max\{ v_i,v_j \}|x_i-x_j|
我们对每个j,分为两个集合来处理:
A=\{ i \lt j \space | \space v_i \le v_j \}, max(v_i,v_j)=v_j \\ B=\{ i \lt j \space | \space v_i \gt v_j \}, max(v_i,v_j)=v_i \\
那么每个集合的贡献可拆分为:
\begin{aligned} =&\sum_{i \in A}v_j(x_j-x_i)+\sum_{i \in B}v_i(x_j-x_i) \\ =&\sum_{i \in A}v_jx_j-\sum_{i \in A}v_jx_i + \sum_{i \in B}v_ix_j-\sum_{i \in B}v_ix_i \\ =&v_j(|A|x_j-\sum_{i \in A}x_i) + x_j\sum_{i \in B}v_i - \sum_{i \in B}v_ix_i \end{aligned}
Step2:树状数组维护
我们分析上式中的四个变量: |A|、\sum_{i \in A}x_i、\sum_{i \in B}v_i、\sum_{i \in B}v_ix_i
从左到右扫描j时,能够得到:
以强度描v为索引的前缀: |A|、\sum_{i \in A}x_i
以强度描v为索引的后缀:\sum_{i \in B}v_i、\sum_{i \in B}v_ix_i
我们构建4颗树状数组维护即可
树状数组:
struct BIT {
int n;
vector<ll> t;
BIT(int n = 0) : n(n), t(n + 1, 0) {}
void add(int i, ll v) {
for (; i <= n; i += i & -i)
t[i] += v;
}
ll sum(int i) {
ll r = 0;
for (; i > 0; i -= i & -i) {
r += t[i];
}
return r;
}
};
维护处理:
ll solve() {
BIT bit_cnt(N), bit_sumx(N), bit_sumv(N), bit_sumvx(N);
ll ans = 0;
for (auto [v, x] : a) {
int idx = v;
ll cnt_le = bit_cnt.sum(idx);
ll sumx_le = bit_sumx.sum(idx);
ll sumv_le = bit_sumv.sum(idx);
ll sumvx_le = bit_sumvx.sum(idx);
ll sumv_all = bit_sumv.sum(N);
ll sumvx_all = bit_sumvx.sum(N);
ll sumv_gt = sumv_all - sumv_le;
ll sumvx_gt = sumvx_all - sumvx_le;
ans += 1ll * v * (cnt_le * x - sumx_le) +
1ll * x * sumv_gt - sumvx_gt;
bit_cnt.add(idx, 1);
bit_sumx.add(idx, x);
bit_sumv.add(idx, v);
bit_sumvx.add(idx, 1ll * x * v);
}
return ans;
}
完整代码:
时间复杂度:O(NlogN)
空间复杂度:O(N)
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 2e4 + 10;
int n;
vector<pair<int, int>> a;
struct BIT {
int n;
vector<ll> t;
BIT(int n = 0) : n(n), t(n + 1, 0) {}
void add(int i, ll v) {
for (; i <= n; i += i & -i)
t[i] += v;
}
ll sum(int i) {
ll r = 0;
for (; i > 0; i -= i & -i) {
r += t[i];
}
return r;
}
};
ll solve() {
BIT bit_cnt(N), bit_sumx(N), bit_sumv(N), bit_sumvx(N);
ll ans = 0;
for (auto [v, x] : a) {
int idx = v;
ll cnt_le = bit_cnt.sum(idx);
ll sumx_le = bit_sumx.sum(idx);
ll sumv_le = bit_sumv.sum(idx);
ll sumvx_le = bit_sumvx.sum(idx);
ll sumv_all = bit_sumv.sum(N);
ll sumvx_all = bit_sumvx.sum(N);
ll sumv_gt = sumv_all - sumv_le;
ll sumvx_gt = sumvx_all - sumvx_le;
ans += 1ll * v * (cnt_le * x - sumx_le) +
1ll * x * sumv_gt - sumvx_gt;
bit_cnt.add(idx, 1);
bit_sumx.add(idx, x);
bit_sumv.add(idx, v);
bit_sumvx.add(idx, 1ll * x * v);
}
return ans;
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
cin >> n;
a.resize(n);
for (int i = 0; i < n; i ++) {
cin >> a[i].first >> a[i].second;
}
sort(a.begin(), a.end(), [](auto &p, auto &q){
if (p.second != q.second) {
return p.second < q.second;
}
return p.first < q.first;
});
cout << solve() << endl;
return 0;
}