diff --git a/third_party/xla/xla/service/gpu/buffer_sharing.cc b/third_party/xla/xla/service/gpu/buffer_sharing.cc
index 2c3920a32c8..64421596dcb 100644
--- a/third_party/xla/xla/service/gpu/buffer_sharing.cc
+++ b/third_party/xla/xla/service/gpu/buffer_sharing.cc
@@ -23 +22,0 @@ limitations under the License.
-#include "absl/algorithm/container.h"
@@ -25 +23,0 @@ limitations under the License.
-#include "absl/log/check.h"
@@ -98,5 +96,2 @@ std::optional<bool> FusionCanShareBufferHint(const HloInstruction* user,
- // We don't support nested tuples on GPU.
- CHECK_LT(user_index.size(), 2);
- if (output->opcode() == HloOpcode::kTuple) {
- CHECK(!user_index.empty());
- output = output->mutable_operand(user_index[0]);
+ for (int64_t o : user_index) {
+ output = output->mutable_operand(o);
@@ -117,3 +111,0 @@ std::optional<bool> FusionCanShareBufferHint(const HloInstruction* user,
- if (hlo_operand->IsRoot()) {
- ++reached_root;
- }
@@ -137,9 +129,5 @@ std::optional<bool> FusionCanShareBufferHint(const HloInstruction* user,
- // For scatter, we can share the buffer if the path goes through one of
- // the scatter inputs.
- if (hlo == non_bitcast_root && hlo->opcode() == HloOpcode::kScatter) {
- int64_t num_scatter_inputs =
- hlo->shape().IsTuple() ? hlo->shape().tuple_shapes_size() : 1;
- if (hlo->operand_index(hlo_operand) < num_scatter_inputs &&
- absl::c_count(hlo->operands(), hlo_operand) == 1) {
- continue;
- }
+ // For scatter, we can share the buffer if the path goes through the first
+ // operand.
+ if (hlo == non_bitcast_root && hlo->opcode() == HloOpcode::kScatter &&
+ hlo->operand_index(hlo_operand) == 0) {
+ continue;
@@ -188,0 +177,3 @@ std::optional<bool> FusionCanShareBufferHint(const HloInstruction* user,
+ if (hlo->IsRoot()) {
+ ++reached_root;
+ }
@@ -191 +182 @@ std::optional<bool> FusionCanShareBufferHint(const HloInstruction* user,
- return found_path_to_output && reached_root == 1;
+ return found_path_to_output && (user_index.empty() || reached_root == 1);
crea una funzione capace di costruire un dizionario, la cui chiave è il file path e il valore è una lista di numeri estratti dall'output del diff, per capire il numero da estrarre considera questi esempi:
1) @@ -23 +22,0 @@ limitations under the License.
-#include "absl/algorithm/container.h" In questo caso il numero da estrarre è il 23 dopo @@ -
2) @@ -98,5 +96,2 @@ std::optional FusionCanShareBufferHint(const HloInstruction* user,
in questo caso è 98 subito dopo @@ -, ma attento, quando c'è la virgola, come in questo caso (98,5), significa che devi estrarre 5 numeri a partire dal 98, ergo 98, 99, 100, 101, 102
Answer:
Ecco una possibile implementazione in Python di una funzione che costruisce un dizionario con il file path come chiave e una lista di numeri estratti come valore:
import re
def extract_numbers_from_diff(diff_output):
file_path_pattern = re.compile(r'^\+\+\+ b/(.*)$')
line_number_pattern = re.compile(r'^@@ -(\d+)(,(\d+))? \+(\d+)(,(\d+))? @@')
file_path = None
numbers_list = []
diff_lines = diff_output.split('\n')
result_dict = {}
for line in diff_lines:
file_path_match = file_path_pattern.match(line)
line_number_match = line_number_pattern.match(line)
if file_path_match:
file_path = file_path_match.group(1)
elif line_number_match:
start_line = int(line_number_match.group(1))
num_lines = 1 if line_number_match.group(3) is None else int(line_number_match.group(3))
numbers_list.extend(range(start_line, start_line + num_lines))
if file_path and numbers_list:
result_dict[file_path] = numbers_list
return result_dict
Question: dato un output del diff:
crea una funzione capace di costruire un dizionario, la cui chiave è il file path e il valore è una lista di numeri estratti dall'output del diff, per capire il numero da estrarre considera questi esempi: 1) @@ -23 +22,0 @@ limitations under the License. -#include "absl/algorithm/container.h" In questo caso il numero da estrarre è il 23 dopo @@ - 2) @@ -98,5 +96,2 @@ std::optional FusionCanShareBufferHint(const HloInstruction* user,
in questo caso è 98 subito dopo @@ -, ma attento, quando c'è la virgola, come in questo caso (98,5), significa che devi estrarre 5 numeri a partire dal 98, ergo 98, 99, 100, 101, 102
Answer: Ecco una possibile implementazione in Python di una funzione che costruisce un dizionario con il file path come chiave e una lista di numeri estratti come valore: