APIs that support "advanced" tasks are available in
Read on to learn how to use them.
In the original UniRep paper, the authors introduced the concept of 'evolutionary finetuning'. Here the pre-trained mLSTM weights get fine-tuned through weight-updates using homolog protein sequences of a given protein of interest as input.
This feature is available as well in
Given a set of starter weights for the mLSTM (defaults to
the weights from the paper) as well as a set of sequences,
the weights get fine-tuned in such a way that test set loss
in the 'next-aa prediction task' is minimized.
There are two functions with differing levels of control available.
evotune function uses
optuna under the hood
to automatically find:
- the optimal number of epochs to train for, and
- the optimal learning rate,
given a set of sequences.
study object will contain all the information
about the training process of each trial.
evotuned_params will contain the fine-tuned mLSTM and dense weights
from the trial with the lowest test set loss.
Speed freaks read this!
As a heads-up, using
evotune is kind of slow,
so read on if you're of the impatient kind -- use
If you want to directly fine-tune the weights
for a fixed number of epochs
while using a fixed learning rate,
you should use the
fit function instead.
fit function has further customization options,
such as different batching strategies.
Please see the function docstring here for more information.
fit function will always default to using a
backend if available for the forward and backward passes
during training of the LSTM.
However, for the calulation of the average loss
on the dataset after every epoch, you can decide
if the CPU or GPU
backend should be used (default is CPU).
Read the docs!
Can't emphasize this enough:
Be sure to read the API docs for the
to learn about what's going on underneath the hood!
If you want to pass a set of embedding, mLSTM and dense weights that were dumped in an earlier run, create params as follows:
from jax_unirep.utils import load_params params = load_params(folderpath="path/to/params/folder")
Make sure that the params were created using the same model architecture that you want to use them with!
If you want to start from randomly initialized embedding, mLSTM and dense weights instead:
from jax_unirep.evotuning_models import mlstm1900 from jax.random import PRNGKey init_fun, apply_fun = mlstm1900() _, params = init_fun(PRNGKey(42), input_shape=(-1, 26))
End-to-end differentiable models
As a user, you might want to write custom "top models", such as a linear model on top of the reps, but might want to jointly optimize the UniRep weights with the top model reps. You're in luck!
We implemented the mLSTM layers in such a way that
they are compatible with
This means that they can easily be plugged into
stax.serial model, e.g. to train both the mLSTM
and a top-model at once:
from jax.experimental import stax from jax.experimental.stax import Dense, Relu, serial from jax_unirep.layers import AAEmbedding, mLSTM, mLSTMAvgHidden init_fun, apply_fun = serial( AAEmbedding(10) mLSTM(1900), mLSTMAvgHidden(), # Add two layers, one dense layer that results in 512-dim activations Dense(512), Relu(), # And then a linear layer to produce a 1-dim activation Dense(1) )
Sampling new protein sequences
When doing protein engineering,
one core task is proposing new sequences to order by gene synthesis.
jax-unirep provides a number of utility functions inside
that help with this task.
The key one to focus on is the
This function takes in a starting sequence, and uses Monte Carlo sampling alongside the Metropolis-Hastings criteria to score and rank-order new sequences to try out. The usage pattern is as follows.
Firstly, you must have a scoring function defined that takes in a string sequence, and outputs a number. This can be, for example, in the form of a pre-trained machine learning model that you have created.
from jax_unirep import get_reps model = SomeSKLearnModel() model.fit(training_X, training_y) def scoring_func(sequence: str): reps, _, _ = get_reps(sequence) return model.predict(reps)
Now, we can use MCMC sampling to propose new sequences.
from jax_unirep import sample_one_chain starter_seq = "MKLNEQLJLA" # can be longer! sampled_sequences = sample_one_chain(starter_seq, n_steps=10, scoring_func=scoring_func) sampled_seqs_df = pd.DataFrame(sampled_sequences)
sampled_sequences is a dictionary
that can be converted directly into a
In there, every single sequence that was ever sampled is recorded,
as well as its score (given by the scoring function)
and whether it was accepted by the MCMC sampler or not.
(All generated sequences are recorded,
just in case there was something good that was rejected!)
If you want to do parallel sampling, you can use any library that does parallel processing. We're going to show you one example using Dask, which happens to be out favourite library for scalable Python!
Assuming you have a Dask
client object instantiated:
client = Client(...) # you'll have to configure this according to your own circumstances starter_seq = "MKLNEQLJLA" # can be longer! chain_results_futures =  for i in range(100): # sample 100 independent chains chain_results_futures.append( # Submit tasks to workers client.submit( sample_one_chain, starter_seq, n_steps=10, scoring_func=scoring_func, pure=False # this is important, esp. with random sampling methods ) ) # Gather results from distributed workers chain_results = client.gather(chain_results_futures) # Convert everything into a single DataFrame chain_data = pd.concat([pd.DataFrame(r) for r in chain_results])
Your contribution here
Is there an "advanced" protocol that you've developed surrounding
If so, please consider contributing it here!