senwu / emmental

A deep learning framework for building multimodal multi-task learning systems.
https://emmental.readthedocs.io
MIT License
109 stars 18 forks source link

Add Action and Batch class to make emmental more modulized #116

Closed senwu closed 2 years ago

senwu commented 2 years ago

Description of the proposed changes

To make Emmental more extendable and easy to use for downstream tasks.

  1. We introduce two new classes: Action and Batch to make the APIs more modularized.
  1. We make the task_flow more flexible by supporting more formats for specifying inputs to each module.
    • It now supports str as inputs (e.g., inputs="input1") which means take the input1's output as input for current action.
    • It also supports a list as inputs which can be constructed by three different formats: a) x (x is str) where takes whole output of x's output as input: this enables users to pass all outputs from one module to another without having to manually specify every input to the module b) (x, y) (y is int) where takes x's y-th output as input c) (x, y) (y is str) where takes x's output str as input

Few emmental.EmmentalTaskFlowAction examples:

from emmental.Action as Act
Act(name="input", module="input_module0", inputs=[("_input_", "data")])
Act(name="input", module="input_module0", inputs=[("_input_", 0)])
Act(name="input", module="input_module0", inputs=["_input_"])
Act(name="input", module="input_module0", inputs="_input_")
Act(name="input", module="input_module0", inputs=[("_input_", "data"), ("_input_", 1), "_input_"])
Act(name="input", module="input_module0", inputs=None)

This design also can be applied to action_outputs, here are few example:

action_outputs=[(f"{task_name}_pred_head", 0), ("_input_", "data"), f"{task_name}_pred_head"]
action_outputs="_input_"

Test plan

Pass the existing tests.

Checklist

codecov[bot] commented 2 years ago

Codecov Report

Merging #116 (b3035f4) into master (6cad215) will not change coverage. The diff coverage is 100.00%.

Impacted file tree graph

@@           Coverage Diff           @@
##           master     #116   +/-   ##
=======================================
  Coverage   92.12%   92.12%           
=======================================
  Files          40       40           
  Lines        2018     2018           
  Branches      431      431           
=======================================
  Hits         1859     1859           
  Misses         94       94           
  Partials       65       65           
Flag Coverage Δ
unittests 92.12% <100.00%> (ø)

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
src/emmental/model.py 93.78% <ø> (ø)
src/emmental/__init__.py 100.00% <100.00%> (ø)
src/emmental/task.py 100.00% <100.00%> (ø)