yongwonshin / PIMFlow

Apache License 2.0
9 stars 5 forks source link

Difficulty testing out PIMFlow on custom network. #3

Open sbird-kick opened 1 year ago

sbird-kick commented 1 year ago

Hi, I am trying to run a custom network using your framework. In particular, I am using a modified version of resnet-18 (it's classifier layers have been modified to work with 10 classes instead of 1000). Furthermore, this network is expected to have an input shape of (3, 32, 32) instead of (3, 224, 224) as in the original resnet-18. Note that this removes the adaptive2d layer in the classified (as is present in the original torchvision model).

My first question is: What are the minimal modifications necessary for me to run this version of resnet-18 using your simulator? I understand some modifications are necessary. I will list the modifications I made below such that the code works with my version of resnet-18 (let's call it resnet-18-mine).

  1. I modified PIMFlow/pim/util.py. I defined my model here and modified the MODEL_LIST to include resnet-18-mine as well as modified _get_torchmodel to return my model when asked for resnet-18-mine.
  2. I did the same in _extractlayers.py.
  3. I did the same in _inspectshape.py.
  4. I modified _get_randominput to return (1, 3, 32, 32) when asked for resnet-18-mine.

Is this all I need to do? I am facing some difficulty which I will elaborate below:

I first ran this for mobilenet-v2 to determine everything was installed correctly and working. I proceeded to follow the same commands for resnet-18-mine.

The following commands worked: ./pimflow -m=profile -t=split -n=resnet-18-mine ./pimflow -m=profile -t=pipeline -n=resnet-18-mine

./pimflow -m=stat --conv_only -n=resnet-18-mine

./pimflow -m=solve -n=resnet-18-mine

However, running the following command (TVM_BACKTRACE=1 ./pimflow -m=run --gpu_only -n=resnet-18-mine) afterwards gives an error which I have put in a file:

output.txt

In particular, the error seems to be that it is expecting 25088 and not 512. I do not know why it would be expecting 25088 (however, it is interesting to note that 25088 is the output to the linear layer if the input to the network is (3, 224, 224) and not (3, 32, 32)). Do you have any idea why this could be happening?

Is there some other place where I have to modify the expected input shape and/or am I missing something?

Also, does this imply that only square strides work in your simulator?

yongwonshin commented 1 year ago

Hi,

I haven't tested a broad spectrum of CNN architectures, especially when input size changes. Could you share your model code in Python or an ONNX file? It would be good for me to inspect the problem first to locate where it arises.

For the last question, I assume input models have square stride (padding and dilation may have more strict restrictions). Some restrictions could be relieved, but I put assertion to prevent untested model testing which could possibly introduce silent performance bugs.

Thank you

sbird-kick commented 1 year ago

Ah, fair. I will get back to you after I try to run some non-square stride layer. In the mean-time, here is the onnx file for the model. I uploaded it to google drive.

If that is not acceptable, here is the model:

class MyModelSimple(pl.LightningModule): #THIS IS NOT RELEVANT, PLEASE IGNORE
    def __init__(self, num_classes=100, **kwargs):
        super(MyModelSimple, self).__init__()
        self.model = models.resnet18()
        self.model.avgpool = nn.Dropout(p=0.333) 
        self.model.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.model(x)
        return x

I defined and trained it on a separate machine using pytorch-lightning before exporting it as an onnx. Then I used onnx2pytorch to convert it to a torch model that is used by your simulator. Please let me know if I can be more helpful.

(This piece of code is the appropriate snippet, I also had some trainers etc defined but they should not matter).

sbird-kick commented 1 year ago

Hi, I tried it with a non-square stride convolutional layer as well. It did not seem to create any errors (other than the one above).

sbird-kick commented 1 year ago

Hi, I was able to recreate the issue above with another neural network which I designed manually rather than using torchvision models. The onnx file for this neural network is here.

I realize that the error is not the difference in size (that is only a warning).

The error seems to be that this line results in a list index out of range, if I follow the steps above. Why would this list index be out of range and how does one fix it?

sbird-kick commented 1 year ago

Apologies for the deluge of questions!

I can't seem to understand what the "get_kernel_start_and_end" function is supposed to do. It is supposed to upon a stats.csv file (I have attached mine). But what exactly is it supposed to output. Does my stats.csv file look okay? I have noticed that there is no "forward_kernel_cuda_start" in this which your code seems to expect. stats.csv

yongwonshin commented 1 year ago

I'm really sorry for the late reply. I'm currently behind my working schedule... I'm reproducing your bug today, and I hope I find the source of the problem or at least locate where the problem occurs.

It seems that your stats.csv is malformed. It must have forward_kernel_cuda_start and forward_kernel_cuda_end kernels. The purpose of get_kernel_start_and_end function is to parse the name of the above kernels to find where the actually (target CONV) GPU kernels start and end.

sbird-kick commented 1 year ago

No worries about the late reply! Thank you. I did manage to make it work with resnet-18 (that is slightly modified for cifar-10) using just those modifications listed above if I write a torch class in each of those three files. That is good news. I guess the issue is in the onnx file? Is there a recommended way you would want for the onnx file to the exported or imported for your simulator to use?

yongwonshin commented 1 year ago

I think your ONNX models are fine once it is constant-folded (e.g., batch normalization layer folded into weight) regardless of how they are created.

I'm not sure exactly what is the source of the previous errors, but I guess that NVBit may produce errors in the profiling phase, leading to missing some kernels. Have you checked split-<NET>.err or pipeline-<NET>.err?

sbird-kick commented 1 year ago

Ah, those don't exist for me. I was following the instructions given in the readme and not in run.example.sh. I don't think .err files exist in that case. (Also, you should change the --gpu-only command in the readme ./pimflow -m=run --gpu_only -n=mobilenet-v2 to include policy=None, it fails otherwise for me and I think that is expected as per run.example.sh). But yes, let me run it again and try to catch something in those .err files. I did skim everything in stdout and didn't find any errors afaik but I could have missed something. I will get back to you on this soon.

sbird-kick commented 1 year ago

It will take me a while to get back to you with those results since I am running the simulator on some large network already and don't want to stop it and lose progress. Apologies. In the mean-time, I had a few high level questions. The first is: 1) The NEWTON architecture does not allow for anything other than GEMM operations. How are you computing the results of something like sigmoid (or other activations)? They must be done on the GPU SMs right? How exactly is this overhead catered for in your simulator if ramulator and accelsim are fed traces separately (if that is correct). 2) This is a slightly superficial question but what GPU are you using in accelsim? I understand the architecture is 2000 series but which particular one (2080 ti or 2060 etc)? Is there a place you have defined these where they could be easily modified? And if so, would your simulator still work with these modifications?

Thanks!

yongwonshin commented 1 year ago

For the first question, you pointed out the weak point of our work. As we received the same question at the conference, I said that we haven't precisely modeled the data movement from the PIM device to the host. I calculated the bandwidth of the data transferred and concluded that it wasn't a serious problem as the original Newton paper argues that data movement of reduction and applying activation is effectively hidden during PIM computation. Actually, later versions of Newton architecture have native support of activation functions (native ReLU, LeakyReLU, and look-up-table-based others), so applying activation function is not a serious problem though it would lengthen the simulation time. But one caveat is that I haven't modeled the gather (or reduction) operation during CONV (or GEMV) for the same reason stated above. At that time I thought that it is a valid assumption, but it may be wrong or too aggressive.

For the second question, I used GeForce RTX 2080Ti GPU hardware for tracing (and simulation) but used the RTX 2060 config files provided/tested in the Accel-Sim configuration since they have the same architecture. You can modify the config files.

Thank you!

sbird-kick commented 1 year ago

Thank you for your reply, @yongwonshin. I have another question (thanks again for answering these).

The newton architecture requires an inherent structuring of the data for it to be used in the PIM. This adds overhead (for example, you must do tiling and some sort of patching together on the "host"). Is the tiling and the patching together also considered negligible (just like the activation functions) or is it done on the GPU cores?

sbird-kick commented 1 year ago

I would also like to discuss another thing with you. I was trying to plot the split for each layer. I took mobilenetv2 and realized that, in a couple of places, some odd things were happening which I could not explain. If you could help me with that, I will be very grateful. Here are my questions:

1) Is max_performance_mobilenet-v2_16_4.csv in the pipeline folder a valid source of information? 2) If it isn't, how would I find the split per layer and gpu/pim cycles per layer? 3) Is it normal/expected for some layers to have a non-zero/non-100 split but have only GPU cycles (PIM cycles = 0) or vice versa? 4) Is it normal/expected for some layers to be entirely mapped to GPU while others are entirely mapped to PIM? 5) How are the gpu cycles calculated? If they're using accel-sim, is that approximate or using the actual physical gpu on my device?

I have attached an image. The blue line is the split (or RATIO). The green bars are the GPU cycles per layer and red bars are the PIM cycles per layer.

I will appreciate an answer to these question and already appreciate the detailed answers you have given me.

image

yongwonshin commented 1 year ago

Thank you for your reply, @yongwonshin. I have another question (thanks again for answering these).

The newton architecture requires an inherent structuring of the data for it to be used in the PIM. This adds overhead (for example, you must do tiling and some sort of patching together on the "host"). Is the tiling and the patching together also considered negligible (just like the activation functions) or is it done on the GPU cores?

If I understand correctly, you asked about input data movement overhead. Input data movement latency is considered in our paper; it is related to GWRITE latency. GWRITE is part of Newton ISA, so the PIM units have the responsibility to move the data. For the large matrix-vector multiplication, it is negligible (amortized), but for the small matrix, it could take a substantial portion of the runtime.

For the next question in another thread, I have to look into the code, so I'll answer it by the end of the weekend.

I'm happy that my answer helps you. Thank you.

sbird-kick commented 1 year ago

Thank you! I am looking forwards to your further replies.

sbird-kick commented 1 year ago

Hi, I am asking another question (apologies again for so many, your kindness in answering them is very valuable!). Could you point me in the direction of where the matmul layers are split between PIM and GPU? The transform.py apply function here does it for conv layers because your satisfy function works for conv only. I want to investigate PIMFlow's effect on transformers like BERT as you have done in your paper :)

yongwonshin commented 1 year ago

Hi, I am asking another question (apologies again for so many, your kindness in answering them is very valuable!). Could you point me in the direction of where the matmul layers are split between PIM and GPU? The transform.py apply function here does it for conv layers because your satisfy function works for conv only. I want to investigate PIMFlow's effect on transformers like BERT as you have done in your paper :)

You're right. I only do split for Conv. For MatMul/Gemm, I set the batch size manually from the script (batch_size = math.ceil(int(config[5]) * args.split_ratio / 100). Our primary focus is CNN, so we don't support BERT in the code since large matrix-matrix multiplication is more suitable for accelerators.

I would also like to discuss another thing with you. I was trying to plot the split for each layer. I took mobilenetv2 and realized that, in a couple of places, some odd things were happening which I could not explain. If you could help me with that, I will be very grateful. Here are my questions:

  1. Is max_performance_mobilenet-v2_16_4.csv in the pipeline folder a valid source of information?
  2. If it isn't, how would I find the split per layer and gpu/pim cycles per layer?
  3. Is it normal/expected for some layers to have a non-zero/non-100 split but have only GPU cycles (PIM cycles = 0) or vice versa?
  4. Is it normal/expected for some layers to be entirely mapped to GPU while others are entirely mapped to PIM?
  5. How are the gpu cycles calculated? If they're using accel-sim, is that approximate or using the actual physical gpu on my device?

I have attached an image. The blue line is the split (or RATIO). The green bars are the GPU cycles per layer and red bars are the PIM cycles per layer.

I will appreciate an answer to these question and already appreciate the detailed answers you have given me.

image

2: Performance for each layer for all split ratios is saved as CSV files: {model}_{TERM}{i}_{args.n_channel}_{args.n_gwrite}{postfix}.csv. process_csv.py gathers each record and calculates optimal performance. 1: It's a valid source. to_full_layer.py calculates maximum performance. But its performance (cycle) information is "mixed", which means you should assemble the records of the other row to get the correct pipelined performance. Assembling is done at solving phase. CSV and script dependency is quite (maybe unnecessarily) complicated. I should have cleaned them up for usability and readability... 3: If that happens during the split/pipeline profiling phase, it could be a bug. But I think that happens during "stat", which assembles performance records. If you think that's a bug, then profiling and statistics results (.csv) may be helpful. 4: Yes. That's expected behavior. 5: I used Accel-Sim cycles using GPU traces obtained from the actual GPU (GeForce RTX 2080 Ti). But I used tested RTX 2060 configs in the Accel-Sim. If you use a GPU that has Turing architecture, then you could get similar performance trends. Simulation results could have some discrepancies with the actual GPU, but we used the simulator to investigate the performance impact of the number of memory channels.

sbird-kick commented 1 year ago

Hi, I am slightly confused by this statement "Our primary focus is CNN, so we don't support BERT in the code since large matrix-matrix multiplication is more suitable for accelerators." I would imagine that FC layers would be suited for both PIM and GPU necessitating the use of data splitting?

Thank you for your answers, again.

yongwonshin commented 1 year ago

We support the FC layer with batch size 1, but there's no splitting. But for BERT, (sequence length x batch size) corresponds to the batch size of FC, leading to matrix-matrix multiplication. For matrix-matrix multiplication (A x B), we split the B matrix in column dimension to make two matrix multiplications {AxB1, AxB2} without reduction. (in case of splitting row dimension, we need reduction: A x B1 + A x B2.

Thank you!

sbird-kick commented 10 months ago

Hi, I don't understand why you are multiplying the speed by 32 when latency hiding is enabled here.

Could you give me some insight? Thank you for your previous (and future) replies.

yongwonshin commented 7 months ago

Hi, While GWRITE latency is not specified in the original paper, we assume GWRITE needs 32-column access for input data fetching. Thank you.