uralbash / sqlalchemy_mptt

SQLAlchemy nested sets mixin (MPTT)
http://sqlalchemy-mptt.readthedocs.io
MIT License
196 stars 32 forks source link

Moving node to leftmost position doesn't work #68

Open RafGb opened 4 years ago

RafGb commented 4 years ago

When i'm trying to move node to leftmost position using node.move_before nothing happens with tree. i think it's because there's no left_sibling here: https://github.com/uralbash/sqlalchemy_mptt/blob/master/sqlalchemy_mptt/events.py#L362

Here's test code:

class MoveLeft(self):
    def test_move_to_leftmost(self):
        self.session.query(self.model).delete()

        _level = self.model.get_default_level()
        pk_column = self.model.get_pk_column()

        self.session.add_all([
            self.model(**{pk_column.name: 1}),
            self.model(**{pk_column.name: 2, 'parent_id': 1}),
            self.model(**{pk_column.name: 3, 'parent_id': 1}),
            self.model(**{pk_column.name: 4, 'parent_id': 1}),
        ])

        # initial tree:
        #        1
        #   /    |   \
        #  2     3    4

        self.assertEqual(
            [
                (1, 1, 8, _level + 0, None, 1),
                (2, 2, 3, _level + 1, 1, 1),
                (3, 4, 5, _level + 1, 1, 1),
                (4, 6, 7, _level + 1, 1, 1)
            ],
            self.result.all()
        )

        # move 4 to left
        node4 = self.session.query(self.model).filter(pk_column == 4).one()
        node4.move_before("2")

        # expected result:
        #        1
        #   /    |   \
        #  4     2    3

        self.assertEqual(
            [
                (1, 1, 8, _level + 0, None, 1),
                (2, 4, 5, _level + 0, 1, 1),
                (3, 6, 7, _level + 0, 1, 1),
                (4, 2, 3, _level + 0, 1, 1)
            ],
            self.result.all()
        )
Meller008 commented 3 years ago

When i'm trying to move node to leftmost position using node.move_before nothing happens with tree. i think it's because there's no left_sibling here: https://github.com/uralbash/sqlalchemy_mptt/blob/master/sqlalchemy_mptt/events.py#L362

Here's test code:

class MoveLeft(self):
    def test_move_to_leftmost(self):
        self.session.query(self.model).delete()

        _level = self.model.get_default_level()
        pk_column = self.model.get_pk_column()

        self.session.add_all([
            self.model(**{pk_column.name: 1}),
            self.model(**{pk_column.name: 2, 'parent_id': 1}),
            self.model(**{pk_column.name: 3, 'parent_id': 1}),
            self.model(**{pk_column.name: 4, 'parent_id': 1}),
        ])

        # initial tree:
        #        1
        #   /    |   \
        #  2     3    4

        self.assertEqual(
            [
                (1, 1, 8, _level + 0, None, 1),
                (2, 2, 3, _level + 1, 1, 1),
                (3, 4, 5, _level + 1, 1, 1),
                (4, 6, 7, _level + 1, 1, 1)
            ],
            self.result.all()
        )

        # move 4 to left
        node4 = self.session.query(self.model).filter(pk_column == 4).one()
        node4.move_before("2")

        # expected result:
        #        1
        #   /    |   \
        #  4     2    3

        self.assertEqual(
            [
                (1, 1, 8, _level + 0, None, 1),
                (2, 4, 5, _level + 0, 1, 1),
                (3, 6, 7, _level + 0, 1, 1),
                (4, 2, 3, _level + 0, 1, 1)
            ],
            self.result.all()
        )

Yes, there is such a problem. You have correctly specified the location of the error. I apply a small patch before the move_before operation. And then it works correctly! Thanks!

igorwang commented 3 years ago

Has the problem been fixed?

sushauai commented 3 years ago

When i'm trying to move node to leftmost position using node.move_before nothing happens with tree. i think it's because there's no left_sibling here: https://github.com/uralbash/sqlalchemy_mptt/blob/master/sqlalchemy_mptt/events.py#L362 Here's test code:

class MoveLeft(self):
    def test_move_to_leftmost(self):
        self.session.query(self.model).delete()

        _level = self.model.get_default_level()
        pk_column = self.model.get_pk_column()

        self.session.add_all([
            self.model(**{pk_column.name: 1}),
            self.model(**{pk_column.name: 2, 'parent_id': 1}),
            self.model(**{pk_column.name: 3, 'parent_id': 1}),
            self.model(**{pk_column.name: 4, 'parent_id': 1}),
        ])

        # initial tree:
        #        1
        #   /    |   \
        #  2     3    4

        self.assertEqual(
            [
                (1, 1, 8, _level + 0, None, 1),
                (2, 2, 3, _level + 1, 1, 1),
                (3, 4, 5, _level + 1, 1, 1),
                (4, 6, 7, _level + 1, 1, 1)
            ],
            self.result.all()
        )

        # move 4 to left
        node4 = self.session.query(self.model).filter(pk_column == 4).one()
        node4.move_before("2")

        # expected result:
        #        1
        #   /    |   \
        #  4     2    3

        self.assertEqual(
            [
                (1, 1, 8, _level + 0, None, 1),
                (2, 4, 5, _level + 0, 1, 1),
                (3, 6, 7, _level + 0, 1, 1),
                (4, 2, 3, _level + 0, 1, 1)
            ],
            self.result.all()
        )

Yes, there is such a problem. You have correctly specified the location of the error. I apply a small patch before the move_before operation. And then it works correctly! Thanks!

How to fix it?

Meller008 commented 2 years ago

I am applying a patch with a line change if not left_sibling

I did this

from sqlalchemy_mptt import events
from sqlalchemy import and_, select
from sqlalchemy.sql import func

def mptt_before_update(mapper, connection, instance):
    """ Based on this example:
        http://stackoverflow.com/questions/889527/move-node-in-nested-set
    """
    node_id = getattr(instance, instance.get_pk_name())
    table = events._get_tree_table(mapper)
    db_pk = instance.get_pk_column()
    default_level = instance.get_default_level()
    table_pk = getattr(table.c, db_pk.name)
    mptt_move_inside = None
    left_sibling = None
    left_sibling_tree_id = None

    if hasattr(instance, 'mptt_move_inside'):
        mptt_move_inside = instance.mptt_move_inside

    if hasattr(instance, 'mptt_move_before'):
        (
            right_sibling_left,
            right_sibling_right,
            right_sibling_parent,
            right_sibling_level,
            right_sibling_tree_id
        ) = connection.execute(
            select(
                [
                    table.c.lft,
                    table.c.rgt,
                    table.c.parent_id,
                    table.c.level,
                    table.c.tree_id
                ]
            ).where(
                table_pk == instance.mptt_move_before
            )
        ).fetchone()
        current_lvl_nodes = connection.execute(
            select(
                [
                    table.c.lft,
                    table.c.rgt,
                    table.c.parent_id,
                    table.c.tree_id
                ]
            ).where(
                and_(
                    table.c.level == right_sibling_level,
                    table.c.tree_id == right_sibling_tree_id,
                    table.c.lft < right_sibling_left
                )
            )
        ).fetchall()
        if current_lvl_nodes:
            (
                left_sibling_left,
                left_sibling_right,
                left_sibling_parent,
                left_sibling_tree_id
            ) = current_lvl_nodes[-1]
            instance.parent_id = left_sibling_parent
            left_sibling = {
                'lft': left_sibling_left,
                'rgt': left_sibling_right,
                'is_parent': False
            }
        # if move_before to top level
        elif not right_sibling_parent:
            left_sibling_tree_id = right_sibling_tree_id - 1

    # if placed after a particular node
    if hasattr(instance, 'mptt_move_after'):
        (
            left_sibling_left,
            left_sibling_right,
            left_sibling_parent,
            left_sibling_tree_id
        ) = connection.execute(
            select(
                [
                    table.c.lft,
                    table.c.rgt,
                    table.c.parent_id,
                    table.c.tree_id
                ]
            ).where(
                table_pk == instance.mptt_move_after
            )
        ).fetchone()
        instance.parent_id = left_sibling_parent
        left_sibling = {
            'lft': left_sibling_left,
            'rgt': left_sibling_right,
            'is_parent': False
        }

    """ Get subtree from node

        SELECT id, name, level FROM my_tree
        WHERE left_key >= $left_key AND right_key <= $right_key
        ORDER BY left_key
    """
    subtree = connection.execute(
        select([table_pk])
        .where(
            and_(
                table.c.lft >= instance.left,
                table.c.rgt <= instance.right,
                table.c.tree_id == instance.tree_id
            )
        ).order_by(
            table.c.lft
        )
    ).fetchall()
    subtree = [x[0] for x in subtree]

    """ step 0: Initialize parameters.

        Put there left and right position of moving node
    """
    (
        node_pos_left,
        node_pos_right,
        node_tree_id,
        node_parent_id,
        node_level
    ) = connection.execute(
        select(
            [
                table.c.lft,
                table.c.rgt,
                table.c.tree_id,
                table.c.parent_id,
                table.c.level
            ]
        ).where(
            table_pk == node_id
        )
    ).fetchone()

    # if instance just update w/o move
    # XXX why this str() around parent_id comparison?
    # Changed the condition
    if left_sibling \
            and str(node_parent_id) == str(instance.parent_id) \
            and not mptt_move_inside:
        if left_sibling_tree_id is None:
            return

    # fix tree shorting
    if instance.parent_id is not None:
        (
            parent_id,
            parent_pos_right,
            parent_pos_left,
            parent_tree_id,
            parent_level
        ) = connection.execute(
            select(
                [
                    table_pk,
                    table.c.rgt,
                    table.c.lft,
                    table.c.tree_id,
                    table.c.level
                ]
            ).where(
                table_pk == instance.parent_id
            )
        ).fetchone()
        if node_parent_id is None and node_tree_id == parent_tree_id:
            instance.parent_id = None
            return

    # delete from old tree
    events.mptt_before_delete(mapper, connection, instance, False)

    if instance.parent_id is not None:
        """ Put there right position of new parent node (there moving node
            should be moved)
        """
        (
            parent_id,
            parent_pos_right,
            parent_pos_left,
            parent_tree_id,
            parent_level
        ) = connection.execute(
            select(
                [
                    table_pk,
                    table.c.rgt,
                    table.c.lft,
                    table.c.tree_id,
                    table.c.level
                ]
            ).where(
                table_pk == instance.parent_id
            )
        ).fetchone()
        # 'size' of moving node (including all it's sub nodes)
        node_size = node_pos_right - node_pos_left + 1

        # left sibling node
        if not left_sibling:
            left_sibling = {
                'lft': parent_pos_left,
                'rgt': parent_pos_right,
                'is_parent': True
            }

        # insert subtree in exist tree
        instance.tree_id = parent_tree_id
        events._insert_subtree(
            table,
            connection,
            node_size,
            node_pos_left,
            node_pos_right,
            parent_pos_left,
            parent_pos_right,
            subtree,
            parent_tree_id,
            parent_level,
            node_level,
            left_sibling,
            table_pk
        )
    else:
        # if insert after
        if left_sibling_tree_id or left_sibling_tree_id == 0:
            tree_id = left_sibling_tree_id + 1
            connection.execute(
                table.update(
                    table.c.tree_id > left_sibling_tree_id
                ).values(
                    tree_id=table.c.tree_id + 1
                )
            )
        # if just insert
        else:
            tree_id = connection.scalar(
                select(
                    [
                        func.max(table.c.tree_id) + 1
                    ]
                )
            )

        connection.execute(
            table.update(
                table_pk.in_(
                    subtree
                )
            ).values(
                lft=table.c.lft - node_pos_left + 1,
                rgt=table.c.rgt - node_pos_left + 1,
                level=table.c.level - node_level + default_level,
                tree_id=tree_id
            )
        )

def apply_path():
    events.mptt_before_update = mptt_before_update

I'm using this patch in the right place

patch_mptt.apply_path()
move_node.move_before(before_id)
db.commit()