caelum02 / Lux-AI-Season-2

Lux AI Season 2 - NeurIPS Stage | Team Martian
1 stars 0 forks source link

Model Architecture #2

Open gyusang opened 11 months ago

gyusang commented 11 months ago

Inputs

Full game state

state._fields

('env_cfg',
'seed',
'rng_state',
'env_steps',
'board',
'units',
'unit_id2idx',
'n_units',
'factories',
'factory_id2idx',
'n_factories',
'teams',
'global_id',
'place_first')

각 field

EnvConfig(
    max_episode_length=ShapedArray(int32[], weak_type=True),
    map_size=ShapedArray(int32[], weak_type=True),
    verbose=ShapedArray(int32[], weak_type=True),
    validate_action_space=ShapedArray(bool[]),
    max_transfer_amount=ShapedArray(int32[], weak_type=True),
    MIN_FACTORIES=ShapedArray(int32[], weak_type=True),
    MAX_FACTORIES=ShapedArray(int32[], weak_type=True),
    CYCLE_LENGTH=ShapedArray(int32[], weak_type=True),
    DAY_LENGTH=ShapedArray(int32[], weak_type=True),
    UNIT_ACTION_QUEUE_SIZE=ShapedArray(int32[], weak_type=True),
    MAX_RUBBLE=ShapedArray(int32[], weak_type=True),
    FACTORY_RUBBLE_AFTER_DESTRUCTION=ShapedArray(int32[], weak_type=True),
    INIT_WATER_METAL_PER_FACTORY=ShapedArray(int32[], weak_type=True),
    INIT_POWER_PER_FACTORY=ShapedArray(int32[], weak_type=True),
    MIN_LICHEN_TO_SPREAD=ShapedArray(int32[], weak_type=True),
    LICHEN_LOST_WITHOUT_WATER=ShapedArray(int32[], weak_type=True),
    LICHEN_GAINED_WITH_WATER=ShapedArray(int32[], weak_type=True),
    MAX_LICHEN_PER_TILE=ShapedArray(int32[], weak_type=True),
    POWER_PER_CONNECTED_LICHEN_TILE=ShapedArray(int32[], weak_type=True),
    LICHEN_WATERING_COST_FACTOR=ShapedArray(int32[], weak_type=True),
    BIDDING_SYSTEM=ShapedArray(bool[]),
    FACTORY_PROCESSING_RATE_WATER=ShapedArray(int32[], weak_type=True),
    ICE_WATER_RATIO=ShapedArray(int32[], weak_type=True),
    FACTORY_PROCESSING_RATE_METAL=ShapedArray(int32[], weak_type=True),
    ORE_METAL_RATIO=ShapedArray(int32[], weak_type=True),
    FACTORY_CHARGE=ShapedArray(int32[], weak_type=True),
    FACTORY_WATER_CONSUMPTION=ShapedArray(int32[], weak_type=True),
    POWER_LOSS_FACTOR=ShapedArray(float32[], weak_type=True),
    ROBOTS=(
        UnitConfig(
            METAL_COST=ShapedArray(int32[], weak_type=True),
            POWER_COST=ShapedArray(int32[], weak_type=True),
            CARGO_SPACE=ShapedArray(int32[], weak_type=True),
            BATTERY_CAPACITY=ShapedArray(int32[], weak_type=True),
            CHARGE=ShapedArray(int32[], weak_type=True),
            INIT_POWER=ShapedArray(int32[], weak_type=True),
            MOVE_COST=ShapedArray(int32[], weak_type=True),
            RUBBLE_MOVEMENT_COST=ShapedArray(float32[], weak_type=True),
            DIG_COST=ShapedArray(int32[], weak_type=True),
            DIG_RUBBLE_REMOVED=ShapedArray(int32[], weak_type=True),
            DIG_RESOURCE_GAIN=ShapedArray(int32[], weak_type=True),
            DIG_LICHEN_REMOVED=ShapedArray(int32[], weak_type=True),
            SELF_DESTRUCT_COST=ShapedArray(int32[], weak_type=True),
            RUBBLE_AFTER_DESTRUCTION=ShapedArray(int32[], weak_type=True),
            ACTION_QUEUE_POWER_COST=ShapedArray(int32[], weak_type=True)
        ),
        UnitConfig(
            METAL_COST=ShapedArray(int32[], weak_type=True),
            POWER_COST=ShapedArray(int32[], weak_type=True),
            CARGO_SPACE=ShapedArray(int32[], weak_type=True),
            BATTERY_CAPACITY=ShapedArray(int32[], weak_type=True),
            CHARGE=ShapedArray(int32[], weak_type=True),
            INIT_POWER=ShapedArray(int32[], weak_type=True),
            MOVE_COST=ShapedArray(int32[], weak_type=True),
            RUBBLE_MOVEMENT_COST=ShapedArray(float32[], weak_type=True),
            DIG_COST=ShapedArray(int32[], weak_type=True),
            DIG_RUBBLE_REMOVED=ShapedArray(int32[], weak_type=True),
            DIG_RESOURCE_GAIN=ShapedArray(int32[], weak_type=True),
            DIG_LICHEN_REMOVED=ShapedArray(int32[], weak_type=True),
            SELF_DESTRUCT_COST=ShapedArray(int32[], weak_type=True),
            RUBBLE_AFTER_DESTRUCTION=ShapedArray(int32[], weak_type=True),
            ACTION_QUEUE_POWER_COST=ShapedArray(int32[], weak_type=True)
        )
    )
)

대부분 같은 세팅으로 플레이할 것이므로 딱히 필요 없음. env_cfg.map_size 정도?

seed: 
ShapedArray(uint32[])
rng_state: 
ShapedArray(uint32[2])

굳이?

env_steps: 
ShapedArray(int16[])

논문에서처럼 day/night cycle 고려해서 one-hot으로 encoding?

board: 
Board(
    seed=ShapedArray(uint32[]),
    factories_per_team=ShapedArray(int8[]),
    map=GameMap(
        rubble=ShapedArray(int8[64,64]),
        ice=ShapedArray(bool[64,64]),
        ore=ShapedArray(bool[64,64]),
        symmetry=ShapedArray(int8[])
    ),
    lichen=ShapedArray(int32[64,64]),
    lichen_strains=ShapedArray(int8[64,64]),
    units_map=ShapedArray(int16[64,64]),
    factory_map=ShapedArray(int8[64,64]),
    factory_occupancy_map=ShapedArray(int8[64,64]),
    factory_pos=ShapedArray(int8[22,2])
)

ice, ore를 합쳐서 0 1 2 one-hot? symmetry는 뭘까

units: 
Unit(
    unit_type=ShapedArray(int8[2,200]),
    action_queue=ActionQueue(
        data=UnitAction(
            action_type=ShapedArray(int8[2,200,20]),
            direction=ShapedArray(int8[2,200,20]),
            resource_type=ShapedArray(int8[2,200,20]),
            amount=ShapedArray(int16[2,200,20]),
            repeat=ShapedArray(int16[2,200,20]),
            n=ShapedArray(int16[2,200,20])
        ),
        front=ShapedArray(int8[2,200]),
        rear=ShapedArray(int8[2,200]),
        count=ShapedArray(int8[2,200])
    ),
    team_id=ShapedArray(int8[2,200]),
    unit_id=ShapedArray(int16[2,200]),
    pos=Position(pos=ShapedArray(int8[2,200,2])),
    cargo=UnitCargo(stock=ShapedArray(int32[2,200,4])),
    power=ShapedArray(int32[2,200])
)
unit_id2idx: 
ShapedArray(int16[2000,2])
n_units: 
ShapedArray(int16[2])

이거를 그 픽셀에 추가로 넣어줘야되나

factories: 
Factory(
    team_id=ShapedArray(int8[2,11]),
    unit_id=ShapedArray(int8[2,11]),
    pos=Position(pos=ShapedArray(int8[2,11,2])),
    power=ShapedArray(int32[2,11]),
    cargo=UnitCargo(stock=ShapedArray(int32[2,11,4]))
)
factory_id2idx: 
ShapedArray(int8[22,2])
n_factories: 
ShapedArray(int8[2])

요것도

teams: 
Team(
    team_id=ShapedArray(int8[2]),
    faction=ShapedArray(int8[2]),
    init_water=ShapedArray(int32[2]),
    init_metal=ShapedArray(int32[2]),
    factories_to_place=ShapedArray(int32[2]),
    factory_strains=ShapedArray(int8[2,11]),
    n_factory=ShapedArray(int8[2]),
    bid=ShapedArray(int32[2])
)

요건 bidding이랑 factory placing 때?

global_id: 
ShapedArray(int16[])
place_first: 
ShapedArray(int8[])

Outputs

Pipeline

caelum02 commented 11 months ago

env step

caelum02 commented 11 months ago

Unit

Robots

Feature Description Type Normalization
[Ally/Enemy] [Light/Heavy] Robot 존재 여부 Bool[2][2]
[Ally/Enemy] [Ice/Ore/Water/Metal/Power] [보유량] Int[2][5]
[Ally/Enemy] [Ice/Ore/Water/Metal/Power] [잔여량] Int[2][5]

Factories

Feature Description Type Normalization
[Ally/Enemy] Factory 존재 여부 Bool[2][1]
[Ally/Enemy] Factory [Ice/Ore/Water/Metal/Power] 보유량 Int[2][5]
[Ally/Enemy] watering cost Int[2][1]
[Ally/Enemy] delta power Int[2][1]

Map

Lichen

Feature Description Type Normalization
Lichen Value Int

Rubble

Feature Description Type Normalization
Rubble value Int

Resource

Feature Description Type Normalization
Ice Bool
Ore Bool

Global Feature

Feature Description Type Normalization
Current Cycle Bool[20]
Current Turn in Cycle Bool[50]
If at day Bool
(Ally/Enemy)Total Lichen Int[2]

Action

gyusang commented 11 months ago
UnitAction(
    action_type=Array(0, dtype=int8),
    direction=Array(1, dtype=int8),
    resource_type=Array(0, dtype=int8),
    amount=Array(0, dtype=int16),
    repeat=Array(0, dtype=int16),
    n=Array(1, dtype=int16)
)

를 아래와 같이 변환

UnitAction(
    action_type=Categorical((MOVE, TRANSFER, PICKUP, DIG, SELF_DESTRUCT, RECHARGE), n=6),
    direction=Categorical((CENTER, UP, RIGHT, DOWN, LEFT), n=5),
    resource_type=Categorical((ice, ore, water, metal, power), n=5),
    amount=Array(0, dtype=int16),
    repeat=Array(False, dtype=bool),
    n=Array(1, dtype=int16)
)

amount & n 을 어떻게 할 것인가...

ActionQueue(
    data=UnitAction(...),
    front=ShapedArray(int8[]),
    rear=ShapedArray(int8[]),
    count=ShapedArray(int8[])
)

[front] , [(front+1)%capacity], ..., [(rear - 1) % capacity] where (rear - 1) % capacity == (front + count - 2) % capacity

    unit_type=ShapedArray(int8[2,200]),
    action_queue=...,
    team_id=ShapedArray(int8[2,200]),
    unit_id=ShapedArray(int16[2,200]),
    pos=Position(pos=ShapedArray(int8[2,200,2])),
    cargo=UnitCargo(stock=ShapedArray(int32[2,200,4])),
    power=ShapedArray(int32[2,200])
)

pos를 이용해서 board에 concatenate?

id2idx[4] = [0 0]
id2idx[5] = [0 1]
id2idx[6] = [1 0]
id2idx[7] = [1 1]
state.units.unit_id[0][0]=Array(4, dtype=int16)
state.units.unit_id[0][1]=Array(5, dtype=int16)
state.units.unit_id[1][0]=Array(6, dtype=int16)
state.units.unit_id[1][1]=Array(7, dtype=int16)

id2idx를 잘라서 batching하고 각 batch마다 이미지에 해당 픽셀에 유닛/팩토리 데이터 넣은 다음에 다 sum() 하면 어차피 한 번씩만 들어갈거고 나머지는 zero니까 좀 병렬적으로 observation을 만들 수 있긴 할듯

한 unit씩 순차적으로 하는 것보다?

caelum02 commented 11 months ago

픽셀별 유닛 위치 구현

0e2de0aa7a9c9dcee621f70b18b078a822b732a5

@partial(vmap, in_axes=0)
def add_at_mask(array, x, y, mask):
    zeros = jnp.zeros_like(array)

    # `mode=drop` prevents unexpected index-out-of-bound behavior
    out = zeros.at[x, y].add(mask, mode='drop')

    return out

# TODO: Consider jitting this function
def get_unit_existence(unit_mask, unit_type, x, y):
    '''
        unit_type : ShapedArray(int8[2, MAX_N_UNITS])
        unit_mask : ShapedArray(bool[2, MAX_N_UNITS])
        x : ShapedArray(int8[2, MAX_N_UNITS])
        y : ShapedArray(int8[2, MAX_N_UNITS])

        output: ShapedArray(int8[4, MAX_N_UNITS])

        light player 0, light player 1, heavy player 0, heavy player 1

        unit type goes to axis 0 to preserve locality of player & unit_id axis
    '''

    light_mask = unit_mask & (unit_type==UnitType.LIGHT)
    heavy_mask = unit_mask & (unit_type==UnitType.HEAVY)

    zeros = jnp.zeros((2, MAP_SIZE, MAP_SIZE), dtype=jnp.int8)

    light_unit_map = add_at_mask(zeros, x, y, light_mask)
    heavy_unit_map = add_at_mask(zeros, x, y, heavy_mask)

    unit_map = jnp.concatenate((light_unit_map, heavy_unit_map))

    return unit_map
caelum02 commented 10 months ago

Action queue를 encoding해서 Channel Attention weight로 넣는 건 어떤지