- Published on
Building Production AI Systems: Architecture Patterns for Senior Engineers
- Authors
- Name
- Gary Huynh
- @gary_atruedev
From POC to Production: The AI Architecture Challenge
After 18 months of building production AI systems, I've learned one critical lesson: the gap between a working AI prototype and a production-ready system is massive. While getting ChatGPT to answer questions is trivial, building a system that serves millions of requests daily with sub-second latency, 99.99% uptime, and predictable costs? That's where real engineering begins.
This article is for senior architects and engineers who need to build AI systems that actually work at scale. We'll cover the architecture patterns, trade-offs, and battle-tested solutions I've implemented across multiple production deployments.
What You'll Learn
๐ Distributed AI Architectures
- Model serving patterns for different latency requirements
- Load balancing strategies for GPU clusters
- Failover and redundancy patterns
๐ Feature Store Architecture
- Real-time vs batch feature engineering
- Feature consistency across training and serving
- Storage optimization strategies
๐งช A/B Testing for ML Models
- Safe rollout strategies for new models
- Statistical significance in ML experiments
- Avoiding common experimentation pitfalls
๐ฐ Cost Optimization at Scale
- Token management strategies
- Intelligent caching patterns
- Multi-tier serving architectures
๐ก๏ธ Production Resilience
- Circuit breakers for AI services
- Graceful degradation patterns
- Monitoring and observability
Quick Architecture Reference
Choose Your Pattern:
- Real-time predictions (< 100ms) โ Online Serving Pattern
- Batch processing (hours-days) โ Batch Serving Pattern
- Streaming data (< 1s) โ Streaming Pattern
- Cost-sensitive workloads โ Cost Optimization
- High availability requirements โ Circuit Breakers
The Production AI Stack
Before diving into specific patterns, let's understand the full production AI stack:
// Production AI System Components
public class ProductionAIStack {
// Layer 1: Infrastructure
private final LoadBalancer modelLoadBalancer;
private final CircuitBreaker aiCircuitBreaker;
private final MetricsCollector metricsCollector;
// Layer 2: Model Management
private final ModelRegistry modelRegistry;
private final ModelVersionController versionController;
private final FeatureStore featureStore;
// Layer 3: Serving Layer
private final ModelServingCluster servingCluster;
private final RequestRouter requestRouter;
private final ResponseCache responseCache;
// Layer 4: Application Layer
private final AIOrchestrator orchestrator;
private final FallbackHandler fallbackHandler;
private final CostOptimizer costOptimizer;
}
Distributed AI Architectures
The Model Serving Challenge
In production, you're not serving one model - you're serving multiple models, multiple versions, across multiple regions. Here's how to architect for this reality:
@Component
public class DistributedModelServingArchitecture {
private final ConsistentHashRing modelHashRing;
private final HealthChecker healthChecker;
private final LoadBalancerStrategy loadBalancer;
// Distributed model serving with consistent hashing
public ModelServingNode selectServingNode(String modelId, String requestId) {
// Use consistent hashing for sticky sessions
int hash = HashUtils.murmur3(modelId + requestId);
ModelServingNode primaryNode = modelHashRing.getNode(hash);
// Check node health
if (healthChecker.isHealthy(primaryNode)) {
return primaryNode;
}
// Failover to secondary node
ModelServingNode secondaryNode = modelHashRing.getNextNode(hash);
if (healthChecker.isHealthy(secondaryNode)) {
metricsCollector.recordFailover(modelId, primaryNode, secondaryNode);
return secondaryNode;
}
// Circuit breaker pattern
throw new NoHealthyNodesException("No healthy nodes available for model: " + modelId);
}
// Load balancing strategies
public enum LoadBalancerStrategy {
ROUND_ROBIN {
@Override
public ModelServingNode selectNode(List<ModelServingNode> nodes) {
return nodes.get(counter.getAndIncrement() % nodes.size());
}
},
LEAST_CONNECTIONS {
@Override
public ModelServingNode selectNode(List<ModelServingNode> nodes) {
return nodes.stream()
.min(Comparator.comparing(node -> node.getActiveConnections()))
.orElseThrow();
}
},
WEIGHTED_RESPONSE_TIME {
@Override
public ModelServingNode selectNode(List<ModelServingNode> nodes) {
// Select based on p99 response times
double totalWeight = nodes.stream()
.mapToDouble(node -> 1.0 / node.getP99ResponseTime())
.sum();
double random = Math.random() * totalWeight;
double weightSum = 0;
for (ModelServingNode node : nodes) {
weightSum += 1.0 / node.getP99ResponseTime();
if (random <= weightSum) {
return node;
}
}
return nodes.get(nodes.size() - 1);
}
}
}
}
GPU Cluster Management
For heavy AI workloads, GPU management becomes critical:
@Service
public class GPUClusterManager {
private final GPUResourcePool resourcePool;
private final PriorityQueue<GPURequest> requestQueue;
private final ScheduledExecutorService scheduler;
@PostConstruct
public void initialize() {
// Initialize GPU monitoring
scheduler.scheduleAtFixedRate(this::monitorGPUHealth, 0, 30, TimeUnit.SECONDS);
scheduler.scheduleAtFixedRate(this::rebalanceWorkloads, 0, 5, TimeUnit.MINUTES);
}
public CompletableFuture<ModelInferenceResult> submitInferenceRequest(
ModelInferenceRequest request) {
// Priority-based scheduling
Priority priority = calculatePriority(request);
GPURequest gpuRequest = new GPURequest(request, priority);
// Try immediate allocation
Optional<GPUResource> availableGPU = resourcePool.tryAcquire(
request.getRequiredMemoryGB(),
request.getRequiredComputeUnits()
);
if (availableGPU.isPresent()) {
return executeOnGPU(availableGPU.get(), request);
}
// Queue for later execution
CompletableFuture<ModelInferenceResult> future = new CompletableFuture<>();
gpuRequest.setFuture(future);
requestQueue.offer(gpuRequest);
// Set timeout
scheduler.schedule(() -> {
if (!future.isDone()) {
future.completeExceptionally(
new TimeoutException("GPU allocation timeout")
);
requestQueue.remove(gpuRequest);
}
}, request.getTimeoutSeconds(), TimeUnit.SECONDS);
return future;
}
private void rebalanceWorkloads() {
// Analyze GPU utilization
Map<GPUResource, Double> utilizationMap = resourcePool.getUtilizationMap();
// Identify imbalanced GPUs
double avgUtilization = utilizationMap.values().stream()
.mapToDouble(Double::doubleValue)
.average()
.orElse(0.0);
utilizationMap.forEach((gpu, utilization) -> {
if (utilization > avgUtilization * 1.3) {
// GPU is overloaded, migrate workloads
migrateWorkloadsFrom(gpu);
} else if (utilization < avgUtilization * 0.7) {
// GPU is underutilized, accept more workloads
acceptWorkloadsTo(gpu);
}
});
}
}
Feature Store Architecture
Real-time vs Batch Features
A production feature store must handle both real-time and batch features efficiently:
@Component
public class HybridFeatureStore {
private final RedisTemplate<String, Feature> realtimeStore;
private final ParquetFeatureRepository batchStore;
private final FeatureRegistry registry;
private final FeatureComputer featureComputer;
// Unified feature retrieval interface
public FeatureVector getFeatures(String entityId, List<String> featureNames) {
FeatureVector.Builder builder = FeatureVector.builder();
// Parallel feature retrieval
CompletableFuture<?>[] futures = featureNames.stream()
.map(featureName -> retrieveFeatureAsync(entityId, featureName, builder))
.toArray(CompletableFuture[]::new);
// Wait for all features with timeout
try {
CompletableFuture.allOf(futures).get(100, TimeUnit.MILLISECONDS);
} catch (TimeoutException e) {
logger.warn("Feature retrieval timeout for entity: {}", entityId);
// Continue with partial features
}
return builder.build();
}
private CompletableFuture<Void> retrieveFeatureAsync(
String entityId,
String featureName,
FeatureVector.Builder builder) {
return CompletableFuture.runAsync(() -> {
FeatureMetadata metadata = registry.getMetadata(featureName);
switch (metadata.getType()) {
case REALTIME:
retrieveRealtimeFeature(entityId, featureName, builder);
break;
case BATCH:
retrieveBatchFeature(entityId, featureName, builder);
break;
case ON_DEMAND:
computeOnDemandFeature(entityId, featureName, builder);
break;
}
});
}
// Real-time feature with write-through cache
public void updateRealtimeFeature(String entityId, String featureName, Object value) {
String key = buildKey(entityId, featureName);
Feature feature = Feature.builder()
.name(featureName)
.value(value)
.timestamp(System.currentTimeMillis())
.build();
// Write to Redis with TTL
realtimeStore.opsForValue().set(key, feature, 1, TimeUnit.HOURS);
// Async write to persistent store
CompletableFuture.runAsync(() -> {
batchStore.persistFeature(entityId, feature);
});
// Emit feature update event
eventBus.publish(new FeatureUpdateEvent(entityId, featureName, value));
}
// Feature versioning and lineage
@Entity
public class FeatureLineage {
@Id
private String id;
private String featureName;
private String version;
private String computationLogic;
private List<String> dependencies;
private LocalDateTime createdAt;
private Map<String, Object> metadata;
// Track feature computation DAG
public boolean isCompatibleWith(String otherVersion) {
// Version compatibility logic
return SemanticVersion.parse(version)
.isCompatibleWith(SemanticVersion.parse(otherVersion));
}
}
}
Feature Store Performance Optimization
@Configuration
public class FeatureStoreOptimization {
// Column-oriented storage for batch features
@Bean
public ParquetFeatureStore batchFeatureStore() {
return ParquetFeatureStore.builder()
.compressionCodec(CompressionCodecName.SNAPPY)
.blockSize(128 * 1024 * 1024) // 128MB blocks
.pageSize(1024 * 1024) // 1MB pages
.enableDictionary(true)
.enableBloomFilter(true)
.build();
}
// Optimized feature retrieval with projection pushdown
public class OptimizedFeatureRetriever {
public FeatureDataset retrieveFeatures(FeatureQuery query) {
// Build optimized query plan
QueryPlan plan = QueryPlan.builder()
.projection(query.getRequestedFeatures())
.predicate(query.getFilterPredicate())
.enableColumnPruning(true)
.enablePredicatePushdown(true)
.build();
// Execute with partition pruning
List<Partition> relevantPartitions = partitionPruner.prune(
query.getTimeRange(),
query.getEntityIds()
);
// Parallel partition reading
return relevantPartitions.parallelStream()
.map(partition -> readPartition(partition, plan))
.reduce(FeatureDataset::merge)
.orElse(FeatureDataset.empty());
}
// Feature materialization optimization
public void materializeFeatures(MaterializationJob job) {
// Incremental materialization
Instant lastMaterialization = job.getLastRunTimestamp();
Instant now = Instant.now();
// Compute only changed features
Set<String> changedEntities = changeDetector.getChangedEntities(
lastMaterialization,
now
);
if (changedEntities.isEmpty()) {
logger.info("No changes detected, skipping materialization");
return;
}
// Batch process with optimal batch size
int optimalBatchSize = calculateOptimalBatchSize(changedEntities.size());
Lists.partition(new ArrayList<>(changedEntities), optimalBatchSize)
.parallelStream()
.forEach(batch -> materializeBatch(batch, job));
}
}
}
Model Serving Patterns
Online Serving Pattern
For real-time inference with strict SLA requirements:
@Service
public class OnlineModelServingService {
private final ModelCache modelCache;
private final RequestBatcher requestBatcher;
private final LatencyMonitor latencyMonitor;
// Optimized online serving with batching
public CompletableFuture<PredictionResult> predict(PredictionRequest request) {
// Start latency tracking
long startTime = System.nanoTime();
// Check cache first
Optional<PredictionResult> cached = checkCache(request);
if (cached.isPresent()) {
latencyMonitor.recordCacheHit(System.nanoTime() - startTime);
return CompletableFuture.completedFuture(cached.get());
}
// Add to batch
CompletableFuture<PredictionResult> future = new CompletableFuture<>();
requestBatcher.addRequest(request, future);
// Set timeout for SLA compliance
ScheduledFuture<?> timeout = scheduler.schedule(() -> {
if (!future.isDone()) {
future.completeExceptionally(new SLAViolationException("Prediction timeout"));
requestBatcher.removeRequest(request);
}
}, 50, TimeUnit.MILLISECONDS); // 50ms SLA
// Cancel timeout on completion
future.whenComplete((result, error) -> {
timeout.cancel(false);
latencyMonitor.recordPrediction(System.nanoTime() - startTime);
});
return future;
}
// Dynamic batching for optimal throughput
@Component
public class DynamicRequestBatcher {
private final Queue<BatchRequest> pendingRequests = new ConcurrentLinkedQueue<>();
private final AtomicInteger batchSize = new AtomicInteger(32);
private final ScheduledExecutorService executor;
@PostConstruct
public void init() {
// Adaptive batch processing
executor.scheduleWithFixedDelay(this::processBatch, 0, 5, TimeUnit.MILLISECONDS);
// Dynamic batch size adjustment
executor.scheduleAtFixedRate(this::adjustBatchSize, 0, 1, TimeUnit.SECONDS);
}
private void processBatch() {
List<BatchRequest> batch = new ArrayList<>();
BatchRequest request;
// Collect requests up to batch size
while ((request = pendingRequests.poll()) != null &&
batch.size() < batchSize.get()) {
batch.add(request);
}
if (!batch.isEmpty()) {
executeBatchPrediction(batch);
}
}
private void adjustBatchSize() {
// Monitor queue depth and latency
int queueDepth = pendingRequests.size();
double avgLatency = latencyMonitor.getAvgLatency();
if (queueDepth > 1000 && avgLatency < 30) {
// Increase batch size to improve throughput
batchSize.updateAndGet(size -> Math.min(size * 2, 256));
} else if (avgLatency > 45) {
// Decrease batch size to reduce latency
batchSize.updateAndGet(size -> Math.max(size / 2, 8));
}
metrics.gauge("model.serving.batch.size", batchSize.get());
}
}
}
Batch Serving Pattern
For high-throughput offline processing:
@Component
public class BatchModelServingService {
private final SparkSession spark;
private final ModelBroadcastManager broadcastManager;
private final CheckpointManager checkpointManager;
// Distributed batch inference with Spark
public void runBatchInference(BatchInferenceJob job) {
// Load and broadcast model
ModelArtifact model = loadModel(job.getModelId(), job.getModelVersion());
Broadcast<ModelArtifact> broadcastModel = spark.sparkContext()
.broadcast(model, classTag(ModelArtifact.class));
// Configure optimal partitioning
int optimalPartitions = calculateOptimalPartitions(job);
// Load dataset with partition optimization
Dataset<Row> inputData = spark.read()
.option("mergeSchema", "false")
.option("pushdown", "true")
.parquet(job.getInputPath())
.repartition(optimalPartitions);
// Prepare features
Dataset<Row> features = prepareFeatures(inputData, job.getFeatureConfig());
// Run inference with checkpointing
Dataset<Row> predictions = features
.mapPartitions(
new PartitionInferenceFunction(broadcastModel, job.getConfig()),
RowEncoder.apply(getPredictionSchema())
)
.checkpoint(); // Enable fault tolerance
// Write results with optimization
predictions.write()
.mode(SaveMode.Overwrite)
.option("compression", "snappy")
.partitionBy(job.getPartitionColumns())
.parquet(job.getOutputPath());
// Clean up broadcast variable
broadcastModel.destroy();
// Update job metrics
updateJobMetrics(job, predictions);
}
// Partition-level inference for efficiency
static class PartitionInferenceFunction
implements MapPartitionsFunction<Row, Row> {
private final Broadcast<ModelArtifact> broadcastModel;
private final InferenceConfig config;
private transient ThreadLocal<ModelRuntime> modelRuntime;
@Override
public Iterator<Row> call(Iterator<Row> partition) {
// Initialize model once per partition
if (modelRuntime == null) {
modelRuntime = ThreadLocal.withInitial(() ->
createRuntime(broadcastModel.value(), config)
);
}
ModelRuntime runtime = modelRuntime.get();
List<Row> results = new ArrayList<>();
// Process in mini-batches for memory efficiency
List<Row> miniBatch = new ArrayList<>(config.getMiniBatchSize());
while (partition.hasNext()) {
miniBatch.add(partition.next());
if (miniBatch.size() >= config.getMiniBatchSize() || !partition.hasNext()) {
List<Row> predictions = runtime.predict(miniBatch);
results.addAll(predictions);
miniBatch.clear();
}
}
return results.iterator();
}
}
// Optimal partition calculation
private int calculateOptimalPartitions(BatchInferenceJob job) {
long inputSizeBytes = getDatasetSize(job.getInputPath());
int availableExecutors = spark.sparkContext().getExecutorIds().size();
int coresPerExecutor = spark.conf().getInt("spark.executor.cores", 4);
// Target partition size: 128MB
long targetPartitionSize = 128 * 1024 * 1024;
int dataPartitions = (int) Math.ceil(inputSizeBytes / (double) targetPartitionSize);
// Consider parallelism
int parallelismPartitions = availableExecutors * coresPerExecutor * 2;
return Math.max(dataPartitions, parallelismPartitions);
}
}
Streaming Serving Pattern
For real-time stream processing:
@Service
public class StreamingModelServingService {
private final KafkaStreams streams;
private final ModelStateStore modelStateStore;
private final WindowedAggregator aggregator;
// Kafka Streams topology for streaming inference
public Topology buildStreamingTopology() {
StreamsBuilder builder = new StreamsBuilder();
// Input stream with exactly-once semantics
KStream<String, InputEvent> inputStream = builder.stream(
"input-events",
Consumed.with(Serdes.String(), inputEventSerde)
.withTimestampExtractor(new EventTimeExtractor())
);
// Feature enrichment from state store
KStream<String, EnrichedEvent> enrichedStream = inputStream
.transformValues(
() -> new FeatureEnrichmentTransformer(featureStoreClient),
"feature-state-store"
);
// Windowed aggregation for feature computation
KTable<Windowed<String>, FeatureAggregate> windowedFeatures = enrichedStream
.groupByKey()
.windowedBy(TimeWindows.of(Duration.ofMinutes(5)).grace(Duration.ofMinutes(1)))
.aggregate(
FeatureAggregate::new,
(key, value, aggregate) -> aggregate.update(value),
Materialized.<String, FeatureAggregate>as("windowed-features-store")
.withKeySerde(Serdes.String())
.withValueSerde(featureAggregateSerde)
);
// Model inference with state management
KStream<String, PredictionEvent> predictions = enrichedStream
.join(
windowedFeatures.toStream(),
(enrichedEvent, featureAggregate) ->
ModelInput.builder()
.event(enrichedEvent)
.aggregateFeatures(featureAggregate)
.build(),
JoinWindows.of(Duration.ofSeconds(10)),
StreamJoined.with(Serdes.String(), enrichedEventSerde, featureAggregateSerde)
)
.transformValues(
() -> new ModelInferenceTransformer(modelStateStore),
"model-state-store"
);
// Output with error handling
predictions
.split()
.branch(
(key, prediction) -> prediction.isSuccessful(),
Branched.withConsumer(stream ->
stream.to("prediction-results", Produced.with(Serdes.String(), predictionSerde))
)
)
.branch(
(key, prediction) -> !prediction.isSuccessful(),
Branched.withConsumer(stream ->
stream.to("prediction-errors", Produced.with(Serdes.String(), predictionSerde))
)
);
return builder.build();
}
// Stateful model inference transformer
static class ModelInferenceTransformer
implements ValueTransformerWithKey<String, ModelInput, PredictionEvent> {
private ProcessorContext context;
private KeyValueStore<String, ModelState> stateStore;
private final ModelRuntime runtime;
@Override
public void init(ProcessorContext context) {
this.context = context;
this.stateStore = context.getStateStore("model-state-store");
// Schedule periodic model updates
context.schedule(
Duration.ofMinutes(5),
PunctuationType.WALL_CLOCK_TIME,
this::updateModel
);
}
@Override
public PredictionEvent transform(String key, ModelInput input) {
try {
// Get or initialize state
ModelState state = stateStore.get(key);
if (state == null) {
state = ModelState.initialize();
}
// Run inference with state
PredictionResult result = runtime.predictWithState(input, state);
// Update state
state.update(result);
stateStore.put(key, state);
// Emit metrics
context.metrics().sensor("inference-latency")
.record(result.getLatencyMs());
return PredictionEvent.success(key, result);
} catch (Exception e) {
logger.error("Inference failed for key: {}", key, e);
return PredictionEvent.failure(key, e.getMessage());
}
}
private void updateModel(long timestamp) {
// Check for model updates
Optional<ModelVersion> newVersion = modelRegistry.getLatestVersion();
if (newVersion.isPresent() && !newVersion.equals(runtime.getVersion())) {
logger.info("Updating model to version: {}", newVersion.get());
runtime.updateModel(newVersion.get());
}
}
}
}
A/B Testing Framework for ML Models
Experimentation Infrastructure
@Component
public class MLExperimentationFramework {
private final ExperimentRegistry registry;
private final TrafficSplitter trafficSplitter;
private final MetricsCollector metricsCollector;
private final StatisticalAnalyzer analyzer;
// Define and configure experiments
public Experiment createExperiment(ExperimentConfig config) {
// Validate experiment configuration
validateExperimentConfig(config);
// Create experiment with proper isolation
Experiment experiment = Experiment.builder()
.id(generateExperimentId())
.name(config.getName())
.hypothesis(config.getHypothesis())
.startTime(Instant.now())
.endTime(Instant.now().plus(config.getDuration()))
.trafficAllocation(config.getTrafficAllocation())
.successMetrics(config.getSuccessMetrics())
.guardrailMetrics(config.getGuardrailMetrics())
.build();
// Register experiment
registry.register(experiment);
// Initialize metric collection
metricsCollector.initializeExperiment(experiment);
return experiment;
}
// Traffic splitting with consistent assignment
public ModelVersion assignModelVersion(String userId, String experimentId) {
Experiment experiment = registry.getExperiment(experimentId);
// Check if user is already assigned
Optional<Assignment> existingAssignment =
assignmentStore.getAssignment(userId, experimentId);
if (existingAssignment.isPresent()) {
return existingAssignment.get().getModelVersion();
}
// Consistent hash-based assignment
double hashValue = consistentHash(userId + experimentId) / (double) Long.MAX_VALUE;
// Determine variant based on traffic split
ModelVersion assignedVersion = null;
double cumulativeProbability = 0.0;
for (ExperimentVariant variant : experiment.getVariants()) {
cumulativeProbability += variant.getTrafficPercentage();
if (hashValue <= cumulativeProbability) {
assignedVersion = variant.getModelVersion();
break;
}
}
// Store assignment
Assignment assignment = new Assignment(userId, experimentId, assignedVersion);
assignmentStore.save(assignment);
// Emit assignment event
eventBus.publish(new UserAssignmentEvent(userId, experimentId, assignedVersion));
return assignedVersion;
}
// Real-time experiment monitoring
@Scheduled(fixedDelay = 60000) // Every minute
public void monitorExperiments() {
List<Experiment> activeExperiments = registry.getActiveExperiments();
for (Experiment experiment : activeExperiments) {
ExperimentMetrics metrics = metricsCollector.getMetrics(experiment.getId());
// Check guardrail metrics
for (GuardrailMetric guardrail : experiment.getGuardrailMetrics()) {
if (isGuardrailViolated(guardrail, metrics)) {
logger.error("Guardrail violated for experiment: {}", experiment.getId());
stopExperiment(experiment, "Guardrail violation: " + guardrail.getName());
break;
}
}
// Check for statistical significance
if (hasEnoughData(metrics)) {
StatisticalResult result = analyzer.analyze(experiment, metrics);
if (result.isSignificant()) {
logger.info("Experiment {} reached statistical significance",
experiment.getId());
if (result.getWinner().equals(experiment.getTreatmentVersion())) {
// Treatment wins
graduateExperiment(experiment);
} else {
// Control wins
concludeExperiment(experiment, "Control performed better");
}
}
}
}
}
// Statistical analysis for experiments
@Component
public class BayesianStatisticalAnalyzer implements StatisticalAnalyzer {
public StatisticalResult analyze(Experiment experiment, ExperimentMetrics metrics) {
// Get conversion data for each variant
Map<ModelVersion, ConversionData> variantData =
metrics.getConversionDataByVariant();
// Bayesian A/B testing
double alphaPrior = 1.0;
double betaPrior = 1.0;
// Calculate posterior distributions
Map<ModelVersion, BetaDistribution> posteriors = new HashMap<>();
for (Map.Entry<ModelVersion, ConversionData> entry : variantData.entrySet()) {
ConversionData data = entry.getValue();
double alpha = alphaPrior + data.getConversions();
double beta = betaPrior + data.getTrials() - data.getConversions();
posteriors.put(entry.getKey(), new BetaDistribution(alpha, beta));
}
// Calculate probability of each variant being best
Map<ModelVersion, Double> winProbabilities =
calculateWinProbabilities(posteriors);
// Determine if we have a clear winner
ModelVersion bestVariant = winProbabilities.entrySet().stream()
.max(Map.Entry.comparingByValue())
.map(Map.Entry::getKey)
.orElseThrow();
double bestProbability = winProbabilities.get(bestVariant);
// Check for practical significance
double lift = calculateLift(variantData, experiment.getControlVersion(), bestVariant);
return StatisticalResult.builder()
.isSignificant(bestProbability > 0.95)
.winner(bestVariant)
.winProbability(bestProbability)
.lift(lift)
.confidenceInterval(calculateConfidenceInterval(variantData, bestVariant))
.build();
}
private Map<ModelVersion, Double> calculateWinProbabilities(
Map<ModelVersion, BetaDistribution> posteriors) {
int numSamples = 10000;
Map<ModelVersion, Integer> winCounts = new HashMap<>();
// Monte Carlo simulation
for (int i = 0; i < numSamples; i++) {
ModelVersion winner = null;
double maxSample = -1;
for (Map.Entry<ModelVersion, BetaDistribution> entry : posteriors.entrySet()) {
double sample = entry.getValue().sample();
if (sample > maxSample) {
maxSample = sample;
winner = entry.getKey();
}
}
winCounts.merge(winner, 1, Integer::sum);
}
// Convert to probabilities
Map<ModelVersion, Double> probabilities = new HashMap<>();
for (Map.Entry<ModelVersion, Integer> entry : winCounts.entrySet()) {
probabilities.put(entry.getKey(), entry.getValue() / (double) numSamples);
}
return probabilities;
}
}
}
Circuit Breakers and Fallback Strategies
Resilient AI Systems
@Component
public class AICircuitBreaker {
private final CircuitBreakerRegistry circuitBreakerRegistry;
private final RetryRegistry retryRegistry;
private final BulkheadRegistry bulkheadRegistry;
private final TimeLimiterRegistry timeLimiterRegistry;
@PostConstruct
public void initialize() {
// Configure circuit breaker for each AI service
CircuitBreakerConfig config = CircuitBreakerConfig.custom()
.failureRateThreshold(50) // Open circuit if 50% of requests fail
.waitDurationInOpenState(Duration.ofSeconds(30))
.slowCallRateThreshold(50) // Consider calls slow if 50% exceed threshold
.slowCallDurationThreshold(Duration.ofSeconds(2))
.permittedNumberOfCallsInHalfOpenState(10)
.slidingWindowSize(100)
.slidingWindowType(SlidingWindowType.COUNT_BASED)
.recordExceptions(IOException.class, TimeoutException.class)
.ignoreExceptions(BusinessException.class)
.build();
circuitBreakerRegistry.addConfiguration("ai-service", config);
}
// Wrapped AI service call with full resilience
public CompletableFuture<AIResponse> callAIService(AIRequest request) {
String serviceName = request.getServiceName();
// Get or create circuit breaker
CircuitBreaker circuitBreaker = circuitBreakerRegistry
.circuitBreaker(serviceName, "ai-service");
// Get or create retry
Retry retry = retryRegistry.retry(serviceName, retryConfig ->
retryConfig.custom()
.maxAttempts(3)
.waitDuration(Duration.ofMillis(100))
.retryOnException(e -> isRetryable(e))
.retryExceptions(IOException.class, TimeoutException.class)
.ignoreExceptions(ValidationException.class)
.build()
);
// Get or create bulkhead for concurrency control
Bulkhead bulkhead = bulkheadRegistry.bulkhead(serviceName, bulkheadConfig ->
bulkheadConfig.custom()
.maxConcurrentCalls(50)
.maxWaitDuration(Duration.ofMillis(500))
.build()
);
// Get or create time limiter
TimeLimiter timeLimiter = timeLimiterRegistry.timeLimiter(serviceName,
TimeLimiterConfig.custom()
.timeoutDuration(Duration.ofSeconds(5))
.cancelRunningFuture(true)
.build()
);
// Decorate the supplier with all resilience patterns
Supplier<CompletableFuture<AIResponse>> decoratedSupplier =
Decorators.ofSupplier(() -> executeAICall(request))
.withCircuitBreaker(circuitBreaker)
.withRetry(retry)
.withBulkhead(bulkhead)
.decorate();
// Execute with time limiting
return timeLimiter.executeCompletionStage(
executorService,
decoratedSupplier
).toCompletableFuture()
.exceptionally(throwable -> {
// Fallback logic
return handleFallback(request, throwable);
});
}
// Intelligent fallback strategies
private AIResponse handleFallback(AIRequest request, Throwable throwable) {
logger.warn("AI service call failed, executing fallback", throwable);
// Determine fallback strategy based on request type
FallbackStrategy strategy = determineFallbackStrategy(request);
switch (strategy) {
case CACHED_RESPONSE:
return getCachedResponse(request);
case SIMPLIFIED_MODEL:
return useSimplifiedModel(request);
case RULE_BASED:
return applyRuleBasedLogic(request);
case GRACEFUL_DEGRADATION:
return provideGracefulDegradation(request);
case QUEUE_FOR_RETRY:
queueForAsyncRetry(request);
return AIResponse.pending(request.getId());
default:
return AIResponse.error("Service temporarily unavailable");
}
}
// Adaptive fallback selection
private FallbackStrategy determineFallbackStrategy(AIRequest request) {
// Check if we have a recent cached response
if (responseCache.hasRecentResponse(request, Duration.ofMinutes(5))) {
return FallbackStrategy.CACHED_RESPONSE;
}
// Check if simplified model is available
if (request.acceptsSimplifiedModel() && simplifiedModelAvailable()) {
return FallbackStrategy.SIMPLIFIED_MODEL;
}
// Check if rule-based logic can handle this request
if (ruleEngine.canHandle(request.getType())) {
return FallbackStrategy.RULE_BASED;
}
// For non-critical requests, queue for retry
if (!request.isCritical()) {
return FallbackStrategy.QUEUE_FOR_RETRY;
}
// Default to graceful degradation
return FallbackStrategy.GRACEFUL_DEGRADATION;
}
// Health monitoring and auto-recovery
@Component
public class AIHealthMonitor {
private final Map<String, ServiceHealth> serviceHealthMap = new ConcurrentHashMap<>();
@Scheduled(fixedDelay = 10000) // Every 10 seconds
public void monitorHealth() {
for (String serviceName : getMonitoredServices()) {
CompletableFuture.runAsync(() -> checkServiceHealth(serviceName))
.orTimeout(5, TimeUnit.SECONDS)
.exceptionally(throwable -> {
updateServiceHealth(serviceName, ServiceHealth.UNKNOWN);
return null;
});
}
}
private void checkServiceHealth(String serviceName) {
try {
// Perform health check
HealthCheckResult result = performHealthCheck(serviceName);
// Update circuit breaker state based on health
if (result.isHealthy()) {
// Try to close circuit if it's open
CircuitBreaker cb = circuitBreakerRegistry.circuitBreaker(serviceName);
if (cb.getState() == CircuitBreaker.State.OPEN) {
cb.transitionToHalfOpenState();
}
updateServiceHealth(serviceName, ServiceHealth.HEALTHY);
} else {
updateServiceHealth(serviceName, ServiceHealth.UNHEALTHY);
}
} catch (Exception e) {
logger.error("Health check failed for service: {}", serviceName, e);
updateServiceHealth(serviceName, ServiceHealth.UNHEALTHY);
}
}
private void updateServiceHealth(String serviceName, ServiceHealth health) {
ServiceHealth previousHealth = serviceHealthMap.put(serviceName, health);
if (previousHealth != health) {
// Emit health change event
eventBus.publish(new ServiceHealthChangeEvent(serviceName, previousHealth, health));
// Update metrics
metrics.gauge("ai.service.health", health.getValue(),
"service", serviceName);
}
}
}
}
Cost Optimization Patterns
Multi-Tier Model Serving
@Service
public class CostOptimizedModelServing {
private final ModelTierManager tierManager;
private final CostCalculator costCalculator;
private final PerformanceMonitor performanceMonitor;
// Tiered model serving based on request characteristics
public AIResponse serveRequest(AIRequest request) {
// Analyze request to determine appropriate tier
ModelTier tier = determineTier(request);
// Track cost
CostTracker costTracker = new CostTracker(request.getId());
try {
switch (tier) {
case EDGE:
return serveFromEdge(request, costTracker);
case SMALL_GPU:
return serveFromSmallGPU(request, costTracker);
case LARGE_GPU:
return serveFromLargeGPU(request, costTracker);
case SPECIALIZED:
return serveFromSpecialized(request, costTracker);
default:
throw new UnsupportedOperationException("Unknown tier: " + tier);
}
} finally {
// Record cost metrics
costTracker.recordCost();
}
}
private ModelTier determineTier(AIRequest request) {
// Rule-based tier selection
if (request.getMaxLatencyMs() < 10) {
return ModelTier.EDGE;
}
if (request.getComplexity() == Complexity.LOW &&
request.getMaxCost() < 0.001) {
return ModelTier.EDGE;
}
if (request.getComplexity() == Complexity.MEDIUM) {
return ModelTier.SMALL_GPU;
}
if (request.requiresSpecializedModel()) {
return ModelTier.SPECIALIZED;
}
return ModelTier.LARGE_GPU;
}
// Cost-aware batch processing
@Component
public class CostAwareBatchProcessor {
private final PriorityQueue<BatchJob> jobQueue = new PriorityQueue<>(
Comparator.comparing(this::calculateCostEfficiency).reversed()
);
public void submitBatchJob(BatchJob job) {
// Calculate optimal batch size based on cost
int optimalBatchSize = calculateOptimalBatchSize(job);
job.setBatchSize(optimalBatchSize);
// Calculate optimal execution time (spot instances, off-peak)
Instant optimalExecutionTime = calculateOptimalExecutionTime(job);
job.setScheduledTime(optimalExecutionTime);
jobQueue.offer(job);
}
@Scheduled(fixedDelay = 60000) // Every minute
public void processJobs() {
Instant now = Instant.now();
List<BatchJob> readyJobs = new ArrayList<>();
// Collect jobs ready for execution
while (!jobQueue.isEmpty() &&
jobQueue.peek().getScheduledTime().isBefore(now)) {
readyJobs.add(jobQueue.poll());
}
if (!readyJobs.isEmpty()) {
// Group jobs by resource requirements
Map<ResourceProfile, List<BatchJob>> groupedJobs =
readyJobs.stream()
.collect(Collectors.groupingBy(this::getResourceProfile));
// Execute groups on appropriate resources
groupedJobs.forEach(this::executeJobGroup);
}
}
private void executeJobGroup(ResourceProfile profile, List<BatchJob> jobs) {
// Select most cost-effective resource
ComputeResource resource = selectResource(profile);
// Execute jobs with cost tracking
resource.execute(jobs, costTracker -> {
// Monitor cost in real-time
if (costTracker.getCurrentCost() > costTracker.getBudget() * 0.8) {
logger.warn("Approaching budget limit for job group");
// Implement cost control measures
implementCostControls(jobs);
}
});
}
private int calculateOptimalBatchSize(BatchJob job) {
// Balance between per-request cost and batch processing overhead
double fixedCost = 0.10; // Fixed cost per batch
double perItemCost = 0.001; // Cost per item
// Optimal batch size = sqrt(2 * fixedCost * totalItems / perItemCost)
int optimalSize = (int) Math.sqrt(
2 * fixedCost * job.getTotalItems() / perItemCost
);
// Apply constraints
return Math.max(
job.getMinBatchSize(),
Math.min(optimalSize, job.getMaxBatchSize())
);
}
}
// Spot instance management for cost optimization
@Component
public class SpotInstanceManager {
private final EC2Client ec2Client;
private final Map<String, SpotFleet> activeFleets = new ConcurrentHashMap<>();
public void provisionSpotFleet(WorkloadProfile workload) {
// Analyze spot price history
SpotPriceHistory priceHistory = ec2Client.describeSpotPriceHistory(
DescribeSpotPriceHistoryRequest.builder()
.instanceTypes(workload.getSuitableInstanceTypes())
.availabilityZones(getTargetZones())
.startTime(Instant.now().minus(Duration.ofDays(7)))
.build()
);
// Select optimal instance types and AZs
List<SpotFleetLaunchSpecification> launchSpecs =
optimizeSpotFleetConfiguration(priceHistory, workload);
// Create spot fleet request
RequestSpotFleetResponse response = ec2Client.requestSpotFleet(
RequestSpotFleetRequest.builder()
.spotFleetRequestConfig(
SpotFleetRequestConfig.builder()
.allocationStrategy(AllocationStrategy.CAPACITY_OPTIMIZED)
.targetCapacity(workload.getTargetCapacity())
.launchSpecifications(launchSpecs)
.terminateInstancesWithExpiration(true)
.type(FleetType.MAINTAIN)
.validFrom(Instant.now())
.validUntil(workload.getEndTime())
.build()
)
.build()
);
// Monitor fleet
SpotFleet fleet = new SpotFleet(response.spotFleetRequestId(), workload);
activeFleets.put(fleet.getId(), fleet);
// Set up interruption handling
setupInterruptionHandling(fleet);
}
private void setupInterruptionHandling(SpotFleet fleet) {
// Poll for interruption notices
scheduledExecutor.scheduleAtFixedRate(() -> {
checkForInterruptions(fleet);
}, 0, 5, TimeUnit.SECONDS);
}
private void checkForInterruptions(SpotFleet fleet) {
// Check CloudWatch Events for spot interruption warnings
List<SpotInterruptionWarning> warnings =
getSpotInterruptionWarnings(fleet.getId());
for (SpotInterruptionWarning warning : warnings) {
logger.warn("Spot interruption warning for instance: {}",
warning.getInstanceId());
// Gracefully drain workload
drainWorkloadFromInstance(warning.getInstanceId());
// Request replacement capacity
requestReplacementCapacity(fleet, warning.getInstanceId());
}
}
}
}
Observability and Monitoring
Comprehensive AI Observability
@Configuration
public class AIObservabilityConfiguration {
@Bean
public MeterRegistry aiMeterRegistry() {
return new CompositeMeterRegistry()
.add(new PrometheusMeterRegistry(PrometheusConfig.DEFAULT))
.add(new DatadogMeterRegistry(datadogConfig(), Clock.SYSTEM))
.add(new CloudWatchMeterRegistry(cloudWatchConfig(), Clock.SYSTEM));
}
@Bean
public AIMetricsCollector aiMetricsCollector(MeterRegistry registry) {
return new AIMetricsCollector(registry);
}
@Component
public class AIMetricsCollector {
private final MeterRegistry registry;
private final Map<String, Timer.Sample> activeRequests = new ConcurrentHashMap<>();
// Model inference metrics
public void recordInference(InferenceEvent event) {
// Latency by model and version
registry.timer("ai.inference.latency",
"model", event.getModelName(),
"version", event.getModelVersion(),
"status", event.getStatus().toString()
).record(event.getLatency(), TimeUnit.MILLISECONDS);
// Token usage for LLMs
if (event.getTokenUsage() != null) {
registry.counter("ai.tokens.input",
"model", event.getModelName()
).increment(event.getTokenUsage().getInputTokens());
registry.counter("ai.tokens.output",
"model", event.getModelName()
).increment(event.getTokenUsage().getOutputTokens());
}
// GPU utilization
if (event.getGpuMetrics() != null) {
registry.gauge("ai.gpu.utilization",
Tags.of("gpu_id", event.getGpuMetrics().getGpuId()),
event.getGpuMetrics().getUtilization()
);
registry.gauge("ai.gpu.memory.used",
Tags.of("gpu_id", event.getGpuMetrics().getGpuId()),
event.getGpuMetrics().getMemoryUsedMB()
);
}
// Model-specific metrics
recordModelSpecificMetrics(event);
}
// Feature engineering metrics
public void recordFeatureEngineering(FeatureEngineeringEvent event) {
// Feature computation time
registry.timer("ai.features.computation.time",
"feature_set", event.getFeatureSetName(),
"source", event.getSource().toString()
).record(event.getComputationTime(), TimeUnit.MILLISECONDS);
// Feature quality metrics
event.getFeatureQualityMetrics().forEach((featureName, quality) -> {
registry.gauge("ai.features.quality",
Tags.of("feature", featureName),
quality.getScore()
);
// Missing value rate
registry.gauge("ai.features.missing_rate",
Tags.of("feature", featureName),
quality.getMissingRate()
);
});
}
// Data pipeline metrics
public void recordDataPipeline(DataPipelineEvent event) {
// Pipeline stage latency
registry.timer("ai.pipeline.stage.latency",
"pipeline", event.getPipelineName(),
"stage", event.getStageName()
).record(event.getStageLatency(), TimeUnit.MILLISECONDS);
// Data quality checks
registry.counter("ai.pipeline.quality.checks",
"pipeline", event.getPipelineName(),
"status", event.getQualityCheckStatus().toString()
).increment();
// Schema violations
if (event.getSchemaViolations() > 0) {
registry.counter("ai.pipeline.schema.violations",
"pipeline", event.getPipelineName()
).increment(event.getSchemaViolations());
}
}
// Model drift detection
@Scheduled(fixedDelay = 300000) // Every 5 minutes
public void detectModelDrift() {
for (ModelDeployment deployment : getActiveDeployments()) {
DriftMetrics drift = calculateDrift(deployment);
// Feature drift
registry.gauge("ai.model.drift.features",
Tags.of("model", deployment.getModelName(),
"version", deployment.getVersion()),
drift.getFeatureDriftScore()
);
// Prediction drift
registry.gauge("ai.model.drift.predictions",
Tags.of("model", deployment.getModelName(),
"version", deployment.getVersion()),
drift.getPredictionDriftScore()
);
// Performance drift
registry.gauge("ai.model.drift.performance",
Tags.of("model", deployment.getModelName(),
"version", deployment.getVersion()),
drift.getPerformanceDriftScore()
);
// Alert if drift exceeds threshold
if (drift.requiresIntervention()) {
alertingService.sendAlert(
Alert.builder()
.severity(AlertSeverity.WARNING)
.title("Model drift detected")
.description(String.format(
"Model %s v%s shows significant drift: %.2f",
deployment.getModelName(),
deployment.getVersion(),
drift.getOverallDriftScore()
))
.build()
);
}
}
}
}
// Distributed tracing for AI pipelines
@Component
public class AITracingInterceptor {
private final Tracer tracer;
public <T> T traceAIOperation(String operationName, Supplier<T> operation) {
Span span = tracer.nextSpan()
.name(operationName)
.tag("ai.operation.type", extractOperationType(operationName))
.start();
try (Tracer.SpanInScope ws = tracer.withSpanInScope(span)) {
T result = operation.get();
// Add result metadata to span
if (result instanceof AIResponse) {
AIResponse aiResponse = (AIResponse) result;
span.tag("ai.model.name", aiResponse.getModelName())
.tag("ai.model.version", aiResponse.getModelVersion())
.tag("ai.confidence", String.valueOf(aiResponse.getConfidence()));
}
return result;
} catch (Exception e) {
span.error(e);
throw e;
} finally {
span.end();
}
}
// Trace complex AI pipelines
public void traceAIPipeline(AIPipeline pipeline) {
Span pipelineSpan = tracer.nextSpan()
.name("ai.pipeline." + pipeline.getName())
.start();
try (Tracer.SpanInScope ws = tracer.withSpanInScope(pipelineSpan)) {
for (PipelineStage stage : pipeline.getStages()) {
Span stageSpan = tracer.nextSpan()
.name("ai.pipeline.stage." + stage.getName())
.start();
try (Tracer.SpanInScope stageScope = tracer.withSpanInScope(stageSpan)) {
// Execute stage with detailed tracing
StageResult result = stage.execute();
// Add stage-specific tags
stageSpan.tag("stage.input.size", String.valueOf(result.getInputSize()))
.tag("stage.output.size", String.valueOf(result.getOutputSize()))
.tag("stage.duration.ms", String.valueOf(result.getDurationMs()));
} catch (Exception e) {
stageSpan.error(e);
throw e;
} finally {
stageSpan.end();
}
}
} finally {
pipelineSpan.end();
}
}
}
}
Architecture Decision Matrices
Model Serving Architecture Decision Matrix
When choosing between different model serving patterns, consider these key factors:
๐ Online Serving
- Latency: < 100ms response time
- Throughput: 1K-100K requests per second
- Infrastructure Cost: High (always-on servers)
- Complexity: Medium
- Best For: Real-time recommendations, fraud detection, instant predictions
- Scaling: Horizontal auto-scaling
- Failure Impact: Immediate user impact
๐ Batch Serving
- Latency: Hours to days (scheduled jobs)
- Throughput: Millions of predictions per batch
- Infrastructure Cost: Low (pay only when running)
- Complexity: Low
- Best For: Report generation, data labeling, periodic scoring
- Scaling: Vertical for batch size
- Failure Impact: Delayed processing
๐ Streaming Serving
- Latency: < 1 second
- Throughput: 10K-1M events per second
- Infrastructure Cost: Medium
- Complexity: High
- Best For: IoT processing, live analytics, real-time monitoring
- Scaling: Partitioned streams
- Failure Impact: Risk of data loss
Storage Architecture Decision Matrix
Choose your storage solution based on these characteristics:
๐ด Redis
- Primary Use: Feature cache, model cache
- Latency: < 1ms (in-memory)
- Cost: High
- Scalability: Limited by memory
- When to Use: Hot data, session storage, real-time features
๐ก DynamoDB
- Primary Use: User features, prediction storage
- Latency: < 10ms
- Cost: Medium
- Scalability: Unlimited
- When to Use: User profiles, prediction history, flexible schemas
๐ข S3 + Parquet
- Primary Use: Training data, batch features
- Latency: > 100ms
- Cost: Low
- Scalability: Unlimited
- When to Use: Historical data, large datasets, cold storage
๐ต PostgreSQL
- Primary Use: Metadata, experiment tracking
- Latency: < 50ms
- Cost: Medium
- Scalability: Vertical scaling
- When to Use: Structured data, ACID requirements, complex queries
๐ฃ Vector Databases (Pinecone/Weaviate)
- Primary Use: Vector embeddings for similarity search
- Latency: < 50ms
- Cost: High
- Scalability: Horizontal
- When to Use: Semantic search, RAG systems, recommendations
โก Kafka
- Primary Use: Event streaming
- Latency: < 10ms
- Cost: Medium
- Scalability: Horizontal
- When to Use: Event sourcing, real-time pipelines, audit logs
Performance Benchmarks and Trade-offs
Real-World Performance Metrics
Based on production deployments across multiple organizations:
Inference Latency by Model Type
- Small Language Models (< 1B parameters): p50: 15ms, p99: 45ms
- Medium Language Models (1-10B parameters): p50: 80ms, p99: 250ms
- Large Language Models (> 10B parameters): p50: 300ms, p99: 1200ms
- Computer Vision (ResNet-50): p50: 8ms, p99: 25ms
- Computer Vision (YOLO v5): p50: 22ms, p99: 60ms
Infrastructure Utilization
- GPU Utilization Target: 70-85% (higher risks throttling)
- CPU Utilization for Feature Engineering: 60-70%
- Memory Usage for Caching: 80% max (prevent OOM)
- Network Bandwidth: 50% sustained, 80% peak
Cost Optimization Results
- Spot Instance Savings: 60-80% vs on-demand
- Model Quantization Savings: 40-60% with < 2% accuracy loss
- Intelligent Batching: 30-50% throughput improvement
- Multi-tier Serving: 45-70% cost reduction
Conclusion: From Architecture to Production
Building production AI systems requires far more than just model training. The patterns and architectures we've covered form the foundation of systems that can handle millions of requests while maintaining reliability, performance, and cost efficiency.
Key takeaways for architects:
-
Design for Failure: AI models will fail. Build resilient systems with circuit breakers, fallbacks, and graceful degradation.
-
Optimize Ruthlessly: The difference between a profitable and unprofitable AI system often comes down to architecture optimization.
-
Monitor Everything: AI systems have unique observability challenges. Instrument comprehensively from day one.
-
Plan for Scale: Design your architecture to handle 10x growth without major refactoring.
-
Balance Innovation and Stability: Use proven patterns while staying flexible enough to adopt new AI capabilities.
The journey from POC to production is challenging, but with the right architecture patterns, you can build AI systems that deliver real business value at scale.
Next in Series
In Part 3, we'll dive deep into building production-ready RAG (Retrieval-Augmented Generation) systems, covering vector databases, embedding strategies, and hybrid search architectures.
Have questions about production AI architectures? Reach out at gary@atruedev.com or connect on LinkedIn.