Advanced Usage
APIs that support "advanced" tasks are available in jax-unirep
.
Read on to learn how to use them.
Evotuning
In the original UniRep paper, the authors introduced the concept of 'evolutionary finetuning'. Here the pre-trained mLSTM weights get fine-tuned through weight-updates using homolog protein sequences of a given protein of interest as input.
This feature is available as well in jax-unirep
.
Given a set of starter weights for the mLSTM (defaults to
the weights from the paper) as well as a set of sequences,
the weights get fine-tuned in such a way that test set loss
in the 'next-aa prediction task' is minimized.
There are two functions with differing levels of control available.
The evotune
function uses optuna
under the hood
to automatically find:
- the optimal number of epochs to train for, and
- the optimal learning rate,
given a set of sequences.
The study
object will contain all the information
about the training process of each trial.
evotuned_params
will contain the fine-tuned mLSTM and dense weights
from the trial with the lowest test set loss.
Speed freaks read this!
As a heads-up, using evotune
is kind of slow,
so read on if you're of the impatient kind -- use fit
!
If you want to directly fine-tune the weights
for a fixed number of epochs
while using a fixed learning rate,
you should use the fit
function instead.
The fit
function has further customization options,
such as different batching strategies.
Please see the function docstring here for more information.
GPU usage
The fit
function will always default to using a
GPU backend
if available for the forward and backward passes
during training of the LSTM.
However, for the calulation of the average loss
on the dataset after every epoch, you can decide
if the CPU or GPU backend
should be used (default is CPU).
You can find an example usage of the evotuning function here.
For an example workflow using fit
, have a look at the notebook
in the next section.
Read the docs!
Can't emphasize this enough:
Be sure to read the API docs for the fit
function
to learn about what's going on underneath the hood!
If you want to pass a set of embedding, mLSTM and dense weights that were dumped in an earlier run, create params as follows:
from jax_unirep.utils import load_params
params = load_params(folderpath="path/to/params/folder")
Make sure that the params were created using the same model architecture that you want to use them with!
If you want to start from randomly initialized embedding, mLSTM and dense weights instead:
from jax_unirep.evotuning_models import mlstm1900
from jax.random import PRNGKey
init_fun, apply_fun = mlstm1900()
_, params = init_fun(PRNGKey(42), input_shape=(-1, 26))
End-to-end differentiable models
As a user, you might want to write custom "top models", such as a linear model on top of the reps, but might want to jointly optimize the UniRep weights with the top model reps. You're in luck!
We implemented the mLSTM layers in such a way that
they are compatible with jax.experimental.stax
.
This means that they can easily be plugged into
a stax.serial
model, e.g. to train both the mLSTM
and a top-model at once:
from jax.experimental import stax
from jax.experimental.stax import Dense, Relu, serial
from jax_unirep.layers import AAEmbedding, mLSTM, mLSTMAvgHidden
init_fun, apply_fun = serial(
AAEmbedding(10)
mLSTM(1900),
mLSTMAvgHidden(),
# Add two layers, one dense layer that results in 512-dim activations
Dense(512), Relu(),
# And then a linear layer to produce a 1-dim activation
Dense(1)
)
Have a look at the documentation and examples
for more information about how to implement a model in jax
.
Sampling new protein sequences
When doing protein engineering,
one core task is proposing new sequences to order by gene synthesis.
jax-unirep
provides a number of utility functions inside jax_unirep.sampler
that help with this task.
Basic sampling
The key one to focus on is the sample_one_chain
function.
This function takes in a starting sequence, and uses Monte Carlo sampling alongside the Metropolis-Hastings criteria to score and rank-order new sequences to try out. The usage pattern is as follows.
Firstly, you must have a scoring function defined that takes in a string sequence, and outputs a number. This can be, for example, in the form of a pre-trained machine learning model that you have created.
from jax_unirep import get_reps
model = SomeSKLearnModel()
model.fit(training_X, training_y)
def scoring_func(sequence: str):
reps, _, _ = get_reps(sequence)
return model.predict(reps)
Now, we can use MCMC sampling to propose new sequences.
from jax_unirep import sample_one_chain
starter_seq = "MKLNEQLJLA" # can be longer!
sampled_sequences = sample_one_chain(starter_seq, n_steps=10, scoring_func=scoring_func)
sampled_seqs_df = pd.DataFrame(sampled_sequences)
sampled_sequences
is a dictionary
that can be converted directly into a pandas.DataFrame
.
In there, every single sequence that was ever sampled is recorded,
as well as its score (given by the scoring function)
and whether it was accepted by the MCMC sampler or not.
(All generated sequences are recorded,
just in case there was something good that was rejected!)
Parallel sampling
If you want to do parallel sampling, you can use any library that does parallel processing. We're going to show you one example using Dask, which happens to be out favourite library for scalable Python!
Assuming you have a Dask client
object instantiated:
client = Client(...) # you'll have to configure this according to your own circumstances
starter_seq = "MKLNEQLJLA" # can be longer!
chain_results_futures = []
for i in range(100): # sample 100 independent chains
chain_results_futures.append(
# Submit tasks to workers
client.submit(
sample_one_chain,
starter_seq,
n_steps=10,
scoring_func=scoring_func,
pure=False # this is important, esp. with random sampling methods
)
)
# Gather results from distributed workers
chain_results = client.gather(chain_results_futures)
# Convert everything into a single DataFrame
chain_data = pd.concat([pd.DataFrame(r) for r in chain_results])
Your contribution here
Is there an "advanced" protocol that you've developed surrounding jax-unirep
?
If so, please consider contributing it here!