Skip to content

API Documentation

Here lies the official top-level API for interacting with jax-unirep.

Calculating Representations


jax_unirep.get_reps(seqs, params=None, mlstm_size=1900)

Get reps of proteins.

This function generates representations of protein sequences using the mLSTM model from the UniRep paper.

Each element of the output 3-tuple is a np.array of shape (n_input_sequences, mlstm_size):

  • h_avg: Average hidden state of the mLSTM over the whole sequence.
  • h_final: Final hidden state of the mLSTM
  • c_final: Final cell state of the mLSTM

You should not use this function if you want to do further JAX-based computations on the output vectors! In that case, the DeviceArray futures returned by mLSTM should be passed directly into the next step instead of converting them to np.arrays. The conversion to np.arrays is done in the dispatched rep_x_lengths functions to force python to wait with returning the values until the computation is completed.

The keys of the params dictionary must be:

b, gh, gmh, gmx, gx, wh, wmh, wmx, wx


  • seqs: A list of sequences as strings or a single string.
  • params: A dictionary of mLSTM weights.
  • mlstm_size: Integer specifying the number of nodes in the mLSTM layer. Though the model architecture space is practically infinite, we assume that you are using the same number of nodes per mLSTM layer. (This is a common simplification used in the design of neural networks.)


A 3-tuple of np.arrays containing the reps, in the order h_avg, h_final, and c_final. Each np.array has shape (n_sequences, mlstm_size).

Evotuning, n_epochs, model_func=.apply_fun at 0x7fe17b2d2050>, params=None, batch_method='random', batch_size=25, step_size=0.0001, holdout_seqs=None, proj_name='temp', epochs_per_print=1, backend='cpu')

Return mLSTM weights fitted to predict the next letter in each AA sequence.

The training loop is as follows, depending on the batching strategy:

Length batching:

  • At each iteration, of all sequence lengths present in sequences, one length gets chosen at random.
  • Next, batch_size number of sequences of the chosen length get selected at random.
  • If there are less sequences of a given length than batch_size, all sequences of that length get chosen.
  • Those sequences then get passed through the model. No padding of sequences occurs.

To get batching of sequences by length done, we call on batch_sequences from our module, which returns a list of sub-lists, in which each sub-list contains the indices in the original list of sequences that are of a particular length.

Random batching:

  • Before training, all sequences get padded to be the same length as the longest sequence in sequences.
  • Then, at each iteration, we randomly sample batch_size sequences and pass them through the model.

The training loop does not adhere to the common notion of epochs, where all sequences would be seen by the model exactly once per epoch. Instead sequences always get sampled at random, and one epoch approximately consists of round(len(sequences) / batch_size) weight updates. Asymptotically, this should be approximately equivalent to doing epoch passes over the dataset.

To learn more about the passing of params, have a look at the evotune function docstring.

You can optionally dump parameters and print weights every epochs_per_print epochs to monitor training progress. For ergonomics, training/holdout set losses are estimated on a batch size the same as batch_size, rather than calculated exactly on the entire set. Set epochs_per_print to None to avoid parameter dumping.


  • sequences: List of sequences to evotune on.
  • n_epochs: The number of iterations to evotune on.
  • model_func: A function that accepts (params, x). Defaults to the mLSTM1900 model function.
  • params: Optionally pass in the params you want to use. These params must yield a correctly-sized mLSTM, otherwise you will get cryptic shape errors! If None, params will be randomly generated, except for mlstm_size of 1900, where the pre-trained weights from the original publication are used.
  • batch_method: One of "length" or "random".
  • batch_size: If random batching is used, number of sequences per batch. As a rule of thumb, batch size of 50 consumes about 5GB of GPU RAM.
  • step_size: The learning rate.
  • holdout_seqs: Holdout set, an optional input.
  • proj_name: The directory path for weights to be output to.
  • epochs_per_print: Number of epochs to progress before printing and dumping of weights. Must be greater than or equal to 1.
  • backend: Whether or not to use the GPU. Defaults to "cpu", but can be set to "gpu" if desired. If you set it to GPU, make sure you have a version of jax that is pre-compiled to work with GPUs.


Final optimized parameters.


jax_unirep.evotune(sequences, model_func=.apply_fun at 0x7fe17b2d2050>, params=None, n_trials=20, n_epochs_config=None, learning_rate_config=None, n_splits=5, out_dom_seqs=None)

Evolutionarily tune the model to a set of sequences.

Evotuning is described in the original UniRep and eUniRep papers. This reimplementation of evotune provides a nicer API that automatically handles multiple sequences of variable lengths.

Evotuning always needs a starter set of weights. By default, the pre-trained weights from the Nature Methods paper are used. However, other pre-trained weights are legitimate.

We first use optuna to figure out how many epochs to fit before overfitting happens. To save on computation time, the number of trials run defaults to 20, but can be configured.

By default, mLSTM and Dense weights from the paper are used by setting mlstm_size=1900 and params=None in the partially-evaluated fit function (fit_func), but if you want to use randomly intialized weights:

from jax_unirep.evotuning import evotuning_funcs, fit
from jax.random import PRNGKey
from functools import partial

init_func,  = evotuning_funcs(mlstm_size=256) # works for any size
, params = init_func(PRNGKey(0), input_shape=(-1, 26))
fit_func = partial(fit, mlstm_size=256, params=params)

or dumped weights:

from jax_unirep.evotuning import fit
from jax_unirep.utils import load_params

params = load_params(folderpath="path/to/params/folder")
fit_func = partial(fit, mlstm_size=256, params=params)

The examples above use mLSTM sizes of 256, but any size works in theory! Just make sure that the mLSTM size of your randomly initialized or dumped params matches the one you set in the partially-evaluated fit function.

This function is intended as an automagic way of identifying the best model and training routine hyperparameters. If you want more control over how fitting happens, please use the fit() function directly. There is an example in the examples/ directory that shows how to use it.


  • sequences: Sequences to evotune against.
  • model_func: Model apply func. Defaults to the mLSTM1900 apply function.
  • params: Model params that are compatible with model apply func. Defaults to the mLSTM1900 params.
  • `n_trials: The number of trials Optuna should attempt.
  • n_epochs_config: A dictionary of kwargs to trial.suggest_discrete_uniform, which are: name, low, high, q. This controls how many epochs to have Optuna test. See source code for default configuration, at the definition of n_epochs_kwargs.
  • learning_rate_config: A dictionary of kwargs to trial.suggest_loguniform, which are: name, low, high. This controls the learning rate of the model. See source code for default configuration, at the definition of learning_rate_kwargs.
  • n_splits: The number of folds of cross-validation to do.
  • out_dom_seqs: Out-domain holdout set of sequences, to check for loss on to prevent overfitting.


  • study: The optuna study object, containing information about all evotuning trials.
  • evotuned_params: A dictionary of the final, optimized weights.



jax_unirep.sample_one_chain(starter_sequence, n_steps, scoring_func, is_accepted_kwargs={}, trust_radius=7, propose_kwargs={})

Return one chain of MCMC samples of new sequences.

Given a starter_sequence, this function will sample one chain of protein sequences, scored using a user-provided scoring_func.

Design choices made here include the following.

Firstly, we record all sequences that were sampled, and not just the accepted ones. This behaviour differs from other MCMC samplers that record only the accepted values. We do this just in case sequences that are still "good" (but not better than current) are rejected. The effect here is that we get a cluster of sequences that are one-apart from newly accepted sequences.

Secondly, we check the Hamming distance between the newly proposed sequences and the original. This corresponds to the "trust radius" specified in the jax-unirep paper. If the hamming distance > trust radius, we reject the sequence outright.

A dictionary containing the following key-value pairs are returned:

  • "sequences": All proposed sequences.
  • "scores": All scores from the scoring function.
  • "accept": Whether the sequence was accepted as the new 'current sequence' on which new sequences are proposed.

This can be turned into a pandas DataFrame.


  • starter_sequence: The starting sequence.
  • n_steps: Number of steps for the MC chain to walk.
  • scoring_func: Scoring function for a new sequence. It should only accept a string sequence.
  • is_accepted_kwargs: Dictionary of kwargs to pass into is_accepted function. See is_accepted docstring for more details.
  • trust_radius: Maximum allowed number of mutations away from starter sequence.
  • propose_kwargs: Dictionary of kwargs to pass into propose function. See propose docstring for more details.
  • verbose: Whether or not to print iteration number and associated sequence + score. Defaults to False


A dictionary with sequences, accept and score as keys.