deel-ai / influenciae

👋 Influenciae is a Tensorflow Toolbox for Influence Functions
https://deel-ai.github.io/influenciae
Other
55 stars 3 forks source link

[Feature Request] Add parallel_iterations and experimental_use_pfor parameters in `_compute_inv_hessian` (ExactIHVP) #4

Closed lucashervier closed 2 years ago

lucashervier commented 2 years ago

Is your feature request related to a problem? Please describe.

When using the first-order-influence-koh-liang branch I have some trouble when I want to compute the exact inverse hessian product on a semantic segmentation model. Here is a minimal example and the corresponding outpout logs that I got:

import tensorflow as tf

from influenciae.common.model_wrappers import InfluenceModel
from influenciae.influence.inverse_hessian_vector_product import ExactIHVP

IMG_SIZE = 768
NUM_CLASSES = 20

inp = tf.keras.Input(shape=(IMG_SIZE, IMG_SIZE, 3))

# A conv block
x = tf.keras.layers.Conv2D(filters=32, kernel_size=1, strides=(1, 1))(inp)
x = tf.keras.layers.Dropout(0.2)(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Activation('relu')(x)
# FCN block
x = tf.keras.layers.UpSampling2D(
    size=(IMG_SIZE // x.shape[1], IMG_SIZE// x.shape[2]),
    interpolation="bilinear",
)(x)
model_output = tf.keras.layers.Conv2D(NUM_CLASSES, kernel_size=(1, 1), padding="same")(x)

# define model
model = tf.keras.Model(inputs=inp, outputs=model_output)
# freeze all layers except last one
for layer in model.layers:
    layer.trainable = False
for layer in model.layers[-1:]:
    layer.trainable = True
print(model.summary())
# define a loss for semantic segmentation fitting reduction None
class CustomLoss2(tf.keras.losses.Loss):

    def __init__(self, num_classes, ignore_label):
        super(CustomLoss2, self).__init__(name='CustomLoss2', reduction=tf.keras.losses.Reduction.NONE)

        self.num_classes = num_classes
        self.ignore_label = ignore_label

    def call(self, y_true, y_pred):

        sample_weights = tf.cast(tf.not_equal(y_true, self.ignore_label), dtype=tf.float32)
        one_hot_gt = tf.stop_gradient(tf.one_hot(y_true, self.num_classes))

        loss = tf.nn.softmax_cross_entropy_with_logits(one_hot_gt, y_pred)
        weighted_loss = tf.multiply(loss, tf.squeeze(sample_weights))

        # Compute mean loss over spatial dimension.
        num_non_zero = tf.reduce_sum(
            tf.cast(tf.not_equal(weighted_loss, 0.0), tf.float32), 1)
        loss_sum_per_sample = tf.reduce_sum(weighted_loss, 1)
        return tf.reduce_sum(tf.math.divide_no_nan(loss_sum_per_sample, num_non_zero), 1)

if __name__ == "__main__":
    random_input = tf.random.normal(shape=(4, IMG_SIZE, IMG_SIZE, 3))
    random_target = tf.random.uniform(shape=(4, IMG_SIZE, IMG_SIZE), minval=0, maxval=NUM_CLASSES-1, dtype=tf.int32)

    random_dataset = tf.data.Dataset.from_tensor_slices((random_input, random_target))

    # define InfluenceModel
    influence_model = InfluenceModel(model, target_layer=-1, loss_function=CustomLoss2(NUM_CLASSES, ignore_label=255))
    # freeze all layers except last one
    for layer in influence_model.layers:
        layer.trainable = False
    for layer in influence_model.layers[-1:]:
        layer.trainable = True
    ihvp_calculator = ExactIHVP(influence_model, random_dataset.take(1).batch(1))

Logs:

(bdd_env) (base) lucas.hervier@soda01:~/bdd100$ python issue_minimal.py 
2022-02-11 10:59:15.926556: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
2022-02-11 10:59:17.602358: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcuda.so.1
2022-02-11 10:59:17.658599: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-02-11 10:59:17.659380: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1733] Found device 0 with properties: 
pciBusID: 0000:21:00.0 name: NVIDIA GeForce RTX 3090 computeCapability: 8.6
coreClock: 1.86GHz coreCount: 82 deviceMemorySize: 23.70GiB deviceMemoryBandwidth: 871.81GiB/s
2022-02-11 10:59:17.659421: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-02-11 10:59:17.660154: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1733] Found device 1 with properties: 
pciBusID: 0000:4a:00.0 name: NVIDIA GeForce RTX 3090 computeCapability: 8.6
coreClock: 1.86GHz coreCount: 82 deviceMemorySize: 23.70GiB deviceMemoryBandwidth: 871.81GiB/s
2022-02-11 10:59:17.660173: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
2022-02-11 10:59:17.661903: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublas.so.11
2022-02-11 10:59:17.661930: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublasLt.so.11
2022-02-11 10:59:17.662492: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcufft.so.10
2022-02-11 10:59:17.662623: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcurand.so.10
2022-02-11 10:59:17.663131: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcusolver.so.11
2022-02-11 10:59:17.663545: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcusparse.so.11
2022-02-11 10:59:17.663616: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudnn.so.8
2022-02-11 10:59:17.663665: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-02-11 10:59:17.664449: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-02-11 10:59:17.665198: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-02-11 10:59:17.665944: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-02-11 10:59:17.666664: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1871] Adding visible gpu devices: 0, 1
2022-02-11 10:59:17.666925: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-02-11 10:59:17.786724: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-02-11 10:59:17.787447: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1733] Found device 0 with properties: 
pciBusID: 0000:21:00.0 name: NVIDIA GeForce RTX 3090 computeCapability: 8.6
coreClock: 1.86GHz coreCount: 82 deviceMemorySize: 23.70GiB deviceMemoryBandwidth: 871.81GiB/s
2022-02-11 10:59:17.787484: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-02-11 10:59:17.788149: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1733] Found device 1 with properties: 
pciBusID: 0000:4a:00.0 name: NVIDIA GeForce RTX 3090 computeCapability: 8.6
coreClock: 1.86GHz coreCount: 82 deviceMemorySize: 23.70GiB deviceMemoryBandwidth: 871.81GiB/s
2022-02-11 10:59:17.788187: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-02-11 10:59:17.788895: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-02-11 10:59:17.789599: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-02-11 10:59:17.790300: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-02-11 10:59:17.790980: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1871] Adding visible gpu devices: 0, 1
2022-02-11 10:59:17.791016: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
2022-02-11 10:59:18.257978: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1258] Device interconnect StreamExecutor with strength 1 edge matrix:
2022-02-11 10:59:18.258015: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1264]      0 1 
2022-02-11 10:59:18.258021: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1277] 0:   N N 
2022-02-11 10:59:18.258024: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1277] 1:   N N 
2022-02-11 10:59:18.258195: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-02-11 10:59:18.258947: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-02-11 10:59:18.259658: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-02-11 10:59:18.260379: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-02-11 10:59:18.261077: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-02-11 10:59:18.261784: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1418] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 22302 MB memory) -> physical GPU (device: 0, name: NVIDIA GeForce RTX 3090, pci bus id: 0000:21:00.0, compute capability: 8.6)
2022-02-11 10:59:18.262082: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-02-11 10:59:18.262783: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1418] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:1 with 22312 MB memory) -> physical GPU (device: 1, name: NVIDIA GeForce RTX 3090, pci bus id: 0000:4a:00.0, compute capability: 8.6)
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 768, 768, 3)]     0         
_________________________________________________________________
conv2d (Conv2D)              (None, 768, 768, 32)      128       
_________________________________________________________________
dropout (Dropout)            (None, 768, 768, 32)      0         
_________________________________________________________________
batch_normalization (BatchNo (None, 768, 768, 32)      128       
_________________________________________________________________
activation (Activation)      (None, 768, 768, 32)      0         
_________________________________________________________________
up_sampling2d (UpSampling2D) (None, 768, 768, 32)      0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 768, 768, 20)      660       
=================================================================
Total params: 916
Trainable params: 660
Non-trainable params: 256
_________________________________________________________________
None
2022-02-11 10:59:18.626040: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)
2022-02-11 10:59:18.644320: I tensorflow/core/platform/profile_utils/cpu_utils.cc:114] CPU Frequency: 3700110000 Hz
2022-02-11 10:59:18.672693: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudnn.so.8
2022-02-11 10:59:19.060262: I tensorflow/stream_executor/cuda/cuda_dnn.cc:359] Loaded cuDNN version 8100
2022-02-11 10:59:19.548581: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublas.so.11
2022-02-11 10:59:19.925576: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublasLt.so.11
WARNING:tensorflow:Using a while_loop for converting Conv2D
WARNING:tensorflow:Using a while_loop for converting Conv2DBackpropInput
WARNING:tensorflow:Using a while_loop for converting ResizeBilinearGrad
2022-02-11 11:04:57.055515: W tensorflow/core/common_runtime/bfc_allocator.cc:456] Allocator (GPU_0_bfc) ran out of memory trying to allocate 45.00GiB (rounded to 48318382080)requested by op loop_body/PartitionedCall/pfor/PartitionedCall/gradients/gradient_tape/model/conv2d_1/Conv2D/Conv2DBackpropFilter_grad/Conv2D/pfor/Tile
If the cause is memory fragmentation maybe the environment variable 'TF_GPU_ALLOCATOR=cuda_malloc_async' will improve the situation. 
Current allocation summary follows.
Current allocation summary follows.
2022-02-11 11:04:57.055561: I tensorflow/core/common_runtime/bfc_allocator.cc:991] BFCAllocator dump for GPU_0_bfc
2022-02-11 11:04:57.055569: I tensorflow/core/common_runtime/bfc_allocator.cc:998] Bin (256):   Total Chunks: 24, Chunks in use: 24. 6.0KiB allocated for chunks. 6.0KiB in use in bin. 1.3KiB client-requested in use in bin.
2022-02-11 11:04:57.055575: I tensorflow/core/common_runtime/bfc_allocator.cc:998] Bin (512):   Total Chunks: 1, Chunks in use: 1. 512B allocated for chunks. 512B in use in bin. 384B client-requested in use in bin.
2022-02-11 11:04:57.055581: I tensorflow/core/common_runtime/bfc_allocator.cc:998] Bin (1024):  Total Chunks: 1, Chunks in use: 1. 1.2KiB allocated for chunks. 1.2KiB in use in bin. 1.0KiB client-requested in use in bin.
2022-02-11 11:04:57.055587: I tensorflow/core/common_runtime/bfc_allocator.cc:998] Bin (2048):  Total Chunks: 5, Chunks in use: 4. 12.5KiB allocated for chunks. 10.5KiB in use in bin. 10.5KiB client-requested in use in bin.
2022-02-11 11:04:57.055593: I tensorflow/core/common_runtime/bfc_allocator.cc:998] Bin (4096):  Total Chunks: 0, Chunks in use: 0. 0B allocated for chunks. 0B in use in bin. 0B client-requested in use in bin.
2022-02-11 11:04:57.055598: I tensorflow/core/common_runtime/bfc_allocator.cc:998] Bin (8192):  Total Chunks: 0, Chunks in use: 0. 0B allocated for chunks. 0B in use in bin. 0B client-requested in use in bin.
2022-02-11 11:04:57.055605: I tensorflow/core/common_runtime/bfc_allocator.cc:998] Bin (16384):         Total Chunks: 0, Chunks in use: 0. 0B allocated for chunks. 0B in use in bin. 0B client-requested in use in bin.
2022-02-11 11:04:57.055613: I tensorflow/core/common_runtime/bfc_allocator.cc:998] Bin (32768):         Total Chunks: 0, Chunks in use: 0. 0B allocated for chunks. 0B in use in bin. 0B client-requested in use in bin.
2022-02-11 11:04:57.055621: I tensorflow/core/common_runtime/bfc_allocator.cc:998] Bin (65536):         Total Chunks: 0, Chunks in use: 0. 0B allocated for chunks. 0B in use in bin. 0B client-requested in use in bin.
2022-02-11 11:04:57.055631: I tensorflow/core/common_runtime/bfc_allocator.cc:998] Bin (131072):        Total Chunks: 0, Chunks in use: 0. 0B allocated for chunks. 0B in use in bin. 0B client-requested in use in bin.
2022-02-11 11:04:57.055636: I tensorflow/core/common_runtime/bfc_allocator.cc:998] Bin (262144):        Total Chunks: 0, Chunks in use: 0. 0B allocated for chunks. 0B in use in bin. 0B client-requested in use in bin.
2022-02-11 11:04:57.055641: I tensorflow/core/common_runtime/bfc_allocator.cc:998] Bin (524288):        Total Chunks: 1, Chunks in use: 0. 571.0KiB allocated for chunks. 0B in use in bin. 0B client-requested in use in bin.
2022-02-11 11:04:57.055647: I tensorflow/core/common_runtime/bfc_allocator.cc:998] Bin (1048576):       Total Chunks: 0, Chunks in use: 0. 0B allocated for chunks. 0B in use in bin. 0B client-requested in use in bin.
2022-02-11 11:04:57.055656: I tensorflow/core/common_runtime/bfc_allocator.cc:998] Bin (2097152):       Total Chunks: 3, Chunks in use: 3. 6.75MiB allocated for chunks. 6.75MiB in use in bin. 6.06MiB client-requested in use in bin.
2022-02-11 11:04:57.055664: I tensorflow/core/common_runtime/bfc_allocator.cc:998] Bin (4194304):       Total Chunks: 0, Chunks in use: 0. 0B allocated for chunks. 0B in use in bin. 0B client-requested in use in bin.
2022-02-11 11:04:57.055671: I tensorflow/core/common_runtime/bfc_allocator.cc:998] Bin (8388608):       Total Chunks: 1, Chunks in use: 1. 9.00MiB allocated for chunks. 9.00MiB in use in bin. 9.00MiB client-requested in use in bin.
2022-02-11 11:04:57.055679: I tensorflow/core/common_runtime/bfc_allocator.cc:998] Bin (16777216):      Total Chunks: 1, Chunks in use: 1. 27.00MiB allocated for chunks. 27.00MiB in use in bin. 27.00MiB client-requested in use in bin.
2022-02-11 11:04:57.055687: I tensorflow/core/common_runtime/bfc_allocator.cc:998] Bin (33554432):      Total Chunks: 4, Chunks in use: 3. 172.68MiB allocated for chunks. 135.00MiB in use in bin. 135.00MiB client-requested in use in bin.
2022-02-11 11:04:57.055695: I tensorflow/core/common_runtime/bfc_allocator.cc:998] Bin (67108864):      Total Chunks: 5, Chunks in use: 5. 360.00MiB allocated for chunks. 360.00MiB in use in bin. 333.00MiB client-requested in use in bin.
2022-02-11 11:04:57.055702: I tensorflow/core/common_runtime/bfc_allocator.cc:998] Bin (134217728):     Total Chunks: 0, Chunks in use: 0. 0B allocated for chunks. 0B in use in bin. 0B client-requested in use in bin.
2022-02-11 11:04:57.055709: I tensorflow/core/common_runtime/bfc_allocator.cc:998] Bin (268435456):     Total Chunks: 1, Chunks in use: 0. 21.22GiB allocated for chunks. 0B in use in bin. 0B client-requested in use in bin.
2022-02-11 11:04:57.055718: I tensorflow/core/common_runtime/bfc_allocator.cc:1014] Bin for 45.00GiB was 256.00MiB, Chunk State: 
2022-02-11 11:04:57.055732: I tensorflow/core/common_runtime/bfc_allocator.cc:1020]   Size: 21.22GiB | Requested Size: 45.00MiB | in_use: 0 | bin_num: 20, prev:   Size: 72.00MiB | Requested Size: 72.00MiB | in_use: 1 | bin_num: -1, for: loop_body/PartitionedCall/pfor/PartitionedCall/gradients/model/conv2d_1/Conv2D_grad/Conv2DBackpropFilter/pfor/Conv2DBackpropFilter-0-TransposeNHWCToNCHW-LayoutOptimizer, stepid: 15496386686427765080, last_action: 4278547630, for: UNUSED, stepid: 15496386686427765080, last_action: 4278547628
2022-02-11 11:04:57.055739: I tensorflow/core/common_runtime/bfc_allocator.cc:1027] Next region of size 23385669632
2022-02-11 11:04:57.055747: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcc6000000 of size 1280 by op ScratchBuffer action_count 4278547493 step 0 next 1
2022-02-11 11:04:57.055753: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcc6000500 of size 256 by op Fill action_count 4278547503 step 0 next 5
2022-02-11 11:04:57.055758: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcc6000600 of size 256 by op Fill action_count 4278547504 step 0 next 2
2022-02-11 11:04:57.055764: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcc6000700 of size 256 by op Sub action_count 4278547495 step 0 next 3
2022-02-11 11:04:57.055768: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcc6000800 of size 256 by op Sub action_count 4278547496 step 0 next 4
2022-02-11 11:04:57.055774: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcc6000900 of size 256 by op Fill action_count 4278547505 step 0 next 8
2022-02-11 11:04:57.055778: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcc6000a00 of size 256 by op Fill action_count 4278547506 step 0 next 9
2022-02-11 11:04:57.055784: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcc6000b00 of size 256 by op Fill action_count 4278547507 step 0 next 6
2022-02-11 11:04:57.055790: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcc6000c00 of size 512 by op Add action_count 4278547500 step 0 next 7
2022-02-11 11:04:57.055795: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcc6000e00 of size 256 by op Fill action_count 4278547508 step 0 next 10
2022-02-11 11:04:57.055801: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcc6000f00 of size 256 by op Fill action_count 4278547509 step 0 next 11
2022-02-11 11:04:57.055806: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcc6001000 of size 256 by op Fill action_count 4278547519 step 0 next 15
2022-02-11 11:04:57.055812: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcc6001100 of size 256 by op AssignVariableOp action_count 4278547520 step 0 next 18
2022-02-11 11:04:57.055818: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcc6001200 of size 256 by op Mul action_count 4278547522 step 0 next 20
2022-02-11 11:04:57.055823: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcc6001300 of size 256 by op Add action_count 4278547524 step 0 next 22
2022-02-11 11:04:57.055829: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcc6001400 of size 256 by op Equal action_count 4278547529 step 0 next 24
2022-02-11 11:04:57.055835: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcc6001500 of size 256 by op CustomLoss2/weighted_loss/Const action_count 4278547533 step 0 next 26
2022-02-11 11:04:57.055841: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcc6001600 of size 256 by op CustomLoss2/NotEqual_1/y action_count 4278547534 step 0 next 27
2022-02-11 11:04:57.055847: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcc6001700 of size 256 by op model/batch_normalization/FusedBatchNormV3 action_count 4278547560 step 13684086849625510338 next 34
2022-02-11 11:04:57.055852: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcc6001800 of size 256 by op model/batch_normalization/FusedBatchNormV3 action_count 4278547561 step 13684086849625510338 next 35
2022-02-11 11:04:57.055858: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcc6001900 of size 256 by op model/batch_normalization/FusedBatchNormV3 action_count 4278547562 step 13684086849625510338 next 12
2022-02-11 11:04:57.055864: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcc6001a00 of size 256 by op Sub action_count 4278547511 step 0 next 13
2022-02-11 11:04:57.055869: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcc6001b00 of size 256 by op Sub action_count 4278547512 step 0 next 14
2022-02-11 11:04:57.055875: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcc6001c00 of size 256 by op model/batch_normalization/FusedBatchNormV3 action_count 4278547563 step 13684086849625510338 next 36
2022-02-11 11:04:57.055880: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcc6001d00 of size 256 by op model/batch_normalization/FusedBatchNormV3 action_count 4278547564 step 13684086849625510338 next 37
2022-02-11 11:04:57.055886: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcc6001e00 of size 256 by op gradient_tape/UnsortedSegmentSum/pfor/mul_1 action_count 4278547624 step 0 next 45
2022-02-11 11:04:57.055892: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] Free  at 7fdcc6001f00 of size 2048 by op UNUSED action_count 0 step 0 next 16
2022-02-11 11:04:57.055898: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcc6002700 of size 2560 by op Add action_count 4278547516 step 0 next 17
2022-02-11 11:04:57.055903: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcc6003100 of size 9437184 by op RandomUniformInt action_count 4278547528 step 0 next 19
2022-02-11 11:04:57.055909: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcc6903100 of size 3072 by op gradient_tape/CustomLoss2/Tile action_count 4278547532 step 0 next 25
2022-02-11 11:04:57.055915: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcc6903d00 of size 2560 by op gradient_tape/model/conv2d_1/Conv2D/Conv2DBackpropFilter action_count 4278547612 step 13684086849625510338 next 44
2022-02-11 11:04:57.055921: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcc6904700 of size 2560 by op gradient_tape/UnsortedSegmentSum/pfor/Tile action_count 4278547623 step 0 next 43
2022-02-11 11:04:57.055927: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] Free  at 7fdcc6905100 of size 584704 by op UNUSED action_count 4278547633 step 0 next 28
2022-02-11 11:04:57.055934: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcc6993d00 of size 2359296 by op CustomLoss2/ArithmeticOptimizer/ReorderCastLikeAndValuePreserving_float_Cast action_count 4278547544 step 13684086849625510338 next 33
2022-02-11 11:04:57.055940: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcc6bd3d00 of size 2359296 by op gradient_tape/UnsortedSegmentSum/pfor/UnsortedSegmentSum action_count 4278547632 step 15496386686427765080 next 41
2022-02-11 11:04:57.055946: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcc6e13d00 of size 2359296 by op CustomLoss2/softmax_cross_entropy_with_logits action_count 4278547592 step 13684086849625510338 next 42
2022-02-11 11:04:57.055952: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] Free  at 7fdcc7053d00 of size 39515136 by op UNUSED action_count 4278547619 step 13684086849625510338 next 21
2022-02-11 11:04:57.055957: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcc9603100 of size 28311552 by op Add action_count 4278547525 step 0 next 23
2022-02-11 11:04:57.055963: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdccb103100 of size 75497472 by op model/activation/Relu-0-1-TransposeNCHWToNHWC-LayoutOptimizer action_count 4278547566 step 13684086849625510338 next 31
2022-02-11 11:04:57.055969: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdccf903100 of size 47185920 by op gradient_tape/CustomLoss2/softmax_cross_entropy_with_logits/mul action_count 4278547610 step 13684086849625510338 next 32
2022-02-11 11:04:57.055975: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcd2603100 of size 75497472 by op model/conv2d/BiasAdd-0-1-TransposeNCHWToNHWC-LayoutOptimizer action_count 4278547558 step 13684086849625510338 next 29
2022-02-11 11:04:57.055981: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcd6e03100 of size 75497472 by op model/up_sampling2d/resize/ResizeBilinear action_count 4278547568 step 13684086849625510338 next 30
2022-02-11 11:04:57.055988: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcdb603100 of size 75497472 by op loop_body/PartitionedCall/pfor/PartitionedCall/gradients/CustomLoss2/softmax_cross_entropy_with_logits_grad/Softmax/pfor/Softmax action_count 4278547625 step 15496386686427765080 next 38
2022-02-11 11:04:57.055994: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdcdfe03100 of size 47185920 by op CustomLoss2/softmax_cross_entropy_with_logits action_count 4278547593 step 13684086849625510338 next 39
2022-02-11 11:04:57.056000: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdce2b03100 of size 47185920 by op model/conv2d_1/BiasAdd-0-0-TransposeNCHWToNHWC-LayoutOptimizer action_count 4278547589 step 13684086849625510338 next 40
2022-02-11 11:04:57.056006: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] InUse at 7fdce5803100 of size 75497472 by op loop_body/PartitionedCall/pfor/PartitionedCall/gradients/model/conv2d_1/Conv2D_grad/Conv2DBackpropFilter/pfor/Conv2DBackpropFilter-0-TransposeNHWCToNCHW-LayoutOptimizer action_count 4278547630 step 15496386686427765080 next 46
2022-02-11 11:04:57.056012: I tensorflow/core/common_runtime/bfc_allocator.cc:1046] Free  at 7fdcea003100 of size 22781677312 by op UNUSED action_count 4278547628 step 15496386686427765080 next 18446744073709551615
2022-02-11 11:04:57.056017: I tensorflow/core/common_runtime/bfc_allocator.cc:1051]      Summary of in-use Chunks by size: 
2022-02-11 11:04:57.056024: I tensorflow/core/common_runtime/bfc_allocator.cc:1054] 24 Chunks of size 256 totalling 6.0KiB
2022-02-11 11:04:57.056030: I tensorflow/core/common_runtime/bfc_allocator.cc:1054] 1 Chunks of size 512 totalling 512B
2022-02-11 11:04:57.056038: I tensorflow/core/common_runtime/bfc_allocator.cc:1054] 1 Chunks of size 1280 totalling 1.2KiB
2022-02-11 11:04:57.056048: I tensorflow/core/common_runtime/bfc_allocator.cc:1054] 3 Chunks of size 2560 totalling 7.5KiB
2022-02-11 11:04:57.056057: I tensorflow/core/common_runtime/bfc_allocator.cc:1054] 1 Chunks of size 3072 totalling 3.0KiB
2022-02-11 11:04:57.056067: I tensorflow/core/common_runtime/bfc_allocator.cc:1054] 3 Chunks of size 2359296 totalling 6.75MiB
2022-02-11 11:04:57.056076: I tensorflow/core/common_runtime/bfc_allocator.cc:1054] 1 Chunks of size 9437184 totalling 9.00MiB
2022-02-11 11:04:57.056087: I tensorflow/core/common_runtime/bfc_allocator.cc:1054] 1 Chunks of size 28311552 totalling 27.00MiB
2022-02-11 11:04:57.056096: I tensorflow/core/common_runtime/bfc_allocator.cc:1054] 3 Chunks of size 47185920 totalling 135.00MiB
2022-02-11 11:04:57.056104: I tensorflow/core/common_runtime/bfc_allocator.cc:1054] 5 Chunks of size 75497472 totalling 360.00MiB
2022-02-11 11:04:57.056113: I tensorflow/core/common_runtime/bfc_allocator.cc:1058] Sum Total of in-use chunks: 537.77MiB
2022-02-11 11:04:57.056122: I tensorflow/core/common_runtime/bfc_allocator.cc:1060] total_region_allocated_bytes_: 23385669632 memory_limit_: 23385669632 available bytes: 0 curr_region_allocation_bytes_: 46771339264
2022-02-11 11:04:57.056135: I tensorflow/core/common_runtime/bfc_allocator.cc:1066] Stats: 
Limit:                     23385669632
InUse:                       563890432
MaxInUse:                    600314368
NumAllocs:                          92
MaxAllocSize:                 99865600
Reserved:                            0
PeakReserved:                        0
LargestFreeBlock:                    0

2022-02-11 11:04:57.056161: W tensorflow/core/common_runtime/bfc_allocator.cc:467] ***_________________________________________________________________________________________________
2022-02-11 11:04:57.056221: W tensorflow/core/framework/op_kernel.cc:1767] OP_REQUIRES failed at tile_ops.cc:198 : Resource exhausted: OOM when allocating tensor with shape[640,1,768,768,32] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
Traceback (most recent call last):
  File "issue_minimal.py", line 67, in <module>
    ihvp_calculator = ExactIHVP(influence_model, random_dataset.take(1).batch(1))
  File "/home/lucas.hervier/bdd100/bdd_env/lib/python3.8/site-packages/Influenciae-0.0.1-py3.8.egg/influenciae/influence/inverse_hessian_vector_product.py", line 59, in __init__
  File "/home/lucas.hervier/bdd100/bdd_env/lib/python3.8/site-packages/Influenciae-0.0.1-py3.8.egg/influenciae/influence/inverse_hessian_vector_product.py", line 83, in _compute_inv_hessian
  File "/home/lucas.hervier/bdd100/bdd_env/lib/python3.8/site-packages/tensorflow/python/eager/backprop.py", line 1175, in jacobian
    output = pfor_ops.pfor(loop_fn, target_size,
  File "/home/lucas.hervier/bdd100/bdd_env/lib/python3.8/site-packages/tensorflow/python/ops/parallel_for/control_flow_ops.py", line 206, in pfor
    outputs = f()
  File "/home/lucas.hervier/bdd100/bdd_env/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 889, in __call__
    result = self._call(*args, **kwds)
  File "/home/lucas.hervier/bdd100/bdd_env/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 956, in _call
    return self._concrete_stateful_fn._call_flat(
  File "/home/lucas.hervier/bdd100/bdd_env/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 1960, in _call_flat
    return self._build_call_outputs(self._inference_function.call(
  File "/home/lucas.hervier/bdd100/bdd_env/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 591, in call
    outputs = execute.execute(
  File "/home/lucas.hervier/bdd100/bdd_env/lib/python3.8/site-packages/tensorflow/python/eager/execute.py", line 59, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.ResourceExhaustedError:  OOM when allocating tensor with shape[640,1,768,768,32] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
         [[{{node loop_body/PartitionedCall/pfor/PartitionedCall/gradients/gradient_tape/model/conv2d_1/Conv2D/Conv2DBackpropFilter_grad/Conv2D/pfor/Tile}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.
 [Op:__inference_f_1291]

Function call stack:
f

As you can see, I face an OOM issue when trying to allocate a tensor with shape [640, 1, 768, 768, 32]. 640 is the number of weights (so basically the gradient vector size) 1 the number of inputs and [768, 768, 32] is the size of the input ONCE he got through all the layers except the last one. And as you might notice, this vector is allocated when we try to do:

hess = tf.squeeze(tape_hess.jacobian(grads, weights))

In the function _compute_inv_hessian in the inverse_hessian_vector_product.py file.

Describe the solution you'd like

I know that to compute the hessian we need this vector. But I was wondering if we cannot split this vector among the grads dim and my colleague @dv-ai has found out a workaround solution if you make some little change in the _compute_inv_hessian function:

Old:

  def _compute_inv_hessian(self, dataset: tf.data.Dataset) -> tf.Tensor:
      """
      Compute the (pseudo)-inverse of the hessian matrix wrt to the model's parameters using backward-mode AD.

      Disclaimer: this implementation trades memory usage for speed, so it can be quite memory intensive, especially
      when dealing with big models.

      Args:
          dataset: tf.data.Dataset
              A TF dataset containing the whole or part of the training dataset for the computation of the inverse
              of the mean hessian matrix.

      Returns:
          A tf.Tensor with the resulting inverse hessian matrix
      """
      weights = self.model.weights
      with tf.GradientTape(persistent=False, watch_accessed_variables=False) as tape_hess:
          tape_hess.watch(weights)
          grads = self.model.batch_gradient(dataset) if dataset._batch_size == 1 \
              else self.model.batch_jacobian(dataset)

      hess = tf.squeeze(tape_hess.jacobian(grads, weights))
      hessian = tf.reduce_mean(tf.reshape(hess, (-1, int(tf.reduce_prod(weights.shape)), int(tf.reduce_prod(weights.shape)))), axis=0)

      return tf.linalg.pinv(hessian)

Alternative:

  def _compute_inv_hessian(self, dataset: tf.data.Dataset) -> tf.Tensor:
      """
      Compute the (pseudo)-inverse of the hessian matrix wrt to the model's parameters using
      backward-mode AD.

      Disclaimer: this implementation trades memory usage for speed, so it can be quite
      memory intensive, especially when dealing with big models.

      Parameters
      ----------
      dataset
          A TF dataset containing the whole or part of the training dataset for the
          computation of the inverse of the mean hessian matrix.

      Returns
      ----------
      inv_hessian
          A tf.Tensor with the resulting inverse hessian matrix
      """
      weights = self.model.weights
      with tf.GradientTape(persistent=True, watch_accessed_variables=False) as tape_hess:
          tape_hess.watch(weights)
          grads = self.model.batch_gradient(dataset) if dataset._batch_size == 1 \
              else self.model.batch_jacobian(dataset) # pylint: disable=W0212

      hess = tf.squeeze(tape_hess.jacobian(grads, weights, parallel_iterations=10, experimental_use_pfor=False))

      hessian = tf.reduce_mean(tf.reshape(hess,
                                          (-1, int(tf.reduce_prod(weights.shape)),
                                           int(tf.reduce_prod(weights.shape)))), axis=0)

      return tf.linalg.pinv(hessian)

By changing: persistent to True and by setting in the .jacobian call the parameters: parallel_iterations=10 and experimental_use_pfor=False the computation is done.

N.B: 10 is not important as long it is a natural divider of the number of grads length (unfortunate for prime number though)

See if I add to my script:

print(ihvp_calculator.inv_hessian)

I got:

tf.Tensor(
[[ 2.3457441  -0.07872738 -0.11368337 ...  0.02131678  0.02238739
   0.04105094]
 [-0.07837234  2.576137   -0.12778574 ...  0.02324321  0.02715976
   0.03761083]
 [-0.11375846 -0.12770845  2.8135462  ...  0.02000072  0.0255051
   0.03220554]
 ...
 [ 0.02132007  0.02319163  0.01998054 ...  0.7005969  -0.01072854
  -0.0270703 ]
 [ 0.02241289  0.02717561  0.02547131 ... -0.0106203   0.87094194
  -0.03467852]
 [ 0.04103031  0.03757853  0.03215647 ... -0.02701988 -0.0346096
   0.77158403]], shape=(640, 640), dtype=float32)

The computation still take some times but that make sense since there is a lot of parameters. Is there any way to set those parameters in the constructor or at least when calling _compute_inv_hessian. Or otherwise, to automatically split the computation over the different gradients ?

Additional remarks While doing those experimentations I also noticed a few thing: