Open UQ3QU opened 9 months ago
In order to use the model, you need to do the following steps:
load pre-trained model
net = DKiS(alpha_list=[0.6 ** i for i in range(16)])
net.load_state_dict(torch.load('model_state.pt'))
set random seed to generate k
torch.random.manual_seed(your_seed)
k = get_k([16, batch_size // 2, 12, img_size1 // 2, img_size2 // 2])
define dwt and iwt
iwt = utils.NIWT(a=0.43, b=0.28) # a = traindataset.mean, b = traindataset.std
dwt = utils.NDWT(a=0.43, b=0.28)
if you want to hide, load host and secret image, the size of them should be same.
host = dwt(host)
secret = dwt(secret)
input = torch.cat([host, secret], dim=1)
container = net(input, k, rev='False')
container = iwt(container([:, :12]))
if you want to extract, load container image
container = dwt(container)
z = torch.randn(container.shape)
input = torch.cat([container, z], dim=1)
extracted = net(input, k, rev='True')
extracted = iwt(extracted[:, 12:])
I'm trying to test this model but have some problems, so can you upload the test.py file,,?
and I think you should modify the hiding commnet
container = iwt(container([:, :12]))
to
container = iwt(container[:, :12])
It makes an error!
I have trained the model according to the instructions, so how do I use the model to check the actual effect