Anaxagor / applyBN

4 stars 0 forks source link

Продумывание кейсов по модулю объяснения результатов работы ML моделей с помощью каузальных моделей #2

Open Anaxagor opened 2 months ago

Anaxagor commented 2 months ago

Варианты возможных кейсов: 1) Расчёт важности признаков 2) Проведение A/B тестирования 3) Оценка уверенности предсказания семплов для ML модели на основе каузальных моделей (DataCentric AI + Causal Models) 4) Использование каузальных моделей для прореживания слоёв NN.

jrzkaminski commented 2 months ago
  1. Важность признаков может быть расчитана двумя способами:

    • Простой: обучаем эго-бс с вайтлистом только на те ребра, которые входят в таргет, тут можно применять exhaustive search, потому что пространство поиска маленькое даже для больших датасетов с сотнями узлов.
    • Посложнее: Использовать independence test как инструмент для расчета p-value, на основе этого делать вывод о "силе" связи между узлами как тут
  2. Для A/B тестирования есть много вариаций, включая те, что представлены в DoWhy, должно быть несложно написать удобный инструмент для подсчета ATE, CATE и т.д.

  3. Опции:

jrzkaminski commented 6 days ago
  1. Подмодуль для прунинга сверточных нейронных сетей на основе каузальных сетей. Данный модуль основан на данном исследовании.

Алгоритм, предложенный в статье, основывается на применении принципов причинно-следственного анализа для интерпретации глубоких нейронных сетей, в частности, сверточных нейронных сетей (CNN). Основная идея заключается в том, чтобы построить structural causal model (SCM) для конкретного аспекта CNN, что позволяет делать количественные оценки важности и влияния различных компонентов модели (например, фильтров) на ее производительность.

Шаги алгоритма:

  1. Построение DAG-структуры:

    • CNN уже имеет встроенную структуру в виде направленного ациклического графа (DAG). Структура DAG создается на основе связей между узлами в модели нейронной сети, где каждый узел представляет собой нейрон, группу нейронов или фильтр, а каждое соединение — направленное ребро.
  2. Применение преобразования:

    • Важным аспектом построения SCM является выбор подходящего преобразования (\phi), которое позволит выразить реакцию фильтра в CNN как вещественное число. Примеры таких преобразований включают дискретизацию отклика фильтра в бинарную форму (низкая или высокая дисперсия) или использование нормы Фробениуса для числового представления матрицы отклика фильтра.
  3. Оценка структурных уравнений:

    • На данном этапе необходимо построить уравнения, которые будут описывать причинно-следственные зависимости между узлами в DAG. Для этого решается задача регрессии, где каждому узлу (r) сопоставляется функция (f), которая зависит от его родителей (PA_r) в DAG.
  4. Оценка влияния элемента:

    • Оценка коэффициентов перед откликами нейронов в данном методе отражает их важность. На основе этих коэффициентов и определяется важность фильтров.

Этот подход позволяет не только объяснить важность отдельных компонентов модели, но и ответить на более сложные вопросы, например, как изменение одного фильтра повлияет на конечные результаты модели.

Подобный алгоритм был имплементирован в applybn как для свертоных нейронных сетей, так и для нейронных сетей, которые работают с табличными данными. Для оценки алгоритма на примере свертоных сетей используется датасет CIFAR-10, а для проверки алгоритма на обычной нейронной сети используется датасет Breast Cancer из библиотеки scikit-learn.

Сверточная сеть, результаты

В случае сверточной сети сначала обучается простая сверточная сеть с тремя свертками, затем происходит прунинг фильтров с нарастанием количества убранных фильтров и заметряется Accuracy на каждом этапе прунинга. Результат сравнивается со случайным прунингом такого же количества фильтров. Результат показан на картинке ниже.

image

Простая сеть для табличных данных, результаты Обучается простая сеть с двумя слоями, происходит последовательный прунинг по всем слоям определенного количества нейронов. Результат сравнивается с прунингом такого же количества случайных нейронов.

image

Таким образом можно сделать вывод о том, что такой метод прунинга действительно позволяет сократить количество нейронов в нейронной сети с минимальной потерей качества, что уменьшит время инференса. На следующей стадии экспримента планируется сравнить такой прунинг с этими методами и сделать этот метод прунинга pytorch-совместимым, используя их API.