The pursuit of local, private, and performant large language models (LLMs) is a significant area of development. Moving from simple text generation to interactive, stateful conversation presents a primary challenge: efficiency. A naive chat implementation, which re-evaluates the entire conversation history for every new message, becomes unusably slow. The key to performant chat is state management. This paper documents the iterative process of creating a robust script that leverages the model's internal Key-Value (KV) cache, addressing challenges such as cache management to achieve a functional and efficient solution.
A transformer model's performance relies on its ability to calculate attention scores between tokens. In a conversation, most tokens from previous turns remain constant. Re-calculating their attention states repeatedly is computationally wasteful. The KV cache prevents this by storing the keys and values for each layer as a cache object (past_key_values or pkv). For subsequent inputs, this cache is provided with only new tokens, allowing the model to use the cached state for old tokens and compute the state for new ones, reducing processing time.
The primary technical hurdle was managing the StaticCache, which requires the cache tensor to be allocated to its maximum size initially. Failure to do so results in an IndexError when the conversation length exceeds the initial allocation. A one-time dummy forward pass with an input tensor of the desired maximum length initializes the cache correctly. Additionally, instruction-tuned models require a strict chat template to ensure conversational behavior.
The final script, g99.py, is a concise implementation of the caching strategy, incorporating the three pillars: respecting the Static Cache, executing a Cache Priming Pass, and adhering to the Mandate of the Chat Template.
tok = AutoTokenizer.from_pretrained(modpath)mod = AutoModelForCausalLM.from_pretrained(modpath, torch_dtype=dtype, device_map={"":dev})
A dummy tensor of max_len (e.g., 1024) filled with the padding token ID is created. A single forward pass with use_cache=True populates the pkv object with correctly-sized, neutral cache tensors.
max_len = 1024pad_id = tok.pad_token_id if tok.pad_token_id is not None else tok.eos_token_idwith torch.no_grad(): dummy_ids = torch.full((1, max_len), pad_id, dtype=torch.long, device=dev) pkv = mod(input_ids=dummy_ids, use_cache=True).past_key_values
The script iterates through a list of turns, managing state and processing incrementally.
A chat list holds the conversation history in the required role/content format. cp
(cache position) tracks the current context length.
chat.append({"role": "user", "content": t})prompt_ids = tok.apply_chat_template(chat, add_generation_prompt=True, return_tensors="pt").to(dev)new_ids = prompt_ids[:, cp:]new_len = new_ids.shape[1]pos = torch.arange(cp, cp + new_len, device=dev)with torch.no_grad(): out = mod(input_ids=new_ids, use_cache=True, past_key_values=pkv, cache_position=pos)cp += new_len
for i in range(30): pos = torch.tensor([cp], device=dev) out = mod(input_ids=nxt.view(1, 1), use_cache=True, past_key_values=pkv, cache_position=pos) cp += 1
The complete model response is appended to the chat list, preparing the state for the next turn.
This implementation demonstrates an efficient, stateful conversational AI system. By leveraging the KV cache, priming it correctly, and adhering to the chat template, the script achieves performance improvements over naive re-processing. The performance difference between 1B and 4B model variants highlights the relationship between parameter count and reasoning capabilities.
The development of a performant, local, and stateful conversational AI faced non-obvious pitfalls. While the KV cache concept is well-known, its practical implementation for instruction-tuned models like Gemma 3 presented challenges that misled even experienced developers.
The promise of a local chatbot—offering privacy, no API costs, and fast responses—was compelling. However, the naive approach of appending new user input to the history and feeding it to the model became computationally untenable as the conversation grew.
attn_implementation
failed, as the model’s default behavior was immutable.The insight was that errors stemmed from misunderstanding the model’s architecture. Success required guiding the model rather than imposing external solutions.
apply_chat_template
to interpret conversational roles correctly.The journey from failure to fluidity revealed that efficiency requires precise alignment with the model’s internals. The three pillars—respecting the cache, priming it correctly, and using the chat template—transformed a broken script into a robust solution, highlighting the importance of understanding model architecture over fighting it.
Big 8 Element | Required Behavior | Gyrator Insight Ref(s) |
---|---|---|
KV Cache | Must be captured, updated, and aligned to every token and position. Immutable if reused. | 1, 4, 7, 11, 21, 30, 40, 44 |
token Embeddings | Must remain in sync with tokenizer; out-of-sync embeddings cause hallucinations. | 2, 5, 42, 43, 50 |
Positional State | Absolute position must be correct and continuous. Matches history and cache shape. | 2, 3, 6, 16, 28, 33, 34, 35, 44 |
Input token Buffer | Reflects every submitted token in correct order for context replay. | 12, 17, 22, 23, 27 |
Attention Mask | Matches token shape and preserves causal behavior. Affected by pad tokens. | 13, 14, 24, 49 |
Model Config Snapshot | Ensures invariant head counts, layer structure, and generation options. | 10, 20, 45, 52 |
LayerNorm State | Static, but adaptive variants require snapshot to avoid logit drift. | 10, 46 |
Attention Module Weights | Must not be tuned mid-inference to avoid state mismatch. | 41, 44, 53 |
+4 Element | Why It Matters | Insight Tie-ins |
---|---|---|
Logits | Validates replay; divergence indicates incorrect context restoration. | 41, 46, 47 |
Tokenizer State | Must match training vocab and special token rules to preserve token identity. | 9, 12, 36, 42 |
KV Format Version | Format mismatches between model versions invalidate reuse attempts. | 1, 19, 35, 58 |
Prompt Injection Meta | Prefix tokens shape behavior; loss or duplication skews intent and position. | 25, 36, 48 |
Each insight below captures a critical behavior, failure mode, or requirement in transformer-based context restoration. These are drawn from real-world debugging of the g99.py system and its live use of KV cache in multi-turn inference.
pkv
outside the model without explicit capture means you’re using outdated memory.cache_position
must match absolute token position." This is not relative to the input slice, but to the full, cumulative position since cache start. Feeding mismatched cache_position
leads to deeply broken attention behavior without error.cp
must track tokens, not turns." Using turn count or line number is a fallacy—token count expands unpredictably, especially with template wrapping. Misestimating this creates irreversible misalignment.past_key_values
. If not reassigned, continued use reflects past state—not the updated context.k=5
aren't stored semantically; the model only reuses textual patterns that look like memory.cp
after sampling yields divergent cache behavior.pkv
from pad tokens is misleading." Initializing with uniform tokens like [PAD]*1024
doesn’t simulate real attention. Rotary or ALiBi buffers may misconfigure based on dummy inputs.pkv
or position—output continues, but semantic coherence is lost subtly and irreversibly.float16
on CPU coerces to float32
, breaking numerical parity between init and sampling. These effects compound across turns.pkv
is not deep-copied by default." Standard Python assignment (pkv2 = pkv
) keeps tensor references live—so any mutation during forward pass affects the original unexpectedly.apply_chat_template()
inject formatting tokens that expand with each turn. Re-tokenizing history causes unpredictable sequence inflation.[PAD]
, <bos>
, or unknown tokens interact with model internals in complex ways.cp
offsets are brittle." Using slicing tricks like prompt_ids[:,cp:]
assumes token alignment from prompt regeneration—but this breaks unless token count is manually confirmed.cp
isn't advanced in-place." Generated tokens need their own cache_position
. If cp
is not incremented per sample, positional encoding and rotary drift silently.pkv
must reflect first real input, not dummy." Cache primed with dummy data leads to structurally invalid restoration later. First real prompt should establish cache structure.cache_position
." Model APIs accept incorrect position values silently. There's no validation layer—only broken inference downstream.pkv
every step." Each new token modifies the full cache state. If you don’t rebind the return value of mod(...)
, you're discarding continuity.cuda
but cache on cpu
, PyTorch may silently cast or copy behind the scenes—breaking timing and tensor integrity.past_key_values
are long, the model’s attention window (e.g., 2048 tokens) limits what it can see—older cache entries may be masked out.cp
alignment.cp=0
mid-session is invalid.<bos>
) may appear implicitly. These are not obvious in text but change token count."m"
means nothing unless the transformer has seen it used consistently. It’s not a variable until inferred.-it
variants learn to recall symbolically. Non-instruction-tuned models don’t.batch_size > 1
can suppress positional errors due to averaging—but will break when scaled down.pad_id
). This initializes the KV cache with correct dimensions, preventing errors during extended conversations. Using non-neutral tokens can poison the cache state. See Insights 7, 17, 49.max_len
) filled with pad_id
, used in the Cache Priming Pass to initialize the KV cache. Ensures the cache is large enough for long conversations without resizing. See Insights 7, 17.eos_token_id
if unavailable), used to fill dummy_ids
for cache initialization. Not context-neutral, as it interacts with attention masks and positional encodings. See Insights 7, 49.new_ids
for incremental processing. See Insights 12, 15, 27.prompt_ids
(e.g., prompt_ids[:,cp:]
) containing only the token IDs for the current turn’s new input. Passed to the model to minimize computation. See Insights 15, 27.<user>
, <assistant>
). Mandatory for instruction-tuned models to ensure conversational behavior. See Insights 12, 25, 37.chat
list using apply_chat_template
for each turn. Necessary for maintaining context but contributes to token inflation. See Insights 12, 23.past_key_values
. Used for generating tokens and updating the KV cache. See Insights 4, 21.cpu
or cuda
) for model and tensor operations. Mismatches cause silent failures or performance issues. See Insight 24.float16
, bfloat16
). Precision mismatches lead to numerical drift. See Insight 10.dtype
. See Insight 10.cp
and cache_position
. See Insights 2, 3, 34.cp
) across all turns, growing with each input and generated token, not reset per turn. See Insights 3, 34.{"role": "user", "content": "Hello"}
). Used to generate prompt_ids
via the chat template. See Insights 12, 25.chat
list before processing. See Insight 25.var | description | insight # |
---|---|---|
mod | Loaded transformer model (e.g., Gemma) with all weights, layers, and attention parameters. Provides the forward pass, generates logits, and mutates the KV cache. | 1, 4, 10, 21 |
tok | Tokenizer tied to the model. Applies chat template, encodes inputs, decodes outputs. Misalignment with model causes token drift and hallucination. | 9, 12, 36, 42 |
maxlen | Maximum length used to prefill dummy input, controlling StaticCache size. Must match intended conversation capacity. | 7, 32 |
padid | Used to fill dummy input. Non-neutral: affects attention mask and rotary even if unused semantically. | 7, 49 |
pkv | Captured past_key_values from dummy or live forward pass. Required for continuation. Rewritten each forward call. Not deep-copied. | 1, 4, 11, 21, 30, 40, 44 |
chat | Ordered list of role/content turns. Input to chat template. Expands nonlinearly. Used for prompt regeneration. | 12, 23, 25 |
turns | List of scripted prompts simulating multi-turn dialog. Source of user queries for the test loop. | 3 |
cp | Global token cursor. Tracks total token count across prompt and generation. Defines alignment for cache_position. | 2, 3, 6, 15, 27, 34 |
dev | Device assignment string ('cpu' or 'cuda'). Mismatch with tensors causes silent transfer or failure. | 24 |
dtype | Precision for model weights. Impacts memory, speed, and reproducibility. float16 on CPU coerces to float32. | 10 |
modpath | Filesystem or hub ID used to load the correct model and tokenizer. Must be consistent across components. | 52 |
mp | Dictionary mapping model nicknames to local or hub paths. Used to select model based on host or config. | — |
dp | Device alias map ('cpu' → 'cpu', 'gpu' → 'cuda'). Abstracts platform logic. | — |
tp | Map from string keys ('f16', etc.) to torch data types. Used for loading weights. | — |
rng | Shape tuple (1, maxlen) defining the dummy input tensor size. Matches cache allocation intent. | 7, 17 |
dummyids | Filled with `padid`, this tensor primes the cache during the initialization pass. | 7, 17 |
r | Result of the dummy forward. Contains logits and `past_key_values`. Used only to extract the cache. | 4 |
promptids | Full tokenized chat after applying chat template. Grows with each turn. Token count ≠ turn count. | 9, 12, 15, 27 |
newids | Slice of `promptids` from `cp:`. Only the new user message. Required for incremental cache extension. | 15, 27 |
newlen | Length of the current input slice. Used to update cursor and position encoding range. | 6 |
pos | Absolute position tensor aligned to `cp`. Needed for cache_position. Controls rotary phase. | 2, 6, 16, 28, 34 |
out | Result from model forward. Contains logits and updated `pkv`. Must be reassigned each call to preserve state. | 4, 21 |
outstr | Accumulated string of model output for this turn. Built token-by-token from decoded samples. | 14, 30 |
i | Loop counter for sampling. Limits number of generated tokens per turn to prevent runaway generation. | 4 |
t | Reused temp var: both time and logits. Naming collision risk, but harmless here. | — |
nxt | Next token ID from argmax(logits). Fed back into model for next-step generation. | 6, 21 |
tokstr | Decoded text of `nxt`. Checked against end-of-turn marker to stop sampling early. | 30 |
v | 2D reshaped tensor of `nxt`. Required shape (1, 1) for model input compatibility. | 6 |
#!/media/krusty/gm/gm120/anaconda3/envs/apy/bin/pythonimport os,sys,timesys.path.insert(0,"/webroot/lib")import plib#!/media/krusty/gm/gm120/anaconda3/envs/apy/bin/pythonimport os,sys,time,socket # Import core Python modules for file operations, system access, timing, and networkingsys.path.insert(0,"/webroot/lib") # Prepend custom path for module resolution; allows loading local libs like 'plib' belowimport plib # Custom library, assumed project-specific, loaded from /webroot/libfrom transformers import AutoTokenizer,AutoModelForCausalLM # HuggingFace interface: AutoTokenizer = text to token IDs; AutoModelForCausalLM = transformer for text generationimport torch # PyTorch library provides tensor ops, model loading, GPU support, and KV cache infrastructuredef init(): # Initializes the transformer pipeline: loads tokenizer, model, sets device, precision, and prepares dummy cache global mod, tok, maxlen, padid, pkv, chat, turns, cp, dev, dtype, modpath # Expose these as globals for use in forward, decoding, and chat state management mp={ # Model path dictionary. Keys are string sizes (1b, 4b, etc), values are either local snapshot paths or remote model hub identifiers "1b":"/home/krusty/.cache/huggingface/hub/models--google--gemma-3-1b-it/snapshots/dcc83ea841ab6100d6b47a070329e1ba4cf78752", "4b":"/home/krusty/.cache/huggingface/hub/models--google--gemma-3-4b-it/snapshots/093f9f388b31de276ce2de164bdc2081324b9767", "9b":"google/gemma-3-9b-it", # remote: this entry uses huggingface's repo format "27b":"google/gemma-3-27b-it" } dp={"cpu":"cpu","gpu":"cuda"} # Device map for abstraction; simplifies logic later tp={"bf":torch.bfloat16,"f16":torch.float16,"f32":torch.float32} # Abbreviation map from string dtype labels to PyTorch precision types # Machine-specific model loading: uses hostname to select model size, device type, and precision mode if socket.gethostname()=="machf": modpath=mp["1b"] # Load smallest model for fast testing dev=dp["cpu"] # Use CPU for local deterministic runs dtype=tp["f16"] # Use float16 even on CPU (nonstandard, experimental) elif socket.gethostname()=="machh": modpath=mp["4b"] # Load mid-size 4b model on GPU for higher throughput dev=dp["gpu"] # Activate CUDA execution dtype=tp["f32"] # Use full float precision (more accurate, slower) tok=AutoTokenizer.from_pretrained(modpath) # Load the tokenizer. This defines vocabulary, special tokens, and chat templates mod=AutoModelForCausalLM.from_pretrained(modpath,torch_dtype=dtype,device_map={"":dev}) # Load transformer model with given precision and device override. This allows GPU memory control. maxlen=1024 # Define maximum context size for input tokens. Most LLMs have 1024 or 2048 as limit. Used for dummy init. padid=tok.pad_token_id # Try to extract pad token from tokenizer config. This is used to fill dummy sequences or pad real ones. if padid is None: padid=tok.eos_token_id # If tokenizer doesn't define padding token, fallback to EOS as a proxy. This affects dummy KV structure. with torch.no_grad(): # Disable gradients for this block. We are only warming up model state with dummy inputs rng=(1,maxlen) # Create dummy input shape: batch of 1, 1024 tokens dummyids=torch.full(rng,padid,dtype=torch.long,device=dev) # Fill dummy tensor with pad tokens. Model will treat this as a no-op input for memory allocation. r=mod(input_ids=dummyids,use_cache=True) # Feed dummy input to model. This triggers KV cache allocation. No content, just side effect. pkv=r.past_key_values # Capture the initialized KV cache. This state object is reused across forward calls for continued generation. chat=[] # Empty list to store full conversation history as alternating role="user" and role="model" turns=[ # Hardcoded multi-turn simulated user dialog. Purpose: test long-term memory and reasoning across multiple instructions. "Let's start by defining a variable k equal to 5. What is k?", "Now set a new variable m equal to k multiplied by 11. What is m?", "If we increment k by 1, what is the new value of m?", "Okay, forget k and m. Let x = 100 and y = 25. What is x divided by y?", "Now, a new variable z is the product of x and y. Calculate z.", "If we subtract 500 from z, what is the result?", "What was the value of m from our earlier conversation?", "Final question: what was the first variable we defined in this entire chat?" ] cp=0 # Cache position counter. Tracks how many tokens have been sent to model so far. This value is critical for position alignment.def atc(): # Returns input tensor from current chat history, formatted using the tokenizer's template logic r=tok.apply_chat_template(chat,add_generation_prompt=True,return_tensors="pt") # Wraps conversation in system/user/model markers and converts to token tensor return r.to(dev) # Sends tensor to correct compute device to match model weightsdef dumoda(): # Perform model forward pass for multiple tokens at once (batch slice from prompt) return mod(input_ids=newids,use_cache=True,past_key_values=pkv,cache_position=pos) # Provides position tensor and cache for accurate generation across turnsdef dumodb(): # Perform model forward for a single decoded token v=nxt.view(1,1) # Reshape scalar token ID into 2D batch form for model input return mod(input_ids=v,use_cache=True,past_key_values=pkv,cache_position=pos) # Forward next token while preserving continuity via cached attentioninit() # Execute full setup sequence: model loading, tokenizer, dummy KV init, and conversation state resetfor t in turns: # Iterate through user prompts, simulating turn-by-turn interaction t0=time.time() # Start timing this turn for performance stats chat.append({"role":"user","content":t}) # Add user message to chat history. Used by tokenizer template logic. promptids=atc() # Convert current chat history into model-ready token tensor using HuggingFace chat templates newids=promptids[:,cp:] # Slice token tensor to include only newly added user prompt tokens newlen=newids.shape[1] # Count number of new tokens generated in this turn pos=torch.arange(cp,cp+newlen,device=dev) # Build position tensor to pass to model. Must align with full cumulative token index (not just turn-local) with torch.no_grad(): out=dumoda() # Run model forward pass on new token segment. Output includes logits for next prediction and updated KV cp+=newlen # Update global cursor to reflect how many total tokens were sent. Critical for maintaining cache alignment. outstr="" # Start building model's reply string, token by token for i in range(60): # Generate up to 60 output tokens, one at a time t=out.logits[:,-1,:] # Extract logits for final position in sequence. These represent model's belief over next token. nxt=torch.argmax(t,dim=-1) # Select the most likely token (greedy decode) tokstr=tok.decode(nxt) # Convert token ID back into string if tokstr=="": break # Stop generation if special token is reached outstr+=tokstr # Accumulate generated text pos=torch.tensor([cp],device=dev) # Update position to reflect the single-token continuation point with torch.no_grad(): out=dumodb() # Forward next token using updated KV state and pos cp+=1 # Increment cursor. Each sampled token must advance the absolute position chat.append({"role":"model","content":outstr}) # Save model reply in chat history to be used for next prompt print(outstr) # Display model's full response for the turn print("Response time:",round(time.time()-t0,3),"sec") # Show how long generation took for diagnostic purposes