What's a fast, portable way of retrieving embeddings?
import hashlib
import json
import numpy as np
import os
import pickle
import sqlite3
n = 10000
c = 1000
data = []
for _ in range(n):
text = "".join(np.random.choice(list("abcdefghijklmnopqrstuvwxyz0123456789"), 100))
data.append({
"text": text,
"hash": hashlib.sha256(text.encode()).hexdigest(),
"embeddings": np.random.rand(1000).astype(np.float32)
})
.tobytes()
and .frombuffer()
are the fastest (1.9 ms)blobs = [d["embeddings"].tobytes() for d in data]
%timeit [np.frombuffer(p) for p in blobs]
1.92 ms ± 21.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
blobs = [pickle.dumps(d["embeddings"]) for d in data]
%timeit [pickle.loads(p) for p in blobs]
32.3 ms ± 409 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
blobs = [pickle.dumps(d["embeddings"].tolist()) for d in data]
%timeit [pickle.loads(p) for p in blobs]
268 ms ± 3.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
blobs = [json.dumps(d["embeddings"].tolist()) for d in data]
%timeit [np.array(json.loads(p)) for p in blobs]
2.38 s ± 17.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
conn = sqlite3.connect(".embeddings.db")
cur = conn.cursor()
cur.execute("DROP TABLE IF EXISTS embeddings")
cur.execute("CREATE TABLE embeddings (text TEXT, data BLOB)")
cur.executemany(
"INSERT INTO embeddings (text, data) VALUES (?, ?)",
[(d["text"], d["embeddings"].tobytes()) for d in data],
)
cur.execute("CREATE INDEX idx_embeddings_text ON embeddings (text)")
conn.commit()
conn.close()
conn = sqlite3.connect(".embeddings.db")
cur = conn.cursor()
texts = [data[np.random.randint(0, n)]["text"] for _ in range(c)]
# Retrieve embeddings one by one
%timeit embeddings = [np.frombuffer(cur.execute("SELECT data FROM embeddings WHERE text = ?", (text,)).fetchone()[0]) for text in texts]
# Retrieve all embeddings in a single query
%timeit embeddings = [np.frombuffer(row[0]) for row in cur.execute("SELECT data FROM embeddings WHERE text IN ({})".format(",".join(["?"]*c)), texts)]
conn.close()
33.8 ms ± 298 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 8.63 ms ± 172 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
if not os.path.exists(".embeddings"):
os.makedirs(".embeddings")
for row in data:
path = os.path.join(".embeddings", row["hash"])
with open(path, "wb") as f:
pickle.dump(row["embeddings"], f)
texts = [data[np.random.randint(0, n)]["text"] for _ in range(c)]
%timeit embeddings = [pickle.load(open(os.path.join(".embeddings", hashlib.sha256(text.encode()).hexdigest()), "rb")) for text in texts]
46.1 ms ± 493 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)