- 节点选择
资源限制
时间限制:1.0s 内存限制:256.0MB
问题描述
有一棵 n 个节点的树,树上每个节点都有一个正整数权值。如果一个点被选择了,那么在树上和它相邻的点都不能被选择。求选出的点的权值和最大是多少?
输入格式
第一行包含一个整数 n 。
接下来的一行包含 n 个正整数,第 i 个正整数代表点 i 的权值。
接下来一共 n-1 行,每行描述树上的一条边。
输出格式
输出一个整数,代表选出的点的权值和的最大值。
样例输入
5
1 2 3 4 5
1 2
1 3
2 4
2 5
样例输出
12
样例说明
选择3、4、5号点,权值和为 3+4+5 = 12 。
数据规模与约定
对于20%的数据, n <= 20。
对于50%的数据, n <= 1000。
对于100%的数据, n <= 100000。
权值均为不超过1000的正整数。
- 解答程序
k, l = list(map(int, input().split()))
s = [0 for i in range(l + 1)]
a = [[0 for i in range(l)] for i in range(k)]
# a[x][y]表示最高位为x,剩余长度为y的K好数数目
if l == 1:
if k == 1:
print(0)
else:
print(k)
else:
for i in range(k):
a[i][0] = 1
s[1] = k
for j in range(1,l):
for i in range(k):
if i == 0:
a[0][j] = (s[j] - a[1][j - 1]) % 1000000007
elif i == k - 1:
a[k - 1][j] = (s[j] - a[k - 2][j - 1]) % 1000000007
else:
a[i][j] = (s[j] - a[i - 1][j - 1] - a[i + 1][j - 1]) % 1000000007
s[j + 1] = (s[j + 1] + a[i][j]) % 1000000007
print((s[l] - a[0][l - 1]) % 1000000007)
- 经验教训
1、第一个版本的代码为
import copy
n = int(input())
weight_list =[0] + list(map(int, input().split()))
aix_dict = {}
for i in range(1, n + 1):
aix_dict[i] = []
for i in range(n - 1):
a, b = list(map(int, input().split()))
aix_dict[a].append(b)
aix_dict[b].append(a)
flag_list = [0 for i in range(n + 1)]
chosen_list = [0 for i in range(n + 1)]
max_list = [0 for i in range(n + 1)]
def tab_flag(a, temp_flag_list):
for i in aix_dict[a]:
temp_flag_list[i] = 1
def count_weight(tmep_chosen_list):
sum = 0
for i in range(n+1):
if tmep_chosen_list[i] == 1:
sum += weight_list[i]
return sum
max_score = 0
def dp(temp_flag_list, temp_chosen_list):
for i in range(1, n + 1):
if temp_flag_list[i] == 0:
temp_flag = copy.deepcopy(temp_flag_list)
temp_flag[i] = 1
temp_chosen = copy.deepcopy(temp_chosen_list)
temp_chosen[i] = 1
tab_flag(i, temp_flag)
chosen = dp(temp_flag, temp_chosen)
s = count_weight(chosen)
global max_score
if max_score < s:
max_score = s
else:
return temp_chosen_list
dp(flag_list, chosen_list)
print(max_score)import copy
n = int(input())
weight_list =[0] + list(map(int, input().split()))
aix_dict = {}
for i in range(1, n + 1):
aix_dict[i] = []
for i in range(n - 1):
a, b = list(map(int, input().split()))
aix_dict[a].append(b)
aix_dict[b].append(a)
flag_list = [0 for i in range(n + 1)]
chosen_list = [0 for i in range(n + 1)]
max_list = [0 for i in range(n + 1)]
def tab_flag(a, temp_flag_list):
for i in aix_dict[a]:
temp_flag_list[i] = 1
def count_weight(tmep_chosen_list):
sum = 0
for i in range(n+1):
if tmep_chosen_list[i] == 1:
sum += weight_list[i]
return sum
max_score = 0
def dp(temp_flag_list, temp_chosen_list):
for i in range(1, n + 1):
if temp_flag_list[i] == 0:
temp_flag = copy.deepcopy(temp_flag_list)
temp_flag[i] = 1
temp_chosen = copy.deepcopy(temp_chosen_list)
temp_chosen[i] = 1
tab_flag(i, temp_flag)
chosen = dp(temp_flag, temp_chosen)
s = count_weight(chosen)
global max_score
if max_score < s:
max_score = s
else:
return temp_chosen_list
dp(flag_list, chosen_list)
print(max_score)
此版本代码存在问题是蓝桥的OJ只提供了math的库,不提供copy库。因此需要手动实现以下deepcopy。
2、修改代码后为
n = int(input())
weight_list =[0] + list(map(int, input().split()))
aix_dict = {}
for i in range(1, n + 1):
aix_dict[i] = []
for i in range(n - 1):
a, b = list(map(int, input().split()))
aix_dict[a].append(b)
aix_dict[b].append(a)
flag_list = [0 for i in range(n + 1)]
chosen_list = [0 for i in range(n + 1)]
max_list = [0 for i in range(n + 1)]
def tab_flag(a, temp_flag_list):
for i in aix_dict[a]:
temp_flag_list[i] = 1
def count_weight(tmep_chosen_list):
sum = 0
for i in range(n+1):
if tmep_chosen_list[i] == 1:
sum += weight_list[i]
return sum
def deepcopy(temp_list):
temp = []
for i in temp_list:
temp.append(i)
return temp
max_score = 0
def dp(temp_flag_list, temp_chosen_list):
for i in range(1, n + 1):
if temp_flag_list[i] == 0:
temp_flag = deepcopy(temp_flag_list)
temp_flag[i] = 1
temp_chosen = deepcopy(temp_chosen_list)
temp_chosen[i] = 1
tab_flag(i, temp_flag)
chosen = dp(temp_flag, temp_chosen)
s = count_weight(chosen)
global max_score
if max_score < s:
max_score = s
else:
return temp_chosen_list
dp(flag_list, chosen_list)
print(max_score)
此代码运行超时,只拿到了10分,运行时间超标,并且内存占用达到了180MB。需要做优化。明日继续更新。
3、啊啊啊刚敲完明日更新突然领悟了,如果用for循环加递归会造成超时和内存占用过大的话,采用DP就可以很好的解决问题!(明天再写DP版算法)