Back to articles

Designing a Cost-Effective RAG ETL Flow on AWS

By tvignoli DevOps Folio·Published on September 15, 2024

Retrieval-Augmented Generation lives or dies by the freshness of the vector store. In production environments we cannot rely on ad-hoc scripts; we need a hardened ETL flow that ingests documents, cleans them, enriches metadata, generates embeddings, and exposes audit trails for compliance. After implementing RAG pipelines for banking, healthcare, and e-commerce clients processing millions of documents monthly, I've distilled the patterns that deliver deterministic behavior under unpredictable loads while keeping costs predictable.

Reference Architecture & Data Contracts

We begin with Amazon EventBridge rules that respond to new objects landing in S3 or to webhook events from ticketing systems. Each event triggers a Lambda that validates the payload against a JSON schema stored in AWS Glue Data Catalog. Strict contracts prevent malformed documents from clogging the stream later. When the payload is valid, Step Functions orchestrates a multi-stage pipeline: extraction, enrichment, embeddings, and load.

The validation layer is critical. I've seen pipelines fail because a single malformed PDF corrupted the entire batch. Our validation Lambda checks document structure, file size limits (we cap at 50MB per document), MIME types, and required metadata fields. Invalid payloads are immediately routed to a dead-letter queue with detailed error context, allowing operations teams to triage without digging through CloudWatch logs.

import json
import boto3
from jsonschema import validate, ValidationError
from typing import Dict, Any

s3 = boto3.client('s3')
glue = boto3.client('glue')

def validate_payload(event: Dict[str, Any]) -> Dict[str, Any]:
    """Validate incoming document payload against Glue schema."""
    schema = glue.get_schema(
        SchemaId={'SchemaName': 'rag-document-schema', 'RegistryName': 'rag-registry'}
    )
    
    try:
        validate(instance=event, schema=json.loads(schema['SchemaDefinition']))
        
        # Additional business rules
        if event['fileSize'] > 50 * 1024 * 1024:  # 50MB limit
            raise ValueError(f"File size {event['fileSize']} exceeds 50MB limit")
        
        if event['mimeType'] not in ['application/pdf', 'text/plain', 'text/markdown']:
            raise ValueError(f"Unsupported MIME type: {event['mimeType']}")
        
        return {'valid': True, 'payload': event}
    except ValidationError as e:
        return {'valid': False, 'error': str(e), 'payload': event}
{
  "StartAt": "ValidatePayload",
  "States": {
    "ValidatePayload": {
      "Type": "Task",
      "Resource": "arn:aws:states:::lambda:invoke",
      "Parameters": {
        "FunctionName": "rag-validate",
        "Payload": {
          "s3Key": "$.s3Key",
          "bucket": "$.bucket",
          "source": "$.source"
        }
      },
      "Retry": [{
        "ErrorEquals": ["Lambda.ServiceException", "Lambda.AWSLambdaException"],
        "IntervalSeconds": 2,
        "MaxAttempts": 3,
        "BackoffRate": 2.0
      }],
      "Catch": [{
        "ErrorEquals": ["States.ALL"],
        "Next": "SendToDLQ",
        "ResultPath": "$.error"
      }],
      "Next": "Extract"
    },
    "Extract": {
      "Type": "Task",
      "Resource": "arn:aws:states:::lambda:invoke",
      "Parameters": { "FunctionName": "rag-extract" },
      "Next": "Enrich"
    },
    "Enrich": {
      "Type": "Parallel",
      "Branches": [
        {
          "StartAt": "ExtractMetadata",
          "States": {
            "ExtractMetadata": {
              "Type": "Task",
              "Resource": "arn:aws:states:::glue:startJobRun",
              "Parameters": {
                "JobName": "rag-metadata-enrichment",
                "Arguments": {
                  "--input-path": "$.extractedPath",
                  "--output-path": "$.enrichedPath"
                }
              },
              "End": true
            }
          }
        },
        {
          "StartAt": "GenerateEmbeddings",
          "States": {
            "GenerateEmbeddings": {
              "Type": "Task",
              "Resource": "arn:aws:states:::lambda:invoke",
              "Parameters": {
                "FunctionName": "rag-embed",
                "Payload": {
                  "text": "$.extractedText",
                  "model": "amazon.titan-embed-text-v1",
                  "dimensions": 1024
                }
              },
              "Retry": [{
                "ErrorEquals": ["Lambda.ThrottlingException"],
                "IntervalSeconds": 5,
                "MaxAttempts": 5,
                "BackoffRate": 2.0
              }],
              "End": true
            }
          }
        }
      ],
      "Next": "Load"
    },
    "Load": {
      "Type": "Task",
      "Resource": "arn:aws:states:::lambda:invoke",
      "Parameters": {
        "FunctionName": "rag-opensearch-indexer",
        "Payload": {
          "embeddings": "$.embeddings",
          "metadata": "$.metadata",
          "documentId": "$.documentId"
        }
      },
      "End": true
    },
    "SendToDLQ": {
      "Type": "Task",
      "Resource": "arn:aws:states:::sqs:sendMessage",
      "Parameters": {
        "QueueUrl": "${DLQ_URL}",
        "MessageBody": "$"
      },
      "End": true
    }
  }
}

Document Processing & Embedding Strategy

Content-heavy documents run through Textract (table extraction) and Bedrock Titan multimodal models for quick summarisation. Language normalisation happens in Amazon Comprehend. Embeddings rely on a SageMaker Serverless Endpoint with auto-scaling warm pools, avoiding cold-start spikes during business hours.

The embedding strategy matters enormously for cost and quality. We use chunking with overlap (typically 200 tokens with 50-token overlap) to preserve context across boundaries. For structured documents like PDFs with tables, we extract tables separately and embed them as distinct chunks, preserving referential integrity through metadata links.

import boto3
from typing import List, Dict
import tiktoken  # For token counting

bedrock = boto3.client('bedrock-runtime')
sagemaker = boto3.client('sagemaker-runtime')

class DocumentChunker:
    """Intelligent chunking with overlap and metadata preservation."""
    
    def __init__(self, chunk_size: int = 200, overlap: int = 50):
        self.chunk_size = chunk_size
        self.overlap = overlap
        self.encoding = tiktoken.get_encoding("cl100k_base")
    
    def chunk_with_overlap(self, text: str, metadata: Dict) -> List[Dict]:
        """Split text into overlapping chunks preserving context."""
        tokens = self.encoding.encode(text)
        chunks = []
        
        for i in range(0, len(tokens), self.chunk_size - self.overlap):
            chunk_tokens = tokens[i:i + self.chunk_size]
            chunk_text = self.encoding.decode(chunk_tokens)
            
            chunks.append({
                'text': chunk_text,
                'chunk_index': len(chunks),
                'start_token': i,
                'end_token': min(i + self.chunk_size, len(tokens)),
                'metadata': {
                    **metadata,
                    'chunk_id': f"{metadata['document_id']}_chunk_{len(chunks)}"
                }
            })
        
        return chunks
    
    def extract_tables(self, pdf_path: str) -> List[Dict]:
        """Extract tables from PDF using Textract."""
        textract = boto3.client('textract')
        response = textract.analyze_document(
            Document={'S3Object': {'Bucket': 'rag-documents', 'Name': pdf_path}},
            FeatureTypes=['TABLES']
        )
        
        tables = []
        for block in response['Blocks']:
            if block['BlockType'] == 'TABLE':
                tables.append({
                    'table_id': block['Id'],
                    'cells': self._extract_table_cells(block, response['Blocks']),
                    'metadata': {'source': pdf_path, 'page': block.get('Page', 1)}
                })
        
        return tables
    
    def generate_embeddings_batch(self, chunks: List[Dict]) -> List[Dict]:
        """Batch embedding generation with retry logic."""
        embeddings = []
        
        for chunk in chunks:
            try:
                response = sagemaker.invoke_endpoint(
                    EndpointName='titan-embed-endpoint',
                    ContentType='application/json',
                    Body=json.dumps({
                        'inputText': chunk['text'],
                        'dimensions': 1024
                    })
                )
                
                embedding = json.loads(response['Body'].read())['embedding']
                
                embeddings.append({
                    **chunk,
                    'embedding': embedding,
                    'embedding_model': 'amazon.titan-embed-text-v1',
                    'embedding_dimensions': len(embedding)
                })
            except Exception as e:
                # Log and continue - failed chunks go to DLQ
                print(f"Failed to embed chunk {chunk['chunk_id']}: {e}")
                continue
        
        return embeddings

Real-world Use Cases & Production Lessons

Case study #1: A European fintech aggregates 200k PDF disclosures daily across multiple regulatory jurisdictions. The challenge: data residency requirements meant documents from EU customers couldn't leave the EU, but we needed a unified search experience. Solution: We split the workload across three regions (eu-west-1, eu-central-1, eu-north-1) using regional S3 buckets, but centralized the OpenSearch Serverless collection in eu-west-1 using cross-region replication. Glue jobs write Parquet + Snappy artifacts to S3, and we keep 30-day rolling windows in hot storage while archiving older embeddings to Glacier Instant Retrieval. Result: 41% storage savings ($2,300/month reduction) and SLA-compliant search latency under 200ms p95.

The key insight here was using Step Functions' Map state to parallelize regional processing while maintaining a single source of truth. Each regional branch processes documents independently, but all write to the same OpenSearch collection through VPC endpoints, ensuring low latency while respecting compliance boundaries.

import boto3
from datetime import datetime, timedelta

s3 = boto3.client('s3')
glue = boto3.client('glue')

def archive_old_embeddings():
    """Archive embeddings older than 30 days to Glacier Instant Retrieval."""
    bucket = 'rag-embeddings-prod'
    prefix = 'embeddings/'
    
    cutoff_date = datetime.now() - timedelta(days=30)
    
    paginator = s3.get_paginator('list_objects_v2')
    for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
        for obj in page.get('Contents', []):
            if obj['LastModified'] < cutoff_date:
                # Copy to Glacier Instant Retrieval
                copy_source = {'Bucket': bucket, 'Key': obj['Key']}
                s3.copy_object(
                    CopySource=copy_source,
                    Bucket=bucket,
                    Key=obj['Key'].replace('embeddings/', 'archive/'),
                    StorageClass='GLACIER_IR',
                    Metadata={
                        'original-date': obj['LastModified'].isoformat(),
                        'archived-date': datetime.now().isoformat()
                    }
                )
                # Delete from hot storage
                s3.delete_object(Bucket=bucket, Key=obj['Key'])
                print(f"Archived {obj['Key']} to Glacier IR")

Case study #2: An ITSM provider needs near-real-time ingestion of incident tickets from Jira, ServiceNow, and PagerDuty. The requirement: tickets must be searchable within 5 minutes of creation. Challenge: webhook storms during incidents could overwhelm the pipeline. Solution: EventBridge triggers the pipeline every five minutes with batching, and Step Functions isolates the embedding branch so that a failure there doesn't block metadata enrichment. Failed embeddings land on an SQS DLQ with exponential backoff, allowing replay without redeploying the whole stack. We also implemented Lambda reserved concurrency (50 concurrent executions) to prevent cascading failures.

The critical pattern here is circuit breaker logic. When embedding failures exceed a threshold (we use 10% failure rate over 5 minutes), we automatically switch to a degraded mode: metadata is still enriched and indexed, but embeddings are queued for batch reprocessing during off-peak hours. This ensures the system remains operational even when Bedrock or SageMaker experience issues.

import boto3
import json
from collections import deque
from datetime import datetime, timedelta

sqs = boto3.client('sqs')
cloudwatch = boto3.client('cloudwatch')

class CircuitBreaker:
    """Circuit breaker for embedding failures."""
    
    def __init__(self, failure_threshold: float = 0.1, window_minutes: int = 5):
        self.failure_threshold = failure_threshold
        self.window_minutes = window_minutes
        self.failure_history = deque(maxlen=100)
        self.state = 'CLOSED'  # CLOSED, OPEN, HALF_OPEN
    
    def record_failure(self):
        """Record a failure event."""
        self.failure_history.append({
            'timestamp': datetime.now(),
            'type': 'failure'
        })
        self._check_state()
    
    def record_success(self):
        """Record a success event."""
        self.failure_history.append({
            'timestamp': datetime.now(),
            'type': 'success'
        })
        self._check_state()
    
    def _check_state(self):
        """Check if circuit should open/close."""
        now = datetime.now()
        window_start = now - timedelta(minutes=self.window_minutes)
        
        recent_events = [
            e for e in self.failure_history
            if e['timestamp'] >= window_start
        ]
        
        if len(recent_events) < 10:
            return  # Not enough data
        
        failures = sum(1 for e in recent_events if e['type'] == 'failure')
        failure_rate = failures / len(recent_events)
        
        if failure_rate >= self.failure_threshold and self.state == 'CLOSED':
            self.state = 'OPEN'
            self._publish_metric('CircuitBreakerOpened', 1)
            print(f"Circuit breaker OPENED - failure rate: {failure_rate:.2%}")
        elif failure_rate < self.failure_threshold / 2 and self.state == 'OPEN':
            self.state = 'HALF_OPEN'
            print("Circuit breaker HALF_OPEN - testing recovery")
        elif self.state == 'HALF_OPEN' and len([e for e in recent_events[-10:] if e['type'] == 'success']) >= 8:
            self.state = 'CLOSED'
            self._publish_metric('CircuitBreakerClosed', 1)
            print("Circuit breaker CLOSED - system recovered")
    
    def should_allow_request(self) -> bool:
        """Check if request should be allowed."""
        return self.state != 'OPEN'
    
    def _publish_metric(self, metric_name: str, value: float):
        """Publish CloudWatch metric."""
        cloudwatch.put_metric_data(
            Namespace='RAG/ETL',
            MetricData=[{
                'MetricName': metric_name,
                'Value': value,
                'Timestamp': datetime.now()
            }]
        )

def replay_from_dlq(limit: int = 10, max_retries: int = 3):
    """Replay failed messages from DLQ with retry tracking."""
    sqs = boto3.client('sqs')
    
    messages = sqs.receive_message(
        QueueUrl=DLQ_URL,
        MaxNumberOfMessages=limit,
        AttributeNames=['ApproximateReceiveCount']
    )
    
    for msg in messages.get('Messages', []):
        retry_count = int(msg['Attributes'].get('ApproximateReceiveCount', 0))
        
        if retry_count >= max_retries:
            # Move to manual review queue
            sqs.send_message(
                QueueUrl=MANUAL_REVIEW_QUEUE,
                MessageBody=msg['Body'],
                MessageAttributes={
                    'OriginalDLQReceiptHandle': {'StringValue': msg['ReceiptHandle'], 'DataType': 'String'},
                    'RetryCount': {'StringValue': str(retry_count), 'DataType': 'Number'}
                }
            )
            sqs.delete_message(QueueUrl=DLQ_URL, ReceiptHandle=msg['ReceiptHandle'])
            continue
        
        # Replay to main queue with exponential backoff delay
        delay_seconds = min(2 ** retry_count, 900)  # Max 15 minutes
        
        sqs.send_message(
            QueueUrl=MAIN_QUEUE,
            MessageBody=msg['Body'],
            DelaySeconds=delay_seconds,
            MessageAttributes={
                'RetryCount': {'StringValue': str(retry_count + 1), 'DataType': 'Number'},
                'OriginalTimestamp': {'StringValue': msg['Attributes'].get('SentTimestamp', ''), 'DataType': 'String'}
            }
        )
        sqs.delete_message(QueueUrl=DLQ_URL, ReceiptHandle=msg['ReceiptHandle'])

Cost Controls, Scaling, and Observability

Two rules keep bills predictable: micro-batching (max 5 minutes of payload per run) and aggressive tagging. Every resource carries `env`, `rag-etl`, `customer`, and `cost-center` tags so AWS Cost Explorer can pivot by feature. Lambda reserved concurrency protects against stampedes caused by third-party webhooks. For the vector store, start with OpenSearch Serverless (2 OCUs baseline) and move to Aurora PostgreSQL + pgvector when you need relational joins or tenant isolation.

Cost optimization is an ongoing discipline. We use AWS Cost Anomaly Detection to alert when daily spend exceeds baseline by 20%. Every Lambda function has CloudWatch alarms for duration and memory usage, and we regularly right-size based on actual metrics. For high-volume workloads, consider provisioned concurrency for critical Lambda functions (at $0.015 per GB-second) to eliminate cold starts, but only after profiling shows cold starts are actually impacting SLA.

import boto3
from datetime import datetime, timedelta

ce = boto3.client('ce')  # Cost Explorer
cloudwatch = boto3.client('cloudwatch')

def analyze_rag_costs(start_date: str, end_date: str):
    """Analyze RAG ETL costs by service and resource."""
    response = ce.get_cost_and_usage(
        TimePeriod={'Start': start_date, 'End': end_date},
        Granularity='DAILY',
        Metrics=['UnblendedCost'],
        GroupBy=[
            {'Type': 'DIMENSION', 'Key': 'SERVICE'},
            {'Type': 'TAG', 'Key': 'rag-etl'}
        ],
        Filter={
            'Tags': {
                'Key': 'rag-etl',
                'Values': ['true']
            }
        }
    )
    
    costs_by_service = {}
    for result in response['ResultsByTime']:
        for group in result['Groups']:
            service = group['Keys'][0]
            cost = float(group['Metrics']['UnblendedCost']['Amount'])
            costs_by_service[service] = costs_by_service.get(service, 0) + cost
    
    # Identify optimization opportunities
    total_cost = sum(costs_by_service.values())
    print(f"Total RAG ETL cost: ${total_cost:.2f}")
    
    for service, cost in sorted(costs_by_service.items(), key=lambda x: x[1], reverse=True):
        percentage = (cost / total_cost) * 100
        print(f"{service}: ${cost:.2f} ({percentage:.1f}%)")
        
        # Recommendations
        if service == 'Amazon SageMaker' and percentage > 30:
            print("  → Consider using Bedrock instead for embeddings (60% cost reduction)")
        elif service == 'AWS Lambda' and percentage > 25:
            print("  → Review memory allocation and consider ARM-based Graviton2 (20% cheaper)")
        elif service == 'Amazon OpenSearch Service' and percentage > 20:
            print("  → Evaluate moving to Aurora PostgreSQL + pgvector for cost savings")

Instrument everything with CloudWatch Embedded Metric Format (EMF) so you can chart ingestion latency, embedding cost per document, and failure rate per datasource. Store provenance metadata (source URL, checksum, embedding model version, processing timestamp) inside OpenSearch fields so your RAG layer can trace answers back to origin—a must-have when auditors ask "why did the assistant mention X?" or when debugging hallucination issues.

import json
from datetime import datetime
from typing import Dict, Any

class RAGMetrics:
    """CloudWatch EMF metrics for RAG ETL pipeline."""
    
    @staticmethod
    def emit_ingestion_metric(document_id: str, latency_ms: float, 
                             datasource: str, success: bool):
        """Emit ingestion latency metric."""
        metric = {
            '_aws': {
                'CloudWatchMetrics': [{
                    'Namespace': 'RAG/ETL',
                    'Metrics': [{
                        'MetricName': 'IngestionLatency',
                        'Unit': 'Milliseconds'
                    }],
                    'Dimensions': [['DocumentId', 'DataSource', 'Status']]
                }],
                'Timestamp': int(datetime.now().timestamp() * 1000)
            },
            'DocumentId': document_id,
            'DataSource': datasource,
            'Status': 'Success' if success else 'Failure',
            'IngestionLatency': latency_ms
        }
        print(json.dumps(metric))
    
    @staticmethod
    def emit_embedding_cost(document_id: str, cost_usd: float, 
                           model: str, dimensions: int):
        """Emit embedding cost metric."""
        metric = {
            '_aws': {
                'CloudWatchMetrics': [{
                    'Namespace': 'RAG/ETL',
                    'Metrics': [{
                        'MetricName': 'EmbeddingCost',
                        'Unit': 'None'
                    }],
                    'Dimensions': [['Model', 'Dimensions']]
                }],
                'Timestamp': int(datetime.now().timestamp() * 1000)
            },
            'DocumentId': document_id,
            'Model': model,
            'Dimensions': dimensions,
            'EmbeddingCost': cost_usd
        }
        print(json.dumps(metric))
    
    @staticmethod
    def store_provenance_metadata(document_id: str, metadata: Dict[str, Any]):
        """Store provenance metadata in OpenSearch."""
        provenance = {
            'document_id': document_id,
            'source_url': metadata.get('source_url'),
            'checksum': metadata.get('checksum'),
            'embedding_model': metadata.get('embedding_model'),
            'model_version': metadata.get('model_version'),
            'processing_timestamp': datetime.now().isoformat(),
            'processing_pipeline_version': metadata.get('pipeline_version', '1.0.0'),
            'datasource': metadata.get('datasource'),
            'extraction_method': metadata.get('extraction_method'),
            'chunk_count': metadata.get('chunk_count', 0)
        }
        
        # Index to OpenSearch with TTL for compliance
        # This allows tracing RAG responses back to source documents
        return provenance

Advanced Patterns: Multi-Tenancy & Security

For SaaS providers serving multiple customers, tenant isolation is non-negotiable. We implement row-level security in OpenSearch using document-level access control (DLAC) or route each tenant to a separate index. The latter approach scales better but requires index management automation. For compliance-heavy industries (healthcare, finance), we add encryption at rest using KMS customer-managed keys and enable VPC endpoints to ensure traffic never leaves AWS's network.

Security best practices include: (1) IAM roles with least-privilege access, (2) Secrets Manager for API keys and database credentials, (3) VPC endpoints for all AWS service calls to avoid internet egress, (4) CloudTrail logging for all API calls, and (5) regular security audits using AWS Security Hub. We also implement data loss prevention (DLP) scanning using Amazon Macie to detect sensitive data before indexing.

import boto3
from typing import Dict, List

kms = boto3.client('kms')
secrets = boto3.client('secretsmanager')
macie = boto3.client('macie2')

class SecureRAGPipeline:
    """Security-hardened RAG pipeline with DLP and encryption."""
    
    def __init__(self, kms_key_id: str):
        self.kms_key_id = kms_key_id
    
    def scan_for_sensitive_data(self, s3_key: str, bucket: str) -> Dict:
        """Scan document for PII/PHI before processing."""
        # Use Macie to detect sensitive data
        response = macie.create_classification_job(
            jobType='ONE_TIME',
            s3JobDefinition={
                'bucketDefinitions': [{
                    'accountId': boto3.client('sts').get_caller_identity()['Account'],
                    'buckets': [bucket]
                }],
                'scoping': {
                    'includes': {
                        'and': [{
                            'simpleScopeTerm': {
                                'key': 'OBJECT_KEY',
                                'values': [s3_key]
                            }
                        }]
                    }
                }
            },
            name=f'dlp-scan-{s3_key.replace("/", "-")}'
        )
        
        return {
            'job_id': response['jobId'],
            'status': 'PENDING'
        }
    
    def encrypt_embedding(self, embedding: List[float], tenant_id: str) -> bytes:
        """Encrypt embedding using tenant-specific KMS key."""
        response = kms.encrypt(
            KeyId=f'alias/rag-tenant-{tenant_id}',
            Plaintext=json.dumps(embedding).encode('utf-8')
        )
        return response['CiphertextBlob']
    
    def get_tenant_credentials(self, tenant_id: str) -> Dict:
        """Retrieve tenant-specific credentials from Secrets Manager."""
        secret_name = f'rag/tenant/{tenant_id}/credentials'
        response = secrets.get_secret_value(SecretId=secret_name)
        return json.loads(response['SecretString'])

Performance Tuning & Monitoring

Production RAG pipelines require careful performance tuning. Key metrics to monitor: (1) end-to-end latency (target: <5 minutes for 10MB document), (2) embedding generation rate (target: >100 documents/minute), (3) OpenSearch indexing throughput (target: >1000 docs/second), and (4) error rate (target: <0.1%). We use CloudWatch dashboards with automated alarms that trigger PagerDuty alerts when thresholds are breached.

Common bottlenecks: (1) Textract processing for large PDFs (mitigate with async processing and S3 event notifications), (2) SageMaker endpoint cold starts (mitigate with provisioned concurrency or warm-up scripts), (3) OpenSearch indexing bottlenecks (mitigate with bulk API and proper sharding strategy), and (4) Lambda memory limits (right-size based on CloudWatch Insights queries showing actual memory usage).

    Thomas Vignoli - Senior DevOps Engineer Portfolio