Open kirk86 opened 2 years ago
Hey @kirk86, thanks for the detailed report!
Looking at the following losses BCELoss vs CrossEntropyLoss it seems that there's not clear indication of the type of data expected for input and target.
We'd accept a PR updating the docs to make it clearer what the input / target formats should be.
Why BCELoss doesn't accept target of type int and expects float?
There's no particular reason I can think of atm that we couldn't support an int target
- it just hasn't been done yet.
Also we see this discrepancy where BCELoss expects the target to be of type float instead CrossEntropyLoss expects it to be int.
For CrossEntropyLoss
, note that both float (class probabilities) and int (class indices) targets are supported (this is detailed within the docs for the module form nn.CrossEntropy
). For consistency, I'd agree it would be good for both to be supported for BCELoss
as well.
Another thing that really would allow more clarity is to have links from the doc directly to the source code of the forward and backward calls for each loss.
Note that the source code for forward / backward in most cases (this one included) is in C++. There's often separate kernels implemented for CPU and CUDA, and codegen creates the python bindings. So this isn't straightforward to do at the moment. @mruberry is currently working on readable reference implementations on the Python side, so at some point in the future, we should be able to link to those instead.
📚 The doc issue
Looking at the following losses BCELoss vs CrossEntropyLoss it seems that there's not clear indication of the type of data expected for
input
andtarget
.We see that for other variables it clearly indicates if its
int
,bool
,float
, etc, but not forinput
andtarget
.Also we see this discrepancy where
BCELoss
expects thetarget
to be of typefloat
insteadCrossEntropyLoss
expects it to beint
.Why
BCELoss
doesn't accepttarget
of typeint
and expectsfloat
?Another thing that really would allow more clarity is to have links from the doc directly to the source code of the forward and backward calls for each loss.
I know that you might say but there are these links, for instance like the following
SOURCE
:But following the
SOURCE
link brings us to the following definition:With the following dispatch code:
As we can see to find the actual implementation we'll have to go and find the definition of
torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
This prevents folks from directly retrieving relevant pieces of information on how different parts work directly from the docs. One would have to spend a would chunk of time searching through the github code base until they reach what they are looking for.
Suggest a potential alternative/fix
The types of data inputs should be clearly indicated for all arguments and there should be some consistency in terms of the types for the first two inputs
input
andtarget
.Improve navigation from docs to source code, linking directly each method to its actual forward/backward pass instead to some dispatch function.
cc @svekars @holly1238 @albanD @mruberry @jbschlosser @walterddr @kshitij12345