titan-23 / Library_py

MIT License
3 stars 0 forks source link

WA HLD #3

Closed zronghui closed 10 months ago

zronghui commented 10 months ago

WA in problem:

[Template] Heavy-Light Decomposition

Problem Description

Given a tree with $N$ nodes (connected and acyclic), where each node contains a value, you need to support the following operations:

Input Format

The first line contains four positive integers $N$, $M$, $R$, and $P$, representing the number of nodes, the number of operations, the index of the root node, and the modulo value (i.e., all output results should be modulo $P$).

The second line contains $N$ non-negative integers, representing the initial values of each node.

The next $N-1$ lines each contain two integers $x$ and $y$, indicating an edge between node $x$ and node $y$ (guaranteed to be a connected tree without cycles).

The next $M$ lines each contain a series of positive integers, representing an operation.

Output Format

Output multiple lines, each representing the result of operation $2$ or operation $4$ (modulo $P$).

Example #1

Input

5 5 2 24
7 3 7 8 0 
1 2
1 5
3 1
4 1
3 4 2
3 2 2
4 5
1 5 1 3
2 1 3

Output

2
21

Note

[Data Constraints]

For $30\%$ of the data: $1 \leq N \leq 10$, $1 \leq M \leq 10$.

For $70\%$ of the data: $1 \leq N \leq {10}^3$, $1 \leq M \leq {10}^3$.

For $100\%$ of the data: $1 \leq N \leq {10}^5$, $1 \leq M \leq {10}^5$, $1 \leq R \leq N$, $1 \leq P \leq 2^{31}-1$.

[Example Explanation]

The tree structure is as follows:

![]()The operations are as follows:![]()Hence, the expected output is $2$ and $21$ in order.

my code

class HLD:
    def __init__(self, G: List[List[int]], root: int):
        n = len(G)
        self.n: int = n
        self.G: List[List[int]] = G
        self.size: List[int] = [1] * n
        self.par: List[int] = [-1] * n
        self.dep: List[int] = [-1] * n
        self.nodein: List[int] = [0] * n
        self.nodeout: List[int] = [0] * n
        self.head: List[int] = [0] * n
        self.hld: List[int] = [0] * n
        self._dfs(root)

    def _dfs(self, root: int) -> None:
        dep, par, size, G = self.dep, self.par, self.size, self.G
        dep[root] = 0
        stack = [root]
        while stack:
            v = stack.pop()
            if v >= 0:
                dep_nxt = dep[v] + 1
                for x in G[v]:
                    if dep[x] != -1:
                        continue
                    dep[x] = dep_nxt
                    stack.append(~x)
                    stack.append(x)
            else:
                v = ~v
                G_v, dep_v = G[v], dep[v]
                for i, x in enumerate(G_v):
                    if dep[x] < dep_v:
                        par[v] = x
                        continue
                    size[v] += size[x]
                    if size[x] > size[G_v[0]]:
                        G_v[0], G_v[i] = G_v[i], G_v[0]

        head, nodein, nodeout, hld = self.head, self.nodein, self.nodeout, self.hld
        curtime = 0
        stack = [~root, root]
        while stack:
            v = stack.pop()
            if v >= 0:
                if par[v] == -1:
                    head[v] = v
                nodein[v] = curtime
                hld[curtime] = v
                curtime += 1
                if not G[v]:
                    continue
                G_v0 = G[v][0]
                for x in reversed(G[v]):
                    if x == par[v]:
                        continue
                    head[x] = head[v] if x == G_v0 else x
                    stack.append(~x)
                    stack.append(x)
            else:
                nodeout[~v] = curtime

    def build_list(self, a: List[Any]) -> List[Any]:
        return [a[e] for e in self.hld]

    def for_each_vertex(self, u: int, v: int) -> Iterator[Tuple[int, int]]:
        head, nodein, dep, par = self.head, self.nodein, self.dep, self.par
        while head[u] != head[v]:
            if dep[head[u]] < dep[head[v]]:
                u, v = v, u
            yield nodein[head[u]], nodein[u] + 1
            u = par[head[u]]
        if dep[u] < dep[v]:
            u, v = v, u
        yield nodein[v], nodein[u] + 1

    def for_each_vertex_subtree(self, v: int) -> Iterator[Tuple[int, int]]:
        yield self.nodein[v], self.nodeout[v]

    def path_kth_elm(self, s: int, t: int, k: int) -> int:
        head, dep, par = self.head, self.dep, self.par
        lca = self.lca(s, t)
        d = dep[s] + dep[t] - 2 * dep[lca]
        if d < k:
            return -1
        if dep[s] - dep[lca] < k:
            s = t
            k = d - k
        hs = head[s]
        while dep[s] - dep[hs] < k:
            k -= dep[s] - dep[hs] + 1
            s = par[hs]
            hs = head[s]
        return self.hld[self.nodein[s] - k]

    def lca(self, u: int, v: int) -> int:
        nodein, head, par = self.nodein, self.head, self.par
        while True:
            if nodein[u] > nodein[v]:
                u, v = v, u
            if head[u] == head[v]:
                return u
            v = par[head[v]]

T = TypeVar("T")
F = TypeVar("F")

class HLDLazySegmentTree(Generic[T, F]):
    def __init__(
        self,
        hld: HLD,
        n_or_a: Union[int, Iterable[T]],
        op: Callable[[T, T], T],
        mapping: Callable[[F, T], T],
        composition: Callable[[F, F], F],
        e: T,
        id: F,
    ):
        self.hld: HLD = hld
        n_or_a = (
            n_or_a if isinstance(n_or_a, int) else self.hld.build_list(list(n_or_a))
        )
        self.seg: LazySegmentTree[T, F] = LazySegmentTree(
            n_or_a=n_or_a, op=op, mapping=mapping, composition=composition, e=e, id=id
        )
        self.op: Callable[[T, T], T] = op
        self.e: T = e

    def path_prod(self, u: int, v: int) -> T:
        head, nodein, dep, par = (
            self.hld.head,
            self.hld.nodein,
            self.hld.dep,
            self.hld.par,
        )
        res = self.e
        while head[u] != head[v]:
            if dep[head[u]] < dep[head[v]]:
                u, v = v, u
            res = self.op(res, self.seg.prod(nodein[head[u]], nodein[u] + 1))
            u = par[head[u]]
        if dep[u] < dep[v]:
            u, v = v, u
        return self.op(res, self.seg.prod(nodein[v], nodein[u] + 1))

    def path_apply(self, u: int, v: int, f: F) -> None:
        head, nodein, dep, par = (
            self.hld.head,
            self.hld.nodein,
            self.hld.dep,
            self.hld.par,
        )
        while head[u] != head[v]:
            if dep[head[u]] < dep[head[v]]:
                u, v = v, u
            self.seg.apply(nodein[head[u]], nodein[u] + 1, f)
            u = par[head[u]]
        if dep[u] < dep[v]:
            u, v = v, u
        self.seg.apply(nodein[v], nodein[u] + 1, f)

    def get(self, k: int) -> T:
        return self.seg[self.hld.nodein[k]]

    def set(self, k: int, v: T) -> None:
        self.seg[self.hld.nodein[k]] = v

    __getitem__ = get
    __setitem__ = set

    def subtree_prod(self, v: int) -> T:
        return self.seg.prod(self.hld.nodein[v], self.hld.nodeout[v])

    def subtree_apply(self, v: int, f: F) -> None:
        self.seg.apply(self.hld.nodein[v], self.hld.nodeout[v], f)

T = TypeVar("T")
F = TypeVar("F")

class LazySegmentTree(Generic[T, F]):
    def __init__(
        self,
        n_or_a: Union[int, Iterable[T]],
        op: Callable[[T, T], T],
        mapping: Callable[[F, T], T],
        composition: Callable[[F, F], F],
        e: T,
        id: F,
    ):
        self.e = e
        self.id = id
        self.op = op
        self.mapping = mapping
        self.composition = composition
        if isinstance(n_or_a, int):
            self.n = n_or_a
            self.log = (self.n - 1).bit_length()
            self.size = 1 << self.log
            self.data = [e] * (self.size << 1)
        else:
            a = list(n_or_a)
            self.n = len(a)
            self.log = (self.n - 1).bit_length()
            self.size = 1 << self.log
            data = [e] * (self.size << 1)
            data[self.size : self.size + self.n] = a
            for i in range(self.size - 1, 0, -1):
                data[i] = op(data[i << 1], data[i << 1 | 1])
            self.data = data
        self.lazy = [id] * self.size

    def _update(self, k: int) -> None:
        self.data[k] = self.op(self.data[k << 1], self.data[k << 1 | 1])

    def _all_apply(self, k: int, f: F) -> None:
        self.data[k] = self.mapping(f, self.data[k])
        if k < self.size:
            self.lazy[k] = self.composition(f, self.lazy[k])

    def _propagate(self, k: int) -> None:
        if self.lazy[k] == self.id:
            return
        self._all_apply(k << 1, self.lazy[k])
        self._all_apply(k << 1 | 1, self.lazy[k])
        self.lazy[k] = self.id

    def apply_point(self, k: int, f: F) -> None:
        k += self.size
        for i in range(self.log, 0, -1):
            self._propagate(k >> i)
        self.data[k] = self.mapping(f, self.data[k])
        for i in range(1, self.log + 1):
            self._update(k >> i)

    def apply(self, l: int, r: int, f: F) -> None:
        if l == r:
            return
        l += self.size
        r += self.size
        for i in range(self.log, 0, -1):
            if l >> i << i != l and self.lazy[l >> i] != self.id:
                self._propagate(l >> i)
            if r >> i << i != r and self.lazy[(r - 1) >> i] != self.id:
                self._propagate((r - 1) >> i)
        l2, r2 = l, r
        while l < r:
            if l & 1:
                self._all_apply(l, f)
                l += 1
            if r & 1:
                self._all_apply(r ^ 1, f)
            l >>= 1
            r >>= 1
        for i in range(1, self.log + 1):
            if l2 >> i << i != l2:
                self._update(l2 >> i)
            if r2 >> i << i != r2:
                self._update((r2 - 1) >> i)

    def all_apply(self, f: F) -> None:
        self.lazy[1] = self.composition(f, self.lazy[1])

    def prod(self, l: int, r: int) -> T:
        if l == r:
            return self.e
        l += self.size
        r += self.size
        for i in range(self.log, 0, -1):
            if l >> i << i != l and self.lazy[l >> i] != self.id:
                self._propagate(l >> i)
            if r >> i << i != r and self.lazy[r >> i] != self.id:
                self._propagate(r >> i)
        lres = self.e
        rres = self.e
        while l < r:
            if l & 1:
                lres = self.op(lres, self.data[l])
                l += 1
            if r & 1:
                rres = self.op(self.data[r ^ 1], rres)
            l >>= 1
            r >>= 1
        return self.op(lres, rres)

    def all_prod(self) -> T:
        return self.data[1]

    def all_propagate(self) -> None:
        for i in range(self.size):
            self._propagate(i)

    def tolist(self) -> List[T]:
        self.all_propagate()
        return self.data[self.size : self.size + self.n]

    def max_right(self, l, f) -> int:
        assert 0 <= l <= self.n
        assert f(self.e)
        if l == self.size:
            return self.n
        l += self.size
        for i in range(self.log, 0, -1):
            self._propagate(l >> i)
        s = self.e
        while True:
            while l & 1 == 0:
                l >>= 1
            if not f(self.op(s, self.data[l])):
                while l < self.size:
                    self._propagate(l)
                    l <<= 1
                    if f(self.op(s, self.data[l])):
                        s = self.op(s, self.data[l])
                        l |= 1
                return l - self.size
            s = self.op(s, self.data[l])
            l += 1
            if l & -l == l:
                break
        return self.n

    def min_left(self, r: int, f) -> int:
        assert 0 <= r <= self.n
        assert f(self.e)
        if r == 0:
            return 0
        r += self.size
        for i in range(self.log, 0, -1):
            self._propagate((r - 1) >> i)
        s = self.e
        while True:
            r -= 1
            while r > 1 and r & 1:
                r >>= 1
            if not f(self.op(self.data[r], s)):
                while r < self.size:
                    self._propagate(r)
                    r = r << 1 | 1
                    if f(self.op(self.data[r], s)):
                        s = self.op(self.data[r], s)
                        r ^= 1
                return r + 1 - self.size
            s = self.op(self.data[r], s)
            if r & -r == r:
                break
        return 0

    def __getitem__(self, k: int) -> T:
        k += self.size
        for i in range(self.log, 0, -1):
            self._propagate(k >> i)
        return self.data[k]

    def __setitem__(self, k: int, v: T):
        k += self.size
        for i in range(self.log, 0, -1):
            self._propagate(k >> i)
        self.data[k] = v
        for i in range(1, self.log + 1):
            self._update(k >> i)

    def __str__(self) -> str:
        return (
            "["
            + ", ".join(map(str, (self.__getitem__(i) for i in range(self.n))))
            + "]"
        )

    def __repr__(self):
        return f"LazySegmentTree({self})"

T = 1

for _ in range(T):
    n, m, r, p = MII()

    r -= 1
    a = LII()
    tree = read_tree(n, base=1)
    hld = HLD(tree, r)

    def op(s, t):
        return (s + t) % p

    def mapping(f, x):
        return (f + x) % p

    def composition(f, g):
        return (f + g) % p

    hldst = HLDLazySegmentTree(
        hld, n_or_a=a, op=op, mapping=mapping, composition=composition, e=0, id=0
    )
    for _ in range(m):
        t = LII()
        if t[0] == 1:
            x, y, z = t[1:]
            x, y = x - 1, y - 1
            hldst.path_apply(x, y, z)
        elif t[0] == 2:
            x, y = t[1:]
            x, y = x - 1, y - 1
            print(hldst.path_prod(x, y))
        elif t[0] == 3:
            x, z = t[1:]
            x = x - 1
            hldst.subtree_apply(x, z)
        elif t[0] == 4:
            x = t[1] - 1
            print(hldst.subtree_prod(x))

it can pass samples, but wa when submit. input

8 10 6 623232
56 838 680 614 846 408 890 829 
1 2
2 3
3 4
1 5
4 7
4 8
7 6
3 4 859
1 4 2 189
2 1 7
1 1 2 159
1 6 7 553
3 5 649
1 3 2 672
4 6
4 8
4 1

expected output

14299
1688
3428

my output

6033
13746
1688
3428
zronghui commented 10 months ago

link of the problem: P3384 【模板】重链剖分/树链剖分 - 洛谷 | 计算机科学教育新生态 https://www.luogu.com.cn/problem/P3384

titan-23 commented 10 months ago

Thank you for reaching out with the issue.

The library seems to be correct, and it appears there may be an issue in your code.

For interval addition and interval sum retrieval, each node of the SegmentTree should maintain the interval sum and interval length.

Here is an example of the implementation:

from Library_py.Graph.HLD.HLD import HLD
from Library_py.Graph.HLD.HLDLazySegmentTree import HLDLazySegmentTree

n, m, r, p = map(int, input().split())
r -= 1
a = list(map(int, input().split()))
tree = [[] for _ in range(n)]
for _ in range(n-1):
  u, v = map(int, input().split())
  u -= 1
  v -= 1
  tree[u].append(v)
  tree[v].append(u)
hld = HLD(tree, r)

class Data():

  def __init__(self, val, size):
    self.val = val
    self.size = size

def op(s: Data, t: Data):
  return Data(s.val+t.val, s.size+t.size)

def mapping(f: int, s: Data):
  return Data(s.val+f*s.size, s.size)

def composition(f: int, g: int):
  return f + g

e = Data(0, 0)
id = 0

a = [Data(v, 1) for v in a]

hldst = HLDLazySegmentTree(
    hld, n_or_a=a, op=op, mapping=mapping, composition=composition, e=e, id=id
)
for _ in range(m):
  t = list(map(int, input().split()))
  if t[0] == 1:
    x, y, z = t[1:]
    x, y = x - 1, y - 1
    hldst.path_apply(x, y, z)
  elif t[0] == 2:
    x, y = t[1:]
    x, y = x - 1, y - 1
    print(hldst.path_prod(x, y).val)
  elif t[0] == 3:
    x, z = t[1:]
    x = x - 1
    hldst.subtree_apply(x, z)
  elif t[0] == 4:
    x = t[1] - 1
    print(hldst.subtree_prod(x).val)

Sample Input:

8 10 6 623232
56 838 680 614 846 408 890 829 
1 2
2 3
3 4
1 5
4 7
4 8
7 6
3 4 859
1 4 2 189
2 1 7
1 1 2 159
1 6 7 553
3 5 649
1 3 2 672
4 6
4 8
4 1

Sample Output:

7081
14299
1688
3428

Please check it.

zronghui commented 10 months ago

Thanks a lot. I'm such an idiot. I forgot to consider the detail of interval length. There is no problem with the code being correct. By the way, using a class instead of a tuple causes more memory consumption. The time and space constraints on this test were strict, with 2 test cases timed out. I'm going to ignore it and do other problems of the same type.