Custom Backends¶
Add support for MongoDB, Elasticsearch, or custom data stores by implementing the Rollbackable protocol.
Overview¶
VenomQA uses adapters to communicate with external systems (databases, caches, message queues). Each adapter implements checkpoint/rollback semantics so exploration can branch cleanly.
The Rollbackable Protocol¶
All system adapters must implement the Rollbackable protocol:
Python
from typing import Protocol, runtime_checkable, Any
from venomqa.sandbox import Observation
SystemCheckpoint = Any
@runtime_checkable
class Rollbackable(Protocol):
"""Protocol for systems that can checkpoint and rollback."""
def checkpoint(self, name: str) -> SystemCheckpoint:
"""Save current state and return checkpoint data."""
...
def rollback(self, checkpoint: SystemCheckpoint) -> None:
"""Restore state to a previous checkpoint."""
...
def observe(self) -> Observation:
"""Get current state as an Observation."""
...
Minimal Implementation¶
A simple in-memory adapter:
Python
import copy
from venomqa.sandbox import Observation
class InMemoryAdapter:
"""Simple in-memory adapter for testing."""
def __init__(self, initial_data: dict | None = None):
self._data = initial_data or {}
self._snapshots: dict[str, dict] = {}
def checkpoint(self, name: str) -> str:
snapshot = copy.deepcopy(self._data)
self._snapshots[name] = snapshot
return name
def rollback(self, checkpoint: str) -> None:
self._data = copy.deepcopy(self._snapshots[checkpoint])
def observe(self) -> Observation:
return Observation.create(
system="memory",
data={"keys": list(self._data.keys()), "count": len(self._data)},
)
# Custom methods for your application
def get(self, key: str) -> Any:
return self._data.get(key)
def set(self, key: str, value: Any) -> None:
self._data[key] = value
def delete(self, key: str) -> None:
self._data.pop(key, None)
Using Custom Adapters¶
Register your adapter with World:
Python
from venomqa import Agent, World, Action, Invariant, Severity
from myapp.adapters import InMemoryAdapter
adapter = InMemoryAdapter(initial_data={"users": []})
world = World(
api=api,
systems={"storage": adapter},
)
agent = Agent(
world=world,
actions=actions,
invariants=invariants,
)
result = agent.explore()
Access the adapter in actions:
Python
def create_user(api, context):
storage = context.world.systems["storage"]
user_id = str(uuid.uuid4())
storage.set(f"user:{user_id}", {"id": user_id, "created": datetime.utcnow()})
context.set("user_id", user_id)
return api.post("/users", json={"id": user_id})
Example: MongoDB Adapter¶
Full implementation for MongoDB:
Python
from __future__ import annotations
import copy
from datetime import datetime
from typing import Any
from venomqa.sandbox import Observation
class MongoDBAdapter:
"""MongoDB adapter with checkpoint/rollback via collection snapshots.
This adapter creates in-memory snapshots of collections for rollback.
Suitable for small-to-medium datasets. For large datasets, consider
using MongoDB transactions or database clones.
Example:
from pymongo import MongoClient
client = MongoClient("mongodb://localhost:27017")
db = client["testdb"]
adapter = MongoDBAdapter(
db=db,
collections=["users", "orders", "products"],
)
world = World(api=api, systems={"mongo": adapter})
"""
def __init__(
self,
db,
collections: list[str],
checkpoint_mode: str = "memory",
):
"""Initialize MongoDB adapter.
Args:
db: PyMongo database instance.
collections: List of collection names to track.
checkpoint_mode: "memory" for in-memory snapshots,
"drop_restore" for temp collections (slower, less memory).
"""
self.db = db
self.collections = collections
self.checkpoint_mode = checkpoint_mode
self._snapshots: dict[str, dict[str, list[dict]]] = {}
self._counter = 0
def checkpoint(self, name: str) -> str:
"""Create a snapshot of tracked collections."""
self._counter += 1
checkpoint_id = f"{name}_{self._counter}"
if self.checkpoint_mode == "memory":
snapshot = {}
for coll_name in self.collections:
snapshot[coll_name] = list(self.db[coll_name].find())
self._snapshots[checkpoint_id] = snapshot
else: # drop_restore mode
for coll_name in self.collections:
temp_name = f"_temp_{checkpoint_id}_{coll_name}"
self.db[coll_name].aggregate([{"$out": temp_name}])
self._snapshots[checkpoint_id] = {"mode": "drop_restore"}
return checkpoint_id
def rollback(self, checkpoint: str) -> None:
"""Restore collections to checkpoint state."""
snapshot = self._snapshots[checkpoint]
if snapshot.get("mode") == "drop_restore":
# Restore from temp collections
for coll_name in self.collections:
temp_name = f"_temp_{checkpoint}_{coll_name}"
self.db[coll_name].drop()
self.db[temp_name].aggregate([{"$out": coll_name}])
self.db[temp_name].drop()
else:
# Restore from memory snapshot
for coll_name in self.collections:
self.db[coll_name].delete_many({})
if coll_name in snapshot and snapshot[coll_name]:
self.db[coll_name].insert_many(snapshot[coll_name])
def observe(self) -> Observation:
"""Get observation data for state hashing."""
data = {}
for coll_name in self.collections:
data[f"{coll_name}_count"] = self.db[coll_name].count_documents({})
# Add some document hashes for better state discrimination
for coll_name in self.collections[:3]: # Limit for performance
sample = list(self.db[coll_name].find().limit(10))
if sample:
data[f"{coll_name}_sample_ids"] = [str(doc.get("_id", "")) for doc in sample]
return Observation.create(
system="mongodb",
data=data,
metadata={"checkpoint_mode": self.checkpoint_mode},
)
def query(self, collection: str, filter: dict | None = None) -> list[dict]:
"""Query a collection (helper method)."""
return list(self.db[collection].find(filter or {}))
def insert(self, collection: str, document: dict) -> str:
"""Insert a document and return its ID."""
result = self.db[collection].insert_one(document)
return str(result.inserted_id)
def update(self, collection: str, filter: dict, update: dict) -> int:
"""Update documents and return count."""
result = self.db[collection].update_many(filter, {"$set": update})
return result.modified_count
def delete(self, collection: str, filter: dict) -> int:
"""Delete documents and return count."""
result = self.db[collection].delete_many(filter)
return result.deleted_count
Usage:
Python
from pymongo import MongoClient
from venomqa import Agent, World, Action, Invariant, Severity
from myapp.adapters.mongodb import MongoDBAdapter
client = MongoClient("mongodb://localhost:27017")
db = client["testdb"]
adapter = MongoDBAdapter(
db=db,
collections=["users", "orders", "products"],
)
world = World(api=api, systems={"db": adapter})
def create_order(api, context):
db = context.world.systems["db"]
order_id = db.insert("orders", {
"user_id": context.get("user_id"),
"amount": 100,
"status": "pending",
"created_at": datetime.utcnow(),
})
context.set("order_id", order_id)
return api.post("/orders", json={"id": order_id})
def fulfill_order(api, context):
order_id = context.get("order_id")
db = context.world.systems["db"]
db.update("orders", {"_id": ObjectId(order_id)}, {"status": "fulfilled"})
return api.post(f"/orders/{order_id}/fulfill")
Example: Elasticsearch Adapter¶
Python
from __future__ import annotations
import copy
from typing import Any
from venomqa.sandbox import Observation
class ElasticsearchAdapter:
"""Elasticsearch adapter with scroll-based snapshots.
Uses scroll API to snapshot documents, restores via bulk delete/insert.
Example:
from elasticsearch import Elasticsearch
client = Elasticsearch(["http://localhost:9200"])
adapter = ElasticsearchAdapter(
client=client,
indices=["users", "orders"],
)
world = World(api=api, systems={"es": adapter})
"""
def __init__(
self,
client,
indices: list[str],
batch_size: int = 1000,
):
"""Initialize Elasticsearch adapter.
Args:
client: Elasticsearch client instance.
indices: List of index names to track.
batch_size: Batch size for scroll operations.
"""
self.client = client
self.indices = indices
self.batch_size = batch_size
self._snapshots: dict[str, dict[str, list[dict]]] = {}
self._counter = 0
def checkpoint(self, name: str) -> str:
"""Create a snapshot of tracked indices."""
self._counter += 1
checkpoint_id = f"{name}_{self._counter}"
snapshot = {}
for index in self.indices:
snapshot[index] = self._snapshot_index(index)
self._snapshots[checkpoint_id] = snapshot
return checkpoint_id
def _snapshot_index(self, index: str) -> list[dict]:
"""Snapshot all documents from an index."""
documents = []
response = self.client.search(
index=index,
body={"query": {"match_all": {}}},
size=self.batch_size,
scroll="2m",
)
scroll_id = response.get("_scroll_id")
hits = response["hits"]["hits"]
while hits:
documents.extend(hits)
response = self.client.scroll(scroll_id=scroll_id, scroll="2m")
scroll_id = response.get("_scroll_id")
hits = response["hits"]["hits"]
if scroll_id:
try:
self.client.clear_scroll(scroll_id=scroll_id)
except Exception:
pass
return documents
def rollback(self, checkpoint: str) -> None:
"""Restore indices to checkpoint state."""
snapshot = self._snapshots[checkpoint]
for index in self.indices:
# Delete all documents
self.client.delete_by_query(
index=index,
body={"query": {"match_all": {}}},
conflicts="proceed",
)
# Restore documents
if snapshot[index]:
actions = []
for doc in snapshot[index]:
actions.append({
"_index": index,
"_id": doc["_id"],
"_source": doc["_source"],
})
from elasticsearch.helpers import bulk
bulk(self.client, actions)
def observe(self) -> Observation:
"""Get observation data for state hashing."""
data = {}
for index in self.indices:
count = self.client.count(index=index)["count"]
data[f"{index}_count"] = count
# Sample some IDs
if count > 0:
sample = self.client.search(
index=index,
body={"query": {"match_all": {}}},
size=5,
_source=False,
)
data[f"{index}_sample_ids"] = [hit["_id"] for hit in sample["hits"]["hits"]]
return Observation.create(
system="elasticsearch",
data=data,
)
def search(self, index: str, query: dict) -> list[dict]:
"""Search an index (helper method)."""
response = self.client.search(index=index, body=query)
return response["hits"]["hits"]
def index_document(self, index: str, doc: dict, id: str | None = None) -> str:
"""Index a document and return its ID."""
body = doc.copy()
response = self.client.index(index=index, id=id, body=body)
return response["_id"]
def delete_document(self, index: str, id: str) -> bool:
"""Delete a document by ID."""
try:
self.client.delete(index=index, id=id)
return True
except Exception:
return False
Example: Redis Adapter (Advanced)¶
Redis adapter using Lua scripts for atomic snapshots:
Python
from __future__ import annotations
from typing import Any
from venomqa.sandbox import Observation
class RedisAdapter:
"""Redis adapter with atomic dump/restore.
Uses Redis DUMP/RESTORE for atomic checkpoint/rollback.
Handles strings, hashes, sets, and sorted sets.
Example:
import redis
client = redis.Redis(host="localhost", port=6379, db=0)
adapter = RedisAdapter(
client=client,
key_prefix="myapp:",
)
world = World(api=api, systems={"redis": adapter})
"""
def __init__(
self,
client,
key_prefix: str = "",
scan_count: int = 1000,
):
"""Initialize Redis adapter.
Args:
client: Redis client instance.
key_prefix: Only track keys with this prefix.
scan_count: Batch size for SCAN operations.
"""
self.client = client
self.key_prefix = key_prefix
self.scan_count = scan_count
self._snapshots: dict[str, dict[str, bytes]] = {}
self._counter = 0
def _get_keys(self) -> list[str]:
"""Get all keys matching prefix."""
keys = []
cursor = 0
pattern = f"{self.key_prefix}*" if self.key_prefix else "*"
while True:
cursor, batch = self.client.scan(cursor, match=pattern, count=self.scan_count)
keys.extend(batch)
if cursor == 0:
break
return keys
def checkpoint(self, name: str) -> str:
"""Dump all keys atomically."""
self._counter += 1
checkpoint_id = f"{name}_{self._counter}"
keys = self._get_keys()
snapshot = {}
for key in keys:
dump = self.client.dump(key)
if dump:
snapshot[key] = dump
ttl = self.client.ttl(key)
snapshot[f"__ttl__{key}"] = ttl
self._snapshots[checkpoint_id] = snapshot
return checkpoint_id
def rollback(self, checkpoint: str) -> None:
"""Restore all keys from dump."""
snapshot = self._snapshots[checkpoint]
# Delete current keys
keys = self._get_keys()
if keys:
self.client.delete(*keys)
# Restore from snapshot
for key, dump in snapshot.items():
if key.startswith("__ttl__"):
continue
ttl = snapshot.get(f"__ttl__{key}", -1)
if ttl == -2: # Key had no expiry
ttl = 0
try:
self.client.restore(key, ttl, dump, replace=True)
except Exception:
pass
def observe(self) -> Observation:
"""Get observation data for state hashing."""
keys = self._get_keys()
data = {
"key_count": len(keys),
"keys": sorted(keys)[:50], # Sample for hashing
}
# Add type counts
type_counts = {}
for key in keys[:100]:
key_type = self.client.type(key)
type_counts[key_type] = type_counts.get(key_type, 0) + 1
data["type_counts"] = type_counts
return Observation.create(
system="redis",
data=data,
)
def get(self, key: str) -> Any:
"""Get a value (helper method)."""
return self.client.get(self.key_prefix + key)
def set(self, key: str, value: Any, ttl: int | None = None) -> None:
"""Set a value (helper method)."""
full_key = self.key_prefix + key
if ttl:
self.client.setex(full_key, ttl, value)
else:
self.client.set(full_key, value)
def delete(self, key: str) -> None:
"""Delete a key (helper method)."""
self.client.delete(self.key_prefix + key)
Best Practices¶
1. Minimize Snapshot Size¶
Python
def checkpoint(self, name: str) -> str:
snapshot = {}
for coll in self.collections:
# Only snapshot essential fields
snapshot[coll] = [
{"_id": doc["_id"], "status": doc.get("status")}
for doc in self.db[coll].find()
]
return snapshot
2. Handle Large Datasets¶
Python
class LargeDatasetAdapter:
def __init__(self, db, collections: list[str], max_docs: int = 10000):
self.max_docs = max_docs
def checkpoint(self, name: str) -> str:
# Warn if dataset is too large
total_docs = sum(self.db[c].count_documents({}) for c in self.collections)
if total_docs > self.max_docs:
warnings.warn(f"Large dataset ({total_docs} docs). Consider sharding.")
3. Use Efficient Queries¶
Python
def observe(self) -> Observation:
# Use aggregation for counts instead of fetching all docs
pipeline = [{"$count": "total"}]
counts = {c: list(self.db[c].aggregate(pipeline))[0]["total"]
for c in self.collections}
4. Clean Up Resources¶
5. Provide Useful Observations¶
Python
def observe(self) -> Observation:
return Observation.create(
system="mongodb",
data={
"users_count": 5,
"users_statuses": {"active": 3, "pending": 2}, # Good for dedup
"orders_count": 10,
},
)
Troubleshooting¶
"Rollback didn't restore state"¶
Ensure you're deep-copying data:
Python
import copy
def checkpoint(self, name: str) -> str:
# Bad: shallow copy
snapshot = self._data.copy()
# Good: deep copy
snapshot = copy.deepcopy(self._data)
"State explosion with adapter"¶
Add meaningful fields to observations:
Python
def observe(self) -> Observation:
return Observation.create(
system="mydb",
data={
"order_count": self.count("orders"),
"max_order_id": self.max_id("orders"), # Helps distinguish states
},
)
"Checkpoint too slow"¶
Consider lazy checkpointing: