Turned out to be very easy with a lot of existing astropy.wcs functions:
from astropy import units as u
from astropy.coordinates import SkyCoord
c = SkyCoord(0. * u.deg, 0. * u.deg, frame='galactic')
new_wcs = make_offset_wcs(wcs, c)
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_axes([0.15, 0.1, 0.8, 0.8], projection=new_wcs)
etc.
Turned out to be very easy with a lot of existing astropy.wcs functions: