Python GraphQL with Strawberry: Schema-First API Design

Strawberry is a modern Python GraphQL library that leverages Python type hints to define the schema. Unlike schema-first tools that generate code from SDL, Strawberry uses Python dataclasses and decorators — the schema is derived from your Python types, ensuring they stay in sync. It integrates natively with FastAPI and Django, supports async resolvers, subscriptions, dataloaders, and custom directives. This guide builds a complete GraphQL API with authentication, N+1 prevention, and real-time subscriptions.

Setup and Basic Schema

Strawberry uses Python's dataclass syntax with @strawberry.type decorators. Fields become GraphQL fields automatically, with types inferred from Python type hints. Optional fields map to nullable GraphQL fields. The schema is built from a Query type (required), an optional Mutation type, and an optional Subscription type.

pip install strawberry-graphql[fastapi] uvicorn sqlalchemy asyncpg
import strawberry
from typing import Optional
from datetime import datetime

@strawberry.type
class User:
    id: strawberry.ID
    name: str
    email: str
    role: str
    created_at: datetime
    bio: Optional[str] = None

@strawberry.type
class Post:
    id: strawberry.ID
    title: str
    content: str
    published: bool
    author_id: strawberry.ID
    created_at: datetime
    author: Optional["User"] = None  # resolved lazily

@strawberry.type
class PageInfo:
    has_next_page: bool
    has_previous_page: bool
    total_count: int

@strawberry.type
class UserConnection:
    """Relay-style pagination."""
    items: list[User]
    page_info: PageInfo

# Enums
import strawberry
from enum import Enum

@strawberry.enum
class UserRole(Enum):
    ADMIN = "admin"
    EDITOR = "editor"
    VIEWER = "viewer"

@strawberry.type
class Query:
    @strawberry.field
    async def user(self, id: strawberry.ID) -> Optional[User]:
        return await get_user_by_id(id)

    @strawberry.field
    async def users(self, limit: int = 20, offset: int = 0) -> list[User]:
        return await get_users(limit=limit, offset=offset)

schema = strawberry.Schema(query=Query)

Queries and Resolvers

Resolvers are async Python functions that return data for each field. Strawberry passes the resolver's return value as the parent object to nested field resolvers. Context (the request, database session, current user) is passed via Strawberry's Info parameter — available to any resolver that declares it.

import strawberry
from strawberry.types import Info
from typing import Optional, Annotated

# Custom context type
@strawberry.type
class Context:
    user_id: Optional[str]
    db: "AsyncSession"

async def get_context(request) -> Context:
    token = request.headers.get("Authorization", "").replace("Bearer ", "")
    user_id = verify_jwt(token) if token else None
    return Context(user_id=user_id, db=get_db_session())

@strawberry.type
class Query:
    @strawberry.field
    async def me(self, info: Info[Context, None]) -> Optional[User]:
        if not info.context.user_id:
            return None
        return await info.context.db.get(UserModel, info.context.user_id)

    @strawberry.field
    async def posts(
        self,
        info: Info[Context, None],
        first: int = 20,
        after: Optional[str] = None,
        filter: Optional["PostFilter"] = None,
    ) -> "PostConnection":
        query = select(PostModel).order_by(PostModel.created_at.desc())
        if filter:
            if filter.published is not None:
                query = query.where(PostModel.published == filter.published)
            if filter.author_id:
                query = query.where(PostModel.author_id == filter.author_id)
        results = await info.context.db.execute(query.limit(first).offset(int(after or 0)))
        posts = results.scalars().all()
        return PostConnection(
            items=[Post(**p.__dict__) for p in posts],
            page_info=PageInfo(
                has_next_page=len(posts) == first,
                has_previous_page=bool(after and int(after) > 0),
                total_count=len(posts),
            )
        )

@strawberry.input
class PostFilter:
    published: Optional[bool] = None
    author_id: Optional[strawberry.ID] = None

Mutations and Input Types

Mutations use @strawberry.input for input types and return rich result types that include both success data and errors. The union-based error handling pattern (returning either a success type or an error type) is cleaner than raising exceptions and maps better to GraphQL's type system.

@strawberry.input
class CreatePostInput:
    title: str
    content: str
    published: bool = False

@strawberry.input
class UpdatePostInput:
    title: Optional[str] = strawberry.UNSET
    content: Optional[str] = strawberry.UNSET
    published: Optional[bool] = strawberry.UNSET

@strawberry.type
class PostCreated:
    post: Post

@strawberry.type
class ValidationError:
    field: str
    message: str

@strawberry.type
class CreatePostError:
    errors: list[ValidationError]

CreatePostResult = strawberry.union("CreatePostResult", [PostCreated, CreatePostError])

@strawberry.type
class Mutation:
    @strawberry.mutation
    async def create_post(
        self,
        input: CreatePostInput,
        info: Info[Context, None],
    ) -> CreatePostResult:
        if not info.context.user_id:
            return CreatePostError(errors=[
                ValidationError(field="auth", message="Authentication required")
            ])

        errors = []
        if len(input.title.strip()) < 3:
            errors.append(ValidationError(field="title", message="Title must be at least 3 characters"))
        if len(input.content.strip()) < 10:
            errors.append(ValidationError(field="content", message="Content too short"))

        if errors:
            return CreatePostError(errors=errors)

        post_model = PostModel(
            title=input.title.strip(),
            content=input.content.strip(),
            published=input.published,
            author_id=info.context.user_id,
        )
        info.context.db.add(post_model)
        await info.context.db.commit()
        return PostCreated(post=Post(
            id=str(post_model.id), title=post_model.title,
            content=post_model.content, published=post_model.published,
            author_id=post_model.author_id, created_at=post_model.created_at
        ))

    @strawberry.mutation
    async def delete_post(self, id: strawberry.ID, info: Info) -> bool:
        post = await info.context.db.get(PostModel, id)
        if not post or post.author_id != info.context.user_id:
            return False
        await info.context.db.delete(post)
        await info.context.db.commit()
        return True

DataLoaders: Solving N+1

The N+1 problem occurs when fetching a list of posts triggers a separate database query for each post's author. DataLoaders batch multiple individual lookups into a single query within the same event loop tick. Strawberry's DataLoader implementation is async and integrates naturally with any async database driver.

from strawberry.dataloader import DataLoader
from typing import Sequence

async def load_users_by_ids(keys: Sequence[str]) -> list[Optional[User]]:
    """Batch load multiple users in one query instead of N queries."""
    # One SQL query for all requested IDs:
    # SELECT * FROM users WHERE id = ANY($1)
    from sqlalchemy import select
    async with get_db_session() as db:
        result = await db.execute(
            select(UserModel).where(UserModel.id.in_(keys))
        )
        users_by_id = {str(u.id): u for u in result.scalars()}
    # Return in the same order as keys (required by DataLoader contract)
    return [users_by_id.get(key) for key in keys]

# Create DataLoaders per-request (not global, to avoid stale data)
@strawberry.type
class Context:
    user_id: Optional[str]
    db: "AsyncSession"
    user_loader: DataLoader = strawberry.field(default_factory=lambda: DataLoader(load_fn=load_users_by_ids))

@strawberry.type
class Post:
    id: strawberry.ID
    title: str
    content: str
    author_id: strawberry.ID

    @strawberry.field
    async def author(self, info: Info[Context, None]) -> Optional[User]:
        """Uses DataLoader — batched automatically for all posts in response."""
        return await info.context.user_loader.load(self.author_id)

# With 100 posts, this generates 2 queries total:
# 1. SELECT * FROM posts LIMIT 100
# 2. SELECT * FROM users WHERE id = ANY([...100 ids...])
# Without DataLoader: 101 queries

Subscriptions: Real-Time Data

Strawberry supports GraphQL subscriptions over WebSocket. Subscriptions are async generators that yield values whenever an event occurs. They work with any pub/sub backend — Redis Pub/Sub, in-memory queues, or async generators. Clients receive updates in real time without polling.

import asyncio
from typing import AsyncGenerator

# Simple in-memory pub/sub for demo (use Redis in production)
_subscribers: dict[str, list[asyncio.Queue]] = {}

def publish(channel: str, data: dict):
    for queue in _subscribers.get(channel, []):
        queue.put_nowait(data)

async def subscribe(channel: str) -> AsyncGenerator:
    queue = asyncio.Queue()
    _subscribers.setdefault(channel, []).append(queue)
    try:
        while True:
            yield await queue.get()
    finally:
        _subscribers[channel].remove(queue)

@strawberry.type
class PostEvent:
    action: str  # "created" | "updated" | "deleted"
    post: Post

@strawberry.type
class Subscription:
    @strawberry.subscription
    async def post_events(self, info: Info) -> AsyncGenerator[PostEvent, None]:
        """Stream real-time post events to the client."""
        async for event in subscribe("posts"):
            yield PostEvent(
                action=event["action"],
                post=Post(**event["post"])
            )

    @strawberry.subscription
    async def user_notifications(
        self, info: Info, user_id: strawberry.ID
    ) -> AsyncGenerator[str, None]:
        """Stream notifications for a specific user."""
        async for msg in subscribe(f"notifications:{user_id}"):
            yield msg["message"]

# Wire up schema
schema = strawberry.Schema(query=Query, mutation=Mutation, subscription=Subscription)

Authentication and Permissions

Strawberry supports field-level permissions via permission classes and schema-level directives. Permission classes are the cleanest approach: each permission class checks one condition and can be composed. They raise PermissionError with a descriptive message that GraphQL returns as an error alongside partial data.

from strawberry.permission import BasePermission
from strawberry.types import Info

class IsAuthenticated(BasePermission):
    message = "You must be logged in"

    def has_permission(self, source, info: Info, **kwargs) -> bool:
        return info.context.user_id is not None

class IsAdmin(BasePermission):
    message = "Admin role required"

    async def has_permission(self, source, info: Info, **kwargs) -> bool:
        if not info.context.user_id:
            return False
        user = await info.context.db.get(UserModel, info.context.user_id)
        return user and user.role == "admin"

class OwnsPost(BasePermission):
    message = "You can only modify your own posts"

    async def has_permission(self, source, info: Info, id: str, **kwargs) -> bool:
        post = await info.context.db.get(PostModel, id)
        return post and str(post.author_id) == info.context.user_id

@strawberry.type
class Query:
    @strawberry.field(permission_classes=[IsAuthenticated])
    async def my_posts(self, info: Info) -> list[Post]:
        return await get_posts_by_author(info.context.user_id, info.context.db)

    @strawberry.field(permission_classes=[IsAdmin])
    async def all_users(self, info: Info) -> list[User]:
        return await get_all_users(info.context.db)

@strawberry.type
class Mutation:
    @strawberry.mutation(permission_classes=[IsAuthenticated, OwnsPost])
    async def update_post(self, id: strawberry.ID, input: UpdatePostInput, info: Info) -> Post:
        post = await info.context.db.get(PostModel, id)
        if input.title is not strawberry.UNSET:
            post.title = input.title
        await info.context.db.commit()
        return Post(**post.__dict__)

FastAPI Integration

Strawberry provides a GraphQLRouter that mounts into FastAPI, giving you automatic Swagger docs alongside the GraphQL endpoint and GraphiQL IDE. The router handles both standard HTTP POST and WebSocket connections for subscriptions.

from fastapi import FastAPI, Depends, Request
from strawberry.fastapi import GraphQLRouter
from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL

async def get_context(request: Request) -> Context:
    token = request.headers.get("Authorization", "").replace("Bearer ", "")
    user_id = None
    if token:
        try:
            payload = verify_jwt(token)
            user_id = payload["sub"]
        except Exception:
            pass
    db = get_db_session()
    return Context(
        user_id=user_id,
        db=db,
        user_loader=DataLoader(load_fn=load_users_by_ids)
    )

schema = strawberry.Schema(
    query=Query,
    mutation=Mutation,
    subscription=Subscription,
)

graphql_app = GraphQLRouter(
    schema,
    context_getter=get_context,
    subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL],
    graphiql=True,  # Enable GraphiQL IDE at /graphql
)

app = FastAPI()
app.include_router(graphql_app, prefix="/graphql")

# Test with curl:
# curl -X POST http://localhost:8000/graphql \
#   -H "Content-Type: application/json" \
#   -H "Authorization: Bearer " \
#   -d '{"query": "{ me { id name email } }"}'
N+1 in production: Always add DataLoaders for any nested relationship that could be queried on a list. Use query complexity limits (strawberry.extensions.QueryDepthLimiter) to prevent deeply nested queries from overloading your database. Enable query logging during development to spot N+1 patterns early.