ROCm / AMDMIGraphX

AMD's graph optimization engine.
https://rocm.docs.amd.com/projects/AMDMIGraphX/en/latest/
MIT License
185 stars 86 forks source link

Fixed instruction::replace() logic. #3553

Closed tcgu-amd closed 2 weeks ago

tcgu-amd commented 3 weeks ago

The previous fix with BFS doesn't fully work in more complex cases (e.g. it will fail in the newly added test case check_replace_dag). This fix implements topological sorting to replace instructions in topological order which should work for all cases.

More details:

In a dummy scenario of add2(reduce(x), add1(abs(reduce(x)), sin(reduce(x)))), we will have a dependency tree looking like

reduce _
        \_abs__
         \_sin__\_add1_
          \_____________\_add2

If we call reduce.replace(), BFS will visit the instructions in the following order:

reduce -> abs -> sin -> add2 -> add1

This will causes an error of shape mismatch at add2 because it is called before its input add1.

Topological sorting the instruction tree will yield:

reduce -> sin -> abs -> add1 -> add2

Which is the correct order to process the instructions.

This should be able to extend to more complex cases.

codecov[bot] commented 3 weeks ago

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 92.16%. Comparing base (1e1a229) to head (92ebe7f).

Additional details and impacted files ```diff @@ Coverage Diff @@ ## develop #3553 +/- ## ======================================== Coverage 92.16% 92.16% ======================================== Files 512 512 Lines 21401 21408 +7 ======================================== + Hits 19724 19731 +7 Misses 1677 1677 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

pfultz2 commented 3 weeks ago

This seems like it will be really slow since it needs topologically sort until end of the model instead of just until the shapes no longer change.

tcgu-amd commented 3 weeks ago

This seems like it will be really slow since it needs topologically sort until end of the model instead of just until the shapes no longer change.

Yes unfortunately I think this is definitely going to be slower than the previous implementations. I am not quite sure if there's potentially a better approach since we don't know the dependencies of instructions beforehand until after the sort.

One way I can think of is to take an optimistic approach and perform BFS assuming everything is going to be fine, and on shape mismatch just push the instruction to the back of the queue. Only return the error if all instructions in the queue are shape mismatches. This is a little bit unconventional so I will need to test it to make sure it is going to generate correct results.

Edit: Actually, upon further consideration, I think this problem can be solved easily by using a modified version of Kahn's algorithm. I will update the code and try it out.

pfultz2 commented 3 weeks ago

There might be a way to traverse up the inputs to check for dependencies. I would need to think about it more.

tcgu-amd commented 3 weeks ago

Hi @pfultz2, I have created a new version of the algorithm that should have the same performance as the old versions.

This is loosely based on Khan's algorithm in that we only process nodes that has been visited by all its children that needs to be replaced.

To achieve this, we perform a BFS from the base instruction as usual, but keep a map counting the number of arguments for each instruction we encounter. If it an instruction is unary, then we can directly process the current instruction. If there's more than one argument, we subtract one from the number of arguments in the map and check to see if the number reaches zero, in which case all of the arguments must have been replaced and we can replace this instruction; otherwise some arguments may still need to be replaced, and we can just skip replacing this instruction for now and wait for it to be encounter again when one of its arguments ultimately adds it back to the queue.

For instructions that have more than one child, but only one of them needs to be replaced and the other ones are from unrelated sub-graphs, we can add them from the map to the queue when it empties, and try to process them. If this ends up generates a shape mismatch it will error out as normal.

Edit:

For instructions that have more than one child, but only one of them needs to be replaced and the other ones are from unrelated sub-graphs, we can add them from the map to the queue when it empties, and try to process them. If this ends up generates a shape mismatch it will error out as normal.

I just realized that there might still be a dependency between the instructions that needs to be partially replaced, and the current version may not be able to capture that..

pfultz2 commented 3 weeks ago

I would think instead you would check if the inputs reaches the instruction and then add that to a revisit queue:

void instruction::replace(const shape& r)
{
    if(r != result)
    {
        result = r;
        std::deque<instruction_ref> q(output.begin(), output.end());
        std::deque<instruction_ref> revisit;
        std::unordered_set<instruction_ref> visited;
        while(not q.empty())
        {
            instruction_ref ins = q.front();
            q.pop_front();
            if(not visited.insert(ins).second)
                continue;
            assert(ins->name() == "@return" or ins->name().front() != '@');
            shape new_r = compute_shape(ins->op, ins->arguments, ins->module_args);
            if(new_r != ins->result)
            {
                ins->result = new_r;
                for(auto out:ins->outputs())
                {
                    if(any_of(out->inputs(), [&](instruction_ref x) { return x != ins and reaches(ins, x); }))
                    {
                        revisit.push_back(out);
                    }
                    else
                    {
                        q.push_back(ins);
                    }
                }
            }
            if(q.empty())
            {
                q.insert(q.end(), revisit.begin(), revisit.end());
                revisit.clear();
            }
        }
    }
}

This would fix the simple case you presented but I am not sure it would handle more complicated cases.

pfultz2 commented 3 weeks ago

Actually, I think it might be much simpler if we just use the order in the instruction list as that should already be in order. So we could just use a priority_queue instead:

struct replace_shape_order
{
    instruction_ref start;

    std::size_t location(instruction_ref x) const
    {
        return std::distance(start, x);
    }

    bool operator()(instruction_ref x, instruction_ref y) const
    {
        return location(x) > location(y);
    }
};

void instruction::replace(const shape& r)
{
    if(r != result)
    {
        result = r;
        auto start = std::find_if(output.front()->inputs().begin(), output.front()->inputs().end(), [&](instruction_ref x) {
            return this == as_address(x);
        });
        assert(as_address(*start) == this);
        std::priority_queue<instruction_ref, std::vector<instruction_ref>, replace_shape_order> q(output, replace_shape_order{*start});
        while(not q.empty())
        {
            instruction_ref ins = q.top();
            q.pop();
            assert(ins->name() == "@return" or ins->name().front() != '@');
            shape new_r = compute_shape(ins->op, ins->arguments, ins->module_args);
            if(new_r != ins->result)
            {
                ins->result = new_r;
                std::copy(ins->output.begin(), ins->output.end(), push_inserter(q));
            }
        }
    }
}
tcgu-amd commented 2 weeks ago

Actually, I think it might be much simpler if we just use the order in the instruction list as that should already be in order. So we could just use a priority_queue instead:

struct replace_shape_order
{
    instruction_ref start;

    std::size_t location(instruction_ref x) const
    {
        return std::distance(start, x);
    }

    bool operator()(instruction_ref x, instruction_ref y) const
    {
        return location(x) > location(y);
    }
};

void instruction::replace(const shape& r)
{
    if(r != result)
    {
        result = r;
        auto start = std::find_if(output.front()->inputs().begin(), output.front()->inputs().end(), [&](instruction_ref x) {
            return this == as_address(x);
        });
        assert(as_address(*start) == this);
        std::priority_queue<instruction_ref, std::vector<instruction_ref>, replace_shape_order> q(output, replace_shape_order{*start});
        while(not q.empty())
        {
            instruction_ref ins = q.top();
            q.pop();
            assert(ins->name() == "@return" or ins->name().front() != '@');
            shape new_r = compute_shape(ins->op, ins->arguments, ins->module_args);
            if(new_r != ins->result)
            {
                ins->result = new_r;
                std::copy(ins->output.begin(), ins->output.end(), push_inserter(q));
            }
        }
    }
}

This makes sense! Much more elegant too! I will test it out.

tcgu-amd commented 2 weeks ago

@pfultz2 I pushed a commit with the new solution you proposed. Seems like it is working. Worth noting that there is no std::inserter for priority_queue so we can't use std::copy to insert the instruction outputs.

pfultz2 commented 2 weeks ago

Worth noting that there is no std::inserter for priority_queue so we can't use std::copy to insert the instruction outputs.

You could add one to the migraphx/output_iterator.hpp header:

template <class Container>
auto push_inserter(Container& c)
{
    return make_function_output_iterator([&](const auto& x) { c.push(x); });
}