Open shanek16 opened 2 years ago
It may be the case that you need to transpose all of your weight matrices. Could you try this and let me know if that works?
Thank you for your comment, But it still doesn't work :(
I have modified my code to transpose all the weights matrices:
def extract_weights(net):
weights = []
for param_tensor in net.state_dict():
tensor = net.state_dict()[param_tensor].detach().numpy().astype(np.float64)
if 'actor' in param_tensor:
print('param_tensor:')
print(param_tensor)
print('\n')
if 'weight' in param_tensor:
weights.append(tensor.T)
return weights
and this is my result of solve_sdp.py:
(lipenv) (base) shane16@dualarm-server:~/Project/model_guard/LipSDP/LipSDP$ python solve_sdp.py --form layer --weight-path examples/saved_weights/uav_sac_actor_weights_T.mat Error using vertcat Dimensions of arrays being concatenated are not consistent.
Error in lipschitz_multi_layer (line 143) A_on_B = vertcat(A, B);
Error in solve_LipSDP (line 41) L = lipschitz_multi_layer(weights, lip_params.formulation, ...
Traceback (most recent call last): File "solve_sdp.py", line 112, in
main(args) File "solve_sdp.py", line 35, in main L = eng.solve_LipSDP(network, lip_params, nargout=1) File "/home/shane16/Project/model_guard/LipSDP/lipenv/lib/python3.7/site-packages/matlab/engine/matlabengine.py", line 71, in call _stderr, feval=True).result() File "/home/shane16/Project/model_guard/LipSDP/lipenv/lib/python3.7/site-packages/matlab/engine/futureresult.py", line 67, in result return self.__future.result(timeout) File "/home/shane16/Project/model_guard/LipSDP/lipenv/lib/python3.7/site-packages/matlab/engine/fevalfuture.py", line 82, in result self._result = pythonengine.getFEvalResult(self._future,self._nargout, None, out=self._out, err=self._err) matlab.engine.MatlabExecutionError: File /home/shane16/Project/model_guard/LipSDP/LipSDP/matlab_engine/lipschitz_multi_layer.m, line 143, in lipschitz_multi_layer File /home/shane16/Project/model_guard/LipSDP/LipSDP/matlab_engine/solve_LipSDP.m, line 41, in solve_LipSDP Dimensions of arrays being concatenated are not consistent.
for additional information, This is what original weight.mat file look like:
weights {'header': b'MATLAB 5.0 MAT-file Platform: posix, Created on: Sat Dec 18 14:20:21 2021', 'version': '1.0', 'globals': [], 'weights': array([[array([[-6.85143769e-02, -3.36603433e-01], [-4.37636882e-01, 2.34987199e-01], [-3.16818148e-01, -5.45625389e-01], [-2.89836824e-01, 9.00892913e-01], [-1.38061807e-01, 3.41346651e-01], [ 3.85727435e-02, 4.22988925e-03], [-2.81139523e-01, 1.06242709e-01], [-3.91652554e-01, -8.87769461e-01], [-2.32657984e-01, 1.38013482e+00], [ 9.39288875e-05, -1.53773642e+00], [-5.49701415e-02, 3.08358729e-01], [-1.47983909e-01, -3.22792605e-02], [-9.24302191e-02, -3.49573493e-02], [ 6.25483915e-02, 6.03022352e-02], [-1.94695652e-01, -4.27681386e-01], [-3.33421618e-01, 9.96067703e-01], [ 4.45502810e-02, -2.12907977e-03], [-3.57661456e-01, 1.95654221e-02], [-3.79328549e-01, 6.77343726e-01], [-1.79474965e-01, 3.16777587e-01], [-3.38969111e-01, -6.44439220e-01], [-3.35297108e-01, -1.03046894e+00], [-7.81602859e-02, 5.66973686e-02], [-5.52919686e-01, 5.90989351e-01], [-1.22955883e+00, 8.56232643e-01], [-1.88813895e-01, 1.01321213e-01], [-4.30013150e-01, 1.42168307e+00], [-3.50720137e-01, -1.01123102e-01], [-5.99819124e-01, -7.11649895e-01], [-1.35085851e-01, 1.60574126e+00], [ 8.81476048e-03, -4.27752674e-01], [-1.72235578e-01, 8.41235936e-01], [-3.53380352e-01, -8.74089301e-01], [-6.54951930e-01, -9.02151346e-01], [-1.10342883e-01, -4.87088114e-02], [-2.17704028e-01, 7.80149519e-01], [-4.58586782e-01, -6.29357770e-02], [-1.50429919e-01, -1.26546144e+00], [-1.21436864e-01, -8.77058029e-01], [-1.29360640e+00, 5.71827412e-01], [ 1.99044764e-01, 7.30524063e-01], [-6.54722393e-01, 2.41861656e-01], [ 5.61621450e-02, 1.21515773e-01], [-4.69913393e-01, -6.45179987e-01], [ 7.27659389e-02, 5.60564213e-02], [-1.42742127e-01, -1.36110112e-01], [-2.32477710e-01, -8.05989563e-01], [-1.93128836e+00, 5.32305419e-01], [ 2.26508337e-03, -7.47355819e-01], [-6.71177506e-02, 5.66721976e-01], [-2.86900342e-01, -1.23376977e+00], [-6.13334887e-02, -4.16614674e-02], [-7.39199162e-01, -9.74034727e-01], [-1.50688469e-01, -1.60530448e-01], [-2.20665395e-01, -2.36648083e-01], [-1.67475808e+00, 5.13051689e-01], [-5.70961177e-01, 1.26409459e+00], [-1.05320908e-01, -1.28656551e-01], [-4.86176521e-01, 1.02747381e+00], [-2.66896546e-01, 1.47360966e-01], [-1.57919437e-01, -2.67666340e-01], [-5.33818424e-01, -9.34508979e-01], [-4.04241353e-01, 7.50177324e-01], [ 6.49868324e-02, 7.99891353e-02], [-1.19405729e-03, -1.13087334e-01], [-9.05777931e-01, -4.53344226e-01], [-5.49191058e-01, -8.32641304e-01], [-2.57960826e-01, 1.04893334e-01], [-3.46958756e-01, 1.20738077e+00], [-1.78298444e-01, -1.71872750e-01], [-4.75126773e-01, 1.00679123e+00], [-1.87536225e-01, 4.18525279e-01], [-7.60911167e-01, 1.10811162e+00], [ 7.68872947e-02, -4.26468067e-02], [-1.81280479e-01, -9.58161056e-01], [-4.01476800e-01, -1.92179576e-01], [-1.52010107e+00, 2.59388573e-02], [-3.68176311e-01, -1.16760635e+00], [-8.44833851e-02, -7.07813501e-02], [-2.66123056e-01, 3.48530449e-02], [-2.58686244e-01, -8.75191510e-01], [-1.49426848e-01, 3.01911622e-01], [-4.93374348e-01, 1.09098125e+00], [ 1.27181664e-01, -4.39624816e-01], [-4.56174910e-01, -7.75252342e-01], [-5.62372655e-02, -3.87163460e-02], [-9.74609017e-01, 8.86448145e-01], [-3.60333174e-01, 1.31217504e+00], [-1.43431008e-01, 8.06200087e-01], [-8.37842468e-03, 1.27602530e+00], [-3.47854078e-01, -3.26256603e-01], [-4.26802456e-01, -9.97594774e-01], [-6.67393655e-02, -5.63718751e-02], [-6.38821423e-02, -1.12170398e+00], [-1.17282771e-01, -1.46603751e+00], [-1.28887296e-01, -2.08965749e-01], [-2.38886729e-01, 1.74527481e-01], [-4.71011966e-01, -9.86448765e-01], [ 1.82654373e-02, 3.13659422e-02], [-5.58682740e-01, -7.24108636e-01], [-1.95165485e-01, -7.98933804e-01], [-5.10833971e-02, -9.89437819e-01], [-1.70823678e-01, -1.77882314e-01], [-5.26996613e-01, -5.64561844e-01], [-2.56211758e-01, 1.81155765e+00], [-7.41678834e-01, 1.36839986e+00], [-1.49451643e-01, -1.72329140e+00], [-3.89405459e-01, -7.70909250e-01], [-1.70388266e-01, 1.05704975e+00], [-5.42294145e-01, 9.16380644e-01], [-1.37954861e-01, 1.38688207e+00], [-1.07425487e+00, 2.50464559e-01], [-3.92255992e-01, 1.53620625e+00], [-2.35263690e-01, -1.36332703e+00], [-8.62313569e-01, 7.17299223e-01], [ 1.96617860e-02, 1.17209822e-01], [-2.71294802e-01, 2.41939407e-02], [-5.24426922e-02, 6.36409894e-02], [-7.49675214e-01, 5.92391551e-01], [-3.09047669e-01, -4.93131012e-01], [-2.05987543e-01, -4.89378721e-02], [-4.86499310e-01, 1.11728787e+00], [-1.22964919e+00, 1.64474010e-01], [-1.10370472e-01, -7.81808719e-02], [-1.24383837e-01, -9.99551788e-02], [-3.14716071e-01, 7.24924803e-01], [ 1.75869409e-02, -1.63655053e-03], [-5.08730948e-01, -7.42523849e-01], [-3.84562939e-01, 1.13943541e+00], [-5.03208518e-01, -1.25559962e+00], [-3.08119088e-01, 1.35422385e+00], [-1.41793382e+00, 9.27046895e-01], [ 5.53365014e-02, 1.61926597e-02], [-1.16442271e-01, -1.49270087e-01], [-3.38886559e-01, 1.52210808e+00], [-1.49027467e-01, 1.60809851e+00], [-7.38697220e-03, 1.50049835e-01], [-4.15518582e-01, -8.22724342e-01], [-1.00603953e-01, 8.28674808e-02], [ 4.77627143e-02, 4.48429435e-01], [-2.91403979e-02, -9.85705495e-01], [-2.90248126e-01, 1.56998858e-01], [-6.70203641e-02, -1.30562857e-01], [-5.61595201e-01, 7.16114819e-01], [-1.36431471e-01, -1.51457608e+00], [-2.32963413e-01, 1.45835304e+00], [ 4.27815244e-02, -6.86439335e-01], [-1.54693499e-01, -8.10685009e-02], [-2.28011087e-02, -1.00114523e-02], [-2.51068085e-01, -2.12578982e-01], [-6.95462897e-02, -1.07704490e-01], [-3.25500183e-02, -9.34520602e-01], [-4.21439379e-01, -8.72218668e-01], [-5.55282474e-01, -1.67548263e+00], [-2.24645421e-01, -7.08654225e-02], [-1.15941226e-01, -6.79708868e-02], [-9.64069068e-01, 2.53156602e-01], [-5.07299364e-01, -9.72091734e-01], [ 1.13090523e-01, 1.55948391e-02], [ 8.96759331e-02, -4.92816031e-01], [-3.53340477e-01, 1.26766729e+00], [-1.13922739e+00, -1.91750079e-01], [-5.32649577e-01, 1.33047259e+00], [-4.38321948e-01, -1.12845027e+00], [-1.65716678e-01, -7.64620900e-02], [-5.69151282e-01, -1.12235856e+00], [-2.21188977e-01, 8.66389453e-01], [-6.76088691e-01, 1.67762414e-01], [-2.42386699e-01, 6.05390668e-02], [-1.61106035e-01, 1.02788734e+00], [-4.01457489e-01, 9.43940878e-01], [-1.05722854e-02, 1.15793375e-02], [-5.89403391e-01, 1.71686387e+00], [-1.33220807e-01, 2.46252894e-01], [-8.95949900e-02, -1.14768989e-01], [-2.41824202e-02, -8.48952904e-02], [ 1.46331182e-02, -3.52581859e-01], [-5.75874513e-03, 1.38468480e+00], [-3.01785581e-02, 1.72366023e+00], [-1.24253404e+00, 5.04970372e-01], [-4.64695930e-01, -1.04173005e+00], [-3.10319602e-01, -5.31640410e-01], [-6.83822930e-02, -1.04936886e+00], [ 1.29613737e-02, -4.21936624e-02], [-6.43785819e-02, -5.02525829e-02], [-1.39409497e-01, 1.92209169e-01], [-8.30474421e-02, -1.11248195e-01], [-3.75379890e-01, -8.50651741e-01], [-3.56702924e-01, 7.49479234e-01], [-5.23479939e-01, -9.11294639e-01], [-2.27834895e-01, 2.05277994e-01], [-7.52443373e-01, -4.09139544e-01], [-6.73535466e-02, -8.07213664e-01], [-3.64913195e-01, -4.23726022e-01], [-1.32498294e-01, 4.04865235e-01], [-7.33781695e-01, 6.03865802e-01], [-3.58522981e-01, -7.16318130e-01], [-2.78221190e-01, 6.77906930e-01], [-2.47608081e-01, -3.52997720e-01], [-7.94031084e-01, -1.16398251e+00], [-2.82960027e-01, -8.17655265e-01], [-2.80453920e-01, -1.67103291e-01], [-5.46526834e-02, 1.73139632e-01], [-1.26927841e+00, 4.68366355e-01], [-4.26691115e-01, 8.85109663e-01], [-8.21300328e-01, -1.50739223e-01], [-5.74597359e-01, -1.56419683e+00], [-3.46872926e-01, -6.20355964e-01], [-4.16146405e-02, 4.41526845e-02], [-1.61938414e-01, 1.45489848e+00], [ 1.14040999e-02, 1.74344536e-02], [-3.67602527e-01, 1.36658502e+00], [-4.77526844e-01, -1.67623281e+00], [-1.50420338e-01, 2.10183069e-01], [-6.03150606e-01, 1.44156480e+00], [ 5.73831350e-02, -5.08895040e-01], [-1.98268175e-01, -1.13669908e+00], [-9.54810828e-02, -3.64515521e-02], [-3.34659815e-01, 8.64602149e-01], [-1.23859912e-01, -8.52473855e-01], [ 1.69087097e-01, 8.02972436e-01], [-6.36473060e-01, -1.63622725e+00], [-4.16115105e-01, 8.53862286e-01], [-7.94829547e-01, 5.62543571e-01], [ 4.64861132e-02, 1.01891172e+00], [-5.36335230e-01, 2.74964459e-02], [-2.80157536e-01, -9.58885550e-01], [ 9.72216278e-02, -5.01175821e-01], [-1.42968902e-02, -1.16730398e-02], [-6.91494904e-03, 3.60297374e-02], [-5.53260922e-01, -6.89113796e-01], [ 1.15756273e-01, -3.17524821e-01], [-1.13445848e-01, -6.50840029e-02], [-6.09598272e-02, -3.05257272e-03], [-5.08547962e-01, -1.38159800e+00], [-4.68757123e-01, -1.30974793e+00], [-2.33617872e-01, -1.09636915e+00], [-2.61846725e-02, -1.61404833e-01], [-2.22611204e-01, 1.12323213e+00], [-2.33813077e-01, -3.32386754e-02], [-2.58558065e-01, -3.00171133e-02], [-2.16094732e-01, 1.63878644e+00], [ 1.42942201e-02, 3.10745575e-02], [-6.74639046e-01, -1.85138553e-01], [-2.58727401e-01, 1.48004818e+00], [-1.12628365e+00, -2.73787439e-01], [-4.06753957e-01, -2.30762169e-01], [-5.90893686e-01, 2.64795601e-01], [-1.42539322e+00, -1.42960802e-01], [-1.45081520e-01, 1.32808387e+00], [-3.25415015e-01, -4.25004959e-01], [-2.91306406e-01, -1.61577329e-01], [-4.24730957e-01, -6.60539389e-01], [ 1.72714591e-02, -6.05386728e-03], [-5.16779840e-01, -1.04821813e+00], [-5.17340600e-01, 9.92565155e-01], [-6.52867794e-01, -6.14330769e-02], [-1.03650130e-01, -6.12370595e-02], [-9.98392254e-02, -4.62541962e-03], [-2.02449262e-02, 7.99310565e-01], [-6.09915912e-01, -9.49235678e-01], [-1.48535445e-01, 1.36771011e+00], [-5.40006936e-01, -6.01608992e-01], [-1.23508260e-01, -6.64888546e-02], [-4.84287888e-01, -9.66315746e-01], [-2.34719744e-04, -5.50626405e-02], [-4.42205936e-01, -6.38400972e-01], [-2.05219805e-01, 3.64799947e-01], [-1.13294709e+00, 1.14065530e-02], [-3.85340929e-01, -1.26428449e+00], [-1.55003415e-02, 8.14648986e-01], [-7.72838891e-01, 1.11270718e-01], [-7.32219368e-02, -1.15076780e+00], [ 8.17508623e-02, -7.19908625e-02], [-4.63112712e-01, 1.10501039e+00], [-2.73998827e-01, 1.71555257e+00], [-2.07605869e-01, -5.73203385e-01], [-1.71772227e-01, 1.37884307e+00], [-1.60799873e+00, 2.73396343e-01], [-4.03662294e-01, -1.06119895e+00], [-4.37071621e-02, 5.61125651e-02], [-4.53095555e-01, 9.18152392e-01], [-7.92810857e-01, -1.28594398e+00], [-3.87161881e-01, 8.97737071e-02], [-7.25102305e-01, -6.51422977e-01], [-5.35262406e-01, -1.38639987e+00], [-7.61458516e-01, -1.15196586e+00], [-1.01287387e-01, 1.38946444e-01], [-3.52832675e-01, -9.41418350e-01], [-2.95413464e-01, -5.05823135e-01], [-2.78604954e-01, -2.15803653e-01], [-1.97922081e-01, -1.12287259e+00], [-6.69863939e-01, -9.73679900e-01], [-1.70272279e+00, 3.12141716e-01], [ 1.04595587e-01, -1.70570984e-02], [-1.23290658e-01, 1.05429757e+00], [-2.17610821e-01, -5.69162309e-01], [-6.52480364e-01, 7.28567421e-01], [-2.16790274e-01, 1.40694821e+00], [-2.74692088e-01, -4.16959941e-01], [-4.88729596e-01, 9.77496862e-01], [-5.66260755e-01, -1.20392644e+00], [-1.89852685e-01, -1.93601847e-01], [-7.96760246e-03, 3.99434537e-01], [-2.00265199e-01, -3.95467281e-01], [-4.01530648e-03, 2.81437729e-02], [-3.20418000e-01, 9.92648244e-01], [-4.80500311e-01, -8.89828429e-02], [-2.78092802e-01, -7.71939754e-01], [-2.91511208e-01, 1.67494524e+00], [-1.08414382e-01, 4.60278869e-01], [-6.31529033e-01, -6.26300931e-01], [-9.30277165e-03, 9.46571052e-01], [-1.08266443e-01, 1.43687129e+00], [-2.93854568e-02, -1.69754550e-02], [-3.95627201e-01, -4.75571632e-01], [-4.08647001e-01, -7.35336125e-01], [-5.34992576e-01, -1.34267569e+00], [-6.32516086e-01, 1.05796075e+00], [-4.52237368e-01, -4.09887359e-02], [-3.79665613e-01, 1.34319663e-01], [-1.47772402e-01, 1.98458731e-01], [-1.75058022e-01, 7.74153948e-01], [-2.94035636e-02, -8.71903747e-02], [-5.73199149e-03, 9.78635177e-02], [-2.99606383e-01, -8.99923384e-01], [-1.29079714e-01, -1.48870781e-01], [-1.74793869e-01, 6.67257458e-02], [-3.53810370e-01, -4.13035899e-01], [-5.32069206e-01, 1.37551713e+00], [-1.02681100e-01, -1.17860787e-01], [-1.32749140e-01, 2.58607686e-01], [-5.20431876e-01, -1.64319134e+00], [-6.58625305e-01, 1.37898874e+00], [-2.35748142e-01, 1.55201629e-02], [ 9.05706733e-03, 6.40194952e-01], [-6.18926398e-02, -1.00056641e-01], [-9.15057287e-02, 3.68416458e-01], [-2.62450337e-01, 1.44937241e+00], [-5.94881535e-01, 1.82293542e-02], [-1.80450052e-01, -9.03574079e-02], [-1.30852059e-01, -6.66531444e-01], [-1.65500507e-01, -3.96247894e-01], [-2.78791666e-01, 1.29573858e+00], [-1.09676614e-01, 2.26825535e-01], [-1.48049921e-01, 2.73054898e-01], [-2.63758719e-01, -6.62458956e-01], [ 7.36281276e-02, 6.73432425e-02], [-4.21359539e-01, 1.11743033e+00], [-7.49257922e-01, 6.76169634e-01], [-1.33202612e-01, -1.00994837e+00], [-5.71687281e-01, -9.77057397e-01], [-2.79146940e-01, 1.77596420e-01], [-1.24701940e-01, 7.69110680e-01], [-1.15750335e-01, -1.36288717e-01], [-3.09202075e-01, -1.16831899e+00], [-4.84196007e-01, -9.10362661e-01], [-3.13956439e-01, -9.92035627e-01], [-2.01910920e-02, -8.82056057e-01], [ 1.00143977e-01, -3.83660972e-01], [-1.21029124e-01, 1.02580273e+00], [-3.39247376e-01, -9.33523476e-01], [-5.01396000e-01, 1.57736260e-02], [-3.62887561e-01, 1.27967525e+00], [-2.60262996e-01, 1.65057218e+00], [-1.52231306e-01, -7.94253826e-01], [-4.91619080e-01, -9.51140523e-01], [-7.46492445e-01, 1.43556333e+00], [-5.18593192e-02, 5.10287285e-02], [-1.29660487e-01, 2.11285043e+00], [-5.06802559e-01, -1.02589488e+00], [-9.87785589e-03, 8.68290842e-01], [-6.40534282e-01, -8.45777333e-01], [-1.61280204e-02, 9.94288087e-01], [-2.45859608e-01, -6.00105882e-01], [-4.40825522e-01, -2.50123441e-01], [-2.39165410e-01, 1.07840657e+00], [-3.09910148e-01, 1.95505893e+00], [-3.55430365e-01, 5.83627634e-02], [-5.40120155e-02, 1.53675467e-01], [-3.05950433e-01, 2.08115768e+00], [-3.27613026e-01, -8.74360979e-01], [ 1.89744234e-01, -1.44169092e-01], [-6.39236271e-02, 6.95572644e-02], [-2.76789933e-01, -9.16198730e-01], [-5.59525155e-02, -3.42101194e-02], [-9.92162153e-02, -1.55778718e+00], [-1.54982554e-02, 1.32259846e-01], [-1.09177075e-01, 3.04755092e-01], [-4.89147335e-01, -1.10692799e+00], [-1.53640181e-01, 3.51888239e-01], [-6.84823021e-02, 1.08417833e+00], [-7.53510416e-01, -1.61606646e+00], [-2.30860300e-02, -4.91644442e-02], [-1.58516362e-01, 6.54639065e-01], [-2.95944661e-01, -1.38793215e-01], [-2.98333019e-01, -1.17142625e-01], [-5.10916889e-01, -8.65615547e-01], [-1.16052195e-01, -4.22744155e-02], [-1.79733589e-01, 2.59376615e-02]]), array([[-0.00347338, -0.05872722, -0.00050712, ..., -0.01385312, -0.05683151, -0.06845538], [-0.05932254, 0.12597044, -0.05169398, ..., 0.15458557, -0.07934367, -0.02644371], [-0.03535785, -0.16315122, 0.02602659, ..., 0.01265492, -0.06003028, -0.14426833], ..., [-0.05476233, -0.12519033, -0.04841653, ..., -0.00105114, -0.00633659, -0.07282975], [-0.32794312, -0.0630713 , 0.01267534, ..., 0.03861733, -0.00643496, 0.4518356 ], [-0.00304984, 0.00233808, 0.03937962, ..., -0.02049346, -0.02721982, -0.0292502 ]]) , array([[-1.37788216e-02, -1.75700448e-02, -9.01518762e-03, -1.77550353e-02, 4.48740196e+00, 1.89883802e-02, 2.92103570e-02, -1.76942557e-01, 1.51347779e-02, 9.97370005e-01, -1.78765431e-02, 3.84710759e-01, -2.44137086e-02, 4.35012221e-01, 1.99974704e+00, -5.59847616e-02, 9.84124601e-01, -5.38679212e-02, -1.63768064e-02, -1.00092143e-02, -1.70517594e-01, 2.33019572e-02, 2.60525439e-02, -2.46700859e+00, 4.64061387e-02, 4.34389897e-03, 5.39909676e-02, 2.23193690e-03, 5.57129741e-01, -3.44100088e-04, -1.72053460e-05, -2.26990320e-03, 7.69877434e-03, 1.64730754e-02, -1.05446177e-02, 5.27715869e-03, 2.59022743e-01, -9.84112620e-01, -1.49734449e+00, -7.68469227e-03, 2.10542791e-02, 1.43484343e-04, 2.10938156e-01, -8.32508318e-03, 5.22896767e-01, 5.75597212e-03, -4.29411931e-03, 1.57470796e-02, 2.23527804e-01, -7.47562014e-03, 4.28060561e-01, -1.68241020e-02, 1.86267905e-02, 3.85014391e+00, -2.05261167e-02, 9.28721577e-03, 7.10609853e-01, -4.94079571e-03, 5.56911994e-03, 1.69520485e+00, 5.63182414e-01, -6.11842096e-01, 1.02576884e-02, 3.36717740e-02, -1.75212771e-02, -4.00439322e-01, -2.21974608e-02, 4.00515506e-03, 5.63668251e-01, -4.95153153e-03, -9.96480323e-03, -1.43687241e-02, 3.03780264e-03, -1.21312551e-02, -1.95983406e-02, 7.89121352e-03, -1.31047172e-02, -1.68624520e-02, 1.02414126e-02, 2.13016104e-02, 1.13913687e-02, 1.37405172e-02, -7.80674722e-03, -4.93164539e-01, 1.21939175e-01, 2.05372949e-03, 2.98700687e-02, 2.62831021e-02, -1.59251448e-02, -9.90365297e-02, -2.36145407e-02, -3.32626775e-02, -6.20973855e-02, -6.62014063e-04, -5.19719347e-02, -1.66715831e-02, -5.59134968e-03, 1.69547629e-02, 3.34512681e-01, 1.52108008e-02, 6.04002737e-03, 3.56838882e-01, 2.13405099e-02, 3.62083167e-02, 1.56843979e-02, -2.46515833e-02, -2.51025893e-02, 2.29414552e-02, 2.07180884e-02, 2.70087551e-03, -1.65229850e-02, -3.49190040e-03, 5.45528568e-02, 3.03819701e-02, 1.68553479e-02, 2.11687945e-03, 2.10794825e-02, 3.84027255e-04, -3.37043181e-02, -2.07593828e-03, -2.69289711e-04, -9.85833108e-01, 8.82928912e-03, -1.09777395e-02, -1.23013341e+00, 3.85869473e-01, -2.64463630e-02, -2.38611341e-01, 4.42886502e-01, 3.00123612e-03, -6.56607375e-03, 5.05365282e-02, 3.56168509e-03, -7.30367377e-03, -9.09072720e-03, -4.94694198e-03, 5.23686083e-03, 1.36592016e-02, 8.69852483e-01, 1.23906992e-02, -2.15562433e-03, -7.44694710e-01, 2.52146795e-02, -1.83370896e-02, 3.60017605e-02, 9.21368320e-03, -1.55515829e-02, -1.96728408e-02, 2.78673461e-03, 9.19356849e-03, -3.46681094e+00, -5.48373871e-02, -1.99966282e-01, 1.20204855e-02, -7.85164088e-02, 1.76548064e-02, 1.11598253e+00, -1.09377825e+00, 2.12640548e+00, 5.00349440e-02, 2.57247556e-02, -2.65226457e-02, 2.15803310e-02, -3.23484726e-02, 1.16806138e+00, 1.48812353e-04, -1.06846234e-02, -4.15321887e-02, 5.87398410e-02, -1.01899743e+00, 1.70089435e-02, 9.23677254e-03, -1.21974921e+00, 1.49432672e-02, 1.21426145e-02, 1.96776226e-01, 7.53171742e-04, -1.50537649e-02, -1.95960719e-02, -1.25899959e+00, 9.12787378e-01, 7.84498639e-03, 8.91653970e-02, -1.96704883e-02, 1.48478687e-01, 1.05110770e-02, -3.27640623e-01, -4.32579547e-01, -2.60793614e+00, -1.00472581e+00, 6.77171111e-01, -2.29745964e-03, 9.99804493e-03, 1.25056133e-02, 1.26885734e-02, -1.26918573e-02, -2.21185386e-02, -1.27081862e-02, 4.78114873e-01, 1.24667687e-02, 9.05451830e-03, 1.21412734e-02, -6.24935096e-03, 1.41170982e-03, 9.97739937e-03, 8.13253038e-03, -2.94227642e-03, 5.66245290e-03, 5.82674844e-03, 8.94912891e-03, 3.13076526e-02, -1.00340135e-03, 4.29024268e-03, 1.67629030e-02, 4.02269125e-01, 1.18631346e-03, 2.10501347e-03, 2.29349174e-03, -3.39133963e-02, -4.77075716e-03, 7.09763961e-03, -2.77705695e-02, 1.90233323e-03, -1.97377112e-02, -1.64122629e+00, 2.51169764e-02, 1.75400134e-02, -3.97822201e-01, 3.52268224e-03, 1.65995155e-02, 3.87875289e-02, 1.25200944e-02, -2.40553364e-01, -7.48400996e-03, -2.18713665e+00, -4.22828272e-02, -3.44225243e-02, 1.57810468e-02, -6.89691259e-03, 2.13804934e-02, -8.75570178e-01, -2.78839767e-02, 3.79608452e-01, 8.09172988e-02, -2.54286557e-01, 2.01689117e-02, 2.02423253e-04, 4.45776014e-03, -2.67576445e-02, 9.24475864e-03, 5.05982153e-02, -1.72549915e-02, -2.36192122e-02, -6.47810986e-03, -1.87972058e-02, -5.76387485e-03, 6.16780901e-03, -2.53893454e-02, 1.34960674e-02, 1.68234352e-02, 3.94389080e-03, -1.97180007e-02, 7.25476013e-04, -4.59107012e-02, -3.46406877e-01, 1.08132049e-01, 1.89604901e-03, -3.27849925e-01, 3.00043561e-02, -4.83688377e-02, -5.65785408e-01, 2.70299971e-01, 4.55066323e-01, -2.02022216e-04, 3.45581830e-01, -1.77612770e-02, -2.62108147e-02, 5.99280233e-03, -5.48362017e-01, -2.61357590e-03, 1.66500568e+00, 5.62610570e-03, -1.98433157e-02, -9.09086782e-03, -1.41522856e-02, 1.71686149e+00, -1.00530861e-02, 9.72271082e-04, -1.33344103e-02, -4.97914804e-03, 3.12344972e-02, 2.99100041e+00, -1.77206309e-03, -4.14526284e-01, -3.52810533e-03, -3.74223180e-02, 2.51768157e-02, 4.66037989e-02, 7.02720344e-01, 6.23262115e-03]]), array([[ 4.65491004e-02, 2.75948867e-02, -1.62194506e-03, -1.90341882e-02, 1.44851983e+00, 1.10667255e-02, 2.39623636e-02, 2.76989758e-01, -1.17389951e-02, -4.17082936e-01, 6.55591767e-03, -2.16769204e-01, -9.46768448e-02, 7.53203928e-02, -1.10858262e+00, 3.72481905e-02, -3.53154987e-01, -2.42546070e-02, 1.24317978e-03, -2.48324890e-02, 1.64321244e-01, 1.09806852e-02, 4.67169890e-03, 2.45411682e+00, 1.53180154e-03, -3.64842601e-02, 6.79149553e-02, 1.65834334e-02, -6.34780154e-02, 1.13838846e-02, 9.09686238e-02, 2.08018348e-02, -3.00305765e-02, -1.80672072e-02, 7.61659537e-03, -1.93655025e-02, 4.58933115e-02, 1.19281299e-01, 1.39705300e+00, -1.02607235e-02, 6.25323355e-02, 4.59831581e-02, 4.78866309e-01, 5.03114387e-02, -2.87342146e-02, -9.86713264e-03, -8.46157968e-03, -4.51642787e-03, -2.23251302e-02, -4.30305302e-02, -1.84713259e-01, 5.70968688e-02, 9.77251865e-03, 2.34741426e+00, 3.65196578e-02, -2.70402618e-02, -4.81732965e-01, -5.01764379e-03, -5.36360079e-04, 1.20130432e+00, 3.77781332e-01, 2.80936956e-01, 1.06814718e-02, 6.97409287e-02, 3.01600751e-02, -8.28526020e-02, 2.11727303e-02, -2.84313634e-02, 1.20653138e-01, 2.81074215e-02, 4.79727276e-02, 1.52691305e-02, 2.76692081e-02, -1.21541368e-02, -1.08070774e-02, 4.17977497e-02, 3.26774083e-02, 1.81504153e-02, 2.22151335e-02, 1.05428826e-02, 3.82720605e-02, -6.79117255e-03, 5.13538625e-03, 8.16609114e-02, 9.61923152e-02, 1.24034220e-02, -5.82645088e-02, -2.70967442e-03, 7.93109590e-04, 6.16127029e-02, -3.26875634e-02, 2.04683412e-02, -4.48205434e-02, -1.75952371e-02, 3.83201316e-02, -7.41406111e-03, 6.67514354e-02, -2.02544443e-02, 9.46394950e-02, -6.43957034e-03, -3.23279947e-02, -7.12516785e-01, 2.95855645e-02, 3.71551886e-02, 2.19547730e-02, -4.17075604e-02, 2.58115921e-02, 2.16520410e-02, -7.64212920e-04, 2.48281024e-02, -1.75198652e-02, 1.57203842e-02, -5.17965853e-03, 1.80782154e-02, 2.44993679e-02, -2.04011146e-02, -3.39620747e-02, -5.48322313e-03, 8.99045989e-02, 1.31340884e-02, 3.92878242e-02, 1.35290429e-01, -3.55715607e-03, 1.65412407e-02, 2.17481613e-01, -6.66095689e-02, 1.86464377e-02, -8.81323963e-02, -5.88616610e-01, -7.66875893e-02, 1.05265090e-02, -3.42469127e-03, 3.46742123e-02, 1.86264180e-02, 4.76501137e-02, 5.42733409e-02, -1.68142151e-02, 3.63312140e-02, -2.84217060e-01, 7.62471231e-03, 1.74686138e-03, 2.96686441e-01, 3.31194066e-02, 3.55131575e-03, -2.98136175e-02, -3.50989215e-02, 7.25846784e-03, 5.48427850e-02, -1.26271416e-02, 6.23666868e-02, 1.61194599e+00, 2.89767869e-02, 4.97422755e-01, 1.69700921e-01, -1.05567493e-01, 5.60784414e-02, -3.19344312e-01, 6.27972782e-01, 2.76824737e+00, 2.67419070e-02, 1.23290671e-02, -1.83889410e-03, 4.55528088e-02, 7.78805688e-02, -1.32425353e-01, -2.61146501e-02, -1.52382301e-02, 3.27687152e-02, -1.26175154e-02, -1.75102532e-01, 6.55494392e-01, 1.38704153e-02, 1.62684083e+00, 3.06072109e-03, 3.32440622e-02, 2.59396225e-01, -5.18346168e-02, -1.11456541e-03, -5.24866767e-03, -3.15264650e-02, 6.69284642e-01, -6.19462691e-03, 2.53582507e-01, 1.22156642e-01, -1.83854606e-02, -3.18263359e-02, 1.95150688e-01, 9.79222417e-01, 1.78013766e+00, -8.12911689e-02, -4.08305883e-01, -4.82058041e-02, 3.85942832e-02, 9.65110585e-03, 1.56665314e-02, 3.95751558e-02, 4.19400120e-03, 4.49118577e-02, 6.79667294e-01, 1.25250015e-02, 2.21961108e-03, 6.58174139e-03, 1.03429057e-01, 9.12185200e-03, 5.24358861e-02, -1.10104904e-02, -7.30941445e-02, 1.88877787e-02, 1.62834383e-03, -2.07051588e-03, -5.21547580e-03, -2.97072511e-02, 2.52311770e-02, -3.16604786e-02, -5.22877499e-02, -8.89985204e-01, 1.13187780e-04, -2.47289557e-02, 1.43413739e-02, -1.06425263e-01, -8.21397174e-03, -4.16969091e-01, 3.24987881e-02, 1.74501780e-02, 6.60500824e-01, -2.61616055e-02, -2.07697041e-02, 1.30887359e-01, 1.19397901e-02, 1.45824328e-02, -1.68448295e-02, 2.92910114e-02, 2.57602364e-01, -1.56166209e-02, 1.44536316e+00, 1.72522217e-02, -1.47966470e-03, -5.82053373e-03, 8.50211829e-04, 5.40737342e-03, -1.00989444e-02, 5.58562540e-02, 1.61847025e-01, -6.82225125e-03, 1.02822281e-01, -1.77668948e-02, -1.68647394e-02, 3.52216773e-02, 6.90010414e-02, 5.91865880e-03, 2.74301544e-02, 5.49219213e-02, 1.46120572e-02, -2.74932589e-02, 8.28231201e-02, -5.51200733e-02, 1.58854891e-02, 3.34876031e-02, -7.44700879e-02, -1.21594174e-02, -4.13211547e-02, -2.10661348e-03, 4.07803096e-02, -4.15315703e-02, -9.56187993e-02, 5.19551896e-03, 1.87933873e-02, -9.95918512e-02, 2.23887563e-02, 4.22002114e-02, 7.04833627e-01, -2.83067743e-03, -2.32858583e-01, 5.20206839e-02, 2.11295053e-01, -7.97515456e-03, 3.62693667e-02, -7.54455710e-03, 8.48646939e-01, -4.07221615e-02, 1.29890025e+00, -1.21367332e-02, 8.09317827e-03, -4.69723158e-02, 1.39973592e-04, 5.62448621e-01, 7.62438923e-02, 3.17619555e-02, -3.27488519e-02, -4.55514118e-02, 6.47964180e-02, 1.23330510e+00, 6.75912993e-03, -2.15177074e-01, 1.05284989e-01, 2.35673189e-02, 1.46848306e-01, -4.28327136e-02, -8.18562284e-02, 1.06871519e-02]])]], dtype=object)}
(lipenv) (base) shane16@dualarm-server:~/Project/model_guard/LipSDP/LipSDP$ python solve_sdp.py --form network --weight-path examples/saved_weights/uav_sac_actor_weights_T.mat Error using sparse Requested 246053x60196622500 (448.5GB) array exceeds maximum array size preference. Creation of arrays greater than this limit may take a long time and cause MATLAB to become unresponsive. See array size limit or preference panel for more information.
Error in cvx/sparse (line 51) xb = sparse( ix, ij, vx, max(ix), m * n );
Error in cvx/diag (line 35) y = sparse( roff + 1 : roff + nn, coff + 1 : coff + nn, v, nel, nel );
Error in lipschitz_multi_layer (line 55) T = T + E diag(zeta) E';
Error in solve_LipSDP (line 41) L = lipschitz_multi_layer(weights, lip_params.formulation, ...
Traceback (most recent call last): File "solve_sdp.py", line 112, in
main(args) File "solve_sdp.py", line 35, in main L = eng.solve_LipSDP(network, lip_params, nargout=1) File "/home/shane16/Project/model_guard/LipSDP/lipenv/lib/python3.7/site-packages/matlab/engine/matlabengine.py", line 71, in call _stderr, feval=True).result() File "/home/shane16/Project/model_guard/LipSDP/lipenv/lib/python3.7/site-packages/matlab/engine/futureresult.py", line 67, in result return self.__future.result(timeout) File "/home/shane16/Project/model_guard/LipSDP/lipenv/lib/python3.7/site-packages/matlab/engine/fevalfuture.py", line 82, in result self._result = pythonengine.getFEvalResult(self._future,self._nargout, None, out=self._out, err=self._err) matlab.engine.MatlabExecutionError: File /home/shane16/Project/model_guard/LipSDP/cvx/builtins/@cvx/sparse.m, line 51, in sparse File /home/shane16/Project/model_guard/LipSDP/cvx/builtins/@cvx/diag.m, line 35, in diag
File /home/shane16/Project/model_guard/LipSDP/LipSDP/matlab_engine/lipschitz_multi_layer.m, line 55, in lipschitz_multi_layer
File /home/shane16/Project/model_guard/LipSDP/LipSDP/matlab_engine/solve_LipSDP.m, line 41, in solve_LipSDP Requested 246053x60196622500 (448.5GB) array exceeds maximum array size preference. Creation of arrays greater than this limit may take a long time and cause MATLAB to become unresponsive. See array size limit or preference panel for more information.
Thank you for reading my long story And your help will be appreciated very much.
Regards, Shane K.
Can you print out the dimensions of your weight matrices in order from the first to last layer?
Sorry for my late response.
I have printed the shape of weight matrices in order:
param_tensor: actor.latent_pi.0.weight tensor: (400, 2) param_tensor: actor.latent_pi.2.weight tensor: (300, 400) param_tensor: actor.mu.weight tensor: (1, 300) param_tensor: actor.log_std.weight tensor: (1, 300)
code:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
import torch.optim as optim
from torchsummary import summary
from scipy.io import savemat
import numpy as np
import os
from stable_baselines3 import SAC
from unicycle import Unicycle
model_path = 'sac_uav_lunar0.019'
k1 = 0.019
sigma = 0.0
def main():
fname = os.path.join(os.getcwd(), 'saved_weights/uav_sac_actor_weights.mat')
env = Unicycle(k1=k1, sigma=sigma)
model = SAC("MlpPolicy", env, verbose=1,
buffer_size=1000000,
batch_size=256,
learning_rate=7.3e-4,
ent_coef='auto',
train_freq=1,
gradient_steps=1,
gamma=0.99,
tau=0.01,
learning_starts=10000,
policy_kwargs=dict(net_arch=[400, 300]),
tensorboard_log="./uav_lunar{}_tensorboard/".format(k1)
)
model, torch_data, torch_params, attr = SAC.load("sac_uav_lunar{}".format(k1))
# save data to saved_weights/ directory
weights = extract_weights(attr)
data = {'weights': np.array(weights, dtype=object)} #np.object
savemat(fname, data)
def extract_weights(net):
"""Extract weights of trained neural network model
params:
* net: torch.nn instance - trained neural network model
returns:
* weights: list of arrays - weights of neural network
"""
weights = []
for param_tensor in net.state_dict():
tensor = net.state_dict()[param_tensor].detach().numpy().astype(np.float64)
if 'actor' in param_tensor:
if 'weight' in param_tensor:
print('param_tensor:', param_tensor)
print('tensor: ', tensor.shape)
weights.append(tensor)
return weights
if __name__ == '__main__':
main()
Thank you.
Dear Alex Roby
Could the solution be to eliminate this layer?
param_tensor: actor.log_std.weight tensor: (1, 300)
Doing so, the code did work, but the output lipschitz constant is about 1391 Which is too large than expected.
I appreciate your help. Thank you.
Regards, Shane K.
Dear users & developers, especially to authors of this repository
I encountered the following error message while solve_sdp.py and stuck. Please save me solve this problem.
python solve_sdp.py --form layer --weight-path examples/saved_weights/sac_actor_weights.mat
I'm not sure how to solve this error because I have extracted weigths from my network and saved it as shown on Readme. my code to extract weights of actor network from my sac network: `import torch import torch.nn as nn from torchvision import datasets, transforms import torch.optim as optim from torchsummary import summary from scipy.io import savemat import numpy as np import os from stable_baselines3 import SAC from unicycle import Unicycle
model_path = 'sac_uav_lunar0.019' k1 = 0.019 sigma = 0.0
def main():
def extract_weights(net): """Extract weights of trained neural network model
if name == 'main': main()` I modified the Stablebaselines3 code that "attr" contains the information for state_dict
and the output for this code:
Thank you for your help.