Skip to content
树形DP入门: 树的直径

"如果能够找到与原问题相似的子问题,我们就能用递归解决"
--- 0x3f

而树的子树,本身就和父亲有相似性,所有天然的具备原问题与子问题的相似性


重点

从二叉树的直径推到树的直径

  • 枚举拐点考虑
  • 有时候DP计算边好写,有时候计算点好写,只需要记住边与点的转换关系即可(边+1=点,注意特殊情况点/边为0要特别处理)

543. 二叉树的直径

枚举拐点考虑,从边的角度写

py
# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def diameterOfBinaryTree(self, root: Optional[TreeNode]) -> int:
        ans=0
        def f(root):
            if not root:
                return -1
            l=f(root.left)
            r=f(root.right)
            nonlocal ans
            ans=max(ans,l+r+2)

            return max(l,r)+1
        f(root)
        return ans

124. 二叉树中的最大路径和

枚举拐点考虑,从点的角度写

py
# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def maxPathSum(self, root: Optional[TreeNode]) -> int:
        ans=-inf
        def f(root):
            if not root:
                return 0
            l=f(root.left)
            r=f(root.right)
            nonlocal ans
            ans=max(ans,l+r+root.val)
            return max(0,l+root.val,r+root.val)
        f(root)
        return ans

2246. 相邻字符不同的最长路径

可将这题视为树形DP-树的直径的模板题

同样枚举拐点考虑,记得先建图

这里dp时算边,答案要求点,所以+1

py
class Solution:
    def longestPath(self, parent: List[int], s: str) -> int:
        n=len(parent)
        g=[[] for _ in range(n)]
        for i in range(1,n):
            g[parent[i]].append(i)
        ans=0
        def f(root):
            if root>=n:
                return 0
            res=0
            for son in g[root]:
                v=f(son)+1
                nonlocal ans
                if s[root]!=s[son]:
                    ans=max(ans,v+res)
                    res=max(res,v)
            return res
        f(0)
        return ans+1

687. 最长同值路径

这里维护点写起来比较优雅,然后答案求边则要-1

然后注意树为空时边数应该返回0(虽然这时点数也为0,但不要再减1)

py
# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def longestUnivaluePath(self, root: Optional[TreeNode]) -> int:
        ans=0
        def f(root):
            if not root:
                return 0
            l=f(root.left)
            r=f(root.right)
            if not l or root.left.val!=root.val: l=0
            if not r or root.right.val!=root.val: r=0
            nonlocal ans
            ans=max(ans,l+r+1)
            return max(l,r)+1
        f(root)
        return max(ans-1,0)

1617. 统计子树中城市之间最大距离

二进制枚举+状压+树的直径

一些值得注意的小细节:

  • 状压use有两个用途: 用于对比p和标记是否走过
  • 如果对于一颗树(不是图)单纯的标记是否走过其实可以传递fa参数
  • 为什么要对比use==p?为了去掉一些重复非法情况,比如1,2联通但与4不联通时,枚举子树为点1,2,4实际上是非法的,又因为use是所有走过的点,如果use不等于二进制枚举的状态则说明枚举非法(非子树)
  • 直径dp入口为二进制枚举的第一个1,即log2(lowbit(type))
  • 点从1开始,因为有点地方偏移一下
py
class Solution:
    def countSubgraphsForEachDiameter(self, n: int, edges: List[List[int]]) -> List[int]:
        g=[[] for _ in range(n+5)]
        for x,y in edges:
            g[x].append(y)
            g[y].append(x)
        res=[0 for _ in range(n)]
        use=ans=0
        def f(x,p):
            ret=0
            for y in g[x]:
                nonlocal use,ans
                if (not (p>>(y-1))&1) or (use>>(y-1))&1:
                    continue
                use|=1<<(y-1)
                length=f(y,p)
                length+=1
                ans=max(ans,ret+length)
                ret=max(ret,length)
            return ret
        for p in range(1,1<<n):
            use=p&(-p)
            ans=0
            root=int(log2(p&(-p)))+1
            f(root,p)
            if use==p:
                res[ans]+=1
        return res[1:]

2538. 最大价值和与最小价值和的差值

很好的一道题,纠正了我对树的直径的一些错误理解与思路

记忆化枚举根会被菊花图hack掉,所以T了

TLE做法54/58
py
class Solution:
    def maxOutput(self, n: int, edges: List[List[int]], price: List[int]) -> int:
        g=[[] for _ in range(n)]
        for x,y in edges:
            g[x].append(y)
            g[y].append(x)
        ans=0
        @cache
        def f(x,fa):
            res=ans=0
            is_leaf=1
            for y in g[x]:
                if y!=fa:
                    is_leaf=0
                    v,vv=f(y,x)
                    ans=max(ans,res+v+price[x],vv)
                    res=max(res,v)
            if not is_leaf:
                res+=price[x]
            return res,ans
        ret=0
        for r in range(n):
            ret=max(ret,f(r,-1)[1])
        return ret

这么写也是对于树的直径的本质不够深入了解,需要明白: 枚举拐点考虑就是该dp的本质

因此类似的思路我们维护一个带叶子的一个不带叶子的即可

请思考以下问题:

为什么更新res与res_lef都要加上price[x]?为什么初始化的时候一个是0一个是price[x]?

答案

其实这两个是一个问题

本质上来说设为0和price[x]是为了保证叶子节点返回值的正确性,以及退化为单链的正确性(退化为单链的时候不再是拐点,可能为头),其他情况必然小于正常值,所以没影响

更新也一样,因为更新的时候必然不是叶子节点,所以都需要加上当前值

py
class Solution:
    def maxOutput(self, n: int, edges: List[List[int]], price: List[int]) -> int:
        g=[[] for _ in range(n)]
        for x,y in edges:
            g[x].append(y)
            g[y].append(x)
        ans=0
        def f(x,fa):
            res=0
            res_lef=price[x]
            for y in g[x]:
                if y!=fa:
                    v,v_lef=f(y,x)
                    nonlocal ans
                    ans=max(ans,res+v_lef,res_lef+v)
                    res=max(res,v+price[x])
                    res_lef=max(res_lef,v_lef+price[x])
            return res,res_lef
        f(0,-1)
        return ans