yandexdataschool / nlp_course

YSDA course in Natural Language Processing
https://lena-voita.github.io/nlp_course.html
MIT License
9.83k stars 2.61k forks source link

week_2 seminar.ipynb #103

Closed aulasau closed 1 month ago

aulasau commented 2 years ago

Problem with this seminar. Right now error is raised in cells with model.predict(make_batch(...)), because there is no predict method for nn.Module class. Please, add somewhere function like:

def predict(model, batch):
    return model(batch).detach().numpy()

and change model.predict(make_batch(...).detach().cpu() expression to the predict(model, make_batch(...))).

poedator commented 1 month ago

while this comment may be valid, it does not seem applicable to the current version of the seminar notebook. Feel free to submit a PR if still see a need for similar improvement.