diff --git a/af/examples/af_single.ipynb b/af/examples/af_single.ipynb new file mode 100644 index 00000000..2b207c2c --- /dev/null +++ b/af/examples/af_single.ipynb @@ -0,0 +1,302 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "AlphaFold_single.ipynb", + "provenance": [], + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "source": [ + "#AlphaFold - single sequence input\n", + "- WARNING - For DEMO and educational purposes only.\n", + "- For natural proteins you often need more than a single sequence to accurately predict the structure. See [ColabFold](https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/AlphaFold2.ipynb) notebook if you want to predict the protein structure from a multiple-sequence-alignment. That being said, this notebook could be useful for evaluating *de novo* designed proteins and learning the idealized principles of proteins.\n", + "\n", + "### Tips and Instructions\n", + "- Patience... The first time you run the cell below it will take 1 minitue to setup, after that it should run in seconds (after each change).\n", + "- click the little ▶ play icon to the left of each cell below.\n", + "- For 3D display, hold mouseover aminoacid to get name and position number\n", + "- use \"/\" to specify chainbreaks, (eg. sequence=\"AAA/AAA\")\n" + ], + "metadata": { + "id": "VpfCw7IzVHXv" + } + }, + { + "cell_type": "code", + "source": [ + "#@title Enter the amino acid sequence to fold ⬇️\n", + "\n", + "###############################################################################\n", + "###############################################################################\n", + "#@title Setup\n", + "# import libraries\n", + "import os,sys,re,time\n", + "\n", + "if \"SETUP_DONE\" not in dir():\n", + " from IPython.utils import io\n", + " from IPython.display import HTML\n", + " import numpy as np\n", + " import matplotlib\n", + " from matplotlib import animation\n", + " import matplotlib.pyplot as plt\n", + " import tqdm.notebook\n", + " TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'\n", + "\n", + " if not os.path.isdir(\"params\"):\n", + " os.system(\"wget -qnc https://raw.githubusercontent.com/sokrypton/ColabFold/main/beta/colabfold.py\")\n", + " # get code\n", + " print(\"installing ColabDesign...\")\n", + " os.system(\"(mkdir params; apt-get install aria2 -qq; \\\n", + " aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar; \\\n", + " tar -xf alphafold_params_2021-07-14.tar -C params; \\\n", + " touch params/done.txt )&\")\n", + "\n", + " os.system(\"pip -q install git+https://github.com/sokrypton/ColabDesign.git@beta\")\n", + " os.system(\"ln -s /usr/local/lib/python3.*/dist-packages/colabdesign colabdesign\")\n", + "\n", + " # download params\n", + " if not os.path.isfile(\"params/done.txt\"):\n", + " print(\"downloading AlphaFold params...\")\n", + " while not os.path.isfile(\"params/done.txt\"):\n", + " time.sleep(5)\n", + "\n", + " # configure which device to use\n", + " import jax\n", + " # disable triton_gemm for jax versions > 0.3\n", + " if int(jax.__version__.split(\".\")[1]) > 3:\n", + " os.environ[\"XLA_FLAGS\"] = \"--xla_gpu_enable_triton_gemm=false\"\n", + " import jax.numpy as jnp\n", + " try:\n", + " # check if TPU is available\n", + " import jax.tools.colab_tpu\n", + " jax.tools.colab_tpu.setup_tpu()\n", + " print('Running on TPU')\n", + " DEVICE = \"tpu\"\n", + " except:\n", + " if jax.local_devices()[0].platform == 'cpu':\n", + " print(\"WARNING: no GPU detected, will be using CPU\")\n", + " DEVICE = \"cpu\"\n", + " else:\n", + " print('Running on GPU')\n", + " DEVICE = \"gpu\"\n", + "\n", + " # import libraries\n", + " sys.path.append('af_backprop')\n", + "\n", + " SETUP_DONE = True\n", + "\n", + "if \"LIBRARY_IMPORTED\" not in dir():\n", + " from colabdesign.af.loss import get_plddt, get_pae\n", + " from colabdesign.af.prep import prep_input_features\n", + " from colabdesign.af.inputs import update_seq, update_aatype\n", + " from colabdesign.af.alphafold.common import protein\n", + " from colabdesign.af.alphafold.model import data, config, model\n", + " from colabdesign.af.alphafold.common import residue_constants\n", + " from colabdesign.rf.utils import make_animation\n", + " import py3Dmol\n", + " import colabfold as cf\n", + "\n", + " # setup model\n", + " cfg = config.model_config(\"model_5_ptm\")\n", + " cfg.model.num_recycle = 0\n", + " cfg.model.global_config.subbatch_size = None\n", + " model_name=\"model_2_ptm\"\n", + " model_params = data.get_model_haiku_params(model_name=model_name,\n", + " data_dir=\".\",\n", + " fuse=True)\n", + " model_runner = model.RunModel(cfg, model_params)\n", + "\n", + " def setup_model(max_len):\n", + "\n", + " seq = \"A\" * max_len\n", + " length = len(seq)\n", + " inputs = prep_input_features(length)\n", + "\n", + " def runner(I):\n", + " # update sequence\n", + " inputs = I[\"inputs\"]\n", + " inputs[\"prev\"] = I[\"prev\"]\n", + "\n", + " seq_oh = jax.nn.one_hot(I[\"seq\"],20)[None]\n", + " update_seq(seq_oh, inputs)\n", + " update_aatype(seq_oh, inputs)\n", + "\n", + " # mask prediction\n", + " mask = jnp.arange(inputs[\"residue_index\"].shape[0]) < I[\"length\"]\n", + " inputs[\"seq_mask\"] = inputs[\"seq_mask\"].at[:].set(mask)\n", + " inputs[\"msa_mask\"] = inputs[\"msa_mask\"].at[:].set(mask)\n", + " inputs[\"residue_index\"] = jnp.where(mask, inputs[\"residue_index\"], 0)\n", + "\n", + " # get prediction\n", + " key = jax.random.PRNGKey(0)\n", + " outputs = model_runner.apply(I[\"params\"], key, inputs)\n", + "\n", + " aux = {\"final_atom_positions\":outputs[\"structure_module\"][\"final_atom_positions\"],\n", + " \"final_atom_mask\":outputs[\"structure_module\"][\"final_atom_mask\"],\n", + " \"plddt\":get_plddt(outputs),\"pae\":get_pae(outputs),\n", + " \"length\":I[\"length\"], \"seq\":I[\"seq\"],\n", + " \"prev\":outputs[\"prev\"],\n", + " \"residue_idx\":inputs[\"residue_index\"]}\n", + " return aux\n", + "\n", + " return jax.jit(runner), {\"inputs\":inputs, \"params\":model_params, \"length\":max_length}\n", + "\n", + " def save_pdb(outs, filename):\n", + " '''save pdb coordinates'''\n", + " p = {\"residue_index\":outs[\"residue_idx\"] + 1,\n", + " \"aatype\":outs[\"seq\"],\n", + " \"atom_positions\":outs[\"final_atom_positions\"],\n", + " \"atom_mask\":outs[\"final_atom_mask\"],\n", + " \"plddt\":outs[\"plddt\"]}\n", + " p = jax.tree_util.tree_map(lambda x:x[:outs[\"length\"]], p)\n", + " b_factors = 100 * p.pop(\"plddt\")[:,None] * p[\"atom_mask\"]\n", + " p = protein.Protein(**p,b_factors=b_factors)\n", + " pdb_lines = protein.to_pdb(p)\n", + " with open(filename, 'w') as f: f.write(pdb_lines)\n", + "\n", + " LIBRARY_IMPORTED = True\n", + "\n", + "###############################################################################\n", + "###############################################################################\n", + "\n", + "# initialize\n", + "if \"current_seq\" not in dir():\n", + " current_seq = \"\"\n", + " r = -1\n", + " max_length = -1\n", + "\n", + "# collect user inputs\n", + "sequence = 'GGGGGGGGGG' #@param {type:\"string\"}\n", + "recycles = 0 #@param [\"0\", \"1\", \"2\", \"3\", \"6\", \"12\", \"24\", \"48\"] {type:\"raw\"}\n", + "\n", + "######\n", + "# Define the allowed amino acids\n", + "AA20 = set(\"ACDEFGHIKLMNPQRSTVWY\")\n", + "\n", + "# Fix regex (no SyntaxWarning), keep only A–Z and separators\n", + "ori_sequence = re.sub(r\"[^A-Z/:]\", \"\", sequence.upper())\n", + "\n", + "# Replace non-AA letters with \"G\", but preserve ':' and '/'\n", + "ori_sequence = \"\".join(\n", + " ch if (ch in AA20 or ch in [\":\", \"/\"]) else (\"G\" if ch.isalpha() else \"\")\n", + " for ch in ori_sequence\n", + ")\n", + "######\n", + "\n", + "Ls = [len(s) for s in ori_sequence.replace(\":\",\"/\").split(\"/\")]\n", + "sequence = re.sub(\"[^A-Z]\",\"\",ori_sequence)\n", + "length = len(sequence)\n", + "\n", + "# avoid recompiling if length within 10\n", + "if length > max_length or (max_length - length) > 20:\n", + " max_length = length + 10\n", + " runner, I = setup_model(max_length)\n", + "\n", + "if ori_sequence != current_seq:\n", + " outs = []\n", + " positions = []\n", + " plddts = []\n", + " paes = []\n", + " r = -1\n", + "\n", + " # pad sequence to max length\n", + " seq = np.array([residue_constants.restype_order.get(aa,0) for aa in sequence])\n", + " seq = np.pad(seq,[0,max_length-length],constant_values=-1)\n", + "\n", + " # update inputs, restart recycle\n", + " I.update({\"seq\":seq, \"length\":length,\n", + " \"prev\":{'prev_msa_first_row': np.zeros([max_length, 256]),\n", + " 'prev_pair': np.zeros([max_length, max_length, 128]),\n", + " 'prev_pos': np.zeros([max_length, 37, 3])}})\n", + "\n", + " I[\"inputs\"][\"use_dropout\"] = False\n", + " I[\"inputs\"]['residue_index'][:] = cf.chain_break(np.arange(max_length), Ls, length=32)\n", + " current_seq = ori_sequence\n", + "\n", + "# run for defined number of recycles\n", + "with tqdm.notebook.tqdm(total=(recycles+1), bar_format=TQDM_BAR_FORMAT) as pbar:\n", + " p = 0\n", + " while p < min(r+1,recycles+1):\n", + " pbar.update(1)\n", + " p += 1\n", + " while r < recycles:\n", + " O = runner(I)\n", + " O = jax.tree_util.tree_map(lambda x:np.asarray(x), O)\n", + " positions.append(O[\"final_atom_positions\"][:length])\n", + " plddts.append(O[\"plddt\"][:length])\n", + " paes.append(O[\"pae\"][:length,:length])\n", + " I[\"prev\"] = O[\"prev\"]\n", + " outs.append(O)\n", + " r += 1\n", + " pbar.update(1)\n", + "\n", + "#@markdown #### Display options\n", + "color = \"confidence\" #@param [\"chain\", \"confidence\", \"rainbow\"]\n", + "if color == \"confidence\": color = \"lDDT\"\n", + "show_sidechains = True #@param {type:\"boolean\"}\n", + "show_mainchains = False #@param {type:\"boolean\"}\n", + "\n", + "print(f\"plotting prediction at recycle={recycles}\")\n", + "save_pdb(outs[recycles], \"out.pdb\")\n", + "v = cf.show_pdb(\"out.pdb\", show_sidechains, show_mainchains, color,\n", + " color_HP=True, size=(800,480), Ls=Ls)\n", + "v.setHoverable({}, True,\n", + " '''function(atom,viewer,event,container){if(!atom.label){atom.label=viewer.addLabel(\" \"+atom.resn+\":\"+atom.resi,{position:atom,backgroundColor:'mintcream',fontColor:'black'});}}''',\n", + " '''function(atom,viewer){if(atom.label){viewer.removeLabel(atom.label);delete atom.label;}}''')\n", + "v.show()\n", + "if color == \"lDDT\":\n", + " cf.plot_plddt_legend().show()\n", + "\n", + "# add confidence plots\n", + "cf.plot_confidence(plddts[recycles]*100, paes[recycles], Ls=Ls).show()" + ], + "metadata": { + "cellView": "form", + "id": "cAoC4ar8G7ZH" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title Animate\n", + "#@markdown - Animate trajectory if more than 0 recycle(s)\n", + "HTML(make_animation(np.asarray(positions)[...,1,:],\n", + " np.asarray(plddts) * 100.0,\n", + " Ls=Ls,\n", + " ref=-1, align_to_ref=True,\n", + " verbose=True))" + ], + "metadata": { + "cellView": "form", + "id": "tdjdC0KFPjWw" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/colabdesign/af/alphafold/model/all_atom.py b/colabdesign/af/alphafold/model/all_atom.py index 43331586..d90dd8d3 100644 --- a/colabdesign/af/alphafold/model/all_atom.py +++ b/colabdesign/af/alphafold/model/all_atom.py @@ -490,11 +490,11 @@ def torsion_angles_to_frames( # chi2, chi3, and chi4 frames do not transform to the backbone frame but to # the previous frame. So chain them up accordingly. - chi2_frame_to_frame = jax.tree_map(lambda x: x[:, 5], all_frames) - chi3_frame_to_frame = jax.tree_map(lambda x: x[:, 6], all_frames) - chi4_frame_to_frame = jax.tree_map(lambda x: x[:, 7], all_frames) + chi2_frame_to_frame = jax.tree_util.tree_map(lambda x: x[:, 5], all_frames) + chi3_frame_to_frame = jax.tree_util.tree_map(lambda x: x[:, 6], all_frames) + chi4_frame_to_frame = jax.tree_util.tree_map(lambda x: x[:, 7], all_frames) - chi1_frame_to_backb = jax.tree_map(lambda x: x[:, 4], all_frames) + chi1_frame_to_backb = jax.tree_util.tree_map(lambda x: x[:, 4], all_frames) chi2_frame_to_backb = r3.rigids_mul_rigids(chi1_frame_to_backb, chi2_frame_to_frame) chi3_frame_to_backb = r3.rigids_mul_rigids(chi2_frame_to_backb, @@ -507,7 +507,7 @@ def _concat_frames(xall, x5, x6, x7): return jnp.concatenate( [xall[:, 0:5], x5[:, None], x6[:, None], x7[:, None]], axis=-1) - all_frames_to_backb = jax.tree_map( + all_frames_to_backb = jax.tree_util.tree_map( _concat_frames, all_frames, chi2_frame_to_backb, @@ -517,7 +517,7 @@ def _concat_frames(xall, x5, x6, x7): # Create the global frames. # shape (N, 8) all_frames_to_global = r3.rigids_mul_rigids( - jax.tree_map(lambda x: x[:, None], backb_to_global), + jax.tree_util.tree_map(lambda x: x[:, None], backb_to_global), all_frames_to_backb) return all_frames_to_global @@ -543,7 +543,7 @@ def frames_and_literature_positions_to_atom14_pos( group_mask = jax.nn.one_hot(residx_to_group_idx, num_classes=8) # shape (N, 14, 8) # r3.Rigids with shape (N, 14) - map_atoms_to_global = jax.tree_map( + map_atoms_to_global = jax.tree_util.tree_map( lambda x: jnp.sum(x[:, None, :] * group_mask, axis=-1), all_frames_to_global) @@ -558,7 +558,7 @@ def frames_and_literature_positions_to_atom14_pos( # Mask out non-existing atoms. mask = utils.batched_gather(residue_constants.restype_atom14_mask, aatype) - pred_positions = jax.tree_map(lambda x: x * mask, pred_positions) + pred_positions = jax.tree_util.tree_map(lambda x: x * mask, pred_positions) return pred_positions @@ -1040,14 +1040,14 @@ def frame_aligned_point_error( # Compute array of predicted positions in the predicted frames. # r3.Vecs (num_frames, num_positions) local_pred_pos = r3.rigids_mul_vecs( - jax.tree_map(lambda r: r[:, None], r3.invert_rigids(pred_frames)), - jax.tree_map(lambda x: x[None, :], pred_positions)) + jax.tree_util.tree_map(lambda r: r[:, None], r3.invert_rigids(pred_frames)), + jax.tree_util.tree_map(lambda x: x[None, :], pred_positions)) # Compute array of target positions in the target frames. # r3.Vecs (num_frames, num_positions) local_target_pos = r3.rigids_mul_vecs( - jax.tree_map(lambda r: r[:, None], r3.invert_rigids(target_frames)), - jax.tree_map(lambda x: x[None, :], target_positions)) + jax.tree_util.tree_map(lambda r: r[:, None], r3.invert_rigids(target_frames)), + jax.tree_util.tree_map(lambda x: x[None, :], target_positions)) # Compute errors between the structures. # jnp.ndarray (num_frames, num_positions) @@ -1119,8 +1119,8 @@ def get_alt_atom14(aatype, positions, mask): renaming_transform = utils.batched_gather( jnp.asarray(RENAMING_MATRICES), aatype) - positions = jax.tree_map(lambda x: x[:, :, None], positions) - alternative_positions = jax.tree_map( + positions = jax.tree_util.tree_map(lambda x: x[:, :, None], positions) + alternative_positions = jax.tree_util.tree_map( lambda x: jnp.sum(x, axis=1), positions * renaming_transform) # Create the mask for the alternative ground truth (differs from the diff --git a/colabdesign/af/alphafold/model/all_atom_multimer.py b/colabdesign/af/alphafold/model/all_atom_multimer.py index fc45c30a..c724661f 100644 --- a/colabdesign/af/alphafold/model/all_atom_multimer.py +++ b/colabdesign/af/alphafold/model/all_atom_multimer.py @@ -242,7 +242,7 @@ def atom37_to_atom14(aatype, all_atom_pos, all_atom_mask): # create a mask for known groundtruth positions atom14_mask *= utils.batched_gather(jnp.asarray(RESTYPE_ATOM14_MASK), aatype) # gather the groundtruth positions - atom14_positions = jax.tree_map( + atom14_positions = jax.tree_util.tree_map( lambda x: utils.batched_gather(x, residx_atom14_to_atom37, batch_dims=1), all_atom_pos) atom14_positions = atom14_mask * atom14_positions @@ -256,7 +256,7 @@ def get_alt_atom14(aatype, positions: geometry.Vec3Array, mask): renaming_transform = utils.batched_gather( jnp.asarray(RENAMING_MATRICES), aatype) - alternative_positions = jax.tree_map( + alternative_positions = jax.tree_util.tree_map( lambda x: jnp.sum(x, axis=1), positions[:, :, None] * renaming_transform) # Create the mask for the alternative ground truth (differs from the @@ -284,7 +284,7 @@ def atom37_to_frames( # If there is a batch axis, just flatten it away, and reshape everything # back at the end of the function. aatype = jnp.reshape(aatype, [-1]) - all_atom_positions = jax.tree_map(lambda x: jnp.reshape(x, [-1, 37]), + all_atom_positions = jax.tree_util.tree_map(lambda x: jnp.reshape(x, [-1, 37]), all_atom_positions) all_atom_mask = jnp.reshape(all_atom_mask, [-1, 37]) @@ -294,7 +294,7 @@ def atom37_to_frames( RESTYPE_RIGIDGROUP_BASE_ATOM37_IDX, aatype) # Gather the base atom positions for each rigid group. - base_atom_pos = jax.tree_map( + base_atom_pos = jax.tree_util.tree_map( lambda x: utils.batched_gather( # pylint: disable=g-long-lambda x, residx_rigidgroup_base_atom37_idx, batch_dims=1), all_atom_positions) @@ -351,11 +351,11 @@ def atom37_to_frames( fix_shape = lambda x: jnp.reshape(x, aatype_in_shape + (8,)) # reshape back to original residue layout - gt_frames = jax.tree_map(fix_shape, gt_frames) + gt_frames = jax.tree_util.tree_map(fix_shape, gt_frames) gt_exists = fix_shape(gt_exists) group_exists = fix_shape(group_exists) residx_rigidgroup_is_ambiguous = fix_shape(residx_rigidgroup_is_ambiguous) - alt_gt_frames = jax.tree_map(fix_shape, alt_gt_frames) + alt_gt_frames = jax.tree_util.tree_map(fix_shape, alt_gt_frames) return { 'rigidgroups_gt_frames': gt_frames, # Rigid (..., 8) @@ -423,7 +423,7 @@ def torsion_angles_to_frames( chi3_frame_to_backb = chi2_frame_to_backb @ all_frames[:, 6] chi4_frame_to_backb = chi3_frame_to_backb @ all_frames[:, 7] - all_frames_to_backb = jax.tree_map( + all_frames_to_backb = jax.tree_util.tree_map( lambda *x: jnp.concatenate(x, axis=-1), all_frames[:, 0:5], chi2_frame_to_backb[:, None], chi3_frame_to_backb[:, None], chi4_frame_to_backb[:, None]) @@ -448,7 +448,7 @@ def frames_and_literature_positions_to_atom14_pos( residx_to_group_idx, num_classes=8) # shape (N, 14, 8) # geometry.Rigid3Array with shape (N, 14) - map_atoms_to_global = jax.tree_map( + map_atoms_to_global = jax.tree_util.tree_map( lambda x: jnp.sum(x[:, None, :] * group_mask, axis=-1), all_frames_to_global) @@ -915,7 +915,7 @@ def compute_chi_angles(positions: geometry.Vec3Array, atom_indices = utils.batched_gather( params=chi_atom_indices, indices=aatype, axis=0) # Gather atom positions. Shape: [num_res, chis=4, atoms=4, xyz=3]. - chi_angle_atoms = jax.tree_map( + chi_angle_atoms = jax.tree_util.tree_map( lambda x: utils.batched_gather( # pylint: disable=g-long-lambda params=x, indices=atom_indices, axis=-1, batch_dims=1), positions) a, b, c, d = [chi_angle_atoms[..., i] for i in range(4)] diff --git a/colabdesign/af/alphafold/model/config.py b/colabdesign/af/alphafold/model/config.py index 15eb8f17..32f001e7 100644 --- a/colabdesign/af/alphafold/model/config.py +++ b/colabdesign/af/alphafold/model/config.py @@ -307,7 +307,8 @@ def model_config(name: str) -> ml_collections.ConfigDict: 'multimer_mode': False, 'subbatch_size': 4, 'use_remat': False, - 'zero_init': True + 'zero_init': True, + 'use_dgram_pred': False, }, 'heads': { 'distogram': { @@ -536,7 +537,7 @@ def model_config(name: str) -> ml_collections.ConfigDict: 'subbatch_size': 4, 'use_remat': False, 'zero_init': True, - 'use_dgram': False + 'use_dgram_pred': False, }, 'heads': { 'distogram': { diff --git a/colabdesign/af/alphafold/model/data.py b/colabdesign/af/alphafold/model/data.py index d34918a5..cb5e896d 100644 --- a/colabdesign/af/alphafold/model/data.py +++ b/colabdesign/af/alphafold/model/data.py @@ -28,7 +28,7 @@ def casp_model_names(data_dir: str) -> List[str]: return [os.path.splitext(filename)[0] for filename in params] -def get_model_haiku_params(model_name: str, data_dir: str, fuse: bool = None) -> hk.Params: +def get_model_haiku_params(model_name: str, data_dir: str, fuse: bool = None, rm_templates: bool = False) -> hk.Params: """Get the Haiku parameters from a model name.""" path = os.path.join(data_dir, 'params', f'params_{model_name}.npz') @@ -38,4 +38,4 @@ def get_model_haiku_params(model_name: str, data_dir: str, fuse: bool = None) -> if os.path.isfile(path): with open(path, 'rb') as f: params = np.load(io.BytesIO(f.read()), allow_pickle=False) - return utils.flat_params_to_haiku(params, fuse=fuse) \ No newline at end of file + return utils.flat_params_to_haiku(params, fuse=fuse, rm_templates=rm_templates) \ No newline at end of file diff --git a/colabdesign/af/alphafold/model/folding.py b/colabdesign/af/alphafold/model/folding.py index ac125859..edadd125 100644 --- a/colabdesign/af/alphafold/model/folding.py +++ b/colabdesign/af/alphafold/model/folding.py @@ -667,8 +667,8 @@ def sidechain_loss(batch, value, config): def _slice_last_layer_and_flatten(x): return jnp.reshape(x[-1], [-1]) - flat_pred_frames = jax.tree_map(_slice_last_layer_and_flatten, pred_frames) - flat_pred_positions = jax.tree_map(_slice_last_layer_and_flatten, pred_positions) + flat_pred_frames = jax.tree_util.tree_map(_slice_last_layer_and_flatten, pred_frames) + flat_pred_positions = jax.tree_util.tree_map(_slice_last_layer_and_flatten, pred_positions) # FAPE Loss on sidechains fape = all_atom.frame_aligned_point_error( pred_frames=flat_pred_frames, diff --git a/colabdesign/af/alphafold/model/folding_multimer.py b/colabdesign/af/alphafold/model/folding_multimer.py index 14db08ae..c1a9c721 100644 --- a/colabdesign/af/alphafold/model/folding_multimer.py +++ b/colabdesign/af/alphafold/model/folding_multimer.py @@ -331,7 +331,7 @@ def __call__( name='v_point_projection')(inputs_1d, rigid) - result_point_global = jax.tree_map( + result_point_global = jax.tree_util.tree_map( lambda x: jnp.sum(attn[..., None] * x, axis=-3), v_point[None]) # Features used in the linear output projection. Should have the size @@ -344,7 +344,7 @@ def __call__( result_scalar = jnp.reshape(result_scalar, flat_shape) output_features.append(result_scalar) - result_point_global = jax.tree_map(lambda r: jnp.reshape(r, flat_shape), + result_point_global = jax.tree_util.tree_map(lambda r: jnp.reshape(r, flat_shape), result_point_global) result_point_local = rigid[..., None].apply_inverse_to_point( result_point_global) @@ -464,7 +464,7 @@ def safe_dropout_fn(tensor, safe_key): outputs = {'rigid': rigid, 'sc': sc} - rotation = rigid.rotation #jax.tree_map(jax.lax.stop_gradient, rigid.rotation) + rotation = rigid.rotation #jax.tree_util.tree_map(jax.lax.stop_gradient, rigid.rotation) rigid = geometry.Rigid3Array(rotation, rigid.translation) new_activations = { @@ -694,7 +694,7 @@ def compute_frames( alt_gt_frames = frames_batch['rigidgroups_alt_gt_frames'] use_alt = use_alt[:, None] - renamed_gt_frames = jax.tree_map( + renamed_gt_frames = jax.tree_util.tree_map( lambda x, y: (1. - use_alt) * x + use_alt * y, gt_frames, alt_gt_frames) return renamed_gt_frames, frames_batch['rigidgroups_gt_exists'] @@ -710,18 +710,18 @@ def sidechain_loss(gt_frames: geometry.Rigid3Array, ) -> Dict[str, jnp.ndarray]: """Sidechain Loss using cleaned up rigids.""" - flat_gt_frames = jax.tree_map(jnp.ravel, gt_frames) + flat_gt_frames = jax.tree_util.tree_map(jnp.ravel, gt_frames) flat_frames_mask = jnp.ravel(gt_frames_mask) - flat_gt_positions = jax.tree_map(jnp.ravel, gt_positions) + flat_gt_positions = jax.tree_util.tree_map(jnp.ravel, gt_positions) flat_positions_mask = jnp.ravel(gt_mask) # Compute frame_aligned_point_error score for the final layer. def _slice_last_layer_and_flatten(x): return jnp.ravel(x[-1]) - flat_pred_frames = jax.tree_map(_slice_last_layer_and_flatten, pred_frames) - flat_pred_positions = jax.tree_map(_slice_last_layer_and_flatten, + flat_pred_frames = jax.tree_util.tree_map(_slice_last_layer_and_flatten, pred_frames) + flat_pred_positions = jax.tree_util.tree_map(_slice_last_layer_and_flatten, pred_positions) fape = all_atom_multimer.frame_aligned_point_error( pred_frames=flat_pred_frames, diff --git a/colabdesign/af/alphafold/model/geometry/rigid_matrix_vector.py b/colabdesign/af/alphafold/model/geometry/rigid_matrix_vector.py index 4c7bb105..29ece845 100644 --- a/colabdesign/af/alphafold/model/geometry/rigid_matrix_vector.py +++ b/colabdesign/af/alphafold/model/geometry/rigid_matrix_vector.py @@ -56,7 +56,7 @@ def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: def compose_rotation(self, other_rotation): rot = self.rotation @ other_rotation - trans = jax.tree_map(lambda x: jnp.broadcast_to(x, rot.shape), + trans = jax.tree_util.tree_map(lambda x: jnp.broadcast_to(x, rot.shape), self.translation) return Rigid3Array(rot, trans) diff --git a/colabdesign/af/alphafold/model/geometry/struct_of_array.py b/colabdesign/af/alphafold/model/geometry/struct_of_array.py index 97a89fd4..530f50e5 100644 --- a/colabdesign/af/alphafold/model/geometry/struct_of_array.py +++ b/colabdesign/af/alphafold/model/geometry/struct_of_array.py @@ -133,7 +133,7 @@ def flatten(instance): inner_treedefs = [] num_arrays = [] for array_like in array_likes: - flat_array_like, inner_treedef = jax.tree_flatten(array_like) + flat_array_like, inner_treedef = jax.tree_util.tree_flatten(array_like) inner_treedefs.append(inner_treedef) flat_array_likes += flat_array_like num_arrays.append(len(flat_array_like)) diff --git a/colabdesign/af/alphafold/model/geometry/vector.py b/colabdesign/af/alphafold/model/geometry/vector.py index 8b5e653b..97cb3b17 100644 --- a/colabdesign/af/alphafold/model/geometry/vector.py +++ b/colabdesign/af/alphafold/model/geometry/vector.py @@ -53,25 +53,25 @@ def __post_init__(self): assert all([x == z for x, z in zip(self.x.shape, self.z.shape)]) def __add__(self, other: Vec3Array) -> Vec3Array: - return jax.tree_map(lambda x, y: x + y, self, other) + return jax.tree_util.tree_map(lambda x, y: x + y, self, other) def __sub__(self, other: Vec3Array) -> Vec3Array: - return jax.tree_map(lambda x, y: x - y, self, other) + return jax.tree_util.tree_map(lambda x, y: x - y, self, other) def __mul__(self, other: Float) -> Vec3Array: - return jax.tree_map(lambda x: x * other, self) + return jax.tree_util.tree_map(lambda x: x * other, self) def __rmul__(self, other: Float) -> Vec3Array: return self * other def __truediv__(self, other: Float) -> Vec3Array: - return jax.tree_map(lambda x: x / other, self) + return jax.tree_util.tree_map(lambda x: x / other, self) def __neg__(self) -> Vec3Array: - return jax.tree_map(lambda x: -x, self) + return jax.tree_util.tree_map(lambda x: -x, self) def __pos__(self) -> Vec3Array: - return jax.tree_map(lambda x: x, self) + return jax.tree_util.tree_map(lambda x: x, self) def cross(self, other: Vec3Array) -> Vec3Array: """Compute cross product between 'self' and 'other'.""" diff --git a/colabdesign/af/alphafold/model/mapping.py b/colabdesign/af/alphafold/model/mapping.py index 1371b618..4bbb2e77 100644 --- a/colabdesign/af/alphafold/model/mapping.py +++ b/colabdesign/af/alphafold/model/mapping.py @@ -17,7 +17,7 @@ import functools import inspect -from typing import Any, Callable, Optional, Sequence, Union +from typing import Any, Callable, Optional, Sequence, TypeVar, Union import haiku as hk import jax @@ -30,6 +30,16 @@ partial = functools.partial PROXY = object() +T = TypeVar('T') +def _set_docstring(docstr: str) -> Callable[[T], T]: + """Decorator for setting the docstring of a function.""" + + def wrapped(fun: T) -> T: + fun.__doc__ = docstr.format(fun=getattr(fun, '__name__', repr(fun))) + return fun + + return wrapped + def _maybe_slice(array, i, slice_size, axis): if axis is PROXY: @@ -47,11 +57,11 @@ def _maybe_get_size(array, axis): def _expand_axes(axes, values, name='sharded_apply'): - values_tree_def = jax.tree_flatten(values)[1] + values_tree_def = jax.tree_util.tree_flatten(values)[1] flat_axes = jax.api_util.flatten_axes(name, values_tree_def, axes) # Replace None's with PROXY flat_axes = [PROXY if x is None else x for x in flat_axes] - return jax.tree_unflatten(values_tree_def, flat_axes) + return jax.tree_util.tree_unflatten(values_tree_def, flat_axes) def sharded_map( @@ -119,13 +129,14 @@ def sharded_apply( if shard_size is None: return fun - @jax.util.wraps(fun, docstr=docstr) + @_set_docstring(docstr) + @functools.wraps(fun) def mapped_fn(*args): # Expand in axes and Determine Loop range in_axes_ = _expand_axes(in_axes, args) in_sizes = jax.tree_util.tree_map(_maybe_get_size, args, in_axes_) - flat_sizes = jax.tree_flatten(in_sizes)[0] + flat_sizes = jax.tree_util.tree_flatten(in_sizes)[0] in_size = max(flat_sizes) assert all(i in {in_size, -1} for i in flat_sizes) @@ -143,14 +154,14 @@ def apply_fun_to_slice(slice_start, slice_size): remainder_shape_dtype = hk.eval_shape( partial(apply_fun_to_slice, 0, last_shard_size)) - out_dtypes = jax.tree_map(lambda x: x.dtype, remainder_shape_dtype) - out_shapes = jax.tree_map(lambda x: x.shape, remainder_shape_dtype) + out_dtypes = jax.tree_util.tree_map(lambda x: x.dtype, remainder_shape_dtype) + out_shapes = jax.tree_util.tree_map(lambda x: x.shape, remainder_shape_dtype) out_axes_ = _expand_axes(out_axes, remainder_shape_dtype) if num_extra_shards > 0: regular_shard_shape_dtype = hk.eval_shape( partial(apply_fun_to_slice, 0, shard_size)) - shard_shapes = jax.tree_map(lambda x: x.shape, regular_shard_shape_dtype) + shard_shapes = jax.tree_util.tree_map(lambda x: x.shape, regular_shard_shape_dtype) def make_output_shape(axis, shard_shape, remainder_shape): return shard_shape[:axis] + ( @@ -219,4 +230,4 @@ def run_module(*batched_args): shard_size=subbatch_size, in_axes=input_subbatch_dim, out_axes=output_subbatch_dim) - return sharded_module(*batched_args) + return sharded_module(*batched_args) \ No newline at end of file diff --git a/colabdesign/af/alphafold/model/model.py b/colabdesign/af/alphafold/model/model.py index 63d35919..3c3a2b47 100644 --- a/colabdesign/af/alphafold/model/model.py +++ b/colabdesign/af/alphafold/model/model.py @@ -81,7 +81,7 @@ def loop(prev, sub_key): keys = jax.random.split(key, self.config.model.num_recycle + 1) _, o = jax.lax.scan(loop, prev, keys) - results = jax.tree_map(lambda x:x[-1], o) + results = jax.tree_util.tree_map(lambda x:x[-1], o) if "add_prev" in self.mode: for k in ["distogram","predicted_lddt","predicted_aligned_error"]: diff --git a/colabdesign/af/alphafold/model/modules.py b/colabdesign/af/alphafold/model/modules.py index 592e5e0f..dac35e47 100644 --- a/colabdesign/af/alphafold/model/modules.py +++ b/colabdesign/af/alphafold/model/modules.py @@ -150,20 +150,28 @@ def __init__(self, config, name='alphafold'): def __call__(self, batch, **kwargs): """Run the AlphaFold model.""" impl = AlphaFoldIteration(self.config, self.global_config) - - def get_prev(ret): + def get_prev(ret, use_dgram=False): new_prev = { 'prev_msa_first_row': ret['representations']['msa_first_row'], 'prev_pair': ret['representations']['pair'], - 'prev_pos': ret['structure_module']['final_atom_positions'] } - if self.global_config.use_dgram: - new_prev['prev_dgram'] = ret["distogram"]["logits"] + if use_dgram: + if self.global_config.use_dgram_pred: + dgram = jax.nn.softmax(ret["distogram"]["logits"]) + dgram_map = jax.nn.one_hot(jnp.repeat(jnp.append(0,jnp.arange(15)),4),15).at[:,0].set(0) + new_prev['prev_dgram'] = dgram @ dgram_map + else: + pos = ret['structure_module']['final_atom_positions'] + prev_pseudo_beta = pseudo_beta_fn(batch['aatype'], pos, None) + new_prev['prev_dgram'] = dgram_from_positions(prev_pseudo_beta, min_bin=3.25, max_bin=20.75, num_bins=15) + else: + new_prev['prev_pos'] = ret['structure_module']['final_atom_positions'] + return new_prev prev = batch.pop("prev") ret = impl(batch={**batch, **prev}) - ret["prev"] = get_prev(ret) + ret["prev"] = get_prev(ret, use_dgram="prev_dgram" in prev) return ret class TemplatePairStack(hk.Module): @@ -1280,6 +1288,19 @@ def __call__(self, activations, masks, use_dropout, safe_key=None): safe_key, *sub_keys = safe_key.split(10) sub_keys = iter(sub_keys) + outer_module = OuterProductMean( + config=c.outer_product_mean, + global_config=self.global_config, + num_output_channel=int(pair_act.shape[-1]), + name='outer_product_mean') + if c.outer_product_mean.first: + pair_act = dropout_wrapper_fn( + outer_module, + msa_act, + msa_mask, + safe_key=next(sub_keys), + output_act=pair_act) + msa_act = dropout_wrapper_fn( MSARowAttentionWithPairBias( c.msa_row_attention_with_pair_bias, gc, @@ -1307,16 +1328,13 @@ def __call__(self, activations, masks, use_dropout, safe_key=None): msa_mask, safe_key=next(sub_keys)) - pair_act = dropout_wrapper_fn( - OuterProductMean( - config=c.outer_product_mean, - global_config=self.global_config, - num_output_channel=int(pair_act.shape[-1]), - name='outer_product_mean'), - msa_act, - msa_mask, - safe_key=next(sub_keys), - output_act=pair_act) + if not c.outer_product_mean.first: + pair_act = dropout_wrapper_fn( + outer_module, + msa_act, + msa_mask, + safe_key=next(sub_keys), + output_act=pair_act) pair_act = dropout_wrapper_fn( TriangleMultiplication(c.triangle_multiplication_outgoing, gc, @@ -1383,6 +1401,7 @@ def __call__(self, batch, safe_key=None): msa_feat = batch['msa_feat'].astype(dtype) target_feat = jnp.pad(batch["target_feat"].astype(dtype),[[0,0],[1,1]]) preprocess_1d = common_modules.Linear(c.msa_channel, name='preprocess_1d')(target_feat) + preprocess_1d = jnp.where(target_feat.sum(-1,keepdims=True) == 0, 0, preprocess_1d) preprocess_msa = common_modules.Linear(c.msa_channel, name='preprocess_msa')(msa_feat) msa_activations = preprocess_1d[None] + preprocess_msa @@ -1397,19 +1416,11 @@ def __call__(self, batch, safe_key=None): # Jumper et al. (2021) Suppl. Alg. 2 "Inference" line 6 # Jumper et al. (2021) Suppl. Alg. 32 "RecyclingEmbedder" - if gc.use_dgram: - # use predicted distogram input (from Sergey) - dgram = jax.nn.softmax(batch["prev_dgram"]) - dgram_map = jax.nn.one_hot(jnp.repeat(jnp.append(0,jnp.arange(15)),4),15).at[:,0].set(0) - dgram = dgram @ dgram_map - + if "prev_dgram" in batch: + dgram = batch["prev_dgram"] else: - # use predicted position input prev_pseudo_beta = pseudo_beta_fn(batch['aatype'], batch['prev_pos'], None) - if c.backprop_dgram: - dgram = dgram_from_positions_soft(prev_pseudo_beta, temp=c.backprop_dgram_temp, **c.prev_pos) - else: - dgram = dgram_from_positions(prev_pseudo_beta, **c.prev_pos) + dgram = dgram_from_positions(prev_pseudo_beta, **c.prev_pos) dgram = dgram.astype(dtype) pair_activations += common_modules.Linear(c.pair_channel, name='prev_pos_linear')(dgram) @@ -1438,6 +1449,9 @@ def __call__(self, batch, safe_key=None): else: pos = batch['residue_index'] offset = pos[:, None] - pos[None, :] + if "asym_id" in batch: + o = batch['asym_id'][:,None] - batch['asym_id'][None,:] + offset = jnp.where(o == 0, offset, jnp.where(o > 0, 2*c.max_relative_feature, 0)) rel_pos = jax.nn.one_hot( jnp.clip( offset + c.max_relative_feature, diff --git a/colabdesign/af/alphafold/model/modules_multimer.py b/colabdesign/af/alphafold/model/modules_multimer.py index 8822c6d8..de0fa70c 100644 --- a/colabdesign/af/alphafold/model/modules_multimer.py +++ b/colabdesign/af/alphafold/model/modules_multimer.py @@ -111,10 +111,8 @@ def __call__(self, self.heads[head_name] = (head_config, head_factory(head_config, self.global_config)) - structure_module_output = None - if 'entity_id' in batch and 'all_atom_positions' in batch: - _, fold_module = self.heads['structure_module'] - structure_module_output = fold_module(representations, batch) + _, fold_module = self.heads['structure_module'] + structure_module_output = fold_module(representations, batch) ret = {} @@ -178,12 +176,23 @@ def __call__( assert isinstance(batch, dict) num_res = batch['aatype'].shape[0] - def get_prev(ret): + def get_prev(ret, use_dgram=False): new_prev = { - 'prev_pos': ret['structure_module']['final_atom_positions'], 'prev_msa_first_row': ret['representations']['msa_first_row'], 'prev_pair': ret['representations']['pair'], } + if use_dgram: + if self.global_config.use_dgram_pred: + dgram = jax.nn.softmax(ret["distogram"]["logits"]) + dgram_map = jax.nn.one_hot(jnp.repeat(jnp.append(0,jnp.arange(15)),4),15).at[:,0].set(0) + new_prev['prev_dgram'] = dgram @ dgram_map + else: + pos = ret['structure_module']['final_atom_positions'] + prev_pseudo_beta = modules.pseudo_beta_fn(batch['aatype'], pos, None) + new_prev['prev_dgram'] = modules.dgram_from_positions(prev_pseudo_beta, min_bin=3.25, max_bin=20.75, num_bins=15) + else: + new_prev['prev_pos'] = ret['structure_module']['final_atom_positions'] + return new_prev def apply_network(prev, safe_key): @@ -191,9 +200,10 @@ def apply_network(prev, safe_key): return impl( batch=recycled_batch, safe_key=safe_key) - - ret = apply_network(prev=batch.pop("prev"), safe_key=safe_key) - ret["prev"] = get_prev(ret) + + prev = batch.pop("prev") + ret = apply_network(prev=prev, safe_key=safe_key) + ret["prev"] = get_prev(ret, use_dgram="prev_dgram" in prev) if not return_representations: del ret['representations'] @@ -315,8 +325,12 @@ def __call__(self, batch, safe_key=None): mask_2d = mask_2d.astype(dtype) if c.recycle_pos: - prev_pseudo_beta = modules.pseudo_beta_fn(batch['aatype'], batch['prev_pos'], None) - dgram = modules.dgram_from_positions(prev_pseudo_beta, **self.config.prev_pos) + if "prev_dgram" in batch: + dgram = batch["prev_dgram"] + else: + prev_pseudo_beta = modules.pseudo_beta_fn(batch['aatype'], batch['prev_pos'], None) + dgram = modules.dgram_from_positions(prev_pseudo_beta, **self.config.prev_pos) + dgram = dgram.astype(dtype) pair_activations += common_modules.Linear(c.pair_channel, name='prev_pos_linear')(dgram) diff --git a/colabdesign/af/alphafold/model/prng.py b/colabdesign/af/alphafold/model/prng.py index 64f348c9..b313f679 100644 --- a/colabdesign/af/alphafold/model/prng.py +++ b/colabdesign/af/alphafold/model/prng.py @@ -43,7 +43,7 @@ def split(self, num_keys=2): self._assert_not_used() self._used = True new_keys = jax.random.split(self._key, num_keys) - return jax.tree_map(SafeKey, tuple(new_keys)) + return jax.tree_util.tree_map(SafeKey, tuple(new_keys)) def duplicate(self, num_keys=2): self._assert_not_used() diff --git a/colabdesign/af/alphafold/model/quat_affine.py b/colabdesign/af/alphafold/model/quat_affine.py index 9ebcd20f..f21b4a96 100644 --- a/colabdesign/af/alphafold/model/quat_affine.py +++ b/colabdesign/af/alphafold/model/quat_affine.py @@ -301,8 +301,8 @@ def apply_to_point(self, point, extra_dims=0): translation = self.translation for _ in range(extra_dims): expand_fn = functools.partial(jnp.expand_dims, axis=-1) - rotation = jax.tree_map(expand_fn, rotation) - translation = jax.tree_map(expand_fn, translation) + rotation = jax.tree_util.tree_map(expand_fn, rotation) + translation = jax.tree_util.tree_map(expand_fn, translation) rot_point = apply_rot_to_vec(rotation, point) return [ @@ -327,8 +327,8 @@ def invert_point(self, transformed_point, extra_dims=0): translation = self.translation for _ in range(extra_dims): expand_fn = functools.partial(jnp.expand_dims, axis=-1) - rotation = jax.tree_map(expand_fn, rotation) - translation = jax.tree_map(expand_fn, translation) + rotation = jax.tree_util.tree_map(expand_fn, rotation) + translation = jax.tree_util.tree_map(expand_fn, translation) rot_point = [ transformed_point[0] - translation[0], diff --git a/colabdesign/af/alphafold/model/utils.py b/colabdesign/af/alphafold/model/utils.py index b59123c4..7970588d 100644 --- a/colabdesign/af/alphafold/model/utils.py +++ b/colabdesign/af/alphafold/model/utils.py @@ -85,14 +85,15 @@ def mask_mean(mask, value, axis=None, drop_mask_channel=False, eps=1e-10): return (jnp.sum(mask * value, axis=axis) / (jnp.sum(mask, axis=axis) * broadcast_factor + eps)) -def flat_params_to_haiku(params, fuse=None): +def flat_params_to_haiku(params, fuse=None, rm_templates=False): """Convert a dictionary of NumPy arrays to Haiku parameters.""" P = {} for path, array in params.items(): scope, name = path.split('//') - if scope not in P: - P[scope] = {} - P[scope][name] = jnp.array(array) + if not rm_templates or "template" not in scope: + if scope not in P: + P[scope] = {} + P[scope][name] = jnp.array(array) if fuse is not None: for a in ["evoformer_iteration", "extra_msa_stack", diff --git a/colabdesign/af/design.py b/colabdesign/af/design.py index 732b3d58..2653d1b4 100644 --- a/colabdesign/af/design.py +++ b/colabdesign/af/design.py @@ -3,8 +3,10 @@ import jax.numpy as jnp import numpy as np from colabdesign.af.alphafold.common import residue_constants +from colabdesign.af.utils import dgram_from_positions from colabdesign.shared.utils import copy_dict, update_dict, Key, dict_to_str, to_float, softmax, categorical, to_list, copy_missing + #################################################### # AF_DESIGN - design functions #################################################### @@ -22,7 +24,6 @@ #################################################### class _af_design: - def restart(self, seed=None, opt=None, weights=None, seq=None, mode=None, keep_history=False, reset_opt=True, **kwargs): ''' @@ -52,6 +53,8 @@ def restart(self, seed=None, opt=None, weights=None, self.set_weights(weights) # initialize sequence + if mode is None and not self._args["optimize_seq"]: + mode = "wildtype" self.set_seed(seed) self.set_seq(seq=seq, mode=mode, **kwargs) @@ -94,14 +97,14 @@ def run(self, num_recycles=None, num_models=None, sample_models=None, models=Non for n in model_nums: p = self._model_params[n] auxs.append(self._recycle(p, num_recycles=num_recycles, backprop=backprop)) - auxs = jax.tree_map(lambda *x: np.stack(x), *auxs) + auxs = jax.tree_util.tree_map(lambda *x: np.stack(x), *auxs) # update aux (average outputs) def avg_or_first(x): if np.issubdtype(x.dtype, np.integer): return x[0] else: return x.mean(0) - self.aux = jax.tree_map(avg_or_first, auxs) + self.aux = jax.tree_util.tree_map(avg_or_first, auxs) self.aux["atom_positions"] = auxs["atom_positions"][0] self.aux["all"] = auxs @@ -112,19 +115,22 @@ def avg_or_first(x): self.aux["log"] = {**self.aux["losses"]} self.aux["log"]["plddt"] = 1 - self.aux["log"]["plddt"] for k in ["loss","i_ptm","ptm"]: self.aux["log"][k] = self.aux[k] - for k in ["hard","soft","temp"]: self.aux["log"][k] = self.opt[k] - # compute sequence recovery - if self.protocol in ["fixbb","partial"] or (self.protocol == "binder" and self._args["redesign"]): - if self.protocol == "partial": - aatype = self.aux["aatype"][...,self.opt["pos"]] - else: - aatype = self.aux["seq"]["pseudo"].argmax(-1) + if self._args["optimize_seq"]: + # keep track of sequence mode + for k in ["hard","soft","temp"]: self.aux["log"][k] = self.opt[k] + + # compute sequence recovery + if self.protocol in ["fixbb","partial"] or (self.protocol == "binder" and self._args["redesign"]): + if self.protocol == "partial": + aatype = self.aux["aatype"][...,self.opt["pos"]] + else: + aatype = self.aux["seq"]["pseudo"].argmax(-1) - mask = self._wt_aatype != -1 - true = self._wt_aatype[mask] - pred = aatype[...,mask] - self.aux["log"]["seqid"] = (true == pred).mean() + mask = self._wt_aatype != -1 + true = self._wt_aatype[mask] + pred = aatype[...,mask] + self.aux["log"]["seqid"] = (true == pred).mean() self.aux["log"] = to_float(self.aux["log"]) self.aux["log"].update({"recycles":int(self.aux["num_recycles"]), @@ -140,7 +146,7 @@ def _single(self, model_params, backprop=True): (loss, aux), grad = self._model["grad_fn"](*flags) else: loss, aux = self._model["fn"](*flags) - grad = jax.tree_map(np.zeros_like, self._params) + grad = jax.tree_util.tree_map(np.zeros_like, self._params) aux.update({"loss":loss,"grad":grad}) return aux @@ -158,27 +164,45 @@ def _recycle(self, model_params, num_recycles=None, backprop=True): else: L = self._inputs["residue_index"].shape[0] - # intialize previous + # intialize previous inputs if "prev" not in self._inputs or a["clear_prev"]: prev = {'prev_msa_first_row': np.zeros([L,256]), - 'prev_pair': np.zeros([L,L,128])} - - if a["use_initial_guess"] and "batch" in self._inputs: - prev["prev_pos"] = self._inputs["batch"]["all_atom_positions"] + 'prev_pair': np.zeros([L,L,128])} + + # initialize coordinates + # TODO: add support for the 'partial' protocol + if "batch" in self._inputs: + ini_seq = self._inputs["batch"]["aatype"] + ini_pos = self._inputs["batch"]["all_atom_positions"] + + # via evoformer + if a["use_initial_guess"]: + # via distogram or positions + if a["use_dgram"] or a["use_dgram_pred"]: + prev["prev_dgram"] = dgram_from_positions(positions=ini_pos, + seq=ini_seq, num_bins=15, min_bin=3.25, max_bin=20.75) + else: + prev["prev_pos"] = ini_pos + else: + if a["use_dgram"] or a["use_dgram_pred"]: + prev["prev_dgram"] = np.zeros([L,L,15]) + else: + prev["prev_pos"] = np.zeros([L,37,3]) + + # via structure module + if a["use_initial_atom_pos"]: + self._inputs["initial_atom_pos"] = ini_pos + else: - prev["prev_pos"] = np.zeros([L,37,3]) - - if a["use_dgram"]: - # TODO: add support for initial_guess + use_dgram - prev["prev_dgram"] = np.zeros([L,L,64]) - - if a["use_initial_atom_pos"]: - if "batch" in self._inputs: - self._inputs["initial_atom_pos"] = self._inputs["batch"]["all_atom_positions"] + # if batch not defined, initialize with zeros + if a["use_dgram"] or a["use_dgram_pred"]: + prev["prev_dgram"] = np.zeros([L,L,15]) else: + prev["prev_pos"] = np.zeros([L,37,3]) + if a["use_initial_atom_pos"]: self._inputs["initial_atom_pos"] = np.zeros([L,37,3]) - self._inputs["prev"] = prev + self._inputs["prev"] = prev # decide which layers to compute gradients for cycles = (num_recycles + 1) mask = [0] * cycles @@ -195,12 +219,14 @@ def _recycle(self, model_params, num_recycles=None, backprop=True): aux = self._single(model_params, backprop=False) else: aux = self._single(model_params, backprop) - grad.append(jax.tree_map(lambda x:x*m, aux["grad"])) + grad.append(jax.tree_util.tree_map(lambda x:x*m, aux["grad"])) + + # update previous inputs self._inputs["prev"] = aux["prev"] if a["use_initial_atom_pos"]: - self._inputs["initial_atom_pos"] = aux["prev"]["prev_pos"] + self._inputs["initial_atom_pos"] = aux["atom_positions"] - aux["grad"] = jax.tree_map(lambda *x: np.stack(x).sum(0), *grad) + aux["grad"] = jax.tree_util.tree_map(lambda *x: np.stack(x).sum(0), *grad) aux["num_recycles"] = num_recycles return aux @@ -214,13 +240,14 @@ def step(self, lr_scale=1.0, num_recycles=None, self.run(num_recycles=num_recycles, num_models=num_models, sample_models=sample_models, models=models, backprop=backprop, callback=callback) - # modify gradients - if self.opt["norm_seq_grad"]: self._norm_seq_grad() + # modify gradients + if self._args["optimize_seq"] and self.opt["norm_seq_grad"]: + self._norm_seq_grad() self._state, self.aux["grad"] = self._optimizer(self._state, self.aux["grad"], self._params) # apply gradients lr = self.opt["learning_rate"] * lr_scale - self._params = jax.tree_map(lambda x,g:x-lr*g, self._params, self.aux["grad"]) + self._params = jax.tree_util.tree_map(lambda x,g:x-lr*g, self._params, self.aux["grad"]) # save results self._save_results(save_best=save_best, verbose=verbose) @@ -248,15 +275,22 @@ def _save_results(self, aux=None, save_best=False, verbose=True): if aux is None: aux = self.aux self._tmp["log"].append(aux["log"]) + + # update traj if (self._k % self._args["traj_iter"]) == 0: - # update traj - traj = {"seq": aux["seq"]["pseudo"], - "xyz": aux["atom_positions"][:,1,:], + traj = {"xyz": aux["atom_positions"][:,1,:], "plddt": aux["plddt"], "pae": aux["pae"]} + if self._args["optimize_seq"]: + traj["seq"] = aux["seq"]["pseudo"] + else: + traj["seq"] = self._inputs["msa_feat"][...,:20] + for k,v in traj.items(): + # rm traj (if max number reached) if len(self._tmp["traj"][k]) == self._args["traj_max"]: self._tmp["traj"][k].pop(0) + # add traj self._tmp["traj"][k].append(v) # save best @@ -279,15 +313,10 @@ def predict(self, seq=None, bias=None, return_aux=False, verbose=True, seed=None, **kwargs): '''predict structure for input sequence (if provided)''' - def load_settings(): - if "save" in self._tmp: - [self.opt, self._args, self._params, self._inputs] = self._tmp.pop("save") - - def save_settings(): - load_settings() - self._tmp["save"] = [copy_dict(x) for x in [self.opt, self._args, self._params, self._inputs]] - - save_settings() + # save current settings + if "save" in self._tmp: + [self.opt, self._args, self._params, self._inputs] = self._tmp.pop("save") + self._tmp["save"] = [copy_dict(x) for x in [self.opt, self._args, self._params, self._inputs]] # set seed if defined if seed is not None: self.set_seed(seed) @@ -295,14 +324,14 @@ def save_settings(): # set [seq]uence/[opt]ions if seq is not None: self.set_seq(seq=seq, bias=bias) self.set_opt(hard=hard, soft=soft, temp=temp, dropout=dropout, pssm_hard=True) - self.set_args(shuffle_first=False) # run self.run(num_recycles=num_recycles, num_models=num_models, sample_models=sample_models, models=models, backprop=False, **kwargs) if verbose: self._print_log("predict") - load_settings() + # load previous settings + [self.opt, self._args, self._params, self._inputs] = self._tmp.pop("save") # return (or save) results if return_aux: return self.aux diff --git a/colabdesign/af/inputs.py b/colabdesign/af/inputs.py index fe390a60..ae57ef1f 100644 --- a/colabdesign/af/inputs.py +++ b/colabdesign/af/inputs.py @@ -12,10 +12,19 @@ ############################################################################ class _af_inputs: - def _get_seq(self, inputs, aux, key=None): - params, opt = inputs["params"], inputs["opt"] + def set_seq(self, seq=None, mode=None, **kwargs): + if self._args["optimize_seq"]: + self._set_seq(seq=seq, mode=mode, **kwargs) + else: + seq,_ = self._set_seq(seq=seq, mode=mode, return_values=True, **kwargs) + update_seq(seq, self._inputs, use_jax=False) + update_aatype(seq, self._inputs, use_jax=False) + + def _update_seq(self, params, inputs, aux, key): '''get sequence features''' - seq = soft_seq(params["seq"], inputs["bias"], opt, key, num_seq=self._num, + opt = inputs["opt"] + k1, k2 = jax.random.split(key) + seq = soft_seq(params["seq"], inputs["bias"], opt, k1, num_seq=self._num, shuffle_first=self._args["shuffle_first"]) seq = self._fix_pos(seq) aux.update({"seq":seq, "seq_pseudo":seq["pseudo"]}) @@ -25,11 +34,24 @@ def _get_seq(self, inputs, aux, key=None): # concatenate target and binder sequence seq_target = jax.nn.one_hot(inputs["batch"]["aatype"][:self._target_len],self._args["alphabet_size"]) seq_target = jnp.broadcast_to(seq_target,(self._num, *seq_target.shape)) - seq = jax.tree_map(lambda x:jnp.concatenate([seq_target,x],1), seq) + seq = jax.tree_util.tree_map(lambda x:jnp.concatenate([seq_target,x],1), seq) if self.protocol in ["fixbb","hallucination","partial"] and self._args["copies"] > 1: - seq = jax.tree_map(lambda x:expand_copies(x, self._args["copies"], self._args["block_diag"]), seq) - + seq = jax.tree_util.tree_map(lambda x:expand_copies(x, self._args["copies"], self._args["block_diag"]), seq) + + # update sequence features + pssm = jnp.where(opt["pssm_hard"], seq["hard"], seq["pseudo"]) + if self._args["use_mlm"]: + shape = seq["pseudo"].shape[:2] + aux["mlm_mask"] = jax.random.bernoulli(k2,opt["mlm_dropout"],shape) + update_seq(seq, inputs, seq_pssm=pssm, mlm=aux["mlm_mask"], mask_target=self._args["mask_target"]) + else: + update_seq(seq, inputs, seq_pssm=pssm) + + # update amino acid sidechain identity + update_aatype(seq, inputs) + inputs["seq"] = aux["seq"] + return seq def _fix_pos(self, seq, return_p=False): @@ -42,7 +64,7 @@ def _fix_pos(self, seq, return_p=False): seq_ref = jax.nn.one_hot(self._wt_aatype,self._args["alphabet_size"]) p = self.opt["fix_pos"] fix_seq = lambda x: x.at[...,p,:].set(seq_ref[...,p,:]) - seq = jax.tree_map(fix_seq, seq) + seq = jax.tree_util.tree_map(fix_seq, seq) if return_p: return seq, p return seq @@ -106,33 +128,59 @@ def _update_template(self, inputs, key): inputs[k] = inputs[k].at[...,5:].set(jnp.where(rm_sc[:,None],0,inputs[k][...,5:])) inputs[k] = jnp.where(rm[:,None],0,inputs[k]) -def update_seq(seq, inputs, seq_1hot=None, seq_pssm=None, mlm=None): - '''update the sequence features''' - - if seq_1hot is None: seq_1hot = seq - if seq_pssm is None: seq_pssm = seq +def update_seq(seq, inputs, seq_1hot=None, seq_pssm=None, mlm=None, mask_target=False, use_jax=True): + '''update the sequence features''' + + _np = jnp if use_jax else np + if isinstance(seq, dict): + if seq_1hot is None: seq_1hot = seq["pseudo"] + if seq_pssm is None: seq_pssm = seq["pseudo"] + else: + if seq_1hot is None: seq_1hot = seq + if seq_pssm is None: seq_pssm = seq + target_feat = seq_1hot[0,:,:20] - seq_1hot = jnp.pad(seq_1hot,[[0,0],[0,0],[0,22-seq_1hot.shape[-1]]]) - seq_pssm = jnp.pad(seq_pssm,[[0,0],[0,0],[0,22-seq_pssm.shape[-1]]]) - msa_feat = jnp.zeros_like(inputs["msa_feat"]).at[...,0:22].set(seq_1hot).at[...,25:47].set(seq_pssm) + seq_1hot = _np.pad(seq_1hot,[[0,0],[0,0],[0,22-seq_1hot.shape[-1]]]) + seq_pssm = _np.pad(seq_pssm,[[0,0],[0,0],[0,22-seq_pssm.shape[-1]]]) + if use_jax: + msa_feat = jnp.zeros_like(inputs["msa_feat"]).at[...,0:22].set(seq_1hot).at[...,25:47].set(seq_pssm) + else: + msa_feat = np.zeros_like(inputs["msa_feat"]) + msa_feat[...,0:22] = seq_1hot + msa_feat[...,25:47] = seq_pssm + # masked language modeling (randomly mask positions) - if mlm is not None: - X = jax.nn.one_hot(22,23) - X = jnp.zeros(msa_feat.shape[-1]).at[...,:23].set(X).at[...,25:48].set(X) - msa_feat = jnp.where(mlm[...,None],X,msa_feat) + if mlm is not None: + X = _np.eye(23)[22] + if use_jax: + Y = jnp.zeros(msa_feat.shape[-1]).at[...,:23].set(X).at[...,25:48].set(X) + else: + Y = np.zeros(msa_feat.shape[-1]) + Y[...,:23] = X + Y[...,25:48] = X + + msa_feat = _np.where(mlm[...,None],Y,msa_feat) + seq["pseudo"] = _np.where(mlm[...,None],X[:seq["pseudo"].shape[-1]],seq["pseudo"]) + if mask_target: + target_feat = _np.where(mlm[0,:,None],0,target_feat) inputs.update({"msa_feat":msa_feat, "target_feat":target_feat}) -def update_aatype(aatype, inputs): +def update_aatype(seq, inputs, use_jax=True): + _np = jnp if use_jax else np + if isinstance(seq,dict): + aatype = seq["pseudo"][0].argmax(-1) + else: + aatype = seq[0].argmax(-1) r = residue_constants a = {"atom14_atom_exists":r.restype_atom14_mask, "atom37_atom_exists":r.restype_atom37_mask, "residx_atom14_to_atom37":r.restype_atom14_to_atom37, "residx_atom37_to_atom14":r.restype_atom37_to_atom14} mask = inputs["seq_mask"][:,None] - inputs.update(jax.tree_map(lambda x:jnp.where(mask,jnp.asarray(x)[aatype],0),a)) + inputs.update(jax.tree_util.tree_map(lambda x:_np.where(mask,_np.asarray(x)[aatype],0),a)) inputs["aatype"] = aatype def expand_copies(x, copies, block_diag=True): diff --git a/colabdesign/af/loss.py b/colabdesign/af/loss.py index f8c28301..e62bd541 100644 --- a/colabdesign/af/loss.py +++ b/colabdesign/af/loss.py @@ -36,6 +36,7 @@ def _loss_binder(self, inputs, outputs, aux): '''get losses''' opt = inputs["opt"] mask = inputs["seq_mask"] + mask_2d = mask[:,None] * mask[None,:] zeros = jnp.zeros_like(mask) tL,bL = self._target_len, self._binder_len binder_id = zeros.at[-bL:].set(mask[-bL:]) @@ -72,8 +73,8 @@ def _loss_binder(self, inputs, outputs, aux): aux["losses"].update({ "rmsd": aln["rmsd"], - "dgram_cce": cce[-bL:].sum() / (mask[-bL:].sum() + 1e-8), - "fape": fape[-bL:].sum() / (mask[-bL:].sum() + 1e-8) + "dgram_cce": cce[-bL:].sum() / (mask_2d[-bL:].sum() + 1e-8), + "fape": fape[-bL:].sum() / (mask_2d[-bL:].sum() + 1e-8) }) else: @@ -91,7 +92,7 @@ def _loss_partial(self, inputs, outputs, aux): pos = (jnp.repeat(pos,C).reshape(-1,C) + jnp.arange(C) * L).T.flatten() def sub(x, axis=0): - return jax.tree_map(lambda y:jnp.take(y,pos,axis),x) + return jax.tree_util.tree_map(lambda y:jnp.take(y,pos,axis),x) copies = self._args["copies"] if self._args["homooligomer"] else 1 aatype = sub(inputs["aatype"]) @@ -158,7 +159,7 @@ def _loss_unsupervised(self, inputs, outputs, aux): else: mask_1d = inputs["seq_mask"] - seq_mask_2d = inputs["seq_mask"][:,None] * inputs["seq_mask"][None,:] + seq_mask_2d = mask_1d[:,None] * mask_1d[None,:] mask_2d = inputs["asym_id"][:,None] == inputs["asym_id"][None,:] masks = {"mask_1d":mask_1d, "mask_2d":jnp.where(seq_mask_2d,mask_2d,0)} @@ -173,8 +174,8 @@ def _loss_unsupervised(self, inputs, outputs, aux): } # define losses at interface - if self._args["copies"] > 1 and not self._args["repeat"]: - masks = {"mask_1d": mask_1d if self._args["homooligomer"] else inputs["seq_mask"], + if len(self._lengths) > 1: + masks = {"mask_1d": mask_1d, "mask_2d": jnp.where(seq_mask_2d,mask_2d == False,0)} losses.update({ "i_pae": get_pae_loss(outputs, **masks), @@ -409,17 +410,17 @@ def _get_pw_loss(true, pred, loss_fn, weights=None, copies=1, return_mtx=False): (L,C) = (length//copies, copies-1) # intra (L,L,F) - intra = jax.tree_map(lambda x:x[:L,:L], F) + intra = jax.tree_util.tree_map(lambda x:x[:L,:L], F) mtx, loss = loss_fn(**intra) # inter (C*L,L,F) - inter = jax.tree_map(lambda x:x[L:,:L], F) + inter = jax.tree_util.tree_map(lambda x:x[L:,:L], F) if C == 0: i_mtx, i_loss = loss_fn(**inter) else: # (C,L,L,F) - inter = jax.tree_map(lambda x:x.reshape(C,L,L,-1), inter) + inter = jax.tree_util.tree_map(lambda x:x.reshape(C,L,L,-1), inter) inter = {"t":inter["t"][:,None], # (C,1,L,L,F) "p":inter["p"][None,:], # (1,C,L,L,F) "m":inter["m"][:,None,:,:,0]} # (C,1,L,L) @@ -530,7 +531,7 @@ def get_seq_ent_loss(inputs): opt = inputs["opt"] x = inputs["seq"]["logits"] / opt["temp"] ent = -(jax.nn.softmax(x) * jax.nn.log_softmax(x)).sum(-1) - mask = inputs["seq_mask"][-x.shape[0]:] + mask = inputs["seq_mask"][-x.shape[1]:] if "fix_pos" in opt: if "pos" in opt: p = opt["pos"][opt["fix_pos"]] @@ -540,8 +541,11 @@ def get_seq_ent_loss(inputs): ent = (ent * mask).sum() / (mask.sum() + 1e-8) return {"seq_ent":ent.mean()} -def get_mlm_loss(outputs, mask, truth=None): +def get_mlm_loss(outputs, mask, truth=None, unbias=False): x = outputs["masked_msa"]["logits"][...,:20] + if unbias: + x_mean = (x * mask[...,None]).sum((0,1)) / (mask.sum() + 1e-8) + x = x - x_mean if truth is None: truth = jax.nn.softmax(x) ent = -(truth[...,:20] * jax.nn.log_softmax(x)).sum(-1) ent = (ent * mask).sum(-1) / (mask.sum() + 1e-8) diff --git a/colabdesign/af/model.py b/colabdesign/af/model.py index 37053bda..432d4f03 100644 --- a/colabdesign/af/model.py +++ b/colabdesign/af/model.py @@ -14,7 +14,7 @@ from colabdesign.af.loss import get_contact_map, get_seq_ent_loss, get_mlm_loss from colabdesign.af.utils import _af_utils from colabdesign.af.design import _af_design -from colabdesign.af.inputs import _af_inputs, update_seq, update_aatype +from colabdesign.af.inputs import _af_inputs ################################################################ # MK_DESIGN_MODEL - initialize model, and put it all together @@ -32,12 +32,15 @@ def __init__(self, self.protocol = protocol self._num = kwargs.pop("num_seq",1) - self._args = {"use_templates":use_templates, "use_multimer":use_multimer, "use_bfloat16":True, - "recycle_mode":"last", "use_mlm": False, "realign": True, + self._args = {"use_templates":use_templates, "num_templates":0, + "use_multimer":use_multimer, "use_bfloat16":True, + "optimize_seq":True, "recycle_mode":"last", + "use_mlm": False, "mask_target":False, "unbias_mlm": False, + "realign": True, "debug":debug, "repeat":False, "homooligomer":False, "copies":1, "optimizer":"sgd", "best_metric":"loss", "traj_iter":1, "traj_max":10000, - "clear_prev": True, "use_dgram":False, + "clear_prev": True, "use_dgram":False, "use_dgram_pred":False, "shuffle_first":True, "use_remat":True, "alphabet_size":20, "use_initial_guess":False, "use_initial_atom_pos":False} @@ -66,6 +69,9 @@ def __init__(self, if k in self._args: self._args[k] = kwargs.pop(k) if k in self.opt: self.opt[k] = kwargs.pop(k) + if self._args["use_templates"] and self._args["num_templates"] == 0: + self._args["num_templates"] = 1 + # collect callbacks self._callbacks = {"model": {"pre": kwargs.pop("pre_callback",None), "post":kwargs.pop("post_callback",None), @@ -89,20 +95,21 @@ def __init__(self, # configure AlphaFold ############################# if self._args["use_multimer"]: - self._cfg = config.model_config("model_1_multimer") - # TODO - self.opt["pssm_hard"] = True + self._cfg = config.model_config("model_1_multimer") + self.opt["pssm_hard"] = True # TODO else: self._cfg = config.model_config("model_1_ptm" if self._args["use_templates"] else "model_3_ptm") - + + if self._args["recycle_mode"] in ["average","first","last","sample"]: num_recycles = 0 else: num_recycles = self.opt["num_recycles"] self._cfg.model.num_recycle = num_recycles self._cfg.model.global_config.use_remat = self._args["use_remat"] - self._cfg.model.global_config.use_dgram = self._args["use_dgram"] + self._cfg.model.global_config.use_dgram_pred = self._args["use_dgram_pred"] self._cfg.model.global_config.bfloat16 = self._args["use_bfloat16"] + self._cfg.model.embeddings_and_evoformer.template.enabled = self._args["use_templates"] # load model_params if model_names is None: @@ -117,10 +124,9 @@ def __init__(self, self._model_params, self._model_names = [],[] for model_name in model_names: - params = data.get_model_haiku_params(model_name=model_name, data_dir=data_dir, fuse=True) + params = data.get_model_haiku_params(model_name=model_name, data_dir=data_dir, + fuse=True, rm_templates=not self._args["use_templates"]) if params is not None: - if not self._args["use_multimer"] and not self._args["use_templates"]: - params = {k:v for k,v in params.items() if "template" not in k} self._model_params.append(params) self._model_names.append(model_name) else: @@ -142,9 +148,8 @@ def _get_model(self, cfg, callback=None): # setup function to get gradients def _model(params, model_params, inputs, key): - inputs["params"] = params + opt = inputs["opt"] - aux = {} key = Key(key=key).get @@ -152,25 +157,15 @@ def _model(params, model_params, inputs, key): # INPUTS ####################################################################### # get sequence - seq = self._get_seq(inputs, aux, key()) - - # update sequence features - pssm = jnp.where(opt["pssm_hard"], seq["hard"], seq["pseudo"]) - if a["use_mlm"]: - shape = seq["pseudo"].shape[:2] - mlm = jax.random.bernoulli(key(),opt["mlm_dropout"],shape) - update_seq(seq["pseudo"], inputs, seq_pssm=pssm, mlm=mlm) + if a["optimize_seq"]: + seq = self._update_seq(params, inputs, aux, key()) else: - update_seq(seq["pseudo"], inputs, seq_pssm=pssm) - - # update amino acid sidechain identity - update_aatype(seq["pseudo"][0].argmax(-1), inputs) + # TODO + inputs["seq"] = seq = None # define masks inputs["msa_mask"] = jnp.where(inputs["seq_mask"],inputs["msa_mask"],0) - inputs["seq"] = aux["seq"] - # update template features inputs["mask_template_interchain"] = opt["template"]["rm_ic"] if a["use_templates"]: @@ -182,10 +177,14 @@ def _model(params, model_params, inputs, key): if "batch" not in inputs: inputs["batch"] = None + # optimize model params + if "model_params" in params: + model_params.update(params["model_params"]) + # pre callback for fn in self._callbacks["model"]["pre"]: - fn_args = {"inputs":inputs, "opt":opt, "aux":aux, - "seq":seq, "key":key(), "params":params} + fn_args = {"inputs":inputs, "opt":opt, "aux":aux, "seq":seq, + "key":key(), "params":params, "model_params":model_params} sub_args = {k:fn_args.get(k,None) for k in signature(fn).parameters} fn(**sub_args) @@ -216,13 +215,14 @@ def _model(params, model_params, inputs, key): self._get_loss(inputs=inputs, outputs=outputs, aux=aux) # sequence entropy loss - aux["losses"].update(get_seq_ent_loss(inputs)) - - # experimental masked-language-modeling - if a["use_mlm"]: - aux["mlm"] = outputs["masked_msa"]["logits"] - mask = jnp.where(inputs["seq_mask"],mlm,0) - aux["losses"].update(get_mlm_loss(outputs, mask=mask, truth=seq["pssm"])) + if a["optimize_seq"]: + aux["losses"].update(get_seq_ent_loss(inputs)) + # experimental masked-language-modeling + if a["use_mlm"]: + aux["mlm"] = outputs["masked_msa"]["logits"] + mask = jnp.where(inputs["seq_mask"],aux["mlm_mask"],0) + aux["losses"].update(get_mlm_loss(outputs, mask=mask, + truth=seq["pssm"], unbias=a["unbias_mlm"])) # run user defined callbacks for c in ["loss","post"]: diff --git a/colabdesign/af/prep.py b/colabdesign/af/prep.py index a8691d9b..f53eddd2 100644 --- a/colabdesign/af/prep.py +++ b/colabdesign/af/prep.py @@ -34,12 +34,13 @@ def _prep_model(self, **kwargs): self._opt = copy_dict(self.opt) self.restart(**kwargs) - def _prep_features(self, num_res, num_seq=None, num_templates=1): + def _prep_features(self, num_res, num_seq=None, num_templates=None): '''process features''' if num_seq is None: num_seq = self._num + if num_templates is None: num_templates = self._args["num_templates"] return prep_input_features(L=num_res, N=num_seq, T=num_templates) - def _prep_fixbb(self, pdb_filename, chain=None, + def _prep_fixbb(self, pdb_filename, chain="A", copies=1, repeat=False, homooligomer=False, rm_template=False, rm_template_seq=True, @@ -59,29 +60,28 @@ def _prep_fixbb(self, pdb_filename, chain=None, -ignore_missing=True - skip positions that have missing density (no CA coordinate) --------------------------------------------------- ''' + if isinstance(chain,str): chain = chain.split(",") + if homooligomer and copies == 1: copies = len(chain) + # prep features self._pdb = prep_pdb(pdb_filename, chain=chain, ignore_missing=ignore_missing, offsets=kwargs.pop("pdb_offsets",None), lengths=kwargs.pop("pdb_lengths",None)) self._len = self._pdb["residue_index"].shape[0] - self._lengths = [self._len] + self._lengths = self._pdb["lengths"] - # feat dims - num_seq = self._num - res_idx = self._pdb["residue_index"] # get [pos]itions of interests if fix_pos is not None and fix_pos != "": self._pos_info = prep_pos(fix_pos, **self._pdb["idx"]) self.opt["fix_pos"] = self._pos_info["pos"] - if homooligomer and chain is not None and copies == 1: - copies = len(chain.split(",")) - + num_seq = self._num + res_idx = self._pdb["residue_index"] + # repeat/homo-oligomeric support if copies > 1: - if repeat or homooligomer: self._len = self._len // copies if "fix_pos" in self.opt: @@ -97,35 +97,37 @@ def _prep_fixbb(self, pdb_filename, chain=None, res_idx = repeat_idx(res_idx[:self._len], copies) num_seq = (self._num * copies + 1) if block_diag else self._num - self.opt["weights"].update({"i_pae":0.0, "i_con":0.0}) self._args.update({"copies":copies, "repeat":repeat, "homooligomer":homooligomer, "block_diag":block_diag}) homooligomer = not repeat - else: - self._lengths = self._pdb["lengths"] # configure input features self._inputs = self._prep_features(num_res=sum(self._lengths), num_seq=num_seq) self._inputs["residue_index"] = res_idx - self._inputs["batch"] = make_fixed_size(self._pdb["batch"], num_res=sum(self._lengths)) + self._inputs["batch"] = make_fixed_size(self._pdb["batch"], + num_res=sum(self._lengths), num_templates=self._args["num_templates"]) self._inputs.update(get_multi_id(self._lengths, homooligomer=homooligomer)) # configure options/weights self.opt["weights"].update({"dgram_cce":1.0, "rmsd":0.0, "fape":0.0, "con":0.0}) + if len(self._lengths) > 1: + self.opt["weights"].update({"i_pae":0.0, "i_con":0.0}) self._wt_aatype = self._inputs["batch"]["aatype"][:self._len] - # configure template [opt]ions - rm,L = {},sum(self._lengths) - for n,x in {"rm_template": rm_template, + # configure template masks + if self._args["use_templates"]: + rm_dict = {} + rm_opt = {"rm_template": rm_template, "rm_template_seq":rm_template_seq, - "rm_template_sc": rm_template_sc}.items(): - rm[n] = np.full(L,False) - if isinstance(x,str): - rm[n][prep_pos(x,**self._pdb["idx"])["pos"]] = True - else: - rm[n][:] = x - self.opt["template"]["rm_ic"] = rm_template_ic - self._inputs.update(rm) + "rm_template_sc": rm_template_sc} + for n,x in rm_opt.items(): + rm_dict[n] = np.full(sum(self._lengths),False) + if isinstance(x,str): + rm_dict[n][prep_pos(x,**self._pdb["idx"])["pos"]] = True + else: + rm_dict[n][:] = x + self._inputs.update(rm_dict) + self.opt["template"]["rm_ic"] = rm_template_ic self._prep_model(**kwargs) @@ -143,11 +145,17 @@ def _prep_hallucination(self, length=100, copies=1, repeat=False, **kwargs): (num_seq, block_diag) = (self._num * copies + 1, True) else: (num_seq, block_diag) = (self._num, False) + self._args.update({"repeat":repeat,"block_diag":block_diag,"copies":copies}) # prep features - self._len = length + if isinstance(length,list): + self._len = sum(length) + else: + self._len = length + length = [length] + res_idx = np.array([i + 50 * n - n for n, L in enumerate(length) for i in range(L)]) # set weights self.opt["weights"].update({"con":1.0}) @@ -161,15 +169,16 @@ def _prep_hallucination(self, length=100, copies=1, repeat=False, **kwargs): self._lengths = [self._len] * copies self.opt["weights"].update({"i_pae":0.0, "i_con":1.0}) self._args["homooligomer"] = True - res_idx = repeat_idx(np.arange(length), copies, offset=offset) + res_idx = repeat_idx(res_idx, copies, offset=offset) + homooligomer = True else: - self._lengths = [self._len] - res_idx = np.arange(length) + self._lengths = length + homooligomer = False # configure input features self._inputs = self._prep_features(num_res=sum(self._lengths), num_seq=num_seq) self._inputs["residue_index"] = res_idx - self._inputs.update(get_multi_id(self._lengths, homooligomer=True)) + self._inputs.update(get_multi_id(self._lengths, homooligomer=homooligomer)) self._prep_model(**kwargs) @@ -200,70 +209,83 @@ def _prep_binder(self, pdb_filename, -ignore_missing=True - skip positions that have missing density (no CA coordinate) --------------------------------------------------- ''' - redesign = binder_chain is not None - rm_binder = not kwargs.pop("use_binder_template", not rm_binder) - self._args.update({"redesign":redesign}) - # get pdb info target_chain = kwargs.pop("chain",target_chain) # backward comp - chains = f"{target_chain},{binder_chain}" if redesign else target_chain - im = [True] * len(target_chain.split(",")) - if redesign: im += [ignore_missing] * len(binder_chain.split(",")) + if isinstance(target_chain,str): target_chain = target_chain.split(",") + if isinstance(binder_chain,str): binder_chain = binder_chain.split(",") + + # decide how to parse chains + if binder_chain is None or len(binder_chain) == 0: + # binder hallucination + redesign = False + chains = target_chain + ignore_missing = [True] * len(target_chain) + else: + # binder redesign + redesign = True + chains = target_chain + binder_chain + ignore_missing = [True] * len(target_chain) + [ignore_missing] * len(binder_chain) + self._args.update({"redesign":redesign}) - self._pdb = prep_pdb(pdb_filename, chain=chains, ignore_missing=im) - res_idx = self._pdb["residue_index"] + # parse pdb + self._pdb = prep_pdb(pdb_filename, chain=chains, ignore_missing=ignore_missing) + target_len = [(self._pdb["idx"]["chain"] == c).sum() for c in target_chain] - if redesign: - self._target_len = sum([(self._pdb["idx"]["chain"] == c).sum() for c in target_chain.split(",")]) - self._binder_len = sum([(self._pdb["idx"]["chain"] == c).sum() for c in binder_chain.split(",")]) - else: - self._target_len = self._pdb["residue_index"].shape[0] - self._binder_len = binder_len - res_idx = np.append(res_idx, res_idx[-1] + np.arange(binder_len) + 50) + # get lengths + if redesign: + binder_len = [(self._pdb["idx"]["chain"] == c).sum() for c in binder_chain] + elif not isinstance(binder_len,list): + binder_len = [binder_len] - self._len = self._binder_len - self._lengths = [self._target_len, self._binder_len] + + self._target_len = sum(target_len) + self._binder_len = self._len = sum(binder_len) + self._lengths = target_len + binder_len # gather hotspot info if hotspot is not None: self.opt["hotspot"] = prep_pos(hotspot, **self._pdb["idx"])["pos"] if redesign: - # binder redesign self._wt_aatype = self._pdb["batch"]["aatype"][self._target_len:] self.opt["weights"].update({"dgram_cce":1.0, "rmsd":0.0, "fape":0.0, "con":0.0, "i_con":0.0, "i_pae":0.0}) else: - # binder hallucination - self._pdb["batch"] = make_fixed_size(self._pdb["batch"], num_res=sum(self._lengths)) + self._pdb["batch"] = make_fixed_size(self._pdb["batch"], + num_res=sum(self._lengths), num_templates=self._args["num_templates"]) self.opt["weights"].update({"plddt":0.1, "con":0.0, "i_con":1.0, "i_pae":0.0}) # configure input features self._inputs = self._prep_features(num_res=sum(self._lengths), num_seq=1) - self._inputs["residue_index"] = res_idx self._inputs["batch"] = self._pdb["batch"] self._inputs.update(get_multi_id(self._lengths)) - # configure template rm masks - (T,L,rm) = (self._lengths[0],sum(self._lengths),{}) + # configure residue index + res_idx = self._pdb["residue_index"] + if not redesign: + binder_res_idx = np.array([i + 50 * n - n for n, L in enumerate(binder_len) for i in range(L)]) + res_idx = np.append(res_idx, res_idx[-1] + 50 + binder_res_idx) + self._inputs["residue_index"] = res_idx + + # configure template masks + rm_binder = not kwargs.pop("use_binder_template", not rm_binder) # backward comp + rm_dict = {} rm_opt = { "rm_template": {"target":rm_target, "binder":rm_binder}, "rm_template_seq":{"target":rm_target_seq,"binder":rm_binder_seq}, "rm_template_sc": {"target":rm_target_sc, "binder":rm_binder_sc} } for n,x in rm_opt.items(): - rm[n] = np.full(L,False) + rm_dict[n] = np.full(sum(self._lengths), False) for m,y in x.items(): if isinstance(y,str): - rm[n][prep_pos(y,**self._pdb["idx"])["pos"]] = True + rm_dict[n][prep_pos(y,**self._pdb["idx"])["pos"]] = True else: - if m == "target": rm[n][:T] = y - if m == "binder": rm[n][T:] = y - - # set template [opt]ions + if m == "target": rm_dict[n][:self._target_len] = y + if m == "binder": rm_dict[n][self._target_len:] = y + self._inputs.update(rm_dict) self.opt["template"]["rm_ic"] = rm_template_ic - self._inputs.update(rm) self._prep_model(**kwargs) @@ -288,8 +310,8 @@ def _prep_partial(self, pdb_filename, chain=None, length=None, ''' # prep features self._pdb = prep_pdb(pdb_filename, chain=chain, ignore_missing=ignore_missing, - offsets=kwargs.pop("pdb_offsets",None), - lengths=kwargs.pop("pdb_lengths",None)) + offsets=kwargs.pop("pdb_offsets",None), + lengths=kwargs.pop("pdb_lengths",None)) self._pdb["len"] = sum(self._pdb["lengths"]) @@ -341,7 +363,7 @@ def _prep_partial(self, pdb_filename, chain=None, length=None, # configure input features self._inputs = self._prep_features(num_res=sum(self._lengths), num_seq=num_seq) self._inputs["residue_index"] = res_idx - self._inputs["batch"] = jax.tree_map(lambda x:x[self._pdb["pos"]], self._pdb["batch"]) + self._inputs["batch"] = jax.tree_util.tree_map(lambda x:x[self._pdb["pos"]], self._pdb["batch"]) self._inputs.update(get_multi_id(self._lengths, homooligomer=homooligomer)) # configure options/weights @@ -372,10 +394,11 @@ def _prep_partial(self, pdb_filename, chain=None, length=None, self.opt["fix_pos"] = np.arange(self.opt["pos"].shape[0]) self._wt_aatype_sub = self._wt_aatype - self.opt["template"].update({"rm_ic":rm_template_ic}) - self._inputs.update({"rm_template": rm_template, - "rm_template_seq": rm_template_seq, - "rm_template_sc": rm_template_sc}) + if self._args["use_templates"]: + self.opt["template"].update({"rm_ic":rm_template_ic}) + self._inputs.update({"rm_template": rm_template, + "rm_template_seq": rm_template_seq, + "rm_template_sc": rm_template_sc}) self._prep_model(**kwargs) @@ -406,11 +429,7 @@ def add_cb(batch): batch["all_atom_mask"][...,cb] = (m[:,cb] + cb_mask) > 0 return {"atoms":batch["all_atom_positions"][:,cb],"mask":cb_mask} - if isinstance(chain,str) and "," in chain: - chains = chain.split(",") - elif not isinstance(chain,list): - chains = [chain] - + chains = chain.split(",") if isinstance(chain,str) else chain o,last = [],0 residue_idx, chain_idx = [],[] full_lengths = [] @@ -428,8 +447,8 @@ def add_cb(batch): im = ignore_missing[n] if isinstance(ignore_missing,list) else ignore_missing if im: - r = batch["all_atom_mask"][:,0] == 1 - batch = jax.tree_map(lambda x:x[r], batch) + r = batch["all_atom_mask"][:,residue_constants.atom_order["CA"]] == 1 + batch = jax.tree_util.tree_map(lambda x:x[r], batch) residue_index = batch["residue_index"] + last else: @@ -479,8 +498,8 @@ def make_fixed_size(feat, num_res, num_seq=1, num_templates=1): } for k,v in feat.items(): if k == "batch": - feat[k] = make_fixed_size(v, num_res) - else: + feat[k] = make_fixed_size(v, num_res, num_seq, num_templates) + elif k in shape_schema: shape = list(v.shape) schema = shape_schema[k] assert len(shape) == len(schema), ( @@ -529,7 +548,7 @@ def get_sc_pos(aa_ident, atoms_to_exclude=None): return {"pos":pos, "pos_alt":pos_alt, "non_amb":non_amb, "weight":w, "weight_non_amb":w_na[:,None]} -def prep_input_features(L, N=1, T=1, eN=1): +def prep_input_features(L, N=1, T=0, eN=1): ''' given [L]ength, [N]umber of sequences and number of [T]emplates return dictionary of blank features @@ -546,10 +565,6 @@ def prep_input_features(L, N=1, T=1, eN=1): 'seq_mask': np.ones(L), 'msa_mask': np.ones((N,L)), 'msa_row_mask': np.ones(N), - 'atom14_atom_exists': np.zeros((L,14)), - 'atom37_atom_exists': np.zeros((L,37)), - 'residx_atom14_to_atom37': np.zeros((L,14),int), - 'residx_atom37_to_atom14': np.zeros((L,37),int), 'residue_index': np.arange(L), 'extra_deletion_value': np.zeros((eN,L)), 'extra_has_deletion': np.zeros((eN,L)), @@ -557,6 +572,18 @@ def prep_input_features(L, N=1, T=1, eN=1): 'extra_msa_mask': np.zeros((eN,L)), 'extra_msa_row_mask': np.zeros(eN), + # for alphafold-ptm + 'atom14_atom_exists': np.zeros((L,14)), + 'atom37_atom_exists': np.zeros((L,37)), + 'residx_atom14_to_atom37': np.zeros((L,14),int), + 'residx_atom37_to_atom14': np.zeros((L,37),int), + + # for alphafold-multimer + 'asym_id': np.zeros(L), + 'sym_id': np.zeros(L), + 'entity_id': np.zeros(L)} + if T > 0: + inputs.update({ # for template inputs 'template_aatype': np.zeros((T,L),int), 'template_all_atom_mask': np.zeros((T,L,37)), @@ -564,12 +591,7 @@ def prep_input_features(L, N=1, T=1, eN=1): 'template_mask': np.zeros(T), 'template_pseudo_beta': np.zeros((T,L,3)), 'template_pseudo_beta_mask': np.zeros((T,L)), - - # for alphafold-multimer - 'asym_id': np.zeros(L), - 'sym_id': np.zeros(L), - 'entity_id': np.zeros(L), - 'all_atom_positions': np.zeros((N,37,3))} + }) return inputs def get_multi_id(lengths, homooligomer=False): diff --git a/colabdesign/af/utils.py b/colabdesign/af/utils.py index ca9dd725..6b64da08 100644 --- a/colabdesign/af/utils.py +++ b/colabdesign/af/utils.py @@ -8,7 +8,7 @@ from colabdesign.shared.utils import update_dict, Key from colabdesign.shared.plot import plot_pseudo_3D, make_animation, show_pdb from colabdesign.shared.protein import renum_pdb_str -from colabdesign.af.alphafold.common import protein +from colabdesign.af.alphafold.common import protein, residue_constants #################################################### # AF_UTILS - various utils (save, plot, etc) @@ -39,7 +39,7 @@ def set_args(self, **kwargs): ''' set [arg]uments ''' - for k in ["best_metric", "traj_iter", "shuffle_first"]: + for k in ["best_metric", "traj_iter"]: if k in kwargs: self._args[k] = kwargs.pop(k) if "recycle_mode" in kwargs: @@ -83,7 +83,7 @@ def to_pdb_str(x, n=None): p_str = "" for n in range(p["atom_positions"].shape[0]): - p_str += to_pdb_str(jax.tree_map(lambda x:x[n],p), n+1) + p_str += to_pdb_str(jax.tree_util.tree_map(lambda x:x[n],p), n+1) p_str += "END\n" if filename is None: @@ -186,4 +186,19 @@ def plot_current_pdb(self, show_sidechains=False, show_mainchains=False, - color=["pLDDT","chain","rainbow"] ''' self.plot_pdb(show_sidechains=show_sidechains, show_mainchains=show_mainchains, color=color, - color_HP=color_HP, size=size, animate=animate, get_best=False) \ No newline at end of file + color_HP=color_HP, size=size, animate=animate, get_best=False) + +def dgram_from_positions(positions, seq=None, num_bins=39, min_bin=3.25, max_bin=50.75): + if seq is None: + atoms = {k:positions[...,residue_constants.atom_order[k],:] for k in ["N","CA","C"]} + c = _np_get_cb(**atoms, use_jax=False) + else: + ca = positions[...,residue_constants.atom_order["CA"],:] + cb = positions[...,residue_constants.atom_order["CB"],:] + is_gly = seq==residue_constants.restype_order["G"] + c = np.where(is_gly[:,None],ca,cb) + dist = np.sqrt(np.square(c[None,:] - c[:,None]).sum(-1,keepdims=True)) + lower_breaks = np.linspace(min_bin, max_bin, num_bins) + lower_breaks = lower_breaks + upper_breaks = np.append(lower_breaks[1:],1e8) + return ((dist > lower_breaks) * (dist < upper_breaks)).astype(float) diff --git a/colabdesign/mpnn/model.py b/colabdesign/mpnn/model.py index b55768f8..c31e17bd 100644 --- a/colabdesign/mpnn/model.py +++ b/colabdesign/mpnn/model.py @@ -39,7 +39,7 @@ def __init__(self, model_name="v_48_020", 'dropout': dropout} self._model = RunModel(config) - self._model.params = jax.tree_map(np.array, checkpoint['model_state_dict']) + self._model.params = jax.tree_util.tree_map(np.array, checkpoint['model_state_dict']) self._setup() self.set_seed(seed) @@ -141,7 +141,7 @@ def sample(self, num=1, batch=1, temperature=0.1, rescore=False, **kwargs): O = [] for _ in range(num): O.append(self.sample_parallel(batch, temperature, rescore, **kwargs)) - return jax.tree_map(lambda *x:np.concatenate(x,0),*O) + return jax.tree_util.tree_map(lambda *x:np.concatenate(x,0),*O) def sample_parallel(self, batch=10, temperature=0.1, rescore=False, **kwargs): '''sample new sequence(s) in parallel''' @@ -152,7 +152,7 @@ def sample_parallel(self, batch=10, temperature=0.1, rescore=False, **kwargs): O = self._sample_parallel(keys, I, temperature, self._tied_lengths) if rescore: O = self._rescore_parallel(keys, I, O["S"], O["decoding_order"]) - O = jax.tree_map(np.array, O) + O = jax.tree_util.tree_map(np.array, O) # process outputs to human-readable form O.update(self._get_seq(O)) @@ -209,7 +209,7 @@ def score(self, seq=None, **kwargs): I["S"][p] = np.array([aa_order.get(aa,-1) for aa in seq]) I.update(kwargs) key = I.pop("key",self.key()) - O = jax.tree_map(np.array, self._score(**I, key=key)) + O = jax.tree_util.tree_map(np.array, self._score(**I, key=key)) O.update(self._get_score(I,O)) return O diff --git a/colabdesign/mpnn/weights/v_48_010_nomem.pkl b/colabdesign/mpnn/weights/v_48_010_nomem.pkl deleted file mode 100644 index e8add480..00000000 Binary files a/colabdesign/mpnn/weights/v_48_010_nomem.pkl and /dev/null differ diff --git a/colabdesign/mpnn/weights/v_48_020_nomem.pkl b/colabdesign/mpnn/weights/v_48_020_nomem.pkl deleted file mode 100644 index 2a8c17a5..00000000 Binary files a/colabdesign/mpnn/weights/v_48_020_nomem.pkl and /dev/null differ diff --git a/colabdesign/rf/utils.py b/colabdesign/rf/utils.py index 9b06f305..1b2b6306 100644 --- a/colabdesign/rf/utils.py +++ b/colabdesign/rf/utils.py @@ -203,16 +203,32 @@ def get_Ls(contigs): Ls.append(L) return Ls -def make_animation(pos, plddt=None, Ls=None, ref=0, line_w=2.0, dpi=100): +def nankabsch(a,b,**kwargs): + ok = np.isfinite(a).all(axis=1) & np.isfinite(b).all(axis=1) + a,b = a[ok],b[ok] + return _np_kabsch(a,b,**kwargs) + +def make_animation(pos, plddt=None, Ls=None, ref=0, line_w=2.0, + align_to_ref=False, verbose=False, dpi=100): if plddt is None: plddt = [None] * len(pos) # center inputs - pos = pos - pos[ref,None].mean(1,keepdims=True) + pos = pos - pos[ref].mean(0) # align to best view best_view = _np_kabsch(pos[ref], pos[ref], return_v=True, use_jax=False) - pos = np.asarray([p @ best_view for p in pos]) + if align_to_ref: + pos[ref] = pos[ref] @ best_view + # align to reference position + new_pos = [] + for p in pos: + p_mu = p.mean(0) + aln = _np_kabsch(p-p_mu, pos[ref], use_jax=False) + new_pos.append((p-p_mu) @ aln) + pos = np.asarray(new_pos) + else: + pos = np.asarray([p @ best_view for p in pos]) fig, (ax1) = plt.subplots(1) fig.set_figwidth(5) @@ -230,7 +246,8 @@ def make_animation(pos, plddt=None, Ls=None, ref=0, line_w=2.0, dpi=100): ax.axis(False) ims=[] - for pos_,plddt_ in zip(pos,plddt): + for k,(pos_,plddt_) in enumerate(zip(pos,plddt)): + text = f"{k}" if plddt_ is None: if Ls is None: img = plot_pseudo_3D(pos_, ax=ax1, line_w=line_w, zmin=z_min, zmax=z_max) @@ -239,7 +256,13 @@ def make_animation(pos, plddt=None, Ls=None, ref=0, line_w=2.0, dpi=100): img = plot_pseudo_3D(pos_, c=c, cmap=pymol_cmap, cmin=0, cmax=39, line_w=line_w, ax=ax1, zmin=z_min, zmax=z_max) else: img = plot_pseudo_3D(pos_, c=plddt_, cmin=50, cmax=90, line_w=line_w, ax=ax1, zmin=z_min, zmax=z_max) - ims.append([img]) + text += f" pLDDT={plddt_.mean():.1f}" + if verbose: + txt = plt.text(0.5, 1.01, text, horizontalalignment='center', + verticalalignment='bottom', transform=ax.transAxes) + ims.append([img,txt]) + else: + ims.append([img]) ani = animation.ArtistAnimation(fig, ims, blit=True, interval=120) plt.close() diff --git a/colabdesign/seq/kmeans.py b/colabdesign/seq/kmeans.py index 03b7c496..b86191e8 100644 --- a/colabdesign/seq/kmeans.py +++ b/colabdesign/seq/kmeans.py @@ -85,7 +85,7 @@ def check(x): if n_init > 0: out = jax.vmap(single_run)(jax.random.split(key,n_init)) i = out["inertia"].argmin() - out = jax.tree_map(lambda x:x[i],out) + out = jax.tree_util.tree_map(lambda x:x[i],out) else: out = single_run(key) @@ -130,4 +130,4 @@ def kmeans_sample(msa, msa_weights, k=1, samples=None, seed=0): "sampled_labels":sampled_labels, "sampled_msa":sampled_msa} - return jax.tree_map(lambda x:np.asarray(x),o) \ No newline at end of file + return jax.tree_util.tree_map(lambda x:np.asarray(x),o) \ No newline at end of file diff --git a/colabdesign/seq/utils.py b/colabdesign/seq/utils.py index 579638f1..bfa86aaa 100644 --- a/colabdesign/seq/utils.py +++ b/colabdesign/seq/utils.py @@ -24,7 +24,7 @@ def parse_fasta(filename, a3m=False, stop=100000): else: header.append(line[1:]) sequence.append([]) - else: + elif len(sequence) > 0: if a3m: line = line.translate(rm_lc) else: line = line.upper() sequence[-1].append(line) diff --git a/colabdesign/shared/model.py b/colabdesign/shared/model.py index 38019b82..f78dbd5a 100644 --- a/colabdesign/shared/model.py +++ b/colabdesign/shared/model.py @@ -24,7 +24,7 @@ def set_weights(self, *args, **kwargs): update_dict(self._opt["weights"], *args, **kwargs) update_dict(self.opt["weights"], *args, **kwargs) - def set_seq(self, seq=None, mode=None, bias=None, rm_aa=None, set_state=True, **kwargs): + def _set_seq(self, seq=None, mode=None, bias=None, rm_aa=None, return_values=False, **kwargs): ''' set sequence params and bias ----------------------------------- @@ -71,7 +71,7 @@ def set_seq(self, seq=None, mode=None, bias=None, rm_aa=None, set_state=True, ** wt_seq[self._wt_aatype == -1] = 0 if "pos" in self.opt and self.opt["pos"].shape[0] == wt_seq.shape[0]: seq = np.zeros(shape) - seq[...,self.opt["pos"],:] = wt_seq + seq[:,self.opt["pos"],:] = wt_seq else: seq = wt_seq @@ -88,8 +88,11 @@ def set_seq(self, seq=None, mode=None, bias=None, rm_aa=None, set_state=True, ** if isinstance(seq[0], str): aa_dict = copy_dict(aa_order) if shape[-1] > 21: + aa_dict["X"] = 20 # add unk character aa_dict["-"] = 21 # add gap character - seq = np.asarray([[aa_dict.get(aa,-1) for aa in s] for s in seq]) + seq = np.asarray([[aa_dict.get(aa,20) for aa in s] for s in seq]) + else: + seq = np.asarray([[aa_dict.get(aa,-1) for aa in s] for s in seq]) else: seq = np.asarray(seq) else: @@ -104,7 +107,9 @@ def set_seq(self, seq=None, mode=None, bias=None, rm_aa=None, set_state=True, ** b = b + seq * 1e7 if seq.ndim == 2: - x = np.pad(seq[None],[[0,shape[0]],[0,0],[0,0]]) + x = np.pad(seq[None],[[0,shape[0]-1],[0,0],[0,0]]) + elif shape[0] > seq.shape[0]: + x = np.pad(seq,[[0,shape[0]-seq.shape[0]],[0,0],[0,0]]) else: x = seq @@ -117,11 +122,14 @@ def set_seq(self, seq=None, mode=None, bias=None, rm_aa=None, set_state=True, ** else: y = x + y_gumbel - x = np.where(x.sum(-1,keepdims=True) == 1, x, y) + x = np.where(x.sum(-1,keepdims=True) == 1, x, y) - # set seq/bias/state - self._params["seq"] = x - self._inputs["bias"] = b + # set seq/bias + if return_values: + return (x,b) + else: + self._params["seq"] = x + self._inputs["bias"] = b def _norm_seq_grad(self): g = self.aux["grad"]["seq"] @@ -153,7 +161,7 @@ def set_optimizer(self, optimizer=None, learning_rate=None, norm_seq_grad=None, def update_grad(state, grad, params): updates, state = o.update(grad, state, params) - grad = jax.tree_map(lambda x:-x, updates) + grad = jax.tree_util.tree_map(lambda x:-x, updates) return state, grad self._optimizer = jax.jit(update_grad) diff --git a/colabdesign/shared/plot.py b/colabdesign/shared/plot.py index e0c31e6a..32dd3d08 100644 --- a/colabdesign/shared/plot.py +++ b/colabdesign/shared/plot.py @@ -316,7 +316,10 @@ def nankabsch(a,b,**kwargs): else: cmap = matplotlib.colors.ListedColormap(jalview_color_list[color_msa]) vmax = len(jalview_color_list[color_msa]) - 1 - ims[-1].append(ax2.imshow(seq[k].argmax(-1), animated=True, cmap=cmap, vmin=0, vmax=vmax, interpolation="none")) + msa_oh = seq[k][:,:,:20] + msa = msa_oh.argmax(-1).astype(float) + msa[msa_oh.sum(-1) == 0] = np.nan + ims[-1].append(ax2.imshow(msa, animated=True, cmap=cmap, vmin=0, vmax=vmax, interpolation="none")) if pae is not None: L = pae[k].shape[0] diff --git a/colabdesign/shared/prng.py b/colabdesign/shared/prng.py index b5747877..cb1eb996 100644 --- a/colabdesign/shared/prng.py +++ b/colabdesign/shared/prng.py @@ -21,7 +21,7 @@ def split(self, num_keys=2): self._assert_not_used() self._used = True new_keys = jax.random.split(self._key, num_keys) - return jax.tree_map(SafeKey, tuple(new_keys)) + return jax.tree_util.tree_map(SafeKey, tuple(new_keys)) def duplicate(self, num_keys=2): self._assert_not_used() diff --git a/colabdesign/shared/protein.py b/colabdesign/shared/protein.py index e75d4bbd..a11372e5 100644 --- a/colabdesign/shared/protein.py +++ b/colabdesign/shared/protein.py @@ -273,7 +273,7 @@ def _np_get_6D_binned(all_atom_positions, all_atom_mask, use_jax=None): ref = _np_get_6D(all_atom_positions, all_atom_mask, use_jax=False, for_trrosetta=True) - ref = jax.tree_map(jnp.squeeze,ref) + ref = jax.tree_util.tree_map(jnp.squeeze,ref) def mtx2bins(x_ref, start, end, nbins, mask): bins = np.linspace(start, end, nbins) diff --git a/colabdesign/shared/utils.py b/colabdesign/shared/utils.py index 31670ff8..4af37ddd 100644 --- a/colabdesign/shared/utils.py +++ b/colabdesign/shared/utils.py @@ -27,7 +27,7 @@ def set_dict(d, x, override=False): elif isinstance(d[k],(np.ndarray,jnp.ndarray)): d[k] = np.asarray(v) elif isinstance(d[k], dict): - d[k] = jax.tree_map(lambda x: type(x)(v), d[k]) + d[k] = jax.tree_util.tree_map(lambda x: type(x)(v), d[k]) else: d[k] = type(d[k])(v) else: @@ -41,7 +41,7 @@ def set_dict(d, x, override=False): def copy_dict(x): '''deepcopy dictionary''' - return jax.tree_map(lambda y:y, x) + return jax.tree_util.tree_map(lambda y:y, x) def to_float(x): '''convert to float''' diff --git a/colabdesign/tr/model.py b/colabdesign/tr/model.py index 8262c924..28d77b04 100644 --- a/colabdesign/tr/model.py +++ b/colabdesign/tr/model.py @@ -54,12 +54,12 @@ def _get_model(self): def _get_loss(inputs, outputs): opt = inputs["opt"] aux = {"outputs":outputs, "losses":{}} - log_p = jax.tree_map(jax.nn.log_softmax, outputs) + log_p = jax.tree_util.tree_map(jax.nn.log_softmax, outputs) # bkg loss if self.protocol in ["hallucination","partial"]: - p = jax.tree_map(jax.nn.softmax, outputs) - log_q = jax.tree_map(jax.nn.log_softmax, inputs["6D_bkg"]) + p = jax.tree_util.tree_map(jax.nn.softmax, outputs) + log_q = jax.tree_util.tree_map(jax.nn.log_softmax, inputs["6D_bkg"]) aux["losses"]["bkg"] = {} for k in ["dist","omega","theta","phi"]: aux["losses"]["bkg"][k] = -(p[k]*(log_p[k]-log_q[k])).sum(-1).mean() @@ -68,7 +68,7 @@ def _get_loss(inputs, outputs): if self.protocol in ["fixbb","partial"]: if "pos" in opt: pos = opt["pos"] - log_p = jax.tree_map(lambda x:x[:,pos][pos,:], log_p) + log_p = jax.tree_util.tree_map(lambda x:x[:,pos][pos,:], log_p) q = inputs["6D"] aux["losses"]["cce"] = {} @@ -80,7 +80,7 @@ def _get_loss(inputs, outputs): # weighted loss w = opt["weights"] - tree_multi = lambda x,y: jax.tree_map(lambda a,b:a*b, x,y) + tree_multi = lambda x,y: jax.tree_util.tree_map(lambda a,b:a*b, x,y) losses = {k:(tree_multi(v,w[k]) if k in w else v) for k,v in aux["losses"].items()} loss = sum(jax.tree_leaves(losses)) return loss, aux @@ -98,7 +98,7 @@ def _model(params, model_params, inputs, key): seq_ref = jax.nn.one_hot(inputs["batch"]["aatype"],20) p = opt["fix_pos"] fix_seq = lambda x:x.at[...,p,:].set(seq_ref[...,p,:]) - seq = jax.tree_map(fix_seq, seq) + seq = jax.tree_util.tree_map(fix_seq, seq) inputs.update({"seq":seq["pseudo"][0], "prf":jnp.where(opt["use_pssm"],seq["pssm"],seq["pseudo"])[0]}) @@ -129,7 +129,7 @@ def prep_inputs(self, pdb_filename=None, chain=None, length=None, self._pos_info = prep_pos(pos, **pdb["idx"]) p = self._pos_info["pos"] aatype = self._inputs["batch"]["aatype"] - self._inputs["batch"] = jax.tree_map(lambda x:x[p], self._inputs["batch"]) + self._inputs["batch"] = jax.tree_util.tree_map(lambda x:x[p], self._inputs["batch"]) self.opt["pos"] = p if "fix_pos" in self.opt: sub_i,sub_p = [],[] @@ -164,7 +164,7 @@ def prep_inputs(self, pdb_filename=None, chain=None, length=None, for n in range(1,6): p = os.path.join(self._data_dir,os.path.join("bkgr_models",f"bkgr0{n}.npy")) self._inputs["6D_bkg"].append(self._bkg_model(get_model_params(p), key, self._len)) - self._inputs["6D_bkg"] = jax.tree_map(lambda *x:np.stack(x).mean(0), *self._inputs["6D_bkg"]) + self._inputs["6D_bkg"] = jax.tree_util.tree_map(lambda *x:np.stack(x).mean(0), *self._inputs["6D_bkg"]) # reweight the background self.opt["weights"]["bkg"] = dict(dist=1/6,omega=1/6,phi=2/6,theta=2/6) @@ -190,6 +190,8 @@ def set_opt(self, *args, **kwargs): update_dict(self.opt, *args, **kwargs) + def set_seq(self, seq=None, mode=None, **kwargs): + self._set_seq(seq=seq, mode=mode, **kwargs) def restart(self, seed=None, opt=None, weights=None, seq=None, reset_opt=True, **kwargs): @@ -233,12 +235,12 @@ def run(self, backprop=True): (loss,aux),grad = self._model["grad_fn"](*flags) else: loss,aux = self._model["fn"](*flags) - grad = jax.tree_map(np.zeros_like, self._params) + grad = jax.tree_util.tree_map(np.zeros_like, self._params) aux.update({"loss":loss, "grad":grad}) aux_all.append(aux) # average results - self.aux = jax.tree_map(lambda *x:np.stack(x).mean(0), *aux_all) + self.aux = jax.tree_util.tree_map(lambda *x:np.stack(x).mean(0), *aux_all) self.aux["model_num"] = model_num @@ -252,7 +254,7 @@ def step(self, backprop=True, callback=None, save_best=True, verbose=1): # apply gradients lr = self.opt["learning_rate"] - self._params = jax.tree_map(lambda x,g:x-lr*g, self._params, self.aux["grad"]) + self._params = jax.tree_util.tree_map(lambda x,g:x-lr*g, self._params, self.aux["grad"]) # increment self._k += 1 @@ -292,7 +294,7 @@ def plot(self, mode="preds", dpi=100, get_best=True): elif mode == "bkg_feats": x = self._inputs["6D_bkg"] - x = jax.tree_map(np.asarray, x) + x = jax.tree_util.tree_map(np.asarray, x) plt.figure(figsize=(4*4,4), dpi=dpi) for n,k in enumerate(["theta","phi","dist","omega"]): @@ -308,7 +310,7 @@ def get_loss(self, k=None, get_best=True): return {k:self.get_loss(k, get_best=get_best) for k in aux["losses"].keys()} losses = aux["losses"][k] weights = aux["opt"]["weights"][k] - weighted_losses = jax.tree_map(lambda l,w:l*w, losses, weights) + weighted_losses = jax.tree_util.tree_map(lambda l,w:l*w, losses, weights) return float(sum(jax.tree_leaves(weighted_losses))) def af_callback(self, weight=1.0, seed=None): diff --git a/colabdesign/tr/trrosetta.py b/colabdesign/tr/trrosetta.py index 8a27b8da..3b631792 100644 --- a/colabdesign/tr/trrosetta.py +++ b/colabdesign/tr/trrosetta.py @@ -57,7 +57,7 @@ def block(x, params, dilation, key, rate=0.15): y = x for n in [0,1]: if n == 1: y = dropout(y, key, rate) - p = jax.tree_map(lambda x:x[n], params) + p = jax.tree_util.tree_map(lambda x:x[n], params) y = conv_2D(y, p, dilation) y = instance_norm(y, p) y = jax.nn.elu(y if n == 0 else (x+y)) @@ -68,7 +68,7 @@ def body(prev, sub_params): (x,key) = prev for n, dilation in enumerate([1,2,4,8,16]): key, sub_key = jax.random.split(key) - p = jax.tree_map(lambda x:x[n], sub_params) + p = jax.tree_util.tree_map(lambda x:x[n], sub_params) x = block(x, p, dilation, sub_key, rate) return (x,key), None return jax.lax.scan(body,(x,key),params)[0][0] @@ -111,5 +111,5 @@ def split(params): steps = min(len(params),len(labels)) return {labels[n]:np.squeeze(params[n::steps]) for n in range(steps)} params = {k:split(xaa[i:i+n]) for k,i,n in zip(layers,idx,num)} - params["resnet"] = jax.tree_map(lambda x:x.reshape(-1,5,2,*x.shape[1:]), params["resnet"]) + params["resnet"] = jax.tree_util.tree_map(lambda x:x.reshape(-1,5,2,*x.shape[1:]), params["resnet"]) return params \ No newline at end of file diff --git a/mpnn/examples/proteinmpnn_in_jax.ipynb b/mpnn/examples/proteinmpnn_in_jax.ipynb index 2966d179..5383fbe6 100644 --- a/mpnn/examples/proteinmpnn_in_jax.ipynb +++ b/mpnn/examples/proteinmpnn_in_jax.ipynb @@ -58,7 +58,7 @@ " import colabdesign\n", "except:\n", " os.system(\"pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.1\")\n", - " os.system(\"ln -s /usr/local/lib/python3.7/dist-packages/colabdesign colabdesign\")\n", + " os.system(\"ln -s /usr/local/lib/python3.*/dist-packages/colabdesign colabdesign\")\n", "\n", "from colabdesign.mpnn import mk_mpnn_model, clear_mem\n", "from colabdesign.shared.protein import pdb_to_string\n", diff --git a/setup.py b/setup.py index 5f1d80de..a070ad6e 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ from setuptools import setup, find_packages setup( name='colabdesign', - version='1.1.1', + version='1.2.0-beta.0', description='Making Protein Design accessible to all via Google Colab!', long_description="Making Protein Design accessible to all via Google Colab!", long_description_content_type='text/markdown',