Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdk/python/feast/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class VersionedOnlineReadNotSupported(FeastError):
def __init__(self, store_name: str, version: int):
super().__init__(
f"Versioned feature reads (@v{version}) are not yet supported by {store_name}. "
f"Currently only SQLite supports version-qualified feature references. "
f"Currently only SQLite and FAISS support version-qualified feature references. "
)


Expand Down
127 changes: 84 additions & 43 deletions sdk/python/feast/infra/online_stores/faiss_online_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,31 @@ def teardown(self):
self.entity_keys = {}


def _table_id(project: str, table: FeatureView, enable_versioning: bool = False) -> str:
"""Compute the table key, including version suffix when versioning is enabled."""
name = table.name
if enable_versioning:
# Prefer version_tag from the projection (set by version-qualified refs like @v2)
# over current_version_number (the FV's active version in metadata).
version = getattr(table.projection, "version_tag", None)
if version is None:
version = getattr(table, "current_version_number", None)
if version is not None and version > 0:
name = f"{table.name}_v{version}"
return f"{project}_{name}"


class FaissOnlineStore(OnlineStore):
_index: Optional[faiss.IndexIVFFlat] = None
_in_memory_store: InMemoryStore = InMemoryStore()
_config: Optional[FaissOnlineStoreConfig] = None
_logger: logging.Logger = logging.getLogger(__name__)

def _get_index(self, config: RepoConfig) -> faiss.IndexIVFFlat:
if self._index is None or self._config is None:
raise ValueError("Index is not initialized")
return self._index
def __init__(self):
super().__init__()
self._indices: Dict[str, faiss.IndexIVFFlat] = {}
self._in_memory_stores: Dict[str, InMemoryStore] = {}
self._config: Optional[FaissOnlineStoreConfig] = None

def _get_index(self, table_key: str) -> Optional[faiss.IndexIVFFlat]:
return self._indices.get(table_key)

def update(
self,
Expand All @@ -63,32 +78,45 @@ def update(
entities_to_keep: Sequence[Entity],
partial: bool,
):
feature_views = tables_to_keep
if not feature_views:
return

feature_names = [f.name for f in feature_views[0].features]
dimension = len(feature_names)

self._config = FaissOnlineStoreConfig(**config.online_store.dict())
if self._index is None or not partial:
quantizer = faiss.IndexFlatL2(dimension)
self._index = faiss.IndexIVFFlat(quantizer, dimension, self._config.nlist)
self._index.train(
np.random.rand(self._config.nlist * 100, dimension).astype(np.float32)
)
self._in_memory_store = InMemoryStore()
versioning = config.registry.enable_online_feature_view_versioning

for table in tables_to_delete:
table_key = _table_id(config.project, table, versioning)
self._indices.pop(table_key, None)
self._in_memory_stores.pop(table_key, None)

for table in tables_to_keep:
table_key = _table_id(config.project, table, versioning)
feature_names = [f.name for f in table.features]
dimension = len(feature_names)

if table_key not in self._indices or not partial:
quantizer = faiss.IndexFlatL2(dimension)
index = faiss.IndexIVFFlat(quantizer, dimension, self._config.nlist)
index.train(
np.random.rand(self._config.nlist * 100, dimension).astype(
np.float32
)
)
self._indices[table_key] = index
self._in_memory_stores[table_key] = InMemoryStore()

self._in_memory_store.update(feature_names, {})
self._in_memory_stores[table_key].update(feature_names, {})

def teardown(
self,
config: RepoConfig,
tables: Sequence[FeatureView],
entities: Sequence[Entity],
):
self._index = None
self._in_memory_store.teardown()
versioning = config.registry.enable_online_feature_view_versioning
for table in tables:
table_key = _table_id(config.project, table, versioning)
self._indices.pop(table_key, None)
store = self._in_memory_stores.pop(table_key, None)
if store is not None:
store.teardown()

def online_read(
self,
Expand All @@ -97,23 +125,28 @@ def online_read(
entity_keys: List[EntityKeyProto],
requested_features: Optional[List[str]] = None,
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
if self._index is None:
versioning = config.registry.enable_online_feature_view_versioning
table_key = _table_id(config.project, table, versioning)
index = self._get_index(table_key)
in_memory_store = self._in_memory_stores.get(table_key)

if index is None or in_memory_store is None:
return [(None, None)] * len(entity_keys)

results: List[Tuple[Optional[datetime], Optional[Dict[str, Any]]]] = []
for entity_key in entity_keys:
serialized_key = serialize_entity_key(
entity_key, config.entity_key_serialization_version
).hex()
idx = self._in_memory_store.entity_keys.get(serialized_key, -1)
idx = in_memory_store.entity_keys.get(serialized_key, -1)
if idx == -1:
results.append((None, None))
else:
feature_vector = self._index.reconstruct(int(idx))
feature_vector = index.reconstruct(int(idx))
feature_dict = {
name: ValueProto(double_val=value)
for name, value in zip(
self._in_memory_store.feature_names, feature_vector
in_memory_store.feature_names, feature_vector
)
}
results.append((None, feature_dict))
Expand All @@ -128,8 +161,16 @@ def online_write_batch(
],
progress: Optional[Callable[[int], Any]],
) -> None:
if self._index is None:
self._logger.warning("Index is not initialized. Skipping write operation.")
versioning = config.registry.enable_online_feature_view_versioning
table_key = _table_id(config.project, table, versioning)
index = self._get_index(table_key)
in_memory_store = self._in_memory_stores.get(table_key)

if index is None or in_memory_store is None:
self._logger.warning(
"Index for table '%s' is not initialized. Skipping write operation.",
table_key,
)
return

feature_vectors = []
Expand All @@ -142,7 +183,7 @@ def online_write_batch(
feature_vector = np.array(
[
feature_dict[name].double_val
for name in self._in_memory_store.feature_names
for name in in_memory_store.feature_names
],
dtype=np.float32,
)
Expand All @@ -153,21 +194,17 @@ def online_write_batch(
feature_vectors_array = np.array(feature_vectors)

existing_indices = [
self._in_memory_store.entity_keys.get(sk, -1) for sk in serialized_keys
in_memory_store.entity_keys.get(sk, -1) for sk in serialized_keys
]
mask = np.array(existing_indices) != -1
if np.any(mask):
self._index.remove_ids(
np.array([idx for idx in existing_indices if idx != -1])
)
index.remove_ids(np.array([idx for idx in existing_indices if idx != -1]))

new_indices = np.arange(
self._index.ntotal, self._index.ntotal + len(feature_vectors_array)
)
self._index.add(feature_vectors_array)
new_indices = np.arange(index.ntotal, index.ntotal + len(feature_vectors_array))
index.add(feature_vectors_array)

for sk, idx in zip(serialized_keys, new_indices):
self._in_memory_store.entity_keys[sk] = idx
in_memory_store.entity_keys[sk] = idx

if progress:
progress(len(data))
Expand All @@ -189,12 +226,16 @@ def retrieve_online_documents(
Optional[ValueProto],
]
]:
if self._index is None:
versioning = config.registry.enable_online_feature_view_versioning
table_key = _table_id(config.project, table, versioning)
index = self._get_index(table_key)

if index is None:
self._logger.warning("Index is not initialized. Returning empty result.")
return []

query_vector = np.array(embedding, dtype=np.float32).reshape(1, -1)
distances, indices = self._index.search(query_vector, top_k)
distances, indices = index.search(query_vector, top_k)

results: List[
Tuple[
Expand All @@ -209,7 +250,7 @@ def retrieve_online_documents(
if idx == -1:
continue

feature_vector = self._index.reconstruct(int(idx))
feature_vector = index.reconstruct(int(idx))

timestamp = Timestamp()
timestamp.GetCurrentTime()
Expand Down
9 changes: 8 additions & 1 deletion sdk/python/feast/infra/online_stores/online_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,16 @@ def get_online_features(

def _check_versioned_read_support(self, grouped_refs):
"""Raise an error if versioned reads are attempted on unsupported stores."""
try:
from feast.infra.online_stores.faiss_online_store import FaissOnlineStore
except ImportError:
FaissOnlineStore = None
from feast.infra.online_stores.sqlite import SqliteOnlineStore

if isinstance(self, SqliteOnlineStore):
supported = [SqliteOnlineStore]
if FaissOnlineStore is not None:
supported.append(FaissOnlineStore)
if isinstance(self, tuple(supported)):
return
for table, _ in grouped_refs:
version_tag = getattr(table.projection, "version_tag", None)
Expand Down
Loading