AllenNeuralDynamics / aind-behavior-curriculum

Starter repository for behavior base primitives.
https://aind-behavior-curriculum.readthedocs.io
MIT License
1 stars 0 forks source link

Suggested modifications for Task model #9

Closed bruno-f-cruz closed 7 months ago

bruno-f-cruz commented 8 months ago

This PR adds a few suggestions to the Task model. The most important one is the ability to "overload" the task_parameters property with a user-created model. This strategy allows the base class to remain with the extra="forbid" while allowing Child -> Base -> Child deserialization. An example:

from aind_behavior_curriculum.behavior import (
    Task,
    ModifiableAttr,
    GenericModel,
)
from pydantic import Field
from typing import Literal, ClassVar, get_args

class Foo(GenericModel):
    param1: int = Field(default=1, description="This is a property")
    prop1: int = Field(default=1, description="This is a property")
    prop2: str = ModifiableAttr(
        default="a", description="This is another property"
    )

_version = Literal["0.1.0"]

class MyTask(Task):
    version: _version = get_args(_version)[0]
    task_parameters: Foo = Field(
        default=Foo(), description="Task parameters.", validate_default=True
    )

# Notice the modifiable tag when modeling as a json-schema
# The tag can be leveraged by the specific application to allow/disallow modification
# while remaining "on-curriculum"
print(MyTask.model_json_schema())

instance = MyTask(name="Task", task_parameters=Foo(prop1=2))
print(instance)
# name='Task' description='' version='0.1.0' task_parameters=Foo(param1=1, prop1=2, prop2='a')

instance_json = instance.model_dump_json()
print(instance_json)
# {"name":"Task","description":"","version":"0.1.0","task_parameters":{"param1":1,"prop1":2,"prop2":"a"}}

instance_parent = Task.model_validate_json(instance_json)
print(instance_parent)
# name='Task' description='' version='0.1.0' task_parameters=GenericModel(param1=1, prop1=2, prop2='a')

instance_prime = MyTask.model_validate_json(instance_parent.model_dump_json())
print(instance_prime == instance)
# True