【题解】P2573 [SCOI2012]滑雪

题目链接(洛谷)

题目大意

小a在一座雪山滑雪,这里分布着$m$条供滑行的轨道和$n$个轨道之间的交点(同时也是景点),而且每个景点都有一个编号$i$和一个高度$h_i$ 。给出每条轨道的长度$k_i$小a能从景点$i$滑到景点$j$当且仅当存在$i$和$j$之间的边,且$i$的高度不小于$j$。

小a喜欢用最短的滑行路径去访问尽量多的景点。如果仅仅访问一条路径上的景点,他会觉得数量太少。于是小a拿出了他随身携带的时间胶囊。这是一种很神奇的药物,吃下之后可以立即回到上个经过的景点(不用移动也不被认为是小a滑行的距离)。

这种神奇的药物是可以连续食用且不考虑消耗的,即能够回到较长时间之前到过的景点(比如上上个经过的景点和上上上个经过的景点)。现在,小a在$1$号景点望以最短滑行距离滑到尽量多的景点的方案(即满足经过景点数最大的前提下使得滑行总距离最小)。你能帮他求出最短距离和景点数吗?

$1\le n \le 10^5,1\le m\le 10^6,1\le h_i \le 10^9,1\le k_i \le 10^9$

分析

题意转换:一张带权有向图,问从$1$开始最多能深度遍历多少点,在遍历点最多的情况下,使得所有经过的边权值和最小。所以求有向图的最小生成树。

对于最大景点数,深搜即可,同时将经过的边记录下来。

对于最短距离,使用最小生成树。对于所有经过的边,以边终点的高度为第一关键字,从大到小,以边权为第二关键字,从小到大。之后跑最小生成树即可。

实现

$(100pts,841ms,45.49ms)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#include<bits/stdc++.h>

using namespace std;
const int N = 1000005;
#define int long long
int n, m, h[N], fa[N], maxn, ans, cnt, sum, num;
bool vis[N];

int find(int x) {
if (fa[x] != x) fa[x] = find(fa[x]);
return fa[x];
}

void merge(int x, int y) {
x = find(x);
y = find(y);
if (x != y) fa[y] = x;
}

struct node {
int to, next, dis;
} e[N << 2];
struct edge {
int st, ed, dis;
} ed[N << 2];

int head[N], tmp;

void add(int x, int y, int z) {
e[++tmp].to = y;
e[tmp].next = head[x];
e[tmp].dis = z;
head[x] = tmp;
}

void dfs(int cur) {
vis[cur] = true;
maxn++;
for (int i = head[cur]; i; i = e[i].next) {
int v = e[i].to;
ed[++cnt].st = cur;
ed[cnt].ed = v;
ed[cnt].dis = e[i].dis;
if (vis[v]) continue;
dfs(v);
}
}

bool cmp(edge x, edge y) {
if (h[x.ed] == h[y.ed]) return x.dis < y.dis;
return h[x.ed] > h[y.ed];
}

int read() {
int s = 0, w = 1;
char ch = getchar();
while (ch < '0' || ch > '9') {
if (ch == '-') w = -1;
ch = getchar();
}
while (ch >= '0' && ch <= '9') {
s = (s << 1) + (s << 3) + ch - '0';
ch = getchar();
}
return s * w;
}

signed main() {
n = read(), m = read();
for (int i = 1; i <= n; i++) {
h[i] = read();
fa[i] = i;
}
for (int i = 1; i <= m; i++) {
int u = read(), v = read(), t = read();
if (h[u] >= h[v]) add(u, v, t);
if (h[u] <= h[v]) add(v, u, t);

}
dfs(1);
cout << maxn << " ";
sort(ed + 1, ed + cnt + 1, cmp);
for (int i = 1; i <= cnt; i++) {
if (find(ed[i].st) != find(ed[i].ed)) {
merge(ed[i].st, ed[i].ed);
ans += ed[i].dis;
}
}
cout << ans << endl;
return 0;
}