neurostuff / NiMARE

Coordinate- and image-based meta-analysis in Python
https://nimare.readthedocs.io
MIT License
179 stars 58 forks source link

Recruiting GPU not possible (coordinate-based meta-analysis) #819

Open MaxKorbmacher opened 1 year ago

MaxKorbmacher commented 1 year ago

Summary

When fitting a coordinate-based meta-regression model (Google Collab, with correct GPU settings), GPUs are not recruited appropriately. Instead CPUs are used.

Here is the error message: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_CUDA_mm)

Additional details

What were you trying to do?

Here is the code:

cbmr = CBMREstimator(
    group_categories=["motor"],
    moderators=[
        "moderator",
    ],
    spline_spacing=100,
    model=models.PoissonEstimator,
    penalty=False,
    lr=1e-1,
    tol=1e3, 
    device="cuda",  # "cuda" if you have GPU or "cpu" for CPU
)

What did you expect to happen?

GPU usage

What actually happened?

(mainly) CPU usage + error message Yet, GPU availability confirmed with

!nvidia-smi
import torch
torch.cuda.is_available()
import tensorflow as tf
tf.test.gpu_device_name()

Reproducing the bug

Some code above, rest of the code is here: https://colab.research.google.com/drive/1EVngYoYlryl-YcjmIEOjm0VBBplc-Qyw?usp=sharing

MaxKorbmacher commented 1 year ago

Whole notebook:

{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "## **Setup**"
      ],
      "metadata": {
        "id": "25Zmh-iSX6jz"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "\n",
        "\n",
        "```\n",
        "# IMPORTANT NOTE:\n",
        "# We work with a GPU here, which requires to change the Runtime settings\n",
        "# go to Runtime >> Change runtime type >> select GPU under Hardware accelerator\n",
        "```\n",
        "\n"
      ],
      "metadata": {
        "id": "dkLtvJsb6Wgp"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "fR0PZAqD_moQ",
        "outputId": "f6e247fa-8010-4263-f362-9f15d8e3ffae"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting nimare\n",
            "  Downloading NiMARE-0.1.1-py3-none-any.whl (13.3 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.3/13.3 MB\u001b[0m \u001b[31m85.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting cognitiveatlas (from nimare)\n",
            "  Downloading cognitiveatlas-0.1.9.tar.gz (5.1 kB)\n",
            "  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "Collecting fuzzywuzzy (from nimare)\n",
            "  Downloading fuzzywuzzy-0.18.0-py2.py3-none-any.whl (18 kB)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from nimare) (3.1.2)\n",
            "Requirement already satisfied: joblib in /usr/local/lib/python3.10/dist-packages (from nimare) (1.3.1)\n",
            "Requirement already satisfied: matplotlib>=3.3 in /usr/local/lib/python3.10/dist-packages (from nimare) (3.7.1)\n",
            "Collecting nibabel>=3.2.0 (from nimare)\n",
            "  Downloading nibabel-5.1.0-py3-none-any.whl (3.3 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.3/3.3 MB\u001b[0m \u001b[31m86.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting nilearn>=0.10.1 (from nimare)\n",
            "  Downloading nilearn-0.10.1-py3-none-any.whl (10.3 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.3/10.3 MB\u001b[0m \u001b[31m122.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting numba>=0.57.0 (from nimare)\n",
            "  Downloading numba-0.57.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (3.6 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.6/3.6 MB\u001b[0m \u001b[31m94.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: numpy>=1.21 in /usr/local/lib/python3.10/dist-packages (from nimare) (1.22.4)\n",
            "Collecting pandas>=2.0.0 (from nimare)\n",
            "  Downloading pandas-2.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.3 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m12.3/12.3 MB\u001b[0m \u001b[31m81.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: patsy in /usr/local/lib/python3.10/dist-packages (from nimare) (0.5.3)\n",
            "Requirement already satisfied: plotly in /usr/local/lib/python3.10/dist-packages (from nimare) (5.13.1)\n",
            "Collecting pymare~=0.0.4rc2 (from nimare)\n",
            "  Downloading PyMARE-0.0.4rc2-py3-none-any.whl (36 kB)\n",
            "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from nimare) (6.0)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from nimare) (2.27.1)\n",
            "Requirement already satisfied: scikit-learn>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from nimare) (1.2.2)\n",
            "Requirement already satisfied: scipy>=1.6.0 in /usr/local/lib/python3.10/dist-packages (from nimare) (1.10.1)\n",
            "Collecting sparse>=0.13.0 (from nimare)\n",
            "  Downloading sparse-0.14.0-py2.py3-none-any.whl (80 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m81.0/81.0 kB\u001b[0m \u001b[31m10.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: statsmodels!=0.13.2 in /usr/local/lib/python3.10/dist-packages (from nimare) (0.13.5)\n",
            "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from nimare) (2.0.1+cu118)\n",
            "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from nimare) (4.65.0)\n",
            "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.3->nimare) (1.1.0)\n",
            "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.3->nimare) (0.11.0)\n",
            "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.3->nimare) (4.41.0)\n",
            "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.3->nimare) (1.4.4)\n",
            "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.3->nimare) (23.1)\n",
            "Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.3->nimare) (8.4.0)\n",
            "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.3->nimare) (3.1.0)\n",
            "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.3->nimare) (2.8.2)\n",
            "Requirement already satisfied: lxml in /usr/local/lib/python3.10/dist-packages (from nilearn>=0.10.1->nimare) (4.9.3)\n",
            "Collecting llvmlite<0.41,>=0.40.0dev0 (from numba>=0.57.0->nimare)\n",
            "  Downloading llvmlite-0.40.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (42.1 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m42.1/42.1 MB\u001b[0m \u001b[31m14.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=2.0.0->nimare) (2022.7.1)\n",
            "Collecting tzdata>=2022.1 (from pandas>=2.0.0->nimare)\n",
            "  Downloading tzdata-2023.3-py2.py3-none-any.whl (341 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m341.8/341.8 kB\u001b[0m \u001b[31m38.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from pymare~=0.0.4rc2->nimare) (1.11.1)\n",
            "Requirement already satisfied: wrapt in /usr/local/lib/python3.10/dist-packages (from pymare~=0.0.4rc2->nimare) (1.14.1)\n",
            "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->nimare) (1.26.16)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->nimare) (2023.5.7)\n",
            "Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->nimare) (2.0.12)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->nimare) (3.4)\n",
            "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=1.0.0->nimare) (3.1.0)\n",
            "Requirement already satisfied: six in /usr/local/lib/python3.10/dist-packages (from patsy->nimare) (1.16.0)\n",
            "Requirement already satisfied: future in /usr/local/lib/python3.10/dist-packages (from cognitiveatlas->nimare) (0.18.3)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->nimare) (2.1.3)\n",
            "Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from plotly->nimare) (8.2.2)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch->nimare) (3.12.2)\n",
            "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch->nimare) (4.7.1)\n",
            "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->nimare) (3.1)\n",
            "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch->nimare) (2.0.0)\n",
            "Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch->nimare) (3.25.2)\n",
            "Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch->nimare) (16.0.6)\n",
            "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->pymare~=0.0.4rc2->nimare) (1.3.0)\n",
            "Building wheels for collected packages: cognitiveatlas\n",
            "  Building wheel for cognitiveatlas (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for cognitiveatlas: filename=cognitiveatlas-0.1.9-py3-none-any.whl size=6376 sha256=d38afcb03ae02081f8101428d4b23675517511c3deb14ec2053c311bca668a73\n",
            "  Stored in directory: /root/.cache/pip/wheels/a3/6c/09/eff269417bf07149992261253dd9d11eae5c2a53bb53cdc7ed\n",
            "Successfully built cognitiveatlas\n",
            "Installing collected packages: fuzzywuzzy, tzdata, nibabel, llvmlite, pandas, numba, sparse, pymare, nilearn, cognitiveatlas, nimare\n",
            "  Attempting uninstall: nibabel\n",
            "    Found existing installation: nibabel 3.0.2\n",
            "    Uninstalling nibabel-3.0.2:\n",
            "      Successfully uninstalled nibabel-3.0.2\n",
            "  Attempting uninstall: llvmlite\n",
            "    Found existing installation: llvmlite 0.39.1\n",
            "    Uninstalling llvmlite-0.39.1:\n",
            "      Successfully uninstalled llvmlite-0.39.1\n",
            "  Attempting uninstall: pandas\n",
            "    Found existing installation: pandas 1.5.3\n",
            "    Uninstalling pandas-1.5.3:\n",
            "      Successfully uninstalled pandas-1.5.3\n",
            "  Attempting uninstall: numba\n",
            "    Found existing installation: numba 0.56.4\n",
            "    Uninstalling numba-0.56.4:\n",
            "      Successfully uninstalled numba-0.56.4\n",
            "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
            "google-colab 1.0.0 requires pandas==1.5.3, but you have pandas 2.0.3 which is incompatible.\u001b[0m\u001b[31m\n",
            "\u001b[0mSuccessfully installed cognitiveatlas-0.1.9 fuzzywuzzy-0.18.0 llvmlite-0.40.1 nibabel-5.1.0 nilearn-0.10.1 nimare-0.1.1 numba-0.57.1 pandas-2.0.3 pymare-0.0.4rc2 sparse-0.14.0 tzdata-2023.3\n"
          ]
        }
      ],
      "source": [
        "pip install nimare"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "pip install biopython"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "WCip6as1AYod",
        "outputId": "d17f36a5-8d95-479e-8ea3-bf2cb90cf566"
      },
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting biopython\n",
            "  Downloading biopython-1.81-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.1/3.1 MB\u001b[0m \u001b[31m32.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from biopython) (1.22.4)\n",
            "Installing collected packages: biopython\n",
            "Successfully installed biopython-1.81\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# objective: testing whether meta-regression reanalysis produces same results as meta-analyses on neurosynth\n",
        "\n",
        "############ PREP\n",
        "# dependencies which need prior installation: nimare & biopython\n",
        "\n",
        "import os\n",
        "from pprint import pprint\n",
        "\n",
        "from nimare.extract import download_abstracts, fetch_neuroquery, fetch_neurosynth\n",
        "from nimare.io import convert_neurosynth_to_dataset\n",
        "import Bio\n",
        "\n",
        "########## DOWNLOAD NEUROSYNTH DATA\n",
        "out_dir = os.path.abspath(\"../example_data/\")\n",
        "os.makedirs(out_dir, exist_ok=True)\n",
        "\n",
        "files = fetch_neurosynth(\n",
        "    data_dir=out_dir,\n",
        "    version=\"7\",\n",
        "    overwrite=False,\n",
        "    source=\"abstract\",\n",
        "    vocab=\"terms\",\n",
        ")\n",
        "# Note that the files are saved to a new folder within \"out_dir\" named \"neurosynth\".\n",
        "pprint(files)\n",
        "neurosynth_db = files[0]"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "gJYAYp3oAnx9",
        "outputId": "a4909a23-9900-4893-adb7-8e3306e24220"
      },
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Downloading data-neurosynth_version-7_coordinates.tsv.gz\n",
            "Downloading data-neurosynth_version-7_metadata.tsv.gz\n",
            "Downloading data-neurosynth_version-7_vocab-terms_source-abstract_type-tfidf_features.npz\n",
            "Downloading data-neurosynth_version-7_vocab-terms_vocabulary.txt\n",
            "[{'coordinates': '/example_data/neurosynth/data-neurosynth_version-7_coordinates.tsv.gz',\n",
            "  'features': [{'features': '/example_data/neurosynth/data-neurosynth_version-7_vocab-terms_source-abstract_type-tfidf_features.npz',\n",
            "                'vocabulary': '/example_data/neurosynth/data-neurosynth_version-7_vocab-terms_vocabulary.txt'}],\n",
            "  'metadata': '/example_data/neurosynth/data-neurosynth_version-7_metadata.tsv.gz'}]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Convert Neurosynth database to NiMARE dataset file\n",
        "neurosynth_dset = convert_neurosynth_to_dataset(\n",
        "    coordinates_file=neurosynth_db[\"coordinates\"],\n",
        "    metadata_file=neurosynth_db[\"metadata\"],\n",
        "    annotations_files=neurosynth_db[\"features\"],\n",
        ")\n",
        "neurosynth_dset.save(os.path.join(out_dir, \"neurosynth_dataset.pkl.gz\"))\n",
        "print(neurosynth_dset)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "QZv6Nzk1BN9e",
        "outputId": "c2ddcdfe-7068-4335-81e8-73c5bbecb227"
      },
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "WARNING:nimare.utils:Not applying transforms to coordinates in unrecognized space 'UNKNOWN'\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Dataset(14371 experiments, space='mni152_2mm')\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Add article abstracts to dataset\n",
        "neurosynth_dset = download_abstracts(neurosynth_dset, \"example@example.edu\")\n",
        "neurosynth_dset.save(os.path.join(out_dir, \"neurosynth_dataset_with_abstracts.pkl.gz\"))"
      ],
      "metadata": {
        "id": "NDB2SjoaBQP-"
      },
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# # FOR NOW, THIS STEP OF **LOADING NEUROQUERY DATA** IS BEING **SKIPPED**!!\n",
        "# # (too large RAM requirements for Google Colab)\n",
        "#\n",
        "#\n",
        "# # Do the same with NeuroQuery\n",
        "# # NeuroQuery’s data files are stored at https://github.com/neuroquery/neuroquery_data.\n",
        "# files = fetch_neuroquery(\n",
        "#     data_dir=out_dir,\n",
        "#     version=\"1\",\n",
        "#     overwrite=False,\n",
        "#     source=\"combined\",\n",
        "#     vocab=\"neuroquery6308\",\n",
        "#     type=\"tfidf\",\n",
        "# )\n",
        "# # Note that the files are saved to a new folder within \"out_dir\" named \"neuroquery\".\n",
        "# pprint(files)\n",
        "# neuroquery_db = files[0]\n",
        "\n",
        "# # Note that the conversion function says \"neurosynth\".\n",
        "# # This is just for backwards compatibility.\n",
        "# neuroquery_dset = convert_neurosynth_to_dataset(\n",
        "#     coordinates_file=neuroquery_db[\"coordinates\"],\n",
        "#     metadata_file=neuroquery_db[\"metadata\"],\n",
        "#     annotations_files=neuroquery_db[\"features\"],\n",
        "# )\n",
        "# neuroquery_dset.save(os.path.join(out_dir, \"neuroquery_dataset.pkl.gz\"))\n",
        "# print(neuroquery_dset)\n",
        "\n",
        "# # NeuroQuery also uses PMIDs as study IDs.\n",
        "# neuroquery_dset = download_abstracts(neuroquery_dset, \"example@example.edu\")\n",
        "# neuroquery_dset.save(os.path.join(out_dir, \"neuroquery_dataset_with_abstracts.pkl.gz\"))\n",
        "#\n",
        "# ######################################"
      ],
      "metadata": {
        "id": "MmK0B7NmBUcb"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## **Test 1: Motor Studies**"
      ],
      "metadata": {
        "id": "sZbbaTdaYLqL"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Label data and test meta-regression on subsets\n",
        "\n",
        "# motor (2565 studies)\n",
        "# compare results to https://neurosynth.org/analyses/terms/motor/\n",
        "\n",
        "# identify which studies examine motor functions\n",
        "motor_study_id = neurosynth_dset.get_studies_by_label(labels=[\"terms_abstract_tfidf__motor\"])\n",
        "# and those which do not (rest of the studies)\n",
        "nonmotor_study_id = list(set(neurosynth_dset.ids) - set(motor_study_id))\n",
        "# create a subset only containing motor studies\n",
        "motor_dset = neurosynth_dset.slice(ids=motor_study_id)\n",
        "# and a subset containing no motor studies\n",
        "nonmotor_dset = neurosynth_dset.slice(ids=nonmotor_study_id)\n",
        "\n",
        "# we can now create a dummy variable indicating which study involved motor tasks and which did not.\n",
        "neurosynth_dset.annotations[\"motor\"] = 'False'\n",
        "neurosynth_dset.annotations.loc[neurosynth_dset.annotations['id'].isin(motor_study_id), 'motor'] = 'True'\n",
        "# and a dummy indicating the absence of moderators (all values = 1)\n",
        "neurosynth_dset.annotations[\"moderator\"] = 1"
      ],
      "metadata": {
        "id": "11-DPKlYPYSG"
      },
      "execution_count": 8,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!nvidia-smi\n",
        "import torch\n",
        "torch.cuda.is_available()\n",
        "import tensorflow as tf\n",
        "tf.test.gpu_device_name()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "id": "tD1g74HT-w3L",
        "outputId": "e0e32c73-f19e-4d66-ab6c-8ff41335b066"
      },
      "execution_count": 16,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'/device:GPU:0'"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            }
          },
          "metadata": {},
          "execution_count": 16
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# set up the model\n",
        "import numpy as np\n",
        "import scipy\n",
        "from nilearn.plotting import plot_stat_map\n",
        "\n",
        "from nimare.generate import create_coordinate_dataset\n",
        "from nimare.meta import models\n",
        "from nimare.transforms import StandardizeField\n",
        "from nimare.meta.cbmr import CBMREstimator\n",
        "\n",
        "cbmr = CBMREstimator(\n",
        "    group_categories=[\"motor\"],\n",
        "    moderators=[\n",
        "        \"moderator\",\n",
        "    ],\n",
        "    spline_spacing=100,  # a reasonable choice is 10 or 5, 100 is for speed\n",
        "    model=models.PoissonEstimator,\n",
        "    penalty=False,\n",
        "    lr=1e-1,\n",
        "    tol=1e3,   # a reasonable choice is 1e-2, 1e3 is for speed\n",
        "    device=\"cuda\",  # \"cuda\" if you have GPU or \"cpu\" for CPU\n",
        ")\n",
        "results = cbmr.fit(dataset=neurosynth_dset)\n",
        "\n",
        "#############################"
      ],
      "metadata": {
        "id": "A0PIvYxy1QKh",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 417
        },
        "outputId": "064bd828-6461-42d9-f8dd-cc75cedc5ec0"
      },
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "error",
          "ename": "RuntimeError",
          "evalue": "ignored",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
            "\u001b[0;32m<ipython-input-9-ba45c24639e9>\u001b[0m in \u001b[0;36m<cell line: 23>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     21\u001b[0m     \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"cuda\"\u001b[0m\u001b[0;34m,\u001b[0m  \u001b[0;31m# \"cuda\" if you have GPU or \"cpu\" for CPU\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     22\u001b[0m )\n\u001b[0;32m---> 23\u001b[0;31m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcbmr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mneurosynth_dset\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     25\u001b[0m \u001b[0;31m#############################\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/nimare/estimator.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, dataset, drop_invalid)\u001b[0m\n\u001b[1;32m    123\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_collect_inputs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdrop_invalid\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdrop_invalid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    124\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_preprocess_input\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 125\u001b[0;31m         \u001b[0mmaps\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtables\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdescription\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    126\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    127\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"masker\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmasker\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/nimare/meta/cbmr.py\u001b[0m in \u001b[0;36m_fit\u001b[0;34m(self, dataset)\u001b[0m\n\u001b[1;32m    395\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    396\u001b[0m         \u001b[0mmoderators_by_group\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minputs_\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"moderators_by_group\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmoderators\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 397\u001b[0;31m         self.model.fit(\n\u001b[0m\u001b[1;32m    398\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minputs_\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"coef_spline_bases\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    399\u001b[0m             \u001b[0mmoderators_by_group\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/nimare/meta/models.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, coef_spline_bases, moderators_by_group, foci_per_voxel, foci_per_study)\u001b[0m\n\u001b[1;32m    295\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcoef_spline_bases\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmoderators_by_group\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfoci_per_voxel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfoci_per_study\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    296\u001b[0m         \u001b[0;34m\"\"\"Fit the model and estimate standard error of estimates.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 297\u001b[0;31m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_optimizer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcoef_spline_bases\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmoderators_by_group\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfoci_per_voxel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfoci_per_study\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    298\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mextract_optimized_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcoef_spline_bases\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmoderators_by_group\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    299\u001b[0m         self.standard_error_estimation(\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/nimare/meta/models.py\u001b[0m in \u001b[0;36m_optimizer\u001b[0;34m(self, coef_spline_bases, moderators_by_group, foci_per_voxel, foci_per_study)\u001b[0m\n\u001b[1;32m    277\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    278\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn_iter\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 279\u001b[0;31m             loss = self._update(\n\u001b[0m\u001b[1;32m    280\u001b[0m                 \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    281\u001b[0m                 \u001b[0mcoef_spline_bases\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/nimare/meta/models.py\u001b[0m in \u001b[0;36m_update\u001b[0;34m(self, optimizer, coef_spline_bases, moderators, foci_per_voxel, foci_per_study, prev_loss)\u001b[0m\n\u001b[1;32m    188\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    189\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 190\u001b[0;31m         \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclosure\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    191\u001b[0m         \u001b[0mscheduler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    192\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misnan\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/optim/lr_scheduler.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     67\u001b[0m                 \u001b[0minstance\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_step_count\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     68\u001b[0m                 \u001b[0mwrapped\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__get__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minstance\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcls\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 69\u001b[0;31m                 \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     70\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     71\u001b[0m             \u001b[0;31m# Note that the returned function here is no longer a bound method,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/optim/optimizer.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    278\u001b[0m                                                f\"but got {result}.\")\n\u001b[1;32m    279\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 280\u001b[0;31m                 \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    281\u001b[0m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_optimizer_step_code\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    282\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py\u001b[0m in \u001b[0;36mdecorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    113\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    114\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0mctx_factory\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 115\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    116\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    117\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/optim/lbfgs.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m    310\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    311\u001b[0m         \u001b[0;31m# evaluate initial f(x) and df/dx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 312\u001b[0;31m         \u001b[0morig_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    313\u001b[0m         \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0morig_loss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    314\u001b[0m         \u001b[0mcurrent_evals\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py\u001b[0m in \u001b[0;36mdecorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    113\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    114\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0mctx_factory\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 115\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    116\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    117\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/nimare/meta/models.py\u001b[0m in \u001b[0;36mclosure\u001b[0;34m()\u001b[0m\n\u001b[1;32m    184\u001b[0m         \u001b[0;32mdef\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    185\u001b[0m             \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 186\u001b[0;31m             \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcoef_spline_bases\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmoderators\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfoci_per_voxel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfoci_per_study\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    187\u001b[0m             \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    188\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1499\u001b[0m                 \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1500\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1502\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1503\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/nimare/meta/models.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, coef_spline_bases, moderators, foci_per_voxel, foci_per_study)\u001b[0m\n\u001b[1;32m    866\u001b[0m             \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    867\u001b[0m                 \u001b[0mmoderators_coef\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgroup_moderators\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 868\u001b[0;31m             group_log_l = self._log_likelihood_single_group(\n\u001b[0m\u001b[1;32m    869\u001b[0m                 \u001b[0mgroup_spatial_coef\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    870\u001b[0m                 \u001b[0mmoderators_coef\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/nimare/meta/models.py\u001b[0m in \u001b[0;36m_log_likelihood_single_group\u001b[0;34m(self, group_spatial_coef, moderators_coef, coef_spline_bases, group_moderators, group_foci_per_voxel, group_foci_per_study, device)\u001b[0m\n\u001b[1;32m    782\u001b[0m         \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"cpu\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    783\u001b[0m     ):\n\u001b[0;32m--> 784\u001b[0;31m         \u001b[0mlog_mu_spatial\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcoef_spline_bases\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgroup_spatial_coef\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    785\u001b[0m         \u001b[0mmu_spatial\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlog_mu_spatial\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    786\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mmoderators_coef\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mRuntimeError\u001b[0m: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_CUDA_mm)"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "plot_stat_map(\n",
        "    results.get_map(\"spatialIntensity_group-True\"),\n",
        "    cut_coords=[0, 0, -8],\n",
        "    draw_cross=False,\n",
        "    cmap=\"RdBu_r\",\n",
        "    title=\"Motor\",\n",
        "    threshold=0.0005, # can be adapted for visualisation purpose\n",
        "    vmax=1e-3,\n",
        ")\n",
        "plot_stat_map(\n",
        "    results.get_map(\"spatialIntensity_group-False\"),\n",
        "    cut_coords=[0, 0, -8],\n",
        "    draw_cross=False,\n",
        "    cmap=\"RdBu_r\",\n",
        "    title=\"Non-Motor\",\n",
        "    threshold=0.0005, # can be adapted for visualisation purpose\n",
        "    vmax=1e-3,\n",
        ")"
      ],
      "metadata": {
        "id": "c-nlS1Fqbtmk"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from nimare.meta.cbmr import CBMRInference\n",
        "\n",
        "inference = CBMRInference(device=\"cuda\")\n",
        "inference.fit(result=results)\n",
        "t_con_groups = inference.create_contrast(\n",
        "    [\"True\", \"False\"], source=\"motor\"\n",
        ")\n",
        "contrast_result = inference.transform(t_con_groups=t_con_groups)"
      ],
      "metadata": {
        "id": "VsOJtFFDcraI"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Test 2: Language Studies"
      ],
      "metadata": {
        "id": "hMpIxVOlYylw"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# # test 2: language (1101 studies)\n",
        "# # compare results to https://neurosynth.org/analyses/terms/language/\n",
        "\n",
        "# # identify which studies examine langage (functions)\n",
        "# language_study_id = neurosynth_dset.get_studies_by_label(labels=[\"terms_abstract_tfidf__language\"])\n",
        "# # and those which do not (rest of the studies)\n",
        "# nonlanguage_study_id = list(set(neurosynth_dset.ids) - set(language_study_id))\n",
        "# # create a subset only containing langage studies\n",
        "# language_dset = neurosynth_dset.slice(ids=language_study_id)\n",
        "# # and a subset containing no langage studies\n",
        "# nonlanguage_dset = neurosynth_dset.slice(ids=nonlanguage_study_id)\n",
        "\n",
        "# # we can now create dummy variables indicating which study involved language tasks and which did not.\n",
        "# neurosynth_dset.annotations[\"language\"] = 'False'\n",
        "# neurosynth_dset.annotations.loc[neurosynth_dset.annotations['id'].isin(language_study_id), 'language'] = 'True'"
      ],
      "metadata": {
        "id": "XbBr-2hyZKiF"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# # test 2: language (1101 studies)\n",
        "# # compare results to https://neurosynth.org/analyses/terms/language/\n",
        "# language_study_id = neurosynth_dset.get_studies_by_label(labels=[\"terms_abstract_tfidf__language\"])\n",
        "# nonlanguage_study_id = list(set(neurosynth_dset.ids) - set(language_study_id))\n",
        "# language_dset = neurosynth_dset.slice(ids=language_study_id)\n",
        "# nonlanguage_dset = neurosynth_dset.slice(ids=nonlanguage_study_id)"
      ],
      "metadata": {
        "id": "RUfSGzeGCll8"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# # test 3: addiction (135 studies)\n",
        "# # compare results to https://neurosynth.org/analyses/terms/addiction/\n",
        "# addiction_study_id = neurosynth_dset.get_studies_by_label(labels=[\"terms_abstract_tfidf__addiction\"])\n",
        "# nonaddiction_study_id = list(set(neurosynth_dset.ids) - set(addiction_study_id))\n",
        "# addiction_dset = neurosynth_dset.slice(ids=addiction_study_id)\n",
        "# nonaddiction_dset = neurosynth_dset.slice(ids=nonaddiction_study_id)"
      ],
      "metadata": {
        "id": "HxWjNlhhQF7G"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}