Open A3shTnT opened 6 months ago
we only model it as bypassing L1.
Are you suggesting the strong
identifier does more than that?
yes. Assuming a certain address has a value in l1cache, for example, a strong st followed by a weak load, is this load taking the new value after l2cache or the old value of l1cache? Or a strong load followed by a weak load, can this weak load receive a new value for strong load instead of the old value for l1cache?
On the other hand, the barrier between the store and the memory seems to be incomplete. The memory only relies on the scoreboard to block unfinished loads, but it does not block unfinished stores. At the same time, the membar itself also has a scope. Do we need to consider the scope here.
Hope you can consider these questions, very thank you.
For the first point, yes. That is what skipping L1 implies.
The data is not cached in L1, so the load always goes to L2 for the data. No matter if strong followed by weak or weak followed by strong, the load always goes to L2.
I'm a little confused on your second point. It's true the stores are non-blocking. GPUs are not coherent (At least not Nvidia ones). Read after Write from different warps are not guaranteed.
Are you suggesting the .STONG
creates fences to block loads to the same address? If yes, can you point me to the related document so I can check that?
Thanks.
For the first point, my question is whether if there is already a copy of l1cache at an address, and then the strong load or store operates on this address, reading or writing the content of l2cache, in addition to that, will it also delete or refresh the existing copy of l1cache, only considering the instruction itself, and not including whether there is a CCTL.IVALL instruction after it. If there are any additional operations to be performed on l1cache, please kindly display this code section. If not, will future reads or writes actually result in reading this old copy of l1cache,how to prevent it?
For the second point, I may be a bit vague. What I mean is not that the read and write instructions themselves are blocked. But mainly asking about the relationship between the "MEMBAR" instruction and this strong load/store instruction. In the code, I see that the MEMBAR instruction blocks read operations based on the scoreboard, without any additional operations to block write operations, as well as different operations for different scopes. If there is this part of the code, perhaps I overlooked it. Please also show me where this section is located. Thank you very much for your reply.
the CACHE_GLOBAL
option bypasses the L1 completely. So it will not put data in the L1 at writeback.
.STRONG
suffix. Does hardware put data in L1 on strong loads?So to answer your question: If not, will future reads or writes actually result in reading this old copy of l1cache?
Whether old/new copy doesn't matter. It will be a hit. If you want a miss, do a L1 invalidate for the addr at writeback if CACHE_GLOBAL. Do it here .
I'm not sure what you mean by "MEMBAR instruction blocks read operations based on the scoreboard". At membar, no instruction can be issued until the membar is revolved. The warp will stall at the membar until all stores are finished. Here, https://github.com/JRPan/gpgpu-sim_distribution/blob/a0c12f5d63504c67c8bdfb1a6cc689b4ab7867a6/src/gpgpu-sim/shader.cc#L3931 the scheduler blocks everything on a membar. The warp will not progress at all. The membar is freed after all writes are finished. Is this what you are looking for?
for the question "Does hardware put data in L1 on strong loads?", I have done a simple test like this:
#include <cuda_runtime.h>
#include <iostream>
__global__ void kernel(unsigned* data) {
data[1] = data[0];
unsigned a = 3;
asm volatile("st.relaxed.gpu.global.b32 [%0], %1;" ::"l"(data), "r"(a)
: "memory");
a = data[0];
data[2] = a;
}
int main() {
unsigned h_data[3] = {1, 0, 0};
unsigned* d_data;
cudaMalloc(&d_data, sizeof(unsigned) * 3);
cudaMemcpy(d_data, h_data, sizeof(unsigned) * 3, cudaMemcpyHostToDevice);
kernel<<<1, 1>>>(d_data);
cudaDeviceSynchronize();
cudaMemcpy(h_data, d_data, sizeof(unsigned) * 3, cudaMemcpyDeviceToHost);
std::cout << h_data[2] << std::endl;
cudaFree(d_data);
return 0;
}
and its sass in sm75 is like this:
Fatbin elf code:
================
arch = sm_75
code version = [1,7]
host = linux
compile_size = 64bit
code for sm_75
Fatbin elf code:
================
arch = sm_75
code version = [1,7]
host = linux
compile_size = 64bit
code for sm_75
Function : _Z6kernelPj
.headerflags @"EF_CUDA_TEXMODE_UNIFIED EF_CUDA_64BIT_ADDRESS EF_CUDA_SM75 EF_CUDA_VIRTUAL_SM(EF_CUDA_SM75)"
/*0000*/ MOV R1, c[0x0][0x28] ; /* 0x00000a0000017a02 */
/* 0x000fc40000000f00 */
/*0010*/ MOV R2, 0x3 ; /* 0x0000000300027802 */
/* 0x000fe20000000f00 */
/*0020*/ ULDC.64 UR4, c[0x0][0x160] ; /* 0x0000580000047ab9 */
/* 0x000fe40000000a00 */
/*0030*/ LDG.E.SYS R3, [UR4] ; /* 0x00000004ff037981 */
/* 0x000eaa000c1ee900 */
/*0040*/ STG.E.STRONG.GPU [UR4], R2 ; /* 0x00000002ff007986 */
/* 0x000fe8000c114904 */
/*0050*/ LDG.E.SYS R0, [UR4] ; /* 0x00000004ff007981 */
/* 0x000ee8000c1ee900 */
/*0060*/ STG.E.SYS [UR4+0x4], R3 ; /* 0x00000403ff007986 */
/* 0x004fe8000c10e904 */
/*0070*/ STG.E.SYS [UR4+0x8], R0 ; /* 0x00000800ff007986 */
/* 0x008fe2000c10e904 */
/*0080*/ EXIT ; /* 0x000000000000794d */
/* 0x000fea0003800000 */
/*0090*/ BRA 0x90; /* 0xfffffff000007947 */
/* 0x000fc0000383ffff */
/*00a0*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*00b0*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*00c0*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*00d0*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*00e0*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*00f0*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
..........
Fatbin ptx code:
================
arch = sm_75
code version = [8,2]
host = linux
compile_size = 64bit
compressed
This program has no practical significance, just to test this issue. The 0x30~0x50 in sass are the main components. 0x30 caches the specified address in l1 through load, 0x40 executes strong store, and 0x50 executes a new load. By executing this program, you can see that the final result of 0x50 is the result of the 0x40 instruction store. Can this example illustrate that the instruction 0x40 not only bypasses l1, but also comes with an operation on l1cache? From the meaning of scope itself, its purpose is to allow threads within a specified range to SEE the value of a certain address. However, if these values are accessed again after executing this instruction, they are actually some old values, which seems incorrect. However, this is only my understanding. I did not see a detailed explanation in the NVIDIA documentation regarding whether strong should perform the l1cache operation.
for the MEMBAR question, the line 3931 in shader.cc,my understanding of this line of code is that it is checking scoreboard whether there are register values related to this instruction that have not been fully written. For the store instruction, it seems that this line of code can only wait for whether its two operand registers are available, that is, whether the address of the store operation and the value to be written are ready. But for the BEHAVIOR of store operation itself, it is unstoppable, that is, whether the store itself has been done to the lower level cache. In NVIDIA's documentation, the description of a member is as follows:" The member instruction guards that prior memory access requested by this thread (ld, st, atom and red instructions) are performed at the specified level, before later memory operations requested by this thread following the membar instruction. The level qualifier specifies the set of threads that may observe the ordering effect of this operation." How to understand this performed, and whether we only consider its operands prepared as performed. Alternatively, do we need to consider the mem_fetch object that passes information between caches, which is pushed to l1cache_queue to be considered performed, or this object is only considered performed when it returns from l1cache_queue? If we consider l2cache again, it may become even more complex. Are these considerations a bit redundant instead?
The code makes sense. You can verify l1 behavior by profiling your example. Could you please use nsight compute to check how many l1 misses are there? Thanks. I expect there are 2 misses. I think by default, Volta L1 is write-through, but since its STONG
it bupasses L1. So 0050
should be a miss. Or I'm wrong.
Also maybe you can try a ld.strong
example as well? That is probably more interesting. I want to see if ld.strong
puts data in L1.
I don't think we consider scope tho. Looks like it's per-warp For example:
ST addr, R0
MEMBAR
LD R0, addr
The store will issue, and then R0 will be saved in the scoreboard.
Then, the member will stall the warp. pendingWrites
checks if scoreboard is empty. Since R0 is in the scoreboard, pendingWrites
returns false.
Once ST is issued, it is unstoppable. But the scoreboard knows when the ST is finished.
When the mem_fetch of the ST arrives at L2, L2 sends an ack back to L1, indicating the ST is finished. L1 receives ACK, and then clears R0. Now pendingWrites
returns true because the scoreboard is empty. Then the warp can proceed and the LD now can be issued.
Does this makes sense?
the hit rate of the strong_store-load program is 50%, rather than 0. (according to nsight compute 2024.2.1) and the program of strong load is like this:
#include <cuda_runtime.h>
#include <iostream>
__global__ void kernel(unsigned* data, volatile unsigned* flag) {
data[1] = data[0];
unsigned a, b;
asm volatile("ld.relaxed.gpu.global.b32 %0, [%1];"
: "=r"(a)
: "l"(data)
: "memory");
b = data[0];
data[2] = a;
data[3] = b;
}
int main() {
unsigned h_data[4] = {1, 0, 0, 0};
unsigned *d_data, *d_flag;
cudaMalloc(&d_data, sizeof(unsigned) * 4);
cudaMalloc(&d_flag, sizeof(unsigned));
cudaMemcpy(d_data, h_data, sizeof(unsigned) * 4, cudaMemcpyHostToDevice);
kernel<<<1, 1>>>(d_data, d_flag);
cudaDeviceSynchronize();
cudaMemcpy(h_data, d_data, sizeof(unsigned) * 4, cudaMemcpyDeviceToHost);
std::cout << h_data[1] << " " << h_data[2] << " " << h_data[3] << std::endl;
cudaFree(d_data);
return 0;
}
and sass is like this:
Fatbin elf code:
================
arch = sm_75
code version = [1,7]
host = linux
compile_size = 64bit
code for sm_75
Fatbin elf code:
================
arch = sm_75
code version = [1,7]
host = linux
compile_size = 64bit
code for sm_75
Function : _Z6kernelPjPVj
.headerflags @"EF_CUDA_TEXMODE_UNIFIED EF_CUDA_64BIT_ADDRESS EF_CUDA_SM75 EF_CUDA_VIRTUAL_SM(EF_CUDA_SM75)"
/*0000*/ MOV R1, c[0x0][0x28] ; /* 0x00000a0000017a02 */
/* 0x000fc40000000f00 */
/*0010*/ ULDC.64 UR4, c[0x0][0x160] ; /* 0x0000580000047ab9 */
/* 0x000fe40000000a00 */
/*0020*/ LDG.E.SYS R0, [UR4] ; /* 0x00000004ff007981 */
/* 0x000ea8000c1ee900 */
/*0030*/ LDG.E.STRONG.GPU R2, [UR4] ; /* 0x00000004ff027981 */
/* 0x000ee8000c1f4900 */
/*0040*/ LDG.E.SYS R3, [UR4] ; /* 0x00000004ff037981 */
/* 0x000f28000c1ee900 */
/*0050*/ STG.E.SYS [UR4+0x4], R0 ; /* 0x00000400ff007986 */
/* 0x004fe8000c10e904 */
/*0060*/ STG.E.SYS [UR4+0x8], R2 ; /* 0x00000802ff007986 */
/* 0x008fe8000c10e904 */
/*0070*/ STG.E.SYS [UR4+0xc], R3 ; /* 0x00000c03ff007986 */
/* 0x010fe2000c10e904 */
/*0080*/ EXIT ; /* 0x000000000000794d */
/* 0x000fea0003800000 */
/*0090*/ BRA 0x90; /* 0xfffffff000007947 */
/* 0x000fc0000383ffff */
/*00a0*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*00b0*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*00c0*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*00d0*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*00e0*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
/*00f0*/ NOP; /* 0x0000000000007918 */
/* 0x000fc00000000000 */
..........
Fatbin ptx code:
================
arch = sm_75
code version = [8,2]
host = linux
compile_size = 64bit
compressed
The output of this program is not meaningful because it must be all 1. but the hit rate of load is 0. so the hit rates of these two programs may differ from what was expected, maybe you can also have a try to test the hit rate of the programs.
And for the MEMBAR instruction, thanks for your reply, I may have taken the code at line 3931 for granted, I will check this part of the code.
In addition, these issues are a small part of the consistency model problem. In fact, I have read the consistency model section in the PTX document, but it is very difficult for me to understand, mainly some mathematical model relationships and axiom explanations. What I want to know is how hardware can execute to ensure that these axioms are completed correctly, but it is very vague. With the emergence of NVIDIA's Hopper and Blackwell architectures, in fact, many instructions related to scope or consistency can be seen to be increasing. Hoping gpgpusim can consider the possibility of adding a relatively systematic model for this part.
why store dont need consider STRONG and GPU? "Isn't this related to the strategy of l1cache? For example, would the behavior of .strong.gpu be different in write-back and write-through ?"