LeelaChessZero / lc0

The rewritten engine, originally for tensorflow. Now all other backends have been ported here.
GNU General Public License v3.0
2.38k stars 525 forks source link

Common changes for new multiple heads architecture. #1915

Closed almaudoh closed 6 months ago

almaudoh commented 1 year ago

Changes shared by all backends for the new BT3 multiple policy and value heads architecture.

borg323 commented 6 months ago

Here is an example porting the blas backend to the new MultHeadWeights. Mostly renaming LegacyWeights to MultiHeadWeights and moving some weight references to the appropriate head (easy to find as the build will fail):

```diff diff --git a/src/neural/blas/network_blas.cc b/src/neural/blas/network_blas.cc index a04708f..7a10b6e 100644 --- a/src/neural/blas/network_blas.cc +++ b/src/neural/blas/network_blas.cc @@ -61,9 +61,9 @@ class BlasNetwork; template class BlasComputation : public NetworkComputation { public: - BlasComputation(BlasNetwork* network, const LegacyWeights& weights, - const size_t max_batch_size, const bool wdl, - const bool moves_left, const bool conv_policy, + BlasComputation(BlasNetwork* network, + const MultiHeadWeights& weights, const size_t max_batch_size, + const bool wdl, const bool moves_left, const bool conv_policy, const ActivationFunction default_activation, const ActivationFunction smolgen_activation, const ActivationFunction ffn_activation, @@ -119,7 +119,7 @@ class BlasComputation : public NetworkComputation { std::vector& head_buffer2, std::vector& head_buffer3, std::vector& head_buffer4, size_t batch_size, - const LegacyWeights::EncoderLayer& layer, + const MultiHeadWeights::EncoderLayer& layer, int embedding_size, int heads, ActivationFunction smolgen_activation, ActivationFunction ffn_activation, float alpha); @@ -132,7 +132,7 @@ class BlasComputation : public NetworkComputation { // The real number of planes is higher because of padding. static constexpr auto kPolicyUsedPlanes = 73; - const LegacyWeights& weights_; + const MultiHeadWeights& weights_; size_t max_batch_size_; std::vector planes_; std::vector> policies_; @@ -194,7 +194,7 @@ class BlasNetwork : public Network { static constexpr auto kHardMaxBatchSize = 2048; const NetworkCapabilities capabilities_; - LegacyWeights weights_; + MultiHeadWeights weights_; size_t max_batch_size_; bool wdl_; bool moves_left_; @@ -210,7 +210,7 @@ class BlasNetwork : public Network { template BlasComputation::BlasComputation( - BlasNetwork* network, const LegacyWeights& weights, + BlasNetwork* network, const MultiHeadWeights& weights, const size_t max_batch_size, const bool wdl, const bool moves_left, const bool conv_policy, const ActivationFunction default_activation, const ActivationFunction smolgen_activation, @@ -260,7 +260,7 @@ template void BlasComputation::MakeEncoderLayer( std::vector& head_buffer, std::vector& head_buffer2, std::vector& head_buffer3, std::vector& head_buffer4, - size_t batch_size, const LegacyWeights::EncoderLayer& layer, + size_t batch_size, const MultiHeadWeights::EncoderLayer& layer, int embedding_size, int heads, ActivationFunction smolgen_activation, ActivationFunction ffn_activation, float alpha) { const int d_model = layer.mha.q_b.size(); @@ -457,12 +457,14 @@ void BlasComputation::MakeEncoderLayer( template void BlasComputation::ComputeBlocking() { + const auto& value_head = weights_.value_heads.winner; + const auto& policy_head = weights_.policy_heads.vanilla; // Retrieve network key dimensions from the weights structure. - const auto num_value_channels = weights_.ip1_val_b.size(); + const auto num_value_channels = value_head.ip1_val_b.size(); const auto num_moves_channels = weights_.ip1_mov_b.size(); const auto num_value_input_planes = - attn_body_ ? weights_.ip_val_b.size() : weights_.value.biases.size(); - const auto num_policy_input_planes = weights_.policy.biases.size(); + attn_body_ ? value_head.ip_val_b.size() : value_head.value.biases.size(); + const auto num_policy_input_planes = policy_head.policy.biases.size(); const auto num_moves_input_planes = attn_body_ ? weights_.ip_mov_b.size() : weights_.moves_left.biases.size(); const auto num_output_policy = static_cast(kPolicyOutputs); @@ -480,8 +482,8 @@ void BlasComputation::ComputeBlocking() { // The policy head may increase convolution max output size. const auto max_output_channels = - (conv_policy_ && weights_.policy.biases.size() > output_channels) - ? weights_.policy.biases.size() + (conv_policy_ && policy_head.policy.biases.size() > output_channels) + ? policy_head.policy.biases.size() : output_channels; // Determine the largest batch for allocations. @@ -505,7 +507,7 @@ void BlasComputation::ComputeBlocking() { std::max(num_value_input_planes, num_moves_input_planes)); if (attn_policy_) { max_head_planes = std::max(std::max(max_head_planes, size_t{67}), - weights_.ip_pol_b.size()); + policy_head.ip_pol_b.size()); } std::unique_ptr buffers = network_->GetBuffers(); @@ -635,21 +637,21 @@ void BlasComputation::ComputeBlocking() { if (attn_body_) { FullyConnectedLayer::Forward1D( batch_size * kSquares, weights_.ip_emb_b.size(), - num_value_input_planes, buffer1.data(), weights_.ip_val_w.data(), - weights_.ip_val_b.data(), default_activation_, head_buffer.data()); + num_value_input_planes, buffer1.data(), value_head.ip_val_w.data(), + value_head.ip_val_b.data(), default_activation_, head_buffer.data()); } else { Convolution1::Forward( batch_size, output_channels, num_value_input_planes, buffer2.data(), - weights_.value.weights.data(), head_buffer.data()); + value_head.value.weights.data(), head_buffer.data()); BiasActivate(batch_size, num_value_input_planes, &head_buffer[0], - weights_.value.biases.data(), default_activation_); + value_head.value.biases.data(), default_activation_); } FullyConnectedLayer::Forward1D( batch_size, num_value_input_planes * kSquares, num_value_channels, - head_buffer.data(), weights_.ip1_val_w.data(), - weights_.ip1_val_b.data(), + head_buffer.data(), value_head.ip1_val_w.data(), + value_head.ip1_val_b.data(), default_activation_, // Activation On buffer3.data()); @@ -658,7 +660,7 @@ void BlasComputation::ComputeBlocking() { std::vector wdl(3 * batch_size); FullyConnectedLayer::Forward1D( batch_size, num_value_channels, 3, buffer3.data(), - weights_.ip2_val_w.data(), weights_.ip2_val_b.data(), + value_head.ip2_val_w.data(), value_head.ip2_val_b.data(), ACTIVATION_NONE, // Activation Off wdl.data()); @@ -673,9 +675,9 @@ void BlasComputation::ComputeBlocking() { } else { for (size_t j = 0; j < batch_size; j++) { double winrate = FullyConnectedLayer::Forward0D( - num_value_channels, weights_.ip2_val_w.data(), + num_value_channels, value_head.ip2_val_w.data(), &buffer3[j * num_value_channels]) + - weights_.ip2_val_b[0]; + value_head.ip2_val_b[0]; q_values_.emplace_back(std::tanh(winrate)); } @@ -726,22 +728,23 @@ void BlasComputation::ComputeBlocking() { } } } - const size_t policy_embedding_size = weights_.ip_pol_b.size(); + const size_t policy_embedding_size = policy_head.ip_pol_b.size(); // Policy Embedding. FullyConnectedLayer::Forward1D( batch_size * kSquares, output_channels, policy_embedding_size, - buffer1.data(), weights_.ip_pol_w.data(), weights_.ip_pol_b.data(), + buffer1.data(), policy_head.ip_pol_w.data(), + policy_head.ip_pol_b.data(), attn_body_ ? default_activation_ : ACTIVATION_SELU, // SELU activation hardcoded for apmish nets. buffer2.data()); - const size_t policy_d_model = weights_.ip2_pol_b.size(); + const size_t policy_d_model = policy_head.ip2_pol_b.size(); - for (auto& layer : weights_.pol_encoder) { + for (auto& layer : policy_head.pol_encoder) { MakeEncoderLayer(buffer2, buffer1, buffer3, head_buffer, batch_size, layer, policy_embedding_size, - weights_.pol_encoder_head_count, + policy_head.pol_encoder_head_count, attn_body_ ? smolgen_activation_ : ACTIVATION_NONE, attn_body_ ? ffn_activation_ : ACTIVATION_SELU, 1.0f); } @@ -749,13 +752,13 @@ void BlasComputation::ComputeBlocking() { // Q FullyConnectedLayer::Forward1D( batch_size * kSquares, policy_embedding_size, policy_d_model, - buffer2.data(), weights_.ip2_pol_w.data(), weights_.ip2_pol_b.data(), - ACTIVATION_NONE, buffer1.data()); + buffer2.data(), policy_head.ip2_pol_w.data(), + policy_head.ip2_pol_b.data(), ACTIVATION_NONE, buffer1.data()); // K FullyConnectedLayer::Forward1D( batch_size * kSquares, policy_embedding_size, policy_d_model, - buffer2.data(), weights_.ip3_pol_w.data(), weights_.ip3_pol_b.data(), - ACTIVATION_NONE, buffer3.data()); + buffer2.data(), policy_head.ip3_pol_w.data(), + policy_head.ip3_pol_b.data(), ACTIVATION_NONE, buffer3.data()); const float scaling = 1.0f / sqrtf(policy_d_model); for (auto batch = size_t{0}; batch < batch_size; batch++) { const float* A = &buffer1[batch * 64 * policy_d_model]; @@ -789,7 +792,7 @@ void BlasComputation::ComputeBlocking() { for (size_t k = 0; k < policy_d_model; k++) { sum += buffer3[batch * kSquares * policy_d_model + (56 + j) * policy_d_model + k] * - weights_.ip4_pol_w[i * policy_d_model + k]; + policy_head.ip4_pol_w[i * policy_d_model + k]; } promotion_offsets[i][j] = sum; } @@ -824,18 +827,18 @@ void BlasComputation::ComputeBlocking() { } else if (conv_policy_) { assert(!attn_body_); // not supported with attention body convolve3.Forward(batch_size, output_channels, output_channels, - buffer2.data(), weights_.policy1.weights.data(), + buffer2.data(), policy_head.policy1.weights.data(), buffer1.data()); BiasActivate(batch_size, output_channels, buffer1.data(), - weights_.policy1.biases.data(), default_activation_); + policy_head.policy1.biases.data(), default_activation_); convolve3.Forward(batch_size, output_channels, num_policy_input_planes, - buffer1.data(), weights_.policy.weights.data(), + buffer1.data(), policy_head.policy.weights.data(), head_buffer.data()); BiasActivate(batch_size, num_policy_input_planes, head_buffer.data(), - weights_.policy.biases.data(), ACTIVATION_NONE); + policy_head.policy.biases.data(), ACTIVATION_NONE); // Mapping from convolutional policy to lc0 policy for (auto batch = size_t{0}; batch < batch_size; batch++) { @@ -854,15 +857,15 @@ void BlasComputation::ComputeBlocking() { assert(!attn_body_); // not supported with attention body Convolution1::Forward( batch_size, output_channels, num_policy_input_planes, buffer2.data(), - weights_.policy.weights.data(), head_buffer.data()); + policy_head.policy.weights.data(), head_buffer.data()); BiasActivate(batch_size, num_policy_input_planes, &head_buffer[0], - weights_.policy.biases.data(), default_activation_); + policy_head.policy.biases.data(), default_activation_); FullyConnectedLayer::Forward1D( batch_size, num_policy_input_planes * kSquares, num_output_policy, - head_buffer.data(), weights_.ip_pol_w.data(), - weights_.ip_pol_b.data(), + head_buffer.data(), policy_head.ip_pol_w.data(), + policy_head.ip_pol_b.data(), ACTIVATION_NONE, // Activation Off buffer3.data()); @@ -913,8 +916,9 @@ BlasNetwork::BlasNetwork(const WeightsFile& file, attn_policy_ = file.format().network_format().policy() == pblczero::NetworkFormat::POLICY_ATTENTION; - attn_body_ = file.format().network_format().network() == - pblczero::NetworkFormat::NETWORK_ATTENTIONBODY_WITH_HEADFORMAT; + attn_body_ = + file.format().network_format().network() == + pblczero::NetworkFormat::NETWORK_ATTENTIONBODY_WITH_HEADFORMAT; default_activation_ = file.format().network_format().default_activation() == pblczero::NetworkFormat::DEFAULT_ACTIVATION_MISH @@ -955,11 +959,12 @@ BlasNetwork::BlasNetwork(const WeightsFile& file, } if (conv_policy_) { - weights_.policy1.weights = - WinogradFilterTransformF(weights_.policy1.weights, channels, channels); - auto pol_channels = weights_.policy.biases.size(); - weights_.policy.weights = WinogradFilterTransformF(weights_.policy.weights, - pol_channels, channels); + auto& policy_head = weights_.policy_heads.vanilla; + policy_head.policy1.weights = WinogradFilterTransformF( + policy_head.policy1.weights, channels, channels); + auto pol_channels = policy_head.policy.biases.size(); + policy_head.policy.weights = WinogradFilterTransformF( + policy_head.policy.weights, pol_channels, channels); } if (use_eigen) { @@ -1052,6 +1057,15 @@ std::unique_ptr MakeBlasNetwork(const std::optional& w, weights.format().network_format().default_activation()) + " is not supported by BLAS backend."); } + if (weights.format().network_format().input_embedding() != + pblczero::NetworkFormat::INPUT_EMBEDDING_NONE && + weights.format().network_format().input_embedding() != + pblczero::NetworkFormat::INPUT_EMBEDDING_PE_MAP) { + throw Exception("Input embedding " + + pblczero::NetworkFormat::InputEmbeddingFormat_Name( + weights.format().network_format().input_embedding()) + + " is not supported by BLAS backend."); + } return std::make_unique>(weights, options); } ```
almaudoh commented 6 months ago

Blas backend reference implementation has been added.