lucidrains / PaLM-rlhf-pytorch

Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture. Basically ChatGPT but with PaLM
MIT License
7.67k stars 668 forks source link

Column and Row Parallel Linear for Apex Tensor Parallel #44

Closed conceptofmind closed 1 year ago

conceptofmind commented 1 year ago

Hi,

I was exploring using Tensor Parallel when training. I was wondering if you had any input on the correct use of RowParallelLinear when it comes to the feedforward out.

For example:

Column Parallel over q, k, v, and ff inner.

self.fused_attn_ff_proj = apex.transformer.tensor_parallel.ColumnParallelLinear(
  dim, 
  sum(self.fused_dims), 
  bias=False,
  gather_output=False,
  init_method=nn.init.xavier_uniform_
)

Row Parallel over attn out.

self.attn_out =  apex.transformer.tensor_parallel.RowParallelLinear(
  attn_inner_dim, 
  dim, 
  bias=False,
  input_is_parallel=True,
  init_method=nn.init.xavier_uniform_
)

I am not 100% sure whether this should be Row Parallel as well.

self.ff_out = nn.Sequential(
    SwiGLU(),
    apex.transformer.tensor_parallel.RowParallelLinear(
      ff_inner_dim, 
      dim, 
      bias=False,
      input_is_parallel=True,
      init_method=nn.init.xavier_uniform_
    )
)

Normally I would just do Column Parallel, SwiGLU, Row Parallel in a standard FeedForward but it is not super clear to me how to handle this case when it comes to fused attn ff and ff tail.

Any input would be greatly appreciated.

Thank you,

Enrico

conceptofmind commented 1 year ago

Decided to go with FSDP instead.