Open IvanYashchuk opened 2 weeks ago
Thanks @IvanYashchuk for writing down what we spoke yesterday and writing the repro snippet. I'll get a look on it since I am working on similar area
Using thunder.jit(f, nv_enable_linear=True, nv_enable_matmul=True)
helps Thunder save less tensors because it simplifies the task for rematerialization leaving only one producer -> one consumer pair. With this change rematerialization decides to save only two intermediates that are outputs of linear:
len(jout.grad_fn.saved_tensors)=5
Saved tensors size Thunder: 580.00 MiB
def computation(x, fc_weight, proj_weight):
# x: "cuda:0 f16[1, 16384, 2560]"
# fc_weight: "cuda:0 f16[10240, 2560]"
# proj_weight: "cuda:0 f16[2560, 10240]"
[t10, t14, t21] = nvFusion0(x, fc_weight, proj_weight)
# t22 = prims.convert_element_type(x, dtypes.float32) # t22: "cuda:0 f32[1, 16384, 2560]"
# t23 = prims.exp(t22) # t23: "cuda:0 f32[1, 16384, 2560]"
# t27 = prims.add(t23, t22) # t27: "cuda:0 f32[1, 16384, 2560]"
# t30 = prims.exp(t27) # t30: "cuda:0 f32[1, 16384, 2560]"
# nx = prims.convert_element_type(t30, dtypes.float16) # nx: "cuda:0 f16[1, 16384, 2560]"
# t10 = prims.linear(nx, fc_weight, None) # t10: "cuda:0 f16[1, 16384, 10240]"
# t33 = prims.convert_element_type(t10, dtypes.float32) # t33: "cuda:0 f32[1, 16384, 10240]"
# t34 = prims.exp(t33) # t34: "cuda:0 f32[1, 16384, 10240]"
# t13 = prims.convert_element_type(t34, dtypes.float16) # t13: "cuda:0 f16[1, 16384, 10240]"
# t14 = prims.linear(t13, proj_weight, None) # t14: "cuda:0 f16[1, 16384, 2560]"
# t37 = prims.convert_element_type(t14, dtypes.float32) # t37: "cuda:0 f32[1, 16384, 2560]"
# t39 = prims.add(t37, t27) # t39: "cuda:0 f32[1, 16384, 2560]"
# t42 = prims.exp(t39) # t42: "cuda:0 f32[1, 16384, 2560]"
# t21 = prims.convert_element_type(t42, dtypes.float16) # t21: "cuda:0 f16[1, 16384, 2560]"
return {'output': t21, 'flat_args': [x, fc_weight, proj_weight], 'flat_output': (t21,)}, ((fc_weight, proj_weight, t10, t14, x), ())
def backward_fn(saved_for_backward, cotangents):
# saved_for_backward: "Collection"
# cotangents: "Collection"
C0, _, = saved_for_backward
clear_mutable_collection(saved_for_backward)
del saved_for_backward
t44, = cotangents
clear_mutable_collection(cotangents)
del cotangents
fc_weight, proj_weight, t10, t14, x, = C0
clear_mutable_collection(C0)
del C0
[t62, t72, t89] = nvFusion0(x, t14, t44, proj_weight, t10, fc_weight)
# t22 = prims.convert_element_type(x, dtypes.float32) # t22: "cuda:0 f32[1, 16384, 2560]"
# t23 = prims.exp(t22) # t23: "cuda:0 f32[1, 16384, 2560]"
# t27 = prims.add(t23, t22) # t27: "cuda:0 f32[1, 16384, 2560]"
# t37 = prims.convert_element_type(t14, dtypes.float32) # t37: "cuda:0 f32[1, 16384, 2560]"
# t39 = prims.add(t37, t27) # t39: "cuda:0 f32[1, 16384, 2560]"
# t42 = prims.exp(t39) # t42: "cuda:0 f32[1, 16384, 2560]"
# t50 = prims.convert_element_type(t44, dtypes.float32) # t50: "cuda:0 f32[1, 16384, 2560]"
# t51 = prims.mul(t50, t42) # t51: "cuda:0 f32[1, 16384, 2560]"
# t52 = prims.convert_element_type(t51, dtypes.float16) # t52: "cuda:0 f16[1, 16384, 2560]"
# t56 = prims.reshape(t52, (16384, 2560)) # t56: "cuda:0 f16[16384, 2560]"
# t57 = prims.matmul(t56, proj_weight) # t57: "cuda:0 f16[16384, 10240]"
# t33 = prims.convert_element_type(t10, dtypes.float32) # t33: "cuda:0 f32[1, 16384, 10240]"
# t58 = prims.reshape(t57, (1, 16384, 10240)) # t58: "cuda:0 f16[1, 16384, 10240]"
# t34 = prims.exp(t33) # t34: "cuda:0 f32[1, 16384, 10240]"
# t63 = prims.convert_element_type(t58, dtypes.float32) # t63: "cuda:0 f32[1, 16384, 10240]"
# t64 = prims.mul(t63, t34) # t64: "cuda:0 f32[1, 16384, 10240]"
# t65 = prims.convert_element_type(t64, dtypes.float16) # t65: "cuda:0 f16[1, 16384, 10240]"
# t66 = prims.reshape(t65, (16384, 10240)) # t66: "cuda:0 f16[16384, 10240]"
# t67 = prims.matmul(t66, fc_weight) # t67: "cuda:0 f16[16384, 2560]"
# t68 = prims.reshape(t67, (1, 16384, 2560)) # t68: "cuda:0 f16[1, 16384, 2560]"
# t30 = prims.exp(t27) # t30: "cuda:0 f32[1, 16384, 2560]"
# t73 = prims.convert_element_type(t68, dtypes.float32) # t73: "cuda:0 f32[1, 16384, 2560]"
# t74 = prims.mul(t73, t30) # t74: "cuda:0 f32[1, 16384, 2560]"
# t78 = prims.add(t51, t74) # t78: "cuda:0 f32[1, 16384, 2560]"
# t84 = prims.mul(t78, t23) # t84: "cuda:0 f32[1, 16384, 2560]"
# nx = prims.convert_element_type(t30, dtypes.float16) # nx: "cuda:0 f16[1, 16384, 2560]"
# t13 = prims.convert_element_type(t34, dtypes.float16) # t13: "cuda:0 f16[1, 16384, 10240]"
# t88 = prims.add(t78, t84) # t88: "cuda:0 f32[1, 16384, 2560]"
# t71 = prims.reshape(nx, (16384, 2560)) # t71: "cuda:0 f16[16384, 2560]"
# t70 = prims.transpose(t66, (1, 0)) # t70: "cuda:0 f16[10240, 16384]"
# t61 = prims.reshape(t13, (16384, 10240)) # t61: "cuda:0 f16[16384, 10240]"
# t60 = prims.transpose(t56, (1, 0)) # t60: "cuda:0 f16[2560, 16384]"
# t89 = prims.convert_element_type(t88, dtypes.float16) # t89: "cuda:0 f16[1, 16384, 2560]"
# t72 = prims.matmul(t70, t71) # t72: "cuda:0 f16[10240, 2560]"
# t62 = prims.matmul(t60, t61) # t62: "cuda:0 f16[2560, 10240]"
del x, t14, t44, proj_weight, t10, fc_weight
return (t89, t72, t62)
🐛 Bug
To Reproduce
Here's a problematic pattern where Thunder's rematerialization algorithm is not effective:
Script output:
Current
thunder.examine.make_trace_dot
doesn't print out all output variables so it's difficult to read the graph, but, anyway, here is the joint forward and backward graph before rematerialization is applied:Created with a breakpoint before this line: https://github.com/Lightning-AI/lightning-thunder/blob/9c916d9df73f3920b51e5951303a76b25ab2d4d4/thunder/core/rematerialization.py#L609
Resolving this bug would help us resolve https://github.com/Lightning-AI/lightning-thunder/issues/246.
cc @riccardofelluga