Skip to content

Commit

Permalink
Merge pull request #331 from mims-harvard/geneformer_server
Browse files Browse the repository at this point in the history
Geneformer server -- extract cell type specific gene embeddings
  • Loading branch information
amva13 authored Nov 12, 2024
2 parents 2daa6ca + f77297b commit e379e0d
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion tdc/test/test_model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,14 @@ def get_ensembl_id_from_chembl_id(chembl_id):
return str(e)


def quant_layers(model):
layer_nums = []
for name, parameter in model.named_parameters():
if "layer" in name:
layer_nums += [int(name.split("layer.")[1].split(".")[0])]
return int(max(layer_nums)) + 1


class TestModelServer(unittest.TestCase):

def setUp(self):
Expand Down Expand Up @@ -146,7 +154,16 @@ def testGeneformerTokenizer(self):
# build an attention mask
attention_mask = torch.tensor(
[[x[0] != 0, x[1] != 0] for x in batch])
out.append(model(batch, attention_mask=attention_mask))
outputs = model(batch,
attention_mask=attention_mask,
output_hidden_states=True)
layer_to_quant = quant_layers(model) + (
-1
) # TODO note this can be parametrized to either 0 (extract last embedding layer) or -1 (second-to-last which is more generalized)
embs_i = outputs.hidden_states[layer_to_quant]
# there are "cls", "cell", and "gene" embeddings. we will only capture "gene", which is cell type specific. for "cell", you'd average out across unmasked gene embeddings per cell
embs = embs_i
out.append(embs)
if ctr == 2:
break
ctr += 1
Expand All @@ -159,6 +176,9 @@ def testGeneformerTokenizer(self):
out
) == 3, "length not matching ctr+1: {} vs {}. output was \n {}".format(
len(out), ctr + 1, out)
print(
"Geneformer ran sucessfully. Find batch embedding example here:\n {}"
.format(out[0]))

def tearDown(self):
try:
Expand Down

0 comments on commit e379e0d

Please sign in to comment.