senwu / emmental

A deep learning framework for building multimodal multi-task learning systems.
https://emmental.readthedocs.io
MIT License
109 stars 18 forks source link

add Action and Batch class to make emmental more modulized #114

Closed senwu closed 2 years ago

senwu commented 3 years ago

Description of the proposed changes

To make Emmental more extendable and easy to use for downstream tasks.

  1. We introduce two new classes: Action and Batch to make the APIs more modularized.
  1. We make the task_flow more flexible by supporting more formats for specifying inputs to each module.
    • It now supports str as inputs (e.g., inputs="input1") which means take the input1's output as input for current action.
    • It also supports a list as inputs which can be constructed by three different formats: a) x (x is str) where takes whole output of x's output as input: this enables users to pass all outputs from one module to another without having to manually specify every input to the module b) (x, y) (y is int) where takes x's y-th output as input c) (x, y) (y is str) where takes x's output str as input

Few emmental.EmmentalTaskFlowAction examples:

from emmental.Action as Act
Act(name="input", module="input_module0", inputs=[("_input_", "data")])
Act(name="input", module="input_module0", inputs=[("_input_", 0)])
Act(name="input", module="input_module0", inputs=["_input_"])
Act(name="input", module="input_module0", inputs="_input_")
Act(name="input", module="input_module0", inputs=[("_input_", "data"), ("_input_", 1), "_input_"])
Act(name="input", module="input_module0", inputs=None)

This design also can be applied to action_outputs, here are few example:

action_outputs=[(f"{task_name}_pred_head", 0), ("_input_", "data"), f"{task_name}_pred_head"]
action_outputs="_input_"

Test plan

Pass the existing tests.

Checklist

codecov[bot] commented 3 years ago

Codecov Report

Merging #114 (651f4eb) into master (03c80bc) will increase coverage by 0.91%. The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #114      +/-   ##
==========================================
+ Coverage   91.21%   92.12%   +0.91%     
==========================================
  Files          40       40              
  Lines        1991     2018      +27     
  Branches      425      431       +6     
==========================================
+ Hits         1816     1859      +43     
+ Misses        101       94       -7     
+ Partials       74       65       -9     
Flag Coverage Δ
unittests 92.12% <100.00%> (+0.91%) :arrow_up:

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
src/emmental/__init__.py 100.00% <100.00%> (ø)
src/emmental/learner.py 78.32% <100.00%> (ø)
src/emmental/model.py 93.78% <100.00%> (+5.43%) :arrow_up:
src/emmental/schedulers/mixed_scheduler.py 100.00% <100.00%> (ø)
src/emmental/schedulers/round_robin_scheduler.py 100.00% <100.00%> (ø)
src/emmental/schedulers/scheduler.py 100.00% <100.00%> (ø)
src/emmental/schedulers/sequential_scheduler.py 100.00% <100.00%> (ø)
src/emmental/task.py 100.00% <100.00%> (ø)