Constraints:
The number of nodes in the tree is in the range [2, 5 * 10^4].
1 <= Node.val <= 10^4
Analysis
Time: O(n)
Space: O(h)
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def maxProduct(self, root: Optional[TreeNode]) -> int:
def getTotal(node):
if not node: return 0
leftTotal = getTotal(node.left)
rightTotal = getTotal(node.right)
return leftTotal + rightTotal + node.val
def trySplit(node):
if not node:
return 0
nonlocal total
nonlocal ans
ltotal = trySplit(node.left)
rtotal = trySplit(node.right)
# try split with left edge
ans = max(ans, (total - ltotal)*ltotal)
# try split with right edge
ans = max(ans, (total - rtotal)*rtotal)
return ltotal + rtotal + node.val
total = getTotal(root)
ans = 0
trySplit(root)
return ans % (10**9+7)
Note
The key here is to calculate the total
After total is calculated we can post order accumulate subtree sum, for each node we can decide to disconnect the tree with left edge or right edge.
Binary Tree Traversal
Constraints: The number of nodes in the tree is in the range [2, 5 * 10^4]. 1 <= Node.val <= 10^4
Analysis
Note