Dynamic sequence length for transforme-based model - error when exporting from Python to MATLAB

2 visualizzazioni (ultimi 30 giorni)
We developed a simple transformer architecture (see the Python code below). This model, which we created using Python, can handle sequences of different lengths. I want to use my model in MATLAB. I tried to export the model to ONNX or to PT format. In both cases, I had to fix the input shape to export my model. I used the torch.jit.script() function in Python to trace and export my model in the .pt format. However, I think pytorchmex from the Deep Learning Toolbox Converter for PyTorch Models only works with torch.jit.trace.
I want to find a way to use a model in MATLAB that can accept inputs of any length.
Any help would be much appreciated.
# Python Code
# Model class to export
class TransformerModel(nn.Module):
def __init__(
self,
input_dim,
model_dim,
n_classes,
num_heads,
num_layers,
):
super(TransformerModel, self).__init__()
self.model_dim = model_dim
# Embedding Layer
self.embedding = nn.Linear(input_dim, model_dim)
# Transformer Encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=model_dim,
nhead=num_heads,
batch_first=True
)
self.transformer_encoder = nn.TransformerEncoder(
encoder_layer,
num_layers=num_layers
)
# Output Layer
self.output_layer = nn.Linear(model_dim, n_classes)
def forward(self, x, padding_mask):
padding_mask = ~padding_mask
x = self.embedding(x)
# Transformer Encoder
x = self.transformer_encoder(x, src_key_padding_mask=padding_mask)
# Model prediction
output = self.output_layer(x)
return output

Risposte (0)

Categorie

Scopri di più su Deep Learning with GPU Coder in Help Center e File Exchange

Prodotti


Release

R2024b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by