Skip to content

Example workflow using fit

Introduction

In this notebook, we are going to use the fit function to train a UniRep model.

Imports

Here are the imports that we are going to need for the notebook.

from jax.random import PRNGKey
from jax.experimental.stax import Dense, Softmax, serial

from jax_unirep import fit
from jax_unirep.evotuning_models import mlstm64
from jax_unirep.utils import load_params
from jax_unirep.layers import AAEmbedding, mLSTM, mLSTMHiddenStates
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Sequences

We'll prepare a bunch of dummy sequences.

In your actual use case, you'll probably need to find a way to load your sequences into memory as a list of strings. (We try our best to stick with Python idioms.)

sequences = ["HASTA", "VISTA", "ALAVA", "LIMED", "HAST", "HAS", "HASVASTA"] * 5
holdout_sequences = [
    "HASTA",
    "VISTA",
    "ALAVA",
    "LIMED",
    "HAST",
    "HASVALTA",
] * 5

Example 1: Default mLSTM model

In this first example, we'll use a default mLSTM1900 model with the shipped weights that are provided.

Nothing needs to be passed in except for:

  1. The sequences to evotune against, and
  2. The number of epochs.

It's the easiest/fastest way to get up and running.

# First way: Use the default mLSTM1900 weights with mLSTM1900 model.

tuned_params = fit(sequences, n_epochs=2)
INFO:evotuning:Random batching done: All sequences padded to max sequence length of 8

INFO:evotuning:Calculations for training set:
INFO:evotuning:Epoch 0: Estimated average loss: 0.20475944876670837. 

created directory at temp

INFO:evotuning:Calculations for training set:
INFO:evotuning:Epoch 1: Estimated average loss: 0.15383651852607727. 

Example 2: Pre-build model architectures

The second way is to use one of the pre-built evotuning models. The pre-trained weights for the three model architectures from the paper are shipped with the repo (1900, 256, 64). You can also leverage JAX to reproducibly initialize random parameters.

In this example, we'll use the mlstm64 model. The mlstm256 model is also available, and it might give you better performance though at the price of longer training time.

init_fun, apply_fun = mlstm64()

# The init_func always requires a PRNGKey,
# and input_shape should be set to (-1, 26)
# This creates randomly initialized parameters
_, params = init_fun(PRNGKey(42), input_shape=(-1, 26))

# Alternatively, you can load the paper weights
params = load_params(paper_weights=64)


# Now we tune the params.
tuned_params = fit(sequences, n_epochs=2, model_func=apply_fun, params=params)
INFO:evotuning:Random batching done: All sequences padded to max sequence length of 8

INFO:evotuning:Calculations for training set:
INFO:evotuning:Epoch 0: Estimated average loss: 0.17912301421165466. 
INFO:evotuning:Calculations for training set:
INFO:evotuning:Epoch 1: Estimated average loss: 0.1769668161869049. 

Example 3: Build your own model

Finally, the modular style of jax-unirep allows you to easily try out your own model architectures. You could for example change the amount of inital embedding dimensions, or the mLSTM architecture. Let's try a model with 20 inital embedding dimensions instead of 10, and two stacked mLSTM's with 512 hidden states each:

model_layers = (
        AAEmbedding(20),
        mLSTM(512),
        mLSTMHiddenStates(),
        mLSTM(512),
        mLSTMHiddenStates(),
        Dense(25),
        Softmax,
    )

init_fun, apply_fun = serial(*model_layers)

_, params = init_fun(PRNGKey(42), input_shape=(-1, 26))

tuned_params = fit(sequences, n_epochs=2, model_func=apply_fun, params=params)
INFO:evotuning:Random batching done: All sequences padded to max sequence length of 8

INFO:evotuning:Calculations for training set:
INFO:evotuning:Epoch 0: Estimated average loss: 0.15384122729301453. 
INFO:evotuning:Calculations for training set:
INFO:evotuning:Epoch 1: Estimated average loss: 0.15379181504249573. 

Obviously...

...you would probably swap in/out a different set of sequences and train for longer :).