This is a remake of a previous post due to its popularity. Let’s consider a PyTorch network and a Tensorflow/Keras network. If they are in exactly the same architecture, we should be able to copy the weight parameters from one to another. But how?

Before we proceed, we should notice that there are numpy arrays, PyTorch tensors, and Tensorflow tensors involved. They are similar but not identical. The result should be close but not necessarily tie out to the bits. After all, adding floating points in a different order may give you different result and the FPU in CPU and in GPU may operate differently.

PyTorch model

Let’s make a simple PyTorch model. LeNet-5 would be a good choice since there are convolutional layers and full-connected layers. We use the MNIST dataset and train one in PyTorch:

import datetime

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision

torch.set_printoptions(edgeitems=2)
np.set_printoptions(precision=4, linewidth=150)

# Load MNIST data
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(), # required, otherwise MNIST are in PIL format
])
train = torchvision.datasets.MNIST('./_datafiles/', train=True, download=True, transform=transform)
test = torchvision.datasets.MNIST('./_datafiles/', train=False, download=True, transform=transform)

# As iterator for data and target
train_loader = torch.utils.data.DataLoader(train, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test, batch_size=32, shuffle=True)

# LeNet5 model: ReLU can be used intead of tanh
torch_model = nn.Sequential(
    # input 1x28x28, output 6x28x28
    nn.Conv2d(1, 6, kernel_size=(5,5), stride=1, padding=2),
    nn.Tanh(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    # output from pooling 6x14x14
    nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
    # output from conv2d 16x10x10
    nn.Tanh(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    # output from pooling 16x5x5
    nn.Conv2d(16, 120, kernel_size=5, stride=1, padding=0),
    # output from pooling 120x1x1
    nn.Tanh(),
    nn.Flatten(),
    nn.Linear(120, 84),
    nn.Tanh(),
    nn.Linear(84, 10),
    nn.Softmax(dim=1)
)
print(torch_model)

# Training loop
def training_loop(model, optimizer, loss_fn, train_loader, val_loader=None, n_epochs=100, early_stopping=None):
    best_loss, best_epoch = np.inf, -1
    best_state = model.state_dict()

    for epoch in range(n_epochs):
        # Training
        model.train()
        train_loss = 0
        for data, target in train_loader:
            output = model(data)
            loss = loss_fn(output, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        # Validation
        model.eval()
        status = (f"{str(datetime.datetime.now())} End of epoch {epoch}, "
                  f"training loss={train_loss/len(train_loader)}")
        if val_loader:
            val_loss = 0
            for data, target in val_loader:
                output = model(data)
                loss = loss_fn(output, target)
                val_loss += loss.item()
            status += f", validation loss={val_loss/len(val_loader)}"
        print(status)
        # Early stopping logic
        if early_stopping is not None:
            if val_loss < best_loss:
                best_loss, best_epoch = val_loss, epoch
                best_state = model.state_dict()
            elif best_epoch + early_stopping < epoch:
                print(f"Patience of {early_stopping} exhausted. Early stopping "
                      f"and restoring weight at epoch {best_epoch}.")
                model.load_state_dict(best_state)
                break

# Run training
optimizer = optim.Adam(torch_model.parameters())
criterion = nn.CrossEntropyLoss()
training_loop(torch_model, optimizer, criterion, train_loader, test_loader, n_epochs=100, early_stopping=4)
# Save model
torch.save(torch_model, "lenet5.pt")

The above is how we build a network in PyTorch and train it. The network created above will be like the following (from the print() command):

Sequential(
  (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (1): Tanh()
  (2): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (4): Tanh()
  (5): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (6): Conv2d(16, 120, kernel_size=(5, 5), stride=(1, 1))
  (7): Tanh()
  (8): Flatten(start_dim=1, end_dim=-1)
  (9): Linear(in_features=120, out_features=84, bias=True)
  (10): Tanh()
  (11): Linear(in_features=84, out_features=10, bias=True)
  (12): Softmax(dim=1)
)

After the model is trained, PyTorch allows us to dump all the weights as the state_dict. In fact, this is how the early stopping code in the training function does to restore the best weight. We can see how the state dict looks like:

{k: (v.dtype, v.shape) for k, v in torch_model.state_dict().items()}

which gives

{'0.weight': (torch.float32, torch.Size([6, 1, 5, 5])),
 '0.bias': (torch.float32, torch.Size([6])),
 '3.weight': (torch.float32, torch.Size([16, 6, 5, 5])),
 '3.bias': (torch.float32, torch.Size([16])),
 '6.weight': (torch.float32, torch.Size([120, 16, 5, 5])),
 '6.bias': (torch.float32, torch.Size([120])),
 '9.weight': (torch.float32, torch.Size([84, 120])),
 '9.bias': (torch.float32, torch.Size([84])),
 '11.weight': (torch.float32, torch.Size([10, 84])),
 '11.bias': (torch.float32, torch.Size([10]))}

So each key contains the layer number. Each value is a Torch tensor.

Tensorflow/Keras model

Build exactly the same model in Keras would be like this:

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, Dense, AveragePooling2D, Dropout, Flatten
from tensorflow.keras.datasets import mnist
from tensorflow.keras.callbacks import EarlyStopping

# LeNet5 model
keras_model = Sequential([
    Conv2D(6, (5,5), input_shape=(28,28,1), padding="same", activation="tanh"),
    AveragePooling2D((2,2), strides=2),
    Conv2D(16, (5,5), activation="tanh"),
    AveragePooling2D((2,2), strides=2),
    Conv2D(120, (5,5), activation="tanh"),
    Flatten(),
    Dense(84, activation="tanh"),
    Dense(10, activation="softmax")
])
keras_model.summary(line_length=100)

# Reshape data to shape of (n_sample, height, width, n_channel)
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = np.expand_dims(X_train, axis=3).astype('float32')
X_test = np.expand_dims(X_test, axis=3).astype('float32')

# Train
keras_model.compile(loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
earlystopping = EarlyStopping(monitor="val_loss", patience=4, restore_best_weights=True)
keras_model.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=100, batch_size=32, callbacks=[earlystopping])

# Save
keras_model.save("lenet5.h5")

The Keras model created will be the following:

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d (Conv2D)             (None, 28, 28, 6)         156       
                                                                 
 average_pooling2d (AverageP  (None, 14, 14, 6)        0         
 ooling2D)                                                       
                                                                 
 conv2d_1 (Conv2D)           (None, 10, 10, 16)        2416      
                                                                 
 average_pooling2d_1 (Averag  (None, 5, 5, 16)         0         
 ePooling2D)                                                     
                                                                 
 conv2d_2 (Conv2D)           (None, 1, 1, 120)         48120     
                                                                 
 flatten (Flatten)           (None, 120)               0         
                                                                 
 dense (Dense)               (None, 84)                10164     
                                                                 
 dense_1 (Dense)             (None, 10)                850       
                                                                 
=================================================================
Total params: 61,706
Trainable params: 61,706
Non-trainable params: 0
_________________________________________________________________

and this is how we can investigate the model weights:

{w.name: (w.dtype, w.shape) for w in keras_model.weights}

which gives:

{'conv2d/kernel:0': (tf.float32, TensorShape([5, 5, 1, 6])),
 'conv2d/bias:0': (tf.float32, TensorShape([6])),
 'conv2d_1/kernel:0': (tf.float32, TensorShape([5, 5, 6, 16])),
 'conv2d_1/bias:0': (tf.float32, TensorShape([16])),
 'conv2d_2/kernel:0': (tf.float32, TensorShape([5, 5, 16, 120])),
 'conv2d_2/bias:0': (tf.float32, TensorShape([120])),
 'dense/kernel:0': (tf.float32, TensorShape([120, 84])),
 'dense/bias:0': (tf.float32, TensorShape([84])),
 'dense_1/kernel:0': (tf.float32, TensorShape([84, 10])),
 'dense_1/bias:0': (tf.float32, TensorShape([10]))}

The keras_model.weights gives all the weights in a list in the format of a Tensorflow tensor. We can find the layer it corresponds to with the name. The other way to get the weights is to call keras_model.get_weights(), but the result would be a list of NumPy arrays, which we do not have the names to go with each.

Converting from PyTorch to Keras

By merely observing the weights between the PyTorch and Keras model, we can see that the two framework set the weights in different shape. In particular,

  • For convolutional layer, Keras set the weight in (height, width, input, output) but PyTorch uses (output, input, height, weight)
  • For dense or fully-connected layer, Keras weight is in (output, input) but PyTorch is in (input, output). This is a difference between left-multiply and right-multiply a matrix.

Now we have to consider how to set the weights into Keras model. Since we learned that we can get weights using get_weights() so we should be able to set using set_weights(), which the argument should be a list of NumPy arrays. The other way to do is to pick up each layer, find the kernel, and assign a matrix to it. But this way is specific to the layer. So we use set_weights() for its convenience.

Now we need to transform the PyTorch weights to Keras weights. This is how it is done:

# Fetch weight from torch model
torch_weights = torch_model.state_dict()
# Reshape weights for Keras model
keras_weights = [w.numpy() for w in torch_weights.values()]
for i in [0, 2, 4]:
    # conv2d layer: Torch (out,in,h,w) Keras (h,w,in,out)
    keras_weights[i] = np.moveaxis(keras_weights[i], [0,1], [-1,-2])
for i in [6, 8]:
    # dense layer: transpose
    keras_weights[i] = keras_weights[i].T
# Set to Keras model
keras_model.set_weights(keras_weights)

This works because the state dicts is an ordered dict with the keys in the same order as the layers’ order. So as what set_weights() expects in Keras. So we transpose the weight arrays to the correct shape to match. If the shape is wrong, Keras will complain.

Now let’s check: We have test_loader to hold the test set as a generator. So we just pick one and verify if the numbers tie out:

data, targets = next(iter(test_loader))
data = data.numpy()
print(data.shape)
print(torch_model(torch.tensor(data[:1])).detach().numpy())
print(keras_model.predict(np.moveaxis(data[:1], 1, -1), verbose=0))

This outputs:

(32, 1, 28, 28)
[[5.3237e-06 3.2647e-03 5.9966e-04 3.4837e-02 3.0882e-06 1.3731e-04 3.3251e-07 9.6026e-01 7.4162e-04 1.5280e-04]]
[[5.3237e-06 3.2647e-03 5.9966e-04 3.4837e-02 3.0882e-06 1.3731e-04 3.3251e-07 9.6026e-01 7.4162e-04 1.5280e-04]]

So the generator gives a batch of 32 samples and in the shape of (batch size, channel, height, width). We clip for only the first sample for our test. But Keras model expects the input in the shape of (batch size, height, width, channel) so we need to move the axis 1 to last axis on calling predict(). As we can see above, the output ties out perfectly.

Per-layer output

In fact, we can also compare the output layer by layer. Note that the activation function is part of a layer in Keras but as a separate module in PyTorch. Hence we have to compare to result after activation only.

In PyTorch, we can add a “forward hook” to capture each layer’s output:

# Add hooks: https://discuss.pytorch.org/t/how-can-i-extract-intermediate-layer-output-from-loaded-cnn-model/77301/3
torch_activation = {}
def get_activation(name):
    def hook(model, input, output):
        torch_activation[name] = output.detach()
    return hook

for key, layer in torch_model.named_children():
    layer.register_forward_hook(get_activation(key))

sample = data[:1]
torch_model(torch.Tensor(sample.reshape(1,1,28,28))).detach().numpy()

but in Keras, we have to build separate models:

submodels = {}
for layer in keras_model.layers:
    submodels[layer.name] = tf.keras.models.Model(inputs=keras_model.input, outputs=layer.output)

keras_activation = {k: m.predict(sample.reshape(1,28,28,1), verbose=0) for k,m in submodels.items()}

and then we can compare, for example, the result of first convolution layer after tanh activation:

print(torch_activation["1"].detach().numpy())
print(np.moveaxis(keras_activation["conv2d"], -1, 1))

or that of third convolution layer:

print(torch_activation["7"].detach().numpy().flatten())
print(np.moveaxis(keras_activation["conv2d_2"], -1, 1).flatten())

or that of the first dense layer:

print(torch_activation["10"].detach().numpy())
print(keras_activation["dense"])

Or the other way to compare is to create a new layer and copy over the weights:

torch_conv = nn.Conv2d(1, 2, kernel_size=(3,3), stride=1, padding=1)
torch_weights = torch_conv.state_dict()

keras_conv = tf.keras.layers.Conv2D(2, (3,3), padding="same")
keras_conv.build((1,28,28,1))
keras_weights = [w.numpy() for w in torch_weights.values()]
keras_weights[0] = np.moveaxis(keras_weights[0], [0,1], [-1,-2])
keras_conv.set_weights(keras_weights)

torch_out = torch_conv(torch.Tensor(sample.reshape(1,28,28)))
keras_out = keras_conv.call(sample.reshape(1,28,28,1))
print(torch_out.detach().numpy())
print(np.moveaxis(keras_out, -1, 1))

but surely not that convenient.

Copying weights from Keras to PyTorch

As we can extract weights from Keras model and PyTorch model:

keras_weights = {w.name: (w.dtype, w.shape) for w in keras_model.weights}
torch_weights = torch_model.state_dict()

which,

print(keras_weights.keys())
print(torch_weights.keys())

gives

dict_keys(['conv2d/kernel:0', 'conv2d/bias:0', 'conv2d_1/kernel:0',
'conv2d_1/bias:0', 'conv2d_2/kernel:0', 'conv2d_2/bias:0', 'dense/kernel:0',
'dense/bias:0', 'dense_1/kernel:0', 'dense_1/bias:0'])
odict_keys(['0.weight', '0.bias', '3.weight', '3.bias', '6.weight', '6.bias',
'9.weight', '9.bias', '11.weight', '11.bias'])

we can match layer-by-layer from the Keras model to the PyTorch model for they are in the same order. So,

new_weights = {}

for k,t in zip(keras_weights.keys(), torch_weights.keys()):
    if k in ["conv2d/kernel:0", "conv2d_1/kernel:0", "conv2d_2/kernel:0"]:
        new_weights[t] = torch.Tensor(np.moveaxis(keras_weights[k], [-1, -2], [0, 1]))
    elif k in ["dense/kernel:0", "dense_1/kernel:0"]:
        new_weights[t] = torch.Tensor(keras_weights[k].T)
    else:
        new_weights[t] = torch.Tensor(keras_weights[k])

torch_model.load_state_dict(new_weights)

Then we can verify that the two models are now behaved exactly the same:

from tensorflow.keras.datasets import mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data()
sample = X_test[42]

keras_out = keras_model.predict(sample.reshape(1,28,28,1), verbose=0)
torch_out = torch_model(torch.Tensor(sample.reshape(1,1,28,28))).detach().numpy()

print(keras_out)
print(torch_out)