Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion src/llama_stack/core/routers/vector_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# the root directory of this source tree.

import asyncio
import re
import uuid
from typing import Annotated, Any

Expand Down Expand Up @@ -44,6 +45,17 @@
logger = get_logger(name=__name__, category="core::routers")


def validate_collection_name(collection_name: str) -> None:
if not collection_name:
raise ValueError("collection_name cannot be empty")

if not re.match(r"^[a-zA-Z0-9_-]+$", collection_name):
raise ValueError(
f"collection_name '{collection_name}' contains invalid characters. "
"Only alphanumeric characters, hyphens, and underscores are allowed."
)


class VectorIORouter(VectorIO):
"""Routes to an provider based on the vector db identifier"""

Expand Down Expand Up @@ -160,22 +172,40 @@ async def openai_create_vector_store(
else:
provider_id = list(self.routing_table.impls_by_provider_id.keys())[0]

# Extract and validate collection_name if provided
collection_name = extra.get("collection_name")
if collection_name:
validate_collection_name(collection_name)
provider_vector_store_id = collection_name
logger.debug(f"Using custom collection name: {collection_name}")
else:
# Fall back to auto-generated UUID for backward compatibility
provider_vector_store_id = f"vs_{uuid.uuid4()}"

# Always generate a unique vector_store_id for internal routing
vector_store_id = f"vs_{uuid.uuid4()}"

registered_vector_store = await self.routing_table.register_vector_store(
vector_store_id=vector_store_id,
embedding_model=embedding_model,
embedding_dimension=embedding_dimension,
provider_id=provider_id,
provider_vector_store_id=vector_store_id,
provider_vector_store_id=provider_vector_store_id,
vector_store_name=params.name,
)
provider = await self.routing_table.get_provider_impl(registered_vector_store.identifier)

# Update model_extra with registered values so provider uses the already-registered vector_store
if params.model_extra is None:
params.model_extra = {}
params.model_extra["vector_store_id"] = vector_store_id # Pass canonical UUID to Provider
params.model_extra["provider_vector_store_id"] = registered_vector_store.provider_resource_id
params.model_extra["provider_id"] = registered_vector_store.provider_id

# Add collection_name to metadata so users can see what was used
if params.metadata is None:
params.metadata = {}
params.metadata["provider_vector_store_id"] = provider_vector_store_id
if embedding_model is not None:
params.model_extra["embedding_model"] = embedding_model
if embedding_dimension is not None:
Expand Down
18 changes: 15 additions & 3 deletions src/llama_stack/providers/inline/vector_io/faiss/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,11 @@ async def initialize(self) -> None:
vector_store = VectorStore.model_validate_json(vector_store_data)
index = VectorStoreWithIndex(
vector_store,
await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
await FaissIndex.create(
vector_store.embedding_dimension,
self.kvstore,
vector_store.provider_resource_id or vector_store.identifier,
),
self.inference_api,
)
self.cache[vector_store.identifier] = index
Expand Down Expand Up @@ -239,7 +243,11 @@ async def register_vector_store(self, vector_store: VectorStore) -> None:
# Store in cache
self.cache[vector_store.identifier] = VectorStoreWithIndex(
vector_store=vector_store,
index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
index=await FaissIndex.create(
vector_store.embedding_dimension,
self.kvstore,
vector_store.provider_resource_id or vector_store.identifier,
),
inference_api=self.inference_api,
)

Expand Down Expand Up @@ -272,7 +280,11 @@ async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> Vecto
vector_store = VectorStore.model_validate_json(vector_store_data)
index = VectorStoreWithIndex(
vector_store=vector_store,
index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
index=await FaissIndex.create(
vector_store.embedding_dimension,
self.kvstore,
vector_store.provider_resource_id or vector_store.identifier,
),
inference_api=self.inference_api,
)
self.cache[vector_store_id] = index
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,9 @@ async def initialize(self) -> None:
for db_json in stored_vector_stores:
vector_store = VectorStore.model_validate_json(db_json)
index = await SQLiteVecIndex.create(
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
vector_store.embedding_dimension,
self.config.db_path,
vector_store.provider_resource_id or vector_store.identifier,
)
self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api)

Expand All @@ -425,7 +427,9 @@ async def register_vector_store(self, vector_store: VectorStore) -> None:

# Create and cache the index
index = await SQLiteVecIndex.create(
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
vector_store.embedding_dimension,
self.config.db_path,
vector_store.provider_resource_id or vector_store.identifier,
)
self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api)

Expand All @@ -448,7 +452,7 @@ async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> Vecto
index=SQLiteVecIndex(
dimension=vector_store.embedding_dimension,
db_path=self.config.db_path,
bank_id=vector_store.identifier,
bank_id=vector_store.provider_resource_id or vector_store.identifier,
kvstore=self.kvstore,
),
inference_api=self.inference_api,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,11 @@ async def openai_create_vector_store(
extra_body = params.model_extra or {}
metadata = params.metadata or {}

provider_vector_store_id = extra_body.get("provider_vector_store_id")
# Get the canonical UUID from router (or generate if called directly without router)
vector_store_id = extra_body.get("vector_store_id") or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}")

# Get the physical storage name (custom collection name or fallback to UUID)
provider_vector_store_id = extra_body.get("provider_vector_store_id") or vector_store_id

# Use embedding info from metadata if available, otherwise from extra_body
if metadata.get("embedding_model"):
Expand All @@ -381,8 +385,6 @@ async def openai_create_vector_store(

# use provider_id set by router; fallback to provider's own ID when used directly via --stack-config
provider_id = extra_body.get("provider_id") or getattr(self, "__provider_id__", None)
# Derive the canonical vector_store_id (allow override, else generate)
vector_store_id = provider_vector_store_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}")

if embedding_model is None:
raise ValueError("embedding_model is required")
Expand All @@ -396,11 +398,11 @@ async def openai_create_vector_store(

# call to the provider to create any index, etc.
vector_store = VectorStore(
identifier=vector_store_id,
identifier=vector_store_id, # Canonical UUID for routing
embedding_dimension=embedding_dimension,
embedding_model=embedding_model,
provider_id=provider_id,
provider_resource_id=vector_store_id,
provider_resource_id=provider_vector_store_id, # Physical storage name (custom or UUID)
vector_store_name=params.name,
)
await self.register_vector_store(vector_store)
Expand Down
80 changes: 80 additions & 0 deletions tests/integration/vector_io/test_openai_vector_stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -1698,3 +1698,83 @@ def get_field(obj, field):
assert with_flags_embedding is not None, "Embeddings should be included when include_embeddings=True"
assert len(with_flags_embedding) > 0, "Embedding should be a non-empty list"
assert without_flags_embedding is None, "Embeddings should not be included when include_embeddings=False"


@vector_provider_wrapper
def test_openai_vector_store_custom_collection_name(
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
):
"""Test creating a vector store with a custom collection name."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
client = compat_client_with_empty_stores

# Create vector store with custom collection name
vector_store = client.vector_stores.create(
name="Test Custom Collection",
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
"collection_name": "my_custom_collection",
},
)

assert vector_store is not None
assert vector_store.id.startswith("vs_")
assert "provider_vector_store_id" in vector_store.metadata
assert vector_store.metadata["provider_vector_store_id"] == "my_custom_collection"


@vector_provider_wrapper
def test_openai_vector_store_collection_name_validation(
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
):
"""Test that invalid collection names are rejected."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
client = compat_client_with_empty_stores

# Test invalid collection names
invalid_names = ["with spaces", "with/slashes", "with@special", ""]

for invalid_name in invalid_names:
with pytest.raises((BadRequestError, ValueError)):
client.vector_stores.create(
name="Test Invalid",
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
"collection_name": invalid_name,
},
)


@vector_provider_wrapper
def test_openai_vector_store_collection_name_with_data(
compat_client_with_empty_stores, client_with_models, sample_chunks, embedding_model_id, embedding_dimension, vector_io_provider_id
):
"""Test that custom collection names work with data insertion and search."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
compat_client = compat_client_with_empty_stores
llama_client = client_with_models

# Create vector store with custom collection name
vector_store = compat_client.vector_stores.create(
name="Data Test Collection",
extra_body={
"embedding_model": embedding_model_id,
"provider_id": vector_io_provider_id,
"collection_name": "test_data_collection",
},
)

# Insert and search data
llama_client.vector_io.insert(vector_store_id=vector_store.id, chunks=sample_chunks[:2])

search_response = compat_client.vector_stores.search(
vector_store_id=vector_store.id,
query="What is Python?",
max_num_results=2,
)

assert search_response is not None
assert len(search_response.data) > 0
assert search_response.data[0].attributes["document_id"] == "doc1"
Loading