Skip to content

Commit a9007a7

Browse files
authored
Merge branch 'pgvector:master' into master
2 parents f2cd351 + 82c976f commit a9007a7

File tree

3 files changed

+38
-6
lines changed

3 files changed

+38
-6
lines changed

pgvector/django/__init__.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,19 @@
44
from .indexes import IvfflatIndex, HnswIndex
55
from .sparsevec import SparsevecField
66
from .vector import VectorField
7-
from ..utils import SparseVec
7+
from ..utils import HalfVec, SparseVec
88

9-
__all__ = ['VectorExtension', 'VectorField', 'HalfvecField', 'SparsevecField', 'IvfflatIndex', 'HnswIndex', 'L2Distance', 'MaxInnerProduct', 'CosineDistance', 'L1Distance']
9+
__all__ = [
10+
'VectorExtension',
11+
'VectorField',
12+
'HalfvecField',
13+
'SparsevecField',
14+
'IvfflatIndex',
15+
'HnswIndex',
16+
'L2Distance',
17+
'MaxInnerProduct',
18+
'CosineDistance',
19+
'L1Distance',
20+
'HalfVec',
21+
'SparseVec'
22+
]

pgvector/django/functions.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
from django.db.models import FloatField, Func, Value
2-
from ..utils import Vector
2+
from ..utils import Vector, HalfVec, SparseVec
33

44

55
class DistanceBase(Func):
66
output_field = FloatField()
77

88
def __init__(self, expression, vector, **extra):
99
if not hasattr(vector, 'resolve_expression'):
10-
vector = Value(Vector.to_db(vector))
10+
if isinstance(vector, HalfVec):
11+
vector = Value(HalfVec.to_db(vector))
12+
elif isinstance(vector, SparseVec):
13+
vector = Value(SparseVec.to_db(vector))
14+
else:
15+
vector = Value(Vector.to_db(vector))
1116
super().__init__(expression, vector, **extra)
1217

1318

tests/test_django.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from math import sqrt
99
import numpy as np
1010
import pgvector.django
11-
from pgvector.django import VectorExtension, VectorField, HalfvecField, SparsevecField, IvfflatIndex, HnswIndex, L2Distance, MaxInnerProduct, CosineDistance, L1Distance, SparseVec
11+
from pgvector.django import VectorExtension, VectorField, HalfvecField, SparsevecField, IvfflatIndex, HnswIndex, L2Distance, MaxInnerProduct, CosineDistance, L1Distance, HalfVec, SparseVec
1212
from unittest import mock
1313

1414
settings.configure(
@@ -92,7 +92,7 @@ def create_items():
9292
[1, 1, 2]
9393
]
9494
for i, v in enumerate(vectors):
95-
item = Item(id=i + 1, embedding=v)
95+
item = Item(id=i + 1, embedding=v, half_embedding=v, sparse_embedding=SparseVec.from_dense(v))
9696
item.save()
9797

9898

@@ -156,6 +156,20 @@ def test_l1_distance(self):
156156
assert [v.id for v in items] == [1, 3, 2]
157157
assert [v.distance for v in items] == [0, 1, 3]
158158

159+
def test_halfvec_l2_distance(self):
160+
create_items()
161+
distance = L2Distance('half_embedding', HalfVec([1, 1, 1]))
162+
items = Item.objects.annotate(distance=distance).order_by(distance)
163+
assert [v.id for v in items] == [1, 3, 2]
164+
assert [v.distance for v in items] == [0, 1, sqrt(3)]
165+
166+
def test_sparsevec_l2_distance(self):
167+
create_items()
168+
distance = L2Distance('sparse_embedding', SparseVec.from_dense([1, 1, 1]))
169+
items = Item.objects.annotate(distance=distance).order_by(distance)
170+
assert [v.id for v in items] == [1, 3, 2]
171+
assert [v.distance for v in items] == [0, 1, sqrt(3)]
172+
159173
def test_filter(self):
160174
create_items()
161175
distance = L2Distance('embedding', [1, 1, 1])

0 commit comments

Comments
 (0)