Closed borgesa closed 5 years ago
for loss.py
*kwargs
need to be replaced with **kwargs
super(_Loss, self)
need to be changed to super(CustomLossClass, self)
(output, target)
over (input, target)
since input
is a reserved word in python. That way would also match my implementations of metric functions too.for other files
It seems that you want to pass arguments to the pytorch loss class, than train.py
and config.json
should be changed too.
"loss": "NLLLoss",
"loss_args": {
reduction='elementwise_mean'
},
for config, and then
loss = get_loss_function(config['loss'], **config['loss_args'])
for train.py (line23)
@SunQpark,
Thank you for your comments. I agree with them all, and have updated the pull request.
With regards to number 3, I used the convention that PyTorch use in the "loss" module (although I agree with you: I do not like using reserved names). Since "output, target" is already used in the repository, I followed your proposal.
Thank you for quick response!
It seems that there are some discussions about using the input
as variable name. I was surprised that this comes first in the google search, although the keyword was 'pytorch convention' which does not include 'input'
Anyway, I will merge this PR now, thank you.
With reference to issue #21: I have done the update I described in the issue.
Now the loss functionality can use all PyTorch loss functions (by specifying the name in config.json). A template of a custom loss function class is included.
Remaining: