API Documentation
Here lies the official top-level API for interacting with jax-unirep
.
Calculating Representations
jax_unirep.get_reps
jax_unirep.get_reps
(seqs, params=None, mlstm_size=1900)Get reps of proteins.
This function generates representations of protein sequences using the mLSTM model from the UniRep paper.
Each element of the output 3-tuple is a np.array
of shape (n_input_sequences, mlstm_size):
h_avg
: Average hidden state of the mLSTM over the whole sequence.h_final
: Final hidden state of the mLSTMc_final
: Final cell state of the mLSTM
You should not use this function
if you want to do further JAX-based computations
on the output vectors!
In that case, the DeviceArray
futures returned by mLSTM
should be passed directly into the next step
instead of converting them to np.array
s.
The conversion to np.array
s is done
in the dispatched rep_x_lengths
functions
to force python to wait with returning the values
until the computation is completed.
The keys of the params
dictionary must be:
b, gh, gmh, gmx, gx, wh, wmh, wmx, wx
Parameters
seqs
: A list of sequences as strings or a single string.params
: A dictionary of mLSTM weights.mlstm_size
: Integer specifying the number of nodes in the mLSTM layer. Though the model architecture space is practically infinite, we assume that you are using the same number of nodes per mLSTM layer. (This is a common simplification used in the design of neural networks.)
Returns
A 3-tuple of np.array
s containing the reps,
in the order h_avg
, h_final
, and c_final
.
Each np.array
has shape (n_sequences, mlstm_size).
Evotuning
jax_unirep.fit
jax_unirep.fit
(sequences, n_epochs, model_func=Return mLSTM weights fitted to predict the next letter in each AA sequence.
The training loop is as follows, depending on the batching strategy:
Length batching:
- At each iteration,
of all sequence lengths present in
sequences
, one length gets chosen at random. - Next,
batch_size
number of sequences of the chosen length get selected at random. - If there are less sequences of a given length than
batch_size
, all sequences of that length get chosen. - Those sequences then get passed through the model. No padding of sequences occurs.
To get batching of sequences by length done,
we call on batch_sequences
from our utils.py
module,
which returns a list of sub-lists,
in which each sub-list contains the indices
in the original list of sequences
that are of a particular length.
Random batching:
- Before training, all sequences get padded
to be the same length as the longest sequence
in
sequences
. - Then, at each iteration,
we randomly sample
batch_size
sequences and pass them through the model.
The training loop does not adhere
to the common notion of epochs
,
where all sequences would be seen by the model
exactly once per epoch.
Instead sequences always get sampled at random,
and one epoch approximately consists of
round(len(sequences) / batch_size)
weight updates.
Asymptotically, this should be approximately equivalent
to doing epoch passes over the dataset.
To learn more about the passing of params
,
have a look at the evotune
function docstring.
You can optionally dump parameters
and print weights every epochs_per_print
epochs
to monitor training progress.
For ergonomics, training/holdout set losses are estimated
on a batch size the same as batch_size
,
rather than calculated exactly on the entire set.
Set epochs_per_print
to None
to avoid parameter dumping.
Parameters
sequences
: List of sequences to evotune on.n_epochs
: The number of iterations to evotune on.model_func
: A function that accepts (params, x). Defaults to the mLSTM1900 model function.params
: Optionally pass in the params you want to use. These params must yield a correctly-sized mLSTM, otherwise you will get cryptic shape errors! If None, params will be randomly generated, except for mlstm_size of 1900, where the pre-trained weights from the original publication are used.batch_method
: One of "length" or "random".batch_size
: If random batching is used, number of sequences per batch. As a rule of thumb, batch size of 50 consumes about 5GB of GPU RAM.step_size
: The learning rate.holdout_seqs
: Holdout set, an optional input.proj_name
: The directory path for weights to be output to.epochs_per_print
: Number of epochs to progress before printing and dumping of weights. Must be greater than or equal to 1.backend
: Whether or not to use the GPU. Defaults to "cpu", but can be set to "gpu" if desired. If you set it to GPU, make sure you have a version ofjax
that is pre-compiled to work with GPUs.
Returns
Final optimized parameters.
jax_unirep.evotune
jax_unirep.evotune
(sequences, model_func=Evolutionarily tune the model to a set of sequences.
Evotuning is described in the original UniRep and eUniRep papers. This reimplementation of evotune provides a nicer API that automatically handles multiple sequences of variable lengths.
Evotuning always needs a starter set of weights. By default, the pre-trained weights from the Nature Methods paper are used. However, other pre-trained weights are legitimate.
We first use optuna to figure out how many epochs to fit before overfitting happens. To save on computation time, the number of trials run defaults to 20, but can be configured.
By default, mLSTM and Dense weights from the paper are used
by setting mlstm_size=1900
and params=None
in the partially-evaluated fit function (fit_func
),
but if you want to use randomly intialized weights:
from jax_unirep.evotuning import evotuning_funcs, fit
from jax.random import PRNGKey
from functools import partial
init_func, = evotuning_funcs(mlstm_size=256) # works for any size
, params = init_func(PRNGKey(0), input_shape=(-1, 26))
fit_func = partial(fit, mlstm_size=256, params=params)
or dumped weights:
from jax_unirep.evotuning import fit
from jax_unirep.utils import load_params
params = load_params(folderpath="path/to/params/folder")
fit_func = partial(fit, mlstm_size=256, params=params)
The examples above use mLSTM sizes of 256, but any size works in theory!
Just make sure that the mLSTM size of your randomly initialized or dumped
params
matches the one you set in the partially-evaluated fit function.
This function is intended as an automagic way of identifying
the best model and training routine hyperparameters.
If you want more control over how fitting happens,
please use the fit()
function directly.
There is an example in the examples/
directory
that shows how to use it.
Parameters
sequences
: Sequences to evotune against.model_func
: Model apply func. Defaults to the mLSTM1900 apply function.params
: Model params that are compatible with model apply func. Defaults to the mLSTM1900 params.- `n_trials: The number of trials Optuna should attempt.
n_epochs_config
: A dictionary of kwargs totrial.suggest_discrete_uniform
, which are:name
,low
,high
,q
. This controls how many epochs to have Optuna test. See source code for default configuration, at the definition ofn_epochs_kwargs
.learning_rate_config
: A dictionary of kwargs totrial.suggest_loguniform
, which are:name
,low
,high
. This controls the learning rate of the model. See source code for default configuration, at the definition oflearning_rate_kwargs
.n_splits
: The number of folds of cross-validation to do.out_dom_seqs
: Out-domain holdout set of sequences, to check for loss on to prevent overfitting.
Returns
study
: The optuna study object, containing information about all evotuning trials.evotuned_params
: A dictionary of the final, optimized weights.
Sampling
jax_unirep.sample_one_chain
jax_unirep.sample_one_chain
(starter_sequence, n_steps, scoring_func, is_accepted_kwargs={}, trust_radius=7, propose_kwargs={})Return one chain of MCMC samples of new sequences.
Given a starter_sequence
,
this function will sample one chain of protein sequences,
scored using a user-provided scoring_func
.
Design choices made here include the following.
Firstly, we record all sequences that were sampled, and not just the accepted ones. This behaviour differs from other MCMC samplers that record only the accepted values. We do this just in case sequences that are still "good" (but not better than current) are rejected. The effect here is that we get a cluster of sequences that are one-apart from newly accepted sequences.
Secondly, we check the Hamming distance between the newly proposed sequences and the original. This corresponds to the "trust radius" specified in the jax-unirep paper. If the hamming distance > trust radius, we reject the sequence outright.
A dictionary containing the following key-value pairs are returned:
- "sequences": All proposed sequences.
- "scores": All scores from the scoring function.
- "accept": Whether the sequence was accepted as the new 'current sequence' on which new sequences are proposed.
This can be turned into a pandas DataFrame.
Parameters
starter_sequence
: The starting sequence.n_steps
: Number of steps for the MC chain to walk.scoring_func
: Scoring function for a new sequence. It should only accept a stringsequence
.is_accepted_kwargs
: Dictionary of kwargs to pass intois_accepted
function. Seeis_accepted
docstring for more details.trust_radius
: Maximum allowed number of mutations away from starter sequence.propose_kwargs
: Dictionary of kwargs to pass intopropose
function. Seepropose
docstring for more details.verbose
: Whether or not to print iteration number and associated sequence + score. Defaults to False
Returns
A dictionary with sequences
, accept
and score
as keys.