Python GraphQL with Strawberry: Schema-First API Design
Strawberry is a modern Python GraphQL library that leverages Python type hints to define the schema. Unlike schema-first tools that generate code from SDL, Strawberry uses Python dataclasses and decorators — the schema is derived from your Python types, ensuring they stay in sync. It integrates natively with FastAPI and Django, supports async resolvers, subscriptions, dataloaders, and custom directives. This guide builds a complete GraphQL API with authentication, N+1 prevention, and real-time subscriptions.
Table of Contents
Setup and Basic Schema
Strawberry uses Python's dataclass syntax with @strawberry.type decorators. Fields become GraphQL fields automatically, with types inferred from Python type hints. Optional fields map to nullable GraphQL fields. The schema is built from a Query type (required), an optional Mutation type, and an optional Subscription type.
pip install strawberry-graphql[fastapi] uvicorn sqlalchemy asyncpg
import strawberry
from typing import Optional
from datetime import datetime
@strawberry.type
class User:
id: strawberry.ID
name: str
email: str
role: str
created_at: datetime
bio: Optional[str] = None
@strawberry.type
class Post:
id: strawberry.ID
title: str
content: str
published: bool
author_id: strawberry.ID
created_at: datetime
author: Optional["User"] = None # resolved lazily
@strawberry.type
class PageInfo:
has_next_page: bool
has_previous_page: bool
total_count: int
@strawberry.type
class UserConnection:
"""Relay-style pagination."""
items: list[User]
page_info: PageInfo
# Enums
import strawberry
from enum import Enum
@strawberry.enum
class UserRole(Enum):
ADMIN = "admin"
EDITOR = "editor"
VIEWER = "viewer"
@strawberry.type
class Query:
@strawberry.field
async def user(self, id: strawberry.ID) -> Optional[User]:
return await get_user_by_id(id)
@strawberry.field
async def users(self, limit: int = 20, offset: int = 0) -> list[User]:
return await get_users(limit=limit, offset=offset)
schema = strawberry.Schema(query=Query)
Queries and Resolvers
Resolvers are async Python functions that return data for each field. Strawberry passes the resolver's return value as the parent object to nested field resolvers. Context (the request, database session, current user) is passed via Strawberry's Info parameter — available to any resolver that declares it.
import strawberry
from strawberry.types import Info
from typing import Optional, Annotated
# Custom context type
@strawberry.type
class Context:
user_id: Optional[str]
db: "AsyncSession"
async def get_context(request) -> Context:
token = request.headers.get("Authorization", "").replace("Bearer ", "")
user_id = verify_jwt(token) if token else None
return Context(user_id=user_id, db=get_db_session())
@strawberry.type
class Query:
@strawberry.field
async def me(self, info: Info[Context, None]) -> Optional[User]:
if not info.context.user_id:
return None
return await info.context.db.get(UserModel, info.context.user_id)
@strawberry.field
async def posts(
self,
info: Info[Context, None],
first: int = 20,
after: Optional[str] = None,
filter: Optional["PostFilter"] = None,
) -> "PostConnection":
query = select(PostModel).order_by(PostModel.created_at.desc())
if filter:
if filter.published is not None:
query = query.where(PostModel.published == filter.published)
if filter.author_id:
query = query.where(PostModel.author_id == filter.author_id)
results = await info.context.db.execute(query.limit(first).offset(int(after or 0)))
posts = results.scalars().all()
return PostConnection(
items=[Post(**p.__dict__) for p in posts],
page_info=PageInfo(
has_next_page=len(posts) == first,
has_previous_page=bool(after and int(after) > 0),
total_count=len(posts),
)
)
@strawberry.input
class PostFilter:
published: Optional[bool] = None
author_id: Optional[strawberry.ID] = None
Mutations and Input Types
Mutations use @strawberry.input for input types and return rich result types that include both success data and errors. The union-based error handling pattern (returning either a success type or an error type) is cleaner than raising exceptions and maps better to GraphQL's type system.
@strawberry.input
class CreatePostInput:
title: str
content: str
published: bool = False
@strawberry.input
class UpdatePostInput:
title: Optional[str] = strawberry.UNSET
content: Optional[str] = strawberry.UNSET
published: Optional[bool] = strawberry.UNSET
@strawberry.type
class PostCreated:
post: Post
@strawberry.type
class ValidationError:
field: str
message: str
@strawberry.type
class CreatePostError:
errors: list[ValidationError]
CreatePostResult = strawberry.union("CreatePostResult", [PostCreated, CreatePostError])
@strawberry.type
class Mutation:
@strawberry.mutation
async def create_post(
self,
input: CreatePostInput,
info: Info[Context, None],
) -> CreatePostResult:
if not info.context.user_id:
return CreatePostError(errors=[
ValidationError(field="auth", message="Authentication required")
])
errors = []
if len(input.title.strip()) < 3:
errors.append(ValidationError(field="title", message="Title must be at least 3 characters"))
if len(input.content.strip()) < 10:
errors.append(ValidationError(field="content", message="Content too short"))
if errors:
return CreatePostError(errors=errors)
post_model = PostModel(
title=input.title.strip(),
content=input.content.strip(),
published=input.published,
author_id=info.context.user_id,
)
info.context.db.add(post_model)
await info.context.db.commit()
return PostCreated(post=Post(
id=str(post_model.id), title=post_model.title,
content=post_model.content, published=post_model.published,
author_id=post_model.author_id, created_at=post_model.created_at
))
@strawberry.mutation
async def delete_post(self, id: strawberry.ID, info: Info) -> bool:
post = await info.context.db.get(PostModel, id)
if not post or post.author_id != info.context.user_id:
return False
await info.context.db.delete(post)
await info.context.db.commit()
return True
DataLoaders: Solving N+1
The N+1 problem occurs when fetching a list of posts triggers a separate database query for each post's author. DataLoaders batch multiple individual lookups into a single query within the same event loop tick. Strawberry's DataLoader implementation is async and integrates naturally with any async database driver.
from strawberry.dataloader import DataLoader
from typing import Sequence
async def load_users_by_ids(keys: Sequence[str]) -> list[Optional[User]]:
"""Batch load multiple users in one query instead of N queries."""
# One SQL query for all requested IDs:
# SELECT * FROM users WHERE id = ANY($1)
from sqlalchemy import select
async with get_db_session() as db:
result = await db.execute(
select(UserModel).where(UserModel.id.in_(keys))
)
users_by_id = {str(u.id): u for u in result.scalars()}
# Return in the same order as keys (required by DataLoader contract)
return [users_by_id.get(key) for key in keys]
# Create DataLoaders per-request (not global, to avoid stale data)
@strawberry.type
class Context:
user_id: Optional[str]
db: "AsyncSession"
user_loader: DataLoader = strawberry.field(default_factory=lambda: DataLoader(load_fn=load_users_by_ids))
@strawberry.type
class Post:
id: strawberry.ID
title: str
content: str
author_id: strawberry.ID
@strawberry.field
async def author(self, info: Info[Context, None]) -> Optional[User]:
"""Uses DataLoader — batched automatically for all posts in response."""
return await info.context.user_loader.load(self.author_id)
# With 100 posts, this generates 2 queries total:
# 1. SELECT * FROM posts LIMIT 100
# 2. SELECT * FROM users WHERE id = ANY([...100 ids...])
# Without DataLoader: 101 queries
Subscriptions: Real-Time Data
Strawberry supports GraphQL subscriptions over WebSocket. Subscriptions are async generators that yield values whenever an event occurs. They work with any pub/sub backend — Redis Pub/Sub, in-memory queues, or async generators. Clients receive updates in real time without polling.
import asyncio
from typing import AsyncGenerator
# Simple in-memory pub/sub for demo (use Redis in production)
_subscribers: dict[str, list[asyncio.Queue]] = {}
def publish(channel: str, data: dict):
for queue in _subscribers.get(channel, []):
queue.put_nowait(data)
async def subscribe(channel: str) -> AsyncGenerator:
queue = asyncio.Queue()
_subscribers.setdefault(channel, []).append(queue)
try:
while True:
yield await queue.get()
finally:
_subscribers[channel].remove(queue)
@strawberry.type
class PostEvent:
action: str # "created" | "updated" | "deleted"
post: Post
@strawberry.type
class Subscription:
@strawberry.subscription
async def post_events(self, info: Info) -> AsyncGenerator[PostEvent, None]:
"""Stream real-time post events to the client."""
async for event in subscribe("posts"):
yield PostEvent(
action=event["action"],
post=Post(**event["post"])
)
@strawberry.subscription
async def user_notifications(
self, info: Info, user_id: strawberry.ID
) -> AsyncGenerator[str, None]:
"""Stream notifications for a specific user."""
async for msg in subscribe(f"notifications:{user_id}"):
yield msg["message"]
# Wire up schema
schema = strawberry.Schema(query=Query, mutation=Mutation, subscription=Subscription)
Authentication and Permissions
Strawberry supports field-level permissions via permission classes and schema-level directives. Permission classes are the cleanest approach: each permission class checks one condition and can be composed. They raise PermissionError with a descriptive message that GraphQL returns as an error alongside partial data.
from strawberry.permission import BasePermission
from strawberry.types import Info
class IsAuthenticated(BasePermission):
message = "You must be logged in"
def has_permission(self, source, info: Info, **kwargs) -> bool:
return info.context.user_id is not None
class IsAdmin(BasePermission):
message = "Admin role required"
async def has_permission(self, source, info: Info, **kwargs) -> bool:
if not info.context.user_id:
return False
user = await info.context.db.get(UserModel, info.context.user_id)
return user and user.role == "admin"
class OwnsPost(BasePermission):
message = "You can only modify your own posts"
async def has_permission(self, source, info: Info, id: str, **kwargs) -> bool:
post = await info.context.db.get(PostModel, id)
return post and str(post.author_id) == info.context.user_id
@strawberry.type
class Query:
@strawberry.field(permission_classes=[IsAuthenticated])
async def my_posts(self, info: Info) -> list[Post]:
return await get_posts_by_author(info.context.user_id, info.context.db)
@strawberry.field(permission_classes=[IsAdmin])
async def all_users(self, info: Info) -> list[User]:
return await get_all_users(info.context.db)
@strawberry.type
class Mutation:
@strawberry.mutation(permission_classes=[IsAuthenticated, OwnsPost])
async def update_post(self, id: strawberry.ID, input: UpdatePostInput, info: Info) -> Post:
post = await info.context.db.get(PostModel, id)
if input.title is not strawberry.UNSET:
post.title = input.title
await info.context.db.commit()
return Post(**post.__dict__)
FastAPI Integration
Strawberry provides a GraphQLRouter that mounts into FastAPI, giving you automatic Swagger docs alongside the GraphQL endpoint and GraphiQL IDE. The router handles both standard HTTP POST and WebSocket connections for subscriptions.
from fastapi import FastAPI, Depends, Request
from strawberry.fastapi import GraphQLRouter
from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL
async def get_context(request: Request) -> Context:
token = request.headers.get("Authorization", "").replace("Bearer ", "")
user_id = None
if token:
try:
payload = verify_jwt(token)
user_id = payload["sub"]
except Exception:
pass
db = get_db_session()
return Context(
user_id=user_id,
db=db,
user_loader=DataLoader(load_fn=load_users_by_ids)
)
schema = strawberry.Schema(
query=Query,
mutation=Mutation,
subscription=Subscription,
)
graphql_app = GraphQLRouter(
schema,
context_getter=get_context,
subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL],
graphiql=True, # Enable GraphiQL IDE at /graphql
)
app = FastAPI()
app.include_router(graphql_app, prefix="/graphql")
# Test with curl:
# curl -X POST http://localhost:8000/graphql \
# -H "Content-Type: application/json" \
# -H "Authorization: Bearer " \
# -d '{"query": "{ me { id name email } }"}'
strawberry.extensions.QueryDepthLimiter) to prevent deeply nested queries from overloading your database. Enable query logging during development to spot N+1 patterns early.