Having build the same LSTM network using PyTorch and Tensorflow 2, this is an exercise on how to copy the trained model from one platform to another. While there may be issue arose from the floating point precisions mismatch between the two platforms, I believe the effect is usually not material. What I want to achieve is to copy the weight tensors from one model to another given we have the same architectures built.

(Note: Please see a rewrite of this for more detailed code)

Tensorflow save its model in HDF5 format while PyTorch has its own .pth format (which is a ZIP file of pickled model and other data). However, the major difference is how the models are built: In Tensorflow 2, it is natural to build the architecture in Keras interface and we do not invent any new data types. In PyTorch, it is idiomic to create a class derived from nn.Module, create layers as member variables, and describe the pipeline in forward() member function. Therefore a pickled model from Tensorflow can be loaded easily given we have the Tensorflow package installed. In PyTorch, however, we must define the same model in the namespace before we can load.

Therefore this is how we load the two models as we defined previously, note we only need to provide the code for the PyTorch model class, but not needed to instantiate one before loading the model from .pth file:

import tensorflow as tf
import torch
import torch.nn as nn

class Model(nn.Module):
    'LSTM neural network model in PyTorch'
    def __init__(self, input_dim=1, hidden_dim=50, output_dim=1, num_layers=4, dropout=0.2):
        super(Model, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        # LSTM multilayer and dense layer
        self.lstm = nn.LSTM(self.input_dim, self.hidden_dim, num_layers, dropout=dropout, batch_first=True)
        self.dropout = nn.Dropout(p=dropout)
        self.linear = nn.Linear(self.hidden_dim, self.output_dim)

    def forward(self, input, hidden=None):
        # input.shape = [batch_size, seq_length, input_dim]
        # lstm_out.shape = [batch_size, seq_length, hidden_dim]
        # self.hidden is a tuple of (h,c) which both are of shape [num_layers, batch_size, hidden_dim]
        lstm_out, self.hidden = self.lstm(input, hidden)
        # take output from the final timestep
        y_pred = self.linear(self.dropout(lstm_out[:,-1,:]))
        return y_pred # shape = [batch_size, out_dim]

tfmodel = tf.keras.models.load_model('tf-keras-model.h5')
torchmodel = torch.load('pytorch-model.pth')

To recap, these models are four layers of LSTM and one layer of fully-connected output. We can examine the models by printing tfmodel.summary() or printing torchmodel directly. To access to each individual layer of the network in TensorFlow, and check their weight tensors, we can do this:

for i,l in enumerate(tfmodel.layers):
    print(i, l)
for i in [0,2,4,6,8]:
    print("L{}: {}".format(i, ", ".join(str(w.shape) for w in tfmodel.layers[i].weights)))

and this is the output:

0 <tensorflow.python.keras.layers.recurrent_v2.LSTM object at 0x7fc457d37ac8>
1 <tensorflow.python.keras.layers.core.Dropout object at 0x7fc454606c18>
2 <tensorflow.python.keras.layers.recurrent_v2.LSTM object at 0x7fc45456ae48>
3 <tensorflow.python.keras.layers.core.Dropout object at 0x7fc454584080>
4 <tensorflow.python.keras.layers.recurrent_v2.LSTM object at 0x7fc4545b42b0>
5 <tensorflow.python.keras.layers.core.Dropout object at 0x7fc48daae3c8>
6 <tensorflow.python.keras.layers.recurrent_v2.LSTM object at 0x7fc454572f60>
7 <tensorflow.python.keras.layers.core.Dropout object at 0x7fc4546243c8>
8 <tensorflow.python.keras.layers.core.Dense object at 0x7fc45455cf60>
L0: (1, 200), (50, 200), (200,)
L2: (50, 200), (50, 200), (200,)
L4: (50, 200), (50, 200), (200,)
L6: (50, 200), (50, 200), (200,)
L8: (50, 1), (1,)

In other words, we can access to the weight tensors of a Tensorflow model with tfmodel.layers[i].weights and there may be multiple tensors in each layer (e.g. dense layer has a weight matrix and a bias vector). The Tensorflow weights are easy to manipulate as they are merely member variables of each layer. An example of setting a weight tensor to zero is this:

model.layers[0].weights[0].assign(np.zeros((1,200)))

we just need to make sure the numpy array assigned to the weight tensor is in the same shape. The PyTorch counterpart, however, are not so trivial. If we access the weights in each layers directly, the tensors are read-only. This is how we make a copy of the weights for the hidden states transformation on the first LSTM layer (layer 0):

torchmodel.lstm.weight_ih_l0.view(-1).detach().numpy()

It is a numpy array cloned from the tensor, therefore any mutation on this will not propagate back to the model. The right way to mutate the model is to build a state dict, make changes, and load it back. This is how to make the state dict:

torchparam = torchmodel.state_dict()
for k,v in torchparam.items():
    print("{:20s} {}".format(k, v.shape))

The output is:

lstm.weight_ih_l0    torch.Size([200, 1])
lstm.weight_hh_l0    torch.Size([200, 50])
lstm.bias_ih_l0      torch.Size([200])
lstm.bias_hh_l0      torch.Size([200])
lstm.weight_ih_l1    torch.Size([200, 50])
lstm.weight_hh_l1    torch.Size([200, 50])
lstm.bias_ih_l1      torch.Size([200])
lstm.bias_hh_l1      torch.Size([200])
lstm.weight_ih_l2    torch.Size([200, 50])
lstm.weight_hh_l2    torch.Size([200, 50])
lstm.bias_ih_l2      torch.Size([200])
lstm.bias_hh_l2      torch.Size([200])
lstm.weight_ih_l3    torch.Size([200, 50])
lstm.weight_hh_l3    torch.Size([200, 50])
lstm.bias_ih_l3      torch.Size([200])
lstm.bias_hh_l3      torch.Size([200])
linear.weight        torch.Size([1, 50])
linear.bias          torch.Size([1])

So we can see the PyToch model not only has a different way of naming layers and tensors, but also the tensors have different shape. For example, we can see that the LSTM layer in Tensorflow is

\[f_i = \sigma(x_t W_f + h_{t-1} U_f + b_f)\]

while that of PyTorch is

\[f_i = \sigma(W_f x_t + b_{if} + U_f h_{t-1} U_f + b_{hf})\]

That is, Tensorflow is using right-multiplication of weight matrices to state vectors but PyToch is using left-multiplication. Note that each layer of LSTM has four such linear transformation, hence for a hidden dimension of 50, the weight matrices as represented by the models has a dimension of 200.

For our goal of porting the weight tensors from Tensorflow to PyTorch, we therefore need to take each matrix from the former, transpose it, and assign it to the corresponding counterpart of the latter. The code below also zero out one of the PyTorch bias vectors in each LSTM layer as the Tensorflow model has only one bias, not two:

# Overwrite PyTorch model with Tensorflow weights
for layer in range(4):
    t = 2*layer # tensorflow layer offset
    key = "lstm.bias_ih_l{}".format(layer)
    torch.zero_(torchparam[key]) # zero out the ih bias
    names = ["weight_ih", "weight_hh", "bias_hh"]
    for name, tfweight in zip(names, tfmodel.layers[t].weights):
        key = "lstm.{}_l{}".format(name, layer)
        torchparam[key] = torch.tensor(tfweight.numpy().T)
    torchname = 'lstm.weight'
torchparam['linear.weight'] = torch.tensor(tfmodel.layers[8].weights[0].numpy().T)
torchparam['linear.bias'] = torch.tensor(tfmodel.layers[8].weights[1].numpy().T)
torchmodel.load_state_dict(torchparam)

After we load the state dict back to the PyTorch model, it will behave exactly as the Tensorflow model.

Note: I notice the title of this post means reverse of the content. Indeed setting weights in Tensorflow model is much easier. Assume you already get the weights (e.g., from PyTorch model) and converted it into numpy, you can manipulate the weights of each layer with set_weights() function.

As a proof-of-concept, you can try the following:

import tensorflow as tf

# Load model (Keras HDF5 assumed)
model = tf.keras.models.load_model("mymodel.h5")
# Extract weights of one layer as a list of numpy arrays
weights = model.layers[0].get_weights()
# Modify one element in one of the array, or other manipulation
weights[0][0] = 0.1
# Overwrite the weights
model.layers[0].set_weights(weights)

Of course, you need to make sure each and every weight arrays are in the right shape that the layer expects it. If you obtain the weights from PyTorch, it means that usually you need to transpose.