Example workflow using fit
Introduction
In this notebook, we are going to use the fit
function to train a UniRep model.
Imports
Here are the imports that we are going to need for the notebook.
from jax.random import PRNGKey
from jax.experimental.stax import Dense, Softmax, serial
from jax_unirep import fit
from jax_unirep.evotuning_models import mlstm64
from jax_unirep.utils import load_params
from jax_unirep.layers import AAEmbedding, mLSTM, mLSTMHiddenStates
Sequences
We'll prepare a bunch of dummy sequences.
In your actual use case, you'll probably need to find a way to load your sequences into memory as a list of strings. (We try our best to stick with Python idioms.)
sequences = ["HASTA", "VISTA", "ALAVA", "LIMED", "HAST", "HAS", "HASVASTA"] * 5
holdout_sequences = [
"HASTA",
"VISTA",
"ALAVA",
"LIMED",
"HAST",
"HASVALTA",
] * 5
Example 1: Default mLSTM model
In this first example, we'll use a default mLSTM1900 model with the shipped weights that are provided.
Nothing needs to be passed in except for:
- The sequences to evotune against, and
- The number of epochs.
It's the easiest/fastest way to get up and running.
# First way: Use the default mLSTM1900 weights with mLSTM1900 model.
tuned_params = fit(sequences, n_epochs=2)
Example 2: Pre-build model architectures
The second way is to use one of the pre-built evotuning models. The pre-trained weights for the three model architectures from the paper are shipped with the repo (1900, 256, 64). You can also leverage JAX to reproducibly initialize random parameters.
In this example, we'll use the mlstm64
model.
The mlstm256
model is also available,
and it might give you better performance
though at the price of longer training time.
init_fun, apply_fun = mlstm64()
# The init_func always requires a PRNGKey,
# and input_shape should be set to (-1, 26)
# This creates randomly initialized parameters
_, params = init_fun(PRNGKey(42), input_shape=(-1, 26))
# Alternatively, you can load the paper weights
params = load_params(paper_weights=64)
# Now we tune the params.
tuned_params = fit(sequences, n_epochs=2, model_func=apply_fun, params=params)
Example 3: Build your own model
Finally, the modular style of jax-unirep
allows you to easily try out your own model architectures. You could for example change the amount of inital embedding dimensions, or the mLSTM architecture. Let's try a model with 20 inital embedding dimensions instead of 10, and two stacked mLSTM's with 512 hidden states each:
model_layers = (
AAEmbedding(20),
mLSTM(512),
mLSTMHiddenStates(),
mLSTM(512),
mLSTMHiddenStates(),
Dense(25),
Softmax,
)
init_fun, apply_fun = serial(*model_layers)
_, params = init_fun(PRNGKey(42), input_shape=(-1, 26))
tuned_params = fit(sequences, n_epochs=2, model_func=apply_fun, params=params)
Obviously...
...you would probably swap in/out a different set of sequences and train for longer :).