diff --git a/examples/Embedding_long_inputs.ipynb b/examples/Embedding_long_inputs.ipynb index dcc9d07..3e02f6b 100644 --- a/examples/Embedding_long_inputs.ipynb +++ b/examples/Embedding_long_inputs.ipynb @@ -200,11 +200,13 @@ "\n", "def len_safe_get_embedding(text, model=EMBEDDING_MODEL, max_tokens=EMBEDDING_CTX_LENGTH, encoding_name=EMBEDDING_ENCODING, average=True):\n", " chunk_embeddings = []\n", + " chunk_lens = []\n", " for chunk in chunked_tokens(text, encoding_name=encoding_name, chunk_length=max_tokens):\n", " chunk_embeddings.append(get_embedding(chunk, model=model))\n", + " chunk_lens.append(len(chunk))\n", "\n", " if average:\n", - " chunk_embeddings = np.average(chunk_embeddings, axis=0, weights=[len(c) for c in chunk_embeddings])\n", + " chunk_embeddings = np.average(chunk_embeddings, axis=0, weights=chunk_lens)\n", " chunk_embeddings = chunk_embeddings / np.linalg.norm(chunk_embeddings) # normalizes length to 1\n", " chunk_embeddings = chunk_embeddings.tolist()\n", " return chunk_embeddings"