LeelaChessZero / lczero-training

For code etc relating to the network training process.
143 stars 119 forks source link

Testing out coatnet style relative attention #182

Open Tilps opened 2 years ago

Tilps commented 2 years ago

multi_head_relative_attention is a bit overkill, I forked the entire multi_head_attention from keras, it supports all kinds of stuff, in terms of dimensions, but with the relative logic added, it is now limited to NHWC input despite all the options indicating otherwise.

Tilps commented 2 years ago

net.proto to go with my saving tweaks. /* This file is part of Leela Chess Zero. Copyright (C) 2018 The LCZero Authors

Leela Chess is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.

Leela Chess is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.

You should have received a copy of the GNU General Public License along with Leela Chess. If not, see http://www.gnu.org/licenses/.

Additional permission under GNU GPL version 3 section 7

If you modify this Program, or any covered work, by linking or combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA Toolkit and the NVIDIA CUDA Deep Neural Network library (or a modified version of those libraries), containing parts covered by the terms of the respective license agreement, the licensors of this Program grant you additional permission to convey the resulting work. */ syntax = "proto2";

package pblczero;

message EngineVersion { optional uint32 major = 1; optional uint32 minor = 2; optional uint32 patch = 3; }

message Weights { message Layer { optional float min_val = 1; optional float max_val = 2; optional bytes params = 3; }

message ConvBlock { optional Layer weights = 1; optional Layer biases = 2; optional Layer bn_means = 3; optional Layer bn_stddivs = 4; optional Layer bn_gammas = 5; optional Layer bn_betas = 6; }

message SEunit { // Squeeze-excitation unit (https://arxiv.org/abs/1709.01507) // weights and biases of the two fully connected layers. optional Layer w1 = 1; optional Layer b1 = 2; optional Layer w2 = 3; optional Layer b2 = 4; }

message Residual { optional ConvBlock conv1 = 1; optional ConvBlock conv2 = 2; optional SEunit se = 3; }

message MHRA { optional Layer lngammas = 1; optional Layer lnbetas = 2; optional Layer rel_bias_table = 3; optional Layer qw = 4; optional Layer qb = 5; optional Layer kw = 6; optional Layer kb = 7; optional Layer vw = 8; optional Layer vb = 9; optional Layer ow = 10; optional Layer ob = 11; }

message FFN { optional Layer lngammas = 1; optional Layer lnbetas = 2; optional Layer fc1w = 3; optional Layer fc1b = 4; optional Layer fc2w = 5; optional Layer fc2b = 6; }

message RRA { optional MHRA mhra = 1; optional FFN ffn = 2; }

// Input convnet. optional ConvBlock input = 1;

// Residual tower. repeated Residual residual = 2;

// Policy head // Extra convolution for AZ-style policy head optional ConvBlock policy1 = 11; optional ConvBlock policy = 3; optional Layer ip_pol_w = 4; optional Layer ip_pol_b = 5;

// Value head optional ConvBlock value = 6; optional Layer ip1_val_w = 7; optional Layer ip1_val_b = 8; optional Layer ip2_val_w = 9; optional Layer ip2_val_b = 10;

// Moves left head optional ConvBlock moves_left = 12; optional Layer ip1_mov_w = 13; optional Layer ip1_mov_b = 14; optional Layer ip2_mov_w = 15; optional Layer ip2_mov_b = 16;

// rra tower. repeated RRA rra = 17; }

message TrainingParams { optional uint32 training_steps = 1; optional float learning_rate = 2; optional float mse_loss = 3; optional float policy_loss = 4; optional float accuracy = 5; optional string lc0_params = 6; }

message NetworkFormat { // Format to encode the input planes with. Used by position encoder. enum InputFormat { INPUT_UNKNOWN = 0; INPUT_CLASSICAL_112_PLANE = 1; INPUT_112_WITH_CASTLING_PLANE = 2; INPUT_112_WITH_CANONICALIZATION = 3; INPUT_112_WITH_CANONICALIZATION_HECTOPLIES = 4; INPUT_112_WITH_CANONICALIZATION_HECTOPLIES_ARMAGEDDON = 132; INPUT_112_WITH_CANONICALIZATION_V2 = 5; INPUT_112_WITH_CANONICALIZATION_V2_ARMAGEDDON = 133; } optional InputFormat input = 1;

// Output format of the NN. Used by search code to interpret results. enum OutputFormat { OUTPUT_UNKNOWN = 0; OUTPUT_CLASSICAL = 1; OUTPUT_WDL = 2; } optional OutputFormat output = 2;

// Network architecture. Used by backends to build the network. enum NetworkStructure { // Networks without PolicyFormat or ValueFormat specified NETWORK_UNKNOWN = 0; NETWORK_CLASSICAL = 1; NETWORK_SE = 2; // Networks with PolicyFormat and ValueFormat specified NETWORK_CLASSICAL_WITH_HEADFORMAT = 3; NETWORK_SE_WITH_HEADFORMAT = 4; } optional NetworkStructure network = 3;

// Policy head architecture enum PolicyFormat { POLICY_UNKNOWN = 0; POLICY_CLASSICAL = 1; POLICY_CONVOLUTION = 2; } optional PolicyFormat policy = 4;

// Value head architecture enum ValueFormat { VALUE_UNKNOWN = 0; VALUE_CLASSICAL = 1; VALUE_WDL = 2; VALUE_PARAM = 3; } optional ValueFormat value = 5;

// Moves left head architecture enum MovesLeftFormat { MOVES_LEFT_NONE = 0; MOVES_LEFT_V1 = 1; } optional MovesLeftFormat moves_left = 6; }

message Format { enum Encoding { UNKNOWN = 0; LINEAR16 = 1; }

optional Encoding weights_encoding = 1; // If network_format is missing, it's assumed to have // INPUT_CLASSICAL_112_PLANE / OUTPUT_CLASSICAL / NETWORK_CLASSICAL format. optional NetworkFormat network_format = 2; }

message Net { optional fixed32 magic = 1; optional string license = 2; optional EngineVersion min_version = 3; optional Format format = 4; optional TrainingParams training_params = 5; optional Weights weights = 10; }