Closed yanhong-zhao-ef closed 1 year ago
@yanhong-zhao-ef Thanks for opening this one!
This is a known issue which we found recently and we will be sending out a fix in the upcoming days.
Thank you for the swift response @pabloduque0! Look forward to the fixes
will test this out and report back thanks @pabloduque0
Hey @pabloduque0 just tested the latest code and I seem to run into this issue when plotting the results:
loc("jit(_unstack)/jit(main)/squeeze[dimensions=(0,)]"("/Users/yanhongzhao/miniforge3/lib/python3.9/site-packages/matplotlib/cbook/__init__.py":1647:1)): error: 'mhlo.reshape' op requires the same element type for all operands and results loc("jit(_unstack)/jit(main)/squeeze[dimensions=(0,)]"("/Users/yanhongzhao/miniforge3/lib/python3.9/site-packages/matplotlib/cbook/__init__.py":1647:1)): error: 'mhlo.reshape' op requires the same element type for all operands and results
Hello @yanhong-zhao-ef ,
What version of matplotlib are you using? Make sure is matplotlib==3.3.4
.
If its not that I would need a reproducible code example (or colab).
But by the looks of it, it might be tied to some version missmatch as our test pass correctly and in the updated examples plot works fine.
Hey @pabloduque0 I have updated the dependencies to be matplotlib==3.3.4 and the problem seems to be that the previous budget allocation and optimised budget are of different types of array for some reason:
So I did
print(previous_budget_allocation)
print(optimal_buget_allocation)
print(jnp.shape(previous_budget_allocation))
print(jnp.shape(optimal_buget_allocation))
print(previous_budget_allocation.dtype)
print(optimal_buget_allocation.dtype)
print(type(previous_budget_allocation))
print(type(optimal_buget_allocation))
Where previous budget allocation is from the starting values and here is the output:
[1.74966221e+04 4.61873158e+03 3.51868975e+05 1.65627613e+05
2.46767943e+03 2.50870759e+01 1.32379315e+03 3.06571498e+05]
[1.93767937e+05 8.06692206e+04 3.55350656e+05 1.67266469e+05
2.81926225e+04 1.72227177e+02 2.14848183e+04 3.09604983e+03]
(8,)
(8,)
float64
float64
<class 'jaxlib.xla_extension.DeviceArray'>
<class 'numpy.ndarray'>
When I plot like this the whole thing works:
budget_allocation_plot = plot.plot_pre_post_budget_allocation_comparison(
media_mix_model=mmm_model_obj,
kpi_with_optim=solution["fun"],
kpi_without_optim=kpi_without_optim,
optimal_buget_allocation=optimal_buget_allocation,
previous_budget_allocation=jnp.array(previous_budget_allocation),
figure_size=(10, 10),
)
Yes, indeed they are different type. We updated a bit the naming and order of things in the notebook so it is more clear now, hopefully following that helps:
All good thank you!
@yanhong-zhao-ef thanks for reporting back!
Hi when will this be included as a release on PyPi?
@itaher-aclu v0.1.6 Is just live now! :)
https://github.com/google/lightweight_mmm/blob/main/lightweight_mmm/optimize_media.py#L145
Here in the generating starting values function, the prices for each media channel is not passed in therefore the starting values are always not the actual values in monetary terms