fusion-energy / openmc_geometry_plot

Create axis slice plots of OpenMC geomtry with specified zoom
6 stars 0 forks source link

plotly openmc.plot method #24

Open shimwell opened 7 months ago

shimwell commented 7 months ago

adapting the openmc plot method to allow plotly

todo normalise the color scale so that it goes from 0 to 1 and the with spaces that match the id number gaps

import openmc
import numpy as np
import warnings
import math
import typing
from tempfile import TemporaryDirectory
from pathlib import Path
from PIL import Image

def get_rgb_from_int(value: int) -> typing.Tuple[int, int, int]:
    blue = value & 255
    green = (value >> 8) & 255
    red = (value >> 16) & 255
    return red, green, blue

def get_int_from_rgb(rgb: typing.Tuple[int, int, int]) -> int:
    red = rgb[0]
    green = rgb[1]
    blue = rgb[2]
    return (red << 16) + (green << 8) + blue

def discrete_colorscale(bvals, colors):
    """
    bvals - list of values bounding intervals/ranges of interest
    colors - list of rgb or hex colorcodes for values in [bvals[k], bvals[k+1]],0<=k < len(bvals)-1
    returns the plotly  discrete colorscale
    """
    print('bvals',bvals)
    print('colors',colors)
    if len(bvals) != len(colors)+1:
        raise ValueError('len(boundary values) should be equal to  len(colors)+1')
    bvals = sorted(bvals)     
    nvals = [(v-bvals[0])/(bvals[-1]-bvals[0]) for v in bvals]  #normalized values

    dcolorscale = [] #discrete colorscale
    for k in range(len(colors)):
        dcolorscale.extend([[nvals[k], colors[k]], [nvals[k+1], colors[k]]])
    return dcolorscale

def plot(
    self,
    origin=None,
    width=None,
    pixels=40000,
    basis='xy',
    color_by='cell',
    colors=None,
    seed=None,
    openmc_exec='openmc',
    axes=None,
    legend=False,
    axis_units='cm',
    # legend_kwargs=_default_legend_kwargs,
    outline=False,
    **kwargs
):
        """Display a slice plot of the universe.

        Parameters
        ----------
        origin : iterable of float
            Coordinates at the origin of the plot. If left as None,
            universe.bounding_box.center will be used to attempt to ascertain
            the origin with infinite values being replaced by 0.
        width : iterable of float
            Width of the plot in each basis direction. If left as none then the
            universe.bounding_box.width() will be used to attempt to
            ascertain the plot width.  Defaults to (10, 10) if the bounding_box
            contains inf values
        pixels : Iterable of int or int
            If iterable of ints provided then this directly sets the number of
            pixels to use in each basis direction. If int provided then this
            sets the total number of pixels in the plot and the number of
            pixels in each basis direction is calculated from this total and
            the image aspect ratio.
        basis : {'xy', 'xz', 'yz'}
            The basis directions for the plot
        color_by : {'cell', 'material'}
            Indicate whether the plot should be colored by cell or by material
        colors : dict
            Assigns colors to specific materials or cells. Keys are instances of
            :class:`Cell` or :class:`Material` and values are RGB 3-tuples, RGBA
            4-tuples, or strings indicating SVG color names. Red, green, blue,
            and alpha should all be floats in the range [0.0, 1.0], for example:

            .. code-block:: python

               # Make water blue
               water = openmc.Cell(fill=h2o)
               universe.plot(..., colors={water: (0., 0., 1.))
        seed : int
            Seed for the random number generator
        openmc_exec : str
            Path to OpenMC executable.
        axes : matplotlib.Axes
            Axes to draw to

            .. versionadded:: 0.13.1
        legend : bool
            Whether a legend showing material or cell names should be drawn

            .. versionadded:: 0.14.0
        legend_kwargs : dict
            Keyword arguments passed to :func:`matplotlib.pyplot.legend`.

            .. versionadded:: 0.14.0
        outline : bool
            Whether outlines between color boundaries should be drawn

            .. versionadded:: 0.14.0
        axis_units : {'km', 'm', 'cm', 'mm'}
            Units used on the plot axis

            .. versionadded:: 0.14.0
        **kwargs
            Keyword arguments passed to :func:`matplotlib.pyplot.imshow`

        Returns
        -------
        matplotlib.axes.Axes
            Axes containing resulting image

        """
        import matplotlib.image as mpimg
        # import matplotlib.patches as mpatches
        import matplotlib.pyplot as plt
        import plotly.graph_objects as go

        # Determine extents of plot
        if basis == 'xy':
            x, y = 0, 1
            xlabel, ylabel = f'x [{axis_units}]', f'y [{axis_units}]'
        elif basis == 'yz':
            x, y = 1, 2
            xlabel, ylabel = f'y [{axis_units}]', f'z [{axis_units}]'
        elif basis == 'xz':
            x, y = 0, 2
            xlabel, ylabel = f'x [{axis_units}]', f'z [{axis_units}]'

        bb = self.bounding_box
        # checks to see if bounding box contains -inf or inf values
        if np.isinf(bb.extent[basis]).any():
            if origin is None:
                origin = (0, 0, 0)
            if width is None:
                width = (10, 10)
        else:
            if origin is None:
                # if nan values in the bb.center they get replaced with 0.0
                # this happens when the bounding_box contains inf values
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore", RuntimeWarning)
                    origin = np.nan_to_num(bb.center)
            if width is None:
                bb_width = bb.width
                x_width = bb_width['xyz'.index(basis[0])]
                y_width = bb_width['xyz'.index(basis[1])]
                width = (x_width, y_width)

        if isinstance(pixels, int):
            aspect_ratio = width[0] / width[1]
            pixels_y = math.sqrt(pixels / aspect_ratio)
            pixels = (int(pixels / pixels_y), int(pixels_y))

        axis_scaling_factor = {'km': 0.00001, 'm': 0.01, 'cm': 1, 'mm': 10}

        x_min = (origin[x] - 0.5*width[0]) * axis_scaling_factor[axis_units]
        x_max = (origin[x] + 0.5*width[0]) * axis_scaling_factor[axis_units]
        y_min = (origin[y] - 0.5*width[1]) * axis_scaling_factor[axis_units]
        y_max = (origin[y] + 0.5*width[1]) * axis_scaling_factor[axis_units]

        # with TemporaryDirectory() as tmpdir:
        if 1==1:
            model = openmc.Model()
            model.geometry = openmc.Geometry(self)
            if seed is not None:
                model.settings.plot_seed = seed

            # Determine whether any materials contains macroscopic data and if
            # so, set energy mode accordingly
            for mat in self.get_all_materials().values():
                if mat._macroscopic is not None:
                    model.settings.energy_mode = 'multi-group'
                    break

            # Create plot object matching passed arguments
            plot = openmc.Plot()
            plot.origin = origin
            plot.width = width
            plot.pixels = pixels
            plot.basis = basis
            plot.color_by = color_by

            if colors is not None:

                colors_based_on_ids = {}
                for key, value in colors.items():
                    colors_based_on_ids[key] = get_rgb_from_int(key.id)
                plot.colors = colors_based_on_ids

            model.plots.append(plot)

            # Run OpenMC in geometry plotting mode
            model.plot_geometry(False, openmc_exec=openmc_exec)
            # model.plot_geometry(False, cwd=tmpdir, openmc_exec=openmc_exec)

            # Read image from file
            # img_path = Path(tmpdir) / f'plot_{plot.id}.png'
            img_path = Path(f'plot_{plot.id}.png')
            if not img_path.is_file():
                img_path = img_path.with_suffix('.ppm')
            img = mpimg.imread(str(img_path))
            image_values = Image.open(img_path)
            print('img.shape,',img)
            image_values = np.asarray(image_values)
            image_values = [
                [get_int_from_rgb(inner_entry) for inner_entry in outer_entry]
                for outer_entry in image_values
            ]

            image_values = np.array(image_values)
            image_values[image_values == 16777215] = 0
            # image_values[:] = [x if x != 16777215 else 0 for x in image_values]

            # Create a figure sized such that the size of the axes within
            # exactly matches the number of pixels specified
            # if axes is None:
            #     px = 1/plt.rcParams['figure.dpi']
            #     fig, axes = plt.subplots()
            #     axes.set_xlabel(xlabel)
            #     axes.set_ylabel(ylabel)
            #     params = fig.subplotpars
            #     width = pixels[0]*px/(params.right - params.left)
            #     height = pixels[1]*px/(params.top - params.bottom)
            #     fig.set_size_inches(width, height)

            data = []

            # if outline:
            #     # Combine R, G, B values into a single int
            #     rgb = (img * 256).astype(int)
            #     image_value = (rgb[..., 0] << 16) + \
            #         (rgb[..., 1] << 8) + (rgb[..., 2])

                # axes.contour(
                #     image_value,
                #     origin="upper",
                #     colors="k",
                #     linestyles="solid",
                #     linewidths=1,
                #     levels=np.unique(image_value),
                #     extent=(x_min, x_max, y_min, y_max),
                # )

                # data.append(
                #     go.Contour(
                #         z=image_value,
                #         contours_coloring='none',
                #         # colorscale=dcolorsc,
                #         showscale=False,
                #         x0=x_min,
                #         dx=abs(x_min - x_max) / (img.shape[0] - 1),
                #         y0=y_min,
                #         dy=abs(y_min - y_max) / (img.shape[1] - 1),
                #     )
                # )

            # add legend showing which colors represent which material
            # or cell if that was requested
            if legend:
                if plot.colors == {}:
                    raise ValueError("Must pass 'colors' dictionary if you "
                                     "are adding a legend via legend=True.")

                if color_by == "cell":
                    expected_key_type = openmc.Cell
                else:
                    expected_key_type = openmc.Material

                # patches = []
                for key, color in plot.colors.items():

                    if isinstance(key, int):
                        raise TypeError(
                            "Cannot use IDs in colors dict for auto legend.")
                    elif not isinstance(key, expected_key_type):
                        raise TypeError(
                            "Color dict key type does not match color_by")

                    # this works whether we're doing cells or materials
                    label = key.name if key.name != '' else key.id

                    # matplotlib takes RGB on 0-1 scale rather than 0-255. at
                    # this point PlotBase has already checked that 3-tuple
                    # based colors are already valid, so if the length is three
                    # then we know it just needs to be converted to the 0-1
                    # format.
                    if len(color) == 3 and not isinstance(color, str):
                        scaled_color = (
                            color[0]/255, color[1]/255, color[2]/255)
                    else:
                        scaled_color = color

                    # key_patch = mpatches.Patch(color=scaled_color, label=label)
                    # patches.append(key_patch)

                # axes.legend(handles=patches, **legend_kwargs)

            # Plot image and return the axes
            # axes.imshow(img, extent=(x_min, x_max, y_min, y_max), **kwargs)           
            print(image_values)
            # print(image_values.shape)

            # rgb = (img * 256).astype(int)
            # image_value = (rgb[..., 0] << 16) + \
            #     (rgb[..., 1] << 8) + (rgb[..., 2])

            list_of_unique_image_value= []
            list_of_unique_colors = []
            # colors=[]
            # for i, x in enumerate(image_value):
            #     for j, y in enumerate(x):
            #         if y not in list_of_unique_image_value:
            #             list_of_unique_image_value.append(y)
            #             c= img[i][j]
            #             colors.append(f'rgb({int(c[0]*255)}, {int(c[1]*255)}, {int(c[2]*255)})')
            print('list_of_unique_image_value', list_of_unique_image_value)
            print('list_of_unique_colors',list_of_unique_colors)

            # for val, col in zip(list_of_unique_image_value, list_of_unique_colors):
            #     colors.append(f'rgb({int(col[0]*255)}, {int(col[1]*255)}, {int(col[2]*255)})')
                # colorscale.append([val, f'rgb({int(col[0]*255)}, {int(col[1]*255)}, {int(col[2]*255)})'])

            # Z = [colors for _,colors in sorted(zip(list_of_unique_image_value,colors))]
            # print(Z)

            dcolorsc=discrete_colorscale([mat.id for mat in colors.keys()], [f'rgb({c[0]},{c[1]},{c[2]})' for c in list(colors.values())])
            # dcolorsc=discrete_colorscale(list_of_unique_image_value, colors)
            # dcolorsc=discrete_colorscale(list_of_unique_image_value+[max(list_of_unique_image_value)+1], colors)
            # dcolorsc=[
            #     [0, 'green'],
            #     [1, 'red'],
            #     [2, 'blue'],
            # ]
            # [0.1, 'green'],
            # [0.1, 'rgb(253, 237, 176)'],
            # [0.2, 'rgb(249, 198, 139)'],
            # [0.3, 'rgb(244, 159, 109)'],
            # [0.4, 'rgb(234, 120, 88)'],
            # [0.5, 'rgb(218, 83, 82)'],
            # [0.6, 'rgb(191, 54, 91)'],
            # [0.7, 'rgb(158, 35, 98)'],
            # [0.8, 'rgb(120, 26, 97)'],
            # [0.9, 'rgb(83, 22, 84)'],
            # [1.0, 'rgb(47, 15, 61)']]
            print('colors', colors)
            print('dcolorsc', dcolorsc)
            print(image_values)
            data.append(
                go.Heatmap(
                    z=image_values,
                    # showscale=True,
                    colorscale=dcolorsc,
                    x0=x_min,
                    dx=abs(x_min - x_max) / (img.shape[0] - 1),
                    y0=y_min,
                    dy=abs(y_min - y_max) / (img.shape[1] - 1),
                )
            )
            plot = go.Figure(data=data)

            plot.update_layout(
                xaxis={"title": xlabel},
                # reversed autorange is required to avoid image needing rotation/flipping in plotly
                yaxis={"title": ylabel, "autorange": "reversed"},
                # title=title,
                autosize=False,
                height=800,
            )
            plot.update_yaxes(
                scaleanchor="x",
                scaleratio=1,
            )
            return plot

openmc.Universe.plot = plot

surf = openmc.Sphere(r=10)
surf2 = openmc.ZCylinder(r=3)
mat1=openmc.Material()
mat1.add_nuclide('Li6',1)
mat1.set_density('g/cm3',1)
mat2=openmc.Material()
mat2.add_nuclide('Li6',1)
mat2.set_density('g/cm3',1)
cell = openmc.Cell(region=-surf & -surf2, cell_id=10, fill=mat1)
cell2 = openmc.Cell(region=-surf & +surf2, cell_id=100, fill=mat2)
geometry = openmc.Geometry([cell, cell2])
plot = geometry.plot(
    outline=True,
    basis='xz',
    pixels=1000,
    colors={mat1:[200,0,0], mat2:[0,200,0]},
    color_by='material'
)
plot.show()
shimwell commented 7 months ago
import openmc
import numpy as np
import warnings
import math
from tempfile import TemporaryDirectory
from pathlib import Path

def discrete_colorscale(bvals, colors):
    """
    bvals - list of values bounding intervals/ranges of interest
    colors - list of rgb or hex colorcodes for values in [bvals[k], bvals[k+1]],0<=k < len(bvals)-1
    returns the plotly  discrete colorscale
    """
    if len(bvals) != len(colors)+1:
        raise ValueError('len(boundary values) should be equal to  len(colors)+1')
    bvals = sorted(bvals)     
    nvals = [(v-bvals[0])/(bvals[-1]-bvals[0]) for v in bvals]  #normalized values

    dcolorscale = [] #discrete colorscale
    for k in range(len(colors)):
        dcolorscale.extend([[nvals[k], colors[k]], [nvals[k+1], colors[k]]])
    return dcolorscale

def plot(
    self,
    origin=None,
    width=None,
    pixels=40000,
    basis='xy',
    color_by='cell',
    colors=None,
    seed=None,
    openmc_exec='openmc',
    axes=None,
    legend=False,
    axis_units='cm',
    # legend_kwargs=_default_legend_kwargs,
    outline=False,
    **kwargs
):
        """Display a slice plot of the universe.

        Parameters
        ----------
        origin : iterable of float
            Coordinates at the origin of the plot. If left as None,
            universe.bounding_box.center will be used to attempt to ascertain
            the origin with infinite values being replaced by 0.
        width : iterable of float
            Width of the plot in each basis direction. If left as none then the
            universe.bounding_box.width() will be used to attempt to
            ascertain the plot width.  Defaults to (10, 10) if the bounding_box
            contains inf values
        pixels : Iterable of int or int
            If iterable of ints provided then this directly sets the number of
            pixels to use in each basis direction. If int provided then this
            sets the total number of pixels in the plot and the number of
            pixels in each basis direction is calculated from this total and
            the image aspect ratio.
        basis : {'xy', 'xz', 'yz'}
            The basis directions for the plot
        color_by : {'cell', 'material'}
            Indicate whether the plot should be colored by cell or by material
        colors : dict
            Assigns colors to specific materials or cells. Keys are instances of
            :class:`Cell` or :class:`Material` and values are RGB 3-tuples, RGBA
            4-tuples, or strings indicating SVG color names. Red, green, blue,
            and alpha should all be floats in the range [0.0, 1.0], for example:

            .. code-block:: python

               # Make water blue
               water = openmc.Cell(fill=h2o)
               universe.plot(..., colors={water: (0., 0., 1.))
        seed : int
            Seed for the random number generator
        openmc_exec : str
            Path to OpenMC executable.
        axes : matplotlib.Axes
            Axes to draw to

            .. versionadded:: 0.13.1
        legend : bool
            Whether a legend showing material or cell names should be drawn

            .. versionadded:: 0.14.0
        legend_kwargs : dict
            Keyword arguments passed to :func:`matplotlib.pyplot.legend`.

            .. versionadded:: 0.14.0
        outline : bool
            Whether outlines between color boundaries should be drawn

            .. versionadded:: 0.14.0
        axis_units : {'km', 'm', 'cm', 'mm'}
            Units used on the plot axis

            .. versionadded:: 0.14.0
        **kwargs
            Keyword arguments passed to :func:`matplotlib.pyplot.imshow`

        Returns
        -------
        matplotlib.axes.Axes
            Axes containing resulting image

        """
        import matplotlib.image as mpimg
        # import matplotlib.patches as mpatches
        import matplotlib.pyplot as plt
        import plotly.graph_objects as go

        # Determine extents of plot
        if basis == 'xy':
            x, y = 0, 1
            xlabel, ylabel = f'x [{axis_units}]', f'y [{axis_units}]'
        elif basis == 'yz':
            x, y = 1, 2
            xlabel, ylabel = f'y [{axis_units}]', f'z [{axis_units}]'
        elif basis == 'xz':
            x, y = 0, 2
            xlabel, ylabel = f'x [{axis_units}]', f'z [{axis_units}]'

        bb = self.bounding_box
        # checks to see if bounding box contains -inf or inf values
        if np.isinf(bb.extent[basis]).any():
            if origin is None:
                origin = (0, 0, 0)
            if width is None:
                width = (10, 10)
        else:
            if origin is None:
                # if nan values in the bb.center they get replaced with 0.0
                # this happens when the bounding_box contains inf values
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore", RuntimeWarning)
                    origin = np.nan_to_num(bb.center)
            if width is None:
                bb_width = bb.width
                x_width = bb_width['xyz'.index(basis[0])]
                y_width = bb_width['xyz'.index(basis[1])]
                width = (x_width, y_width)

        if isinstance(pixels, int):
            aspect_ratio = width[0] / width[1]
            pixels_y = math.sqrt(pixels / aspect_ratio)
            pixels = (int(pixels / pixels_y), int(pixels_y))

        axis_scaling_factor = {'km': 0.00001, 'm': 0.01, 'cm': 1, 'mm': 10}

        x_min = (origin[x] - 0.5*width[0]) * axis_scaling_factor[axis_units]
        x_max = (origin[x] + 0.5*width[0]) * axis_scaling_factor[axis_units]
        y_min = (origin[y] - 0.5*width[1]) * axis_scaling_factor[axis_units]
        y_max = (origin[y] + 0.5*width[1]) * axis_scaling_factor[axis_units]

        # with TemporaryDirectory() as tmpdir:
        if 1==1:
            model = openmc.Model()
            model.geometry = openmc.Geometry(self)
            if seed is not None:
                model.settings.plot_seed = seed

            # Determine whether any materials contains macroscopic data and if
            # so, set energy mode accordingly
            for mat in self.get_all_materials().values():
                if mat._macroscopic is not None:
                    model.settings.energy_mode = 'multi-group'
                    break

            # Create plot object matching passed arguments
            plot = openmc.Plot()
            plot.origin = origin
            plot.width = width
            plot.pixels = pixels
            plot.basis = basis
            plot.color_by = color_by
            if colors is not None:
                plot.colors = colors
            model.plots.append(plot)

            # Run OpenMC in geometry plotting mode
            model.plot_geometry(False, openmc_exec=openmc_exec)
            # model.plot_geometry(False, cwd=tmpdir, openmc_exec=openmc_exec)

            # Read image from file
            # img_path = Path(tmpdir) / f'plot_{plot.id}.png'
            img_path = Path(f'plot_{plot.id}.png')
            if not img_path.is_file():
                img_path = img_path.with_suffix('.ppm')
            img = mpimg.imread(str(img_path))

            # Create a figure sized such that the size of the axes within
            # exactly matches the number of pixels specified
            # if axes is None:
            #     px = 1/plt.rcParams['figure.dpi']
            #     fig, axes = plt.subplots()
            #     axes.set_xlabel(xlabel)
            #     axes.set_ylabel(ylabel)
            #     params = fig.subplotpars
            #     width = pixels[0]*px/(params.right - params.left)
            #     height = pixels[1]*px/(params.top - params.bottom)
            #     fig.set_size_inches(width, height)

            data = []

            if outline:
                # Combine R, G, B values into a single int
                rgb = (img * 256).astype(int)
                image_value = (rgb[..., 0] << 16) + \
                    (rgb[..., 1] << 8) + (rgb[..., 2])

                # axes.contour(
                #     image_value,
                #     origin="upper",
                #     colors="k",
                #     linestyles="solid",
                #     linewidths=1,
                #     levels=np.unique(image_value),
                #     extent=(x_min, x_max, y_min, y_max),
                # )

                data.append(
                    go.Contour(
                        z=image_value,
                        contours_coloring='none',
                        # colorscale=dcolorsc,
                        showscale=False,
                        x0=x_min,
                        dx=abs(x_min - x_max) / (img.shape[0] - 1),
                        y0=y_min,
                        dy=abs(y_min - y_max) / (img.shape[1] - 1),
                    )
                )

            # add legend showing which colors represent which material
            # or cell if that was requested
            if legend:
                if plot.colors == {}:
                    raise ValueError("Must pass 'colors' dictionary if you "
                                     "are adding a legend via legend=True.")

                if color_by == "cell":
                    expected_key_type = openmc.Cell
                else:
                    expected_key_type = openmc.Material

                # patches = []
                for key, color in plot.colors.items():

                    if isinstance(key, int):
                        raise TypeError(
                            "Cannot use IDs in colors dict for auto legend.")
                    elif not isinstance(key, expected_key_type):
                        raise TypeError(
                            "Color dict key type does not match color_by")

                    # this works whether we're doing cells or materials
                    label = key.name if key.name != '' else key.id

                    # matplotlib takes RGB on 0-1 scale rather than 0-255. at
                    # this point PlotBase has already checked that 3-tuple
                    # based colors are already valid, so if the length is three
                    # then we know it just needs to be converted to the 0-1
                    # format.
                    if len(color) == 3 and not isinstance(color, str):
                        scaled_color = (
                            color[0]/255, color[1]/255, color[2]/255)
                    else:
                        scaled_color = color

                    # key_patch = mpatches.Patch(color=scaled_color, label=label)
                    # patches.append(key_patch)

                # axes.legend(handles=patches, **legend_kwargs)

            # Plot image and return the axes
            # axes.imshow(img, extent=(x_min, x_max, y_min, y_max), **kwargs)           
            print(img)
            print(image_value)
            print(img.shape)

            rgb = (img * 256).astype(int)
            image_value = (rgb[..., 0] << 16) + \
                (rgb[..., 1] << 8) + (rgb[..., 2])

            list_of_unique_image_value= []
            list_of_unique_colors = []
            colors=[]
            for i, x in enumerate(image_value):
                for j, y in enumerate(x):
                    if y not in list_of_unique_image_value:
                        list_of_unique_image_value.append(y)
                        c= img[i][j]
                        colors.append(f'rgb({int(c[0]*255)}, {int(c[1]*255)}, {int(c[2]*255)})')
            print('list_of_unique_image_value', list_of_unique_image_value)
            print('list_of_unique_colors',list_of_unique_colors)

            # for val, col in zip(list_of_unique_image_value, list_of_unique_colors):
            #     colors.append(f'rgb({int(col[0]*255)}, {int(col[1]*255)}, {int(col[2]*255)})')
                # colorscale.append([val, f'rgb({int(col[0]*255)}, {int(col[1]*255)}, {int(col[2]*255)})'])

            Z = [colors for _,colors in sorted(zip(list_of_unique_image_value,colors))]
            print(Z)

            dcolorsc=discrete_colorscale([1e12]+sorted(list_of_unique_image_value), Z)
            # dcolorsc=discrete_colorscale(list_of_unique_image_value, colors)
            # dcolorsc=discrete_colorscale(list_of_unique_image_value+[max(list_of_unique_image_value)+1], colors)

            print('colors', colors)
            print('dcolorsc', dcolorsc)
            data.append(
                go.Heatmap(
                    z=image_value,
                    showscale=True,
                    colorscale=dcolorsc,
                    x0=x_min,
                    dx=abs(x_min - x_max) / (img.shape[0] - 1),
                    y0=y_min,
                    dy=abs(y_min - y_max) / (img.shape[1] - 1),
                )
            )
            plot = go.Figure(data=data)

            plot.update_layout(
                xaxis={"title": xlabel},
                # reversed autorange is required to avoid image needing rotation/flipping in plotly
                yaxis={"title": ylabel, "autorange": "reversed"},
                # title=title,
                autosize=False,
                height=800,
            )
            plot.update_yaxes(
                scaleanchor="x",
                scaleratio=1,
            )
            return plot

openmc.Universe.plot = plot

surf = openmc.Sphere(r=10)
surf2 = openmc.ZCylinder(r=3)
cell = openmc.Cell(region=-surf & -surf2, cell_id=10)
cell2 = openmc.Cell(region=-surf & +surf2, cell_id=100)
geometry = openmc.Geometry([cell, cell2])
plot = geometry.plot(
    outline=True,
    basis='xz',
    pixels=100000
)
plot.show()
shimwell commented 7 months ago
import openmc
import numpy as np
import warnings
import math
import typing
from tempfile import TemporaryDirectory
from pathlib import Path
from PIL import Image

def get_rgb_from_int(value: int) -> typing.Tuple[int, int, int]:
    blue = value & 255
    green = (value >> 8) & 255
    red = (value >> 16) & 255
    return red, green, blue

def get_int_from_rgb(rgb: typing.Tuple[int, int, int]) -> int:
    red = rgb[0]
    green = rgb[1]
    blue = rgb[2]
    return (red << 16) + (green << 8) + blue

def discrete_colorscale(bvals, colors):
    """
    bvals - list of values bounding intervals/ranges of interest
    colors - list of rgb or hex colorcodes for values in [bvals[k], bvals[k+1]],0<=k < len(bvals)-1
    returns the plotly  discrete colorscale
    """
    print('bvals',bvals)
    print('colors',colors)
    if len(bvals) != len(colors)+1:
        raise ValueError('len(boundary values) should be equal to  len(colors)+1')
    bvals = sorted(bvals)     
    nvals = [(v-bvals[0])/(bvals[-1]-bvals[0]) for v in bvals]  #normalized values

    dcolorscale = [] #discrete colorscale
    for k in range(len(colors)):
        dcolorscale.extend([[nvals[k], colors[k]], [nvals[k+1], colors[k]]])
    return dcolorscale

def plot(
    self,
    origin=None,
    width=None,
    pixels=40000,
    basis='xy',
    color_by='cell',
    colors=None,
    seed=None,
    openmc_exec='openmc',
    axes=None,
    legend=False,
    axis_units='cm',
    # legend_kwargs=_default_legend_kwargs,
    outline=False,
    **kwargs
):
        """Display a slice plot of the universe.

        Parameters
        ----------
        origin : iterable of float
            Coordinates at the origin of the plot. If left as None,
            universe.bounding_box.center will be used to attempt to ascertain
            the origin with infinite values being replaced by 0.
        width : iterable of float
            Width of the plot in each basis direction. If left as none then the
            universe.bounding_box.width() will be used to attempt to
            ascertain the plot width.  Defaults to (10, 10) if the bounding_box
            contains inf values
        pixels : Iterable of int or int
            If iterable of ints provided then this directly sets the number of
            pixels to use in each basis direction. If int provided then this
            sets the total number of pixels in the plot and the number of
            pixels in each basis direction is calculated from this total and
            the image aspect ratio.
        basis : {'xy', 'xz', 'yz'}
            The basis directions for the plot
        color_by : {'cell', 'material'}
            Indicate whether the plot should be colored by cell or by material
        colors : dict
            Assigns colors to specific materials or cells. Keys are instances of
            :class:`Cell` or :class:`Material` and values are RGB 3-tuples, RGBA
            4-tuples, or strings indicating SVG color names. Red, green, blue,
            and alpha should all be floats in the range [0.0, 1.0], for example:

            .. code-block:: python

               # Make water blue
               water = openmc.Cell(fill=h2o)
               universe.plot(..., colors={water: (0., 0., 1.))
        seed : int
            Seed for the random number generator
        openmc_exec : str
            Path to OpenMC executable.
        axes : matplotlib.Axes
            Axes to draw to

            .. versionadded:: 0.13.1
        legend : bool
            Whether a legend showing material or cell names should be drawn

            .. versionadded:: 0.14.0
        legend_kwargs : dict
            Keyword arguments passed to :func:`matplotlib.pyplot.legend`.

            .. versionadded:: 0.14.0
        outline : bool
            Whether outlines between color boundaries should be drawn

            .. versionadded:: 0.14.0
        axis_units : {'km', 'm', 'cm', 'mm'}
            Units used on the plot axis

            .. versionadded:: 0.14.0
        **kwargs
            Keyword arguments passed to :func:`matplotlib.pyplot.imshow`

        Returns
        -------
        matplotlib.axes.Axes
            Axes containing resulting image

        """
        import matplotlib.image as mpimg
        # import matplotlib.patches as mpatches
        import matplotlib.pyplot as plt
        import plotly.graph_objects as go

        # Determine extents of plot
        if basis == 'xy':
            x, y = 0, 1
            xlabel, ylabel = f'x [{axis_units}]', f'y [{axis_units}]'
        elif basis == 'yz':
            x, y = 1, 2
            xlabel, ylabel = f'y [{axis_units}]', f'z [{axis_units}]'
        elif basis == 'xz':
            x, y = 0, 2
            xlabel, ylabel = f'x [{axis_units}]', f'z [{axis_units}]'

        bb = self.bounding_box
        # checks to see if bounding box contains -inf or inf values
        if np.isinf(bb.extent[basis]).any():
            if origin is None:
                origin = (0, 0, 0)
            if width is None:
                width = (10, 10)
        else:
            if origin is None:
                # if nan values in the bb.center they get replaced with 0.0
                # this happens when the bounding_box contains inf values
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore", RuntimeWarning)
                    origin = np.nan_to_num(bb.center)
            if width is None:
                bb_width = bb.width
                x_width = bb_width['xyz'.index(basis[0])]
                y_width = bb_width['xyz'.index(basis[1])]
                width = (x_width, y_width)

        if isinstance(pixels, int):
            aspect_ratio = width[0] / width[1]
            pixels_y = math.sqrt(pixels / aspect_ratio)
            pixels = (int(pixels / pixels_y), int(pixels_y))

        axis_scaling_factor = {'km': 0.00001, 'm': 0.01, 'cm': 1, 'mm': 10}

        x_min = (origin[x] - 0.5*width[0]) * axis_scaling_factor[axis_units]
        x_max = (origin[x] + 0.5*width[0]) * axis_scaling_factor[axis_units]
        y_min = (origin[y] - 0.5*width[1]) * axis_scaling_factor[axis_units]
        y_max = (origin[y] + 0.5*width[1]) * axis_scaling_factor[axis_units]

        # with TemporaryDirectory() as tmpdir:
        if 1==1:
            model = openmc.Model()
            model.geometry = openmc.Geometry(self)
            if seed is not None:
                model.settings.plot_seed = seed

            # Determine whether any materials contains macroscopic data and if
            # so, set energy mode accordingly
            for mat in self.get_all_materials().values():
                if mat._macroscopic is not None:
                    model.settings.energy_mode = 'multi-group'
                    break

            # Create plot object matching passed arguments
            plot = openmc.Plot()
            plot.origin = origin
            plot.width = width
            plot.pixels = pixels
            plot.basis = basis
            plot.color_by = color_by

            if colors is not None:

                colors_based_on_ids = {}
                for key, value in colors.items():
                    colors_based_on_ids[key] = get_rgb_from_int(key.id)
                plot.colors = colors_based_on_ids

            model.plots.append(plot)

            # Run OpenMC in geometry plotting mode
            model.plot_geometry(False, openmc_exec=openmc_exec)
            # model.plot_geometry(False, cwd=tmpdir, openmc_exec=openmc_exec)

            # Read image from file
            # img_path = Path(tmpdir) / f'plot_{plot.id}.png'
            img_path = Path(f'plot_{plot.id}.png')
            if not img_path.is_file():
                img_path = img_path.with_suffix('.ppm')
            img = mpimg.imread(str(img_path))
            image_values = Image.open(img_path)
            print('img.shape,',img)
            image_values = np.asarray(image_values)
            image_values = [
                [get_int_from_rgb(inner_entry) for inner_entry in outer_entry]
                for outer_entry in image_values
            ]

            image_values = np.array(image_values)
            image_values[image_values == 16777215] = 0

            # image_values[:] = [x if x != 16777215 else 0 for x in image_values]

            # Create a figure sized such that the size of the axes within
            # exactly matches the number of pixels specified
            # if axes is None:
            #     px = 1/plt.rcParams['figure.dpi']
            #     fig, axes = plt.subplots()
            #     axes.set_xlabel(xlabel)
            #     axes.set_ylabel(ylabel)
            #     params = fig.subplotpars
            #     width = pixels[0]*px/(params.right - params.left)
            #     height = pixels[1]*px/(params.top - params.bottom)
            #     fig.set_size_inches(width, height)

            data = []

            # if outline:
            #     # Combine R, G, B values into a single int
            #     rgb = (img * 256).astype(int)
            #     image_value = (rgb[..., 0] << 16) + \
            #         (rgb[..., 1] << 8) + (rgb[..., 2])

                # axes.contour(
                #     image_value,
                #     origin="upper",
                #     colors="k",
                #     linestyles="solid",
                #     linewidths=1,
                #     levels=np.unique(image_value),
                #     extent=(x_min, x_max, y_min, y_max),
                # )

                # data.append(
                #     go.Contour(
                #         z=image_value,
                #         contours_coloring='none',
                #         # colorscale=dcolorsc,
                #         showscale=False,
                #         x0=x_min,
                #         dx=abs(x_min - x_max) / (img.shape[0] - 1),
                #         y0=y_min,
                #         dy=abs(y_min - y_max) / (img.shape[1] - 1),
                #     )
                # )

            # add legend showing which colors represent which material
            # or cell if that was requested
            if legend:
                if plot.colors == {}:
                    raise ValueError("Must pass 'colors' dictionary if you "
                                     "are adding a legend via legend=True.")

                if color_by == "cell":
                    expected_key_type = openmc.Cell
                else:
                    expected_key_type = openmc.Material

                # patches = []
                for key, color in plot.colors.items():

                    if isinstance(key, int):
                        raise TypeError(
                            "Cannot use IDs in colors dict for auto legend.")
                    elif not isinstance(key, expected_key_type):
                        raise TypeError(
                            "Color dict key type does not match color_by")

                    # this works whether we're doing cells or materials
                    label = key.name if key.name != '' else key.id

                    # matplotlib takes RGB on 0-1 scale rather than 0-255. at
                    # this point PlotBase has already checked that 3-tuple
                    # based colors are already valid, so if the length is three
                    # then we know it just needs to be converted to the 0-1
                    # format.
                    if len(color) == 3 and not isinstance(color, str):
                        scaled_color = (
                            color[0]/255, color[1]/255, color[2]/255)
                    else:
                        scaled_color = color

                    # key_patch = mpatches.Patch(color=scaled_color, label=label)
                    # patches.append(key_patch)

                # axes.legend(handles=patches, **legend_kwargs)

            # Plot image and return the axes
            # axes.imshow(img, extent=(x_min, x_max, y_min, y_max), **kwargs)           
            print(image_values)
            # print(image_values.shape)

            # rgb = (img * 256).astype(int)
            # image_value = (rgb[..., 0] << 16) + \
            #     (rgb[..., 1] << 8) + (rgb[..., 2])

            list_of_unique_image_value= []
            list_of_unique_colors = []
            # colors=[]
            # for i, x in enumerate(image_value):
            #     for j, y in enumerate(x):
            #         if y not in list_of_unique_image_value:
            #             list_of_unique_image_value.append(y)
            #             c= img[i][j]
            #             colors.append(f'rgb({int(c[0]*255)}, {int(c[1]*255)}, {int(c[2]*255)})')
            print('list_of_unique_image_value', list_of_unique_image_value)
            print('list_of_unique_colors',list_of_unique_colors)

            # for val, col in zip(list_of_unique_image_value, list_of_unique_colors):
            #     colors.append(f'rgb({int(col[0]*255)}, {int(col[1]*255)}, {int(col[2]*255)})')
                # colorscale.append([val, f'rgb({int(col[0]*255)}, {int(col[1]*255)}, {int(col[2]*255)})'])

            # Z = [colors for _,colors in sorted(zip(list_of_unique_image_value,colors))]
            # print(Z)

            # dcolorsc=discrete_colorscale([mat.id for mat in colors.keys()], [f'rgb({c[0]},{c[1]},{c[2]})' for c in list(colors.values())])
            # dcolorsc.append
            # dcolorsc=discrete_colorscale(list_of_unique_image_value, colors)
            # dcolorsc=discrete_colorscale(list_of_unique_image_value+[max(list_of_unique_image_value)+1], colors)

            dcolorsc=[
                [0, 'green'],
                # [(1/20)*2, 'red'],
                # [(1/20)*20, 'blue'],
            ]

            for rgb_col, mat_id in zip([f'rgb({c[0]},{c[1]},{c[2]})' for c in list(colors.values())], [mat.id for mat in colors.keys()]):
                dcolorsc.append(((1/20)*mat_id,rgb_col))
            print('dcolorsc', dcolorsc)
            # [0.1, 'green'],
            # [0.1, 'rgb(253, 237, 176)'],
            # [0.2, 'rgb(249, 198, 139)'],
            # [0.3, 'rgb(244, 159, 109)'],
            # [0.4, 'rgb(234, 120, 88)'],
            # [0.5, 'rgb(218, 83, 82)'],
            # [0.6, 'rgb(191, 54, 91)'],
            # [0.7, 'rgb(158, 35, 98)'],
            # [0.8, 'rgb(120, 26, 97)'],
            # [0.9, 'rgb(83, 22, 84)'],
            # [1.0, 'rgb(47, 15, 61)']]
            print('colors', colors)
            print('dcolorsc', dcolorsc)
            print(image_values)
            data.append(
                go.Heatmap(
                    z=image_values,
                    # showscale=True,
                    colorscale=dcolorsc,
                    x0=x_min,
                    dx=abs(x_min - x_max) / (img.shape[0] - 1),
                    y0=y_min,
                    dy=abs(y_min - y_max) / (img.shape[1] - 1),
                )
            )
            plot = go.Figure(data=data)

            plot.update_layout(
                xaxis={"title": xlabel},
                # reversed autorange is required to avoid image needing rotation/flipping in plotly
                yaxis={"title": ylabel, "autorange": "reversed"},
                # title=title,
                autosize=False,
                height=800,
            )
            plot.update_yaxes(
                scaleanchor="x",
                scaleratio=1,
            )
            return plot

openmc.Universe.plot = plot

surf = openmc.Sphere(r=10)
surf2 = openmc.ZCylinder(r=3)
mat1=openmc.Material(material_id=2)
mat1.add_nuclide('Li6',1)
mat1.set_density('g/cm3',1)
mat2=openmc.Material(material_id=20)
mat2.add_nuclide('Li6',1)
mat2.set_density('g/cm3',1)
cell = openmc.Cell(region=-surf & -surf2, cell_id=10, fill=mat1)
cell2 = openmc.Cell(region=-surf & +surf2, cell_id=100, fill=mat2)
geometry = openmc.Geometry([cell, cell2])
plot = geometry.plot(
    outline=True,
    basis='xz',
    pixels=1000,
    colors={mat1:[200,0,0], mat2:[0,0, 255]},
    color_by='material'
)
plot.show()