var | description | insight # |
---|
#!/media/krusty/gm/gm120/anaconda3/envs/apy/bin/python | Shebang for Python interpreter in Anaconda env `apy` | — |
import os,sys,time,socket | Standard library imports: file ops, system, timing, networking | — |
sys.path.insert(0,"/webroot/lib") | Prioritize custom library path `/webroot/lib` | — |
import plib | User-defined or local module `plib` assumed to be in inserted path | — |
from transformers import AutoTokenizer,AutoModelForCausalLM | Load HuggingFace classes for tokenizer and model | 3 |
import torch | Import PyTorch for tensor and model operations | — |
def init(): | Initializes model, tokenizer, dummy KV cache, and metadata | 1,3,4 |
global mod, tok, maxlen, padid, pkv, chat, turns, cp, dev, dtype, modpath | Expose critical objects and settings globally across session | — |
mp={...} | Model paths for local and remote Gemma variants | — |
"4b":"/home/krusty/.cache/huggingface/hub/models--google--gemma-3-4b-it/..." | Local snapshot directory for 4b variant | — |
"9b":"google/gemma-3-9b-it" | Remote hub-based identifier | — |
"27b":"google/gemma-3-27b-it" | Additional model identifier for higher capacity | — |
dp={"cpu":"cpu","gpu":"cuda"} | Simple device mapping string | — |
tp={"bf":torch.bfloat16,"f16":torch.float16,"f32":torch.float32} | Abbreviated types to corresponding torch dtypes | — |
if socket.gethostname()=="machf":... | Machine-specific config: match hostname to default model, device, type | — |
tok=AutoTokenizer.from_pretrained(modpath) | Instantiate tokenizer using resolved path | 3 |
mod=AutoModelForCausalLM.from_pretrained(...) | Load model with correct type/device map | 3 |
maxlen=1024 | Define maximum token length (fits model window) | 3 |
padid=tok.pad_token_id | Attempt to use model's pad ID | 1 |
if padid is None: | Fallback check: no pad token defined | 1 |
padid=tok.eos_token_id | Substitute pad with EOS token | 1 |
with torch.no_grad(): | Preload model once for cache without tracking gradients | 4 |
rng=(1,maxlen) | Input shape for dummy token forward | 4 |
dummyids=torch.full(...) | Fill dummy tensor with pad tokens to trigger cache prep | 4 |
r=mod(input_ids=dummyids,use_cache=True) | Single forward pass to populate `past_key_values` | 4 |
pkv=r.past_key_values | Capture initial KV state from dummy run | 1,4 |
chat=[] | Empty history of conversation so far | 3 |
turns=[...] | Simulated multi-turn user messages | 3 |
cp=0 | Initialize token cursor at 0 | 2 |
def atc(): | Encodes chat history with template formatting | 3 |
r=tok.apply_chat_template(...) | Produce prompt using chat template | 3 |
return r.to(dev) | Ensure prompt is on same device as model | 3 |
def dumoda(): | Model forward using current newids and cached pkv | 4 |
return mod(input_ids=newids,...) | Perform generation from current slice | 4 |
def dumodb(): | Model forward for a single token input | 4 |
v=nxt.view(1,1) | Reshape single token to batch form | 4 |
return mod(input_ids=v,...) | Feed next token with cache and position | 4 |
init() | Run setup for tokenizer, model, and dummy cache | 1,4 |
for t in turns: | Iterate over user turns | 3 |
t0=time.time() | Start timing the response | — |
chat.append({"role":"user","content":t}) | Add user message to history | 3 |
promptids=atc() | Tokenize full prompt using current chat | 3 |
newids=promptids[:,cp:] | Slice prompt to new portion | 2 |
newlen=newids.shape[1] | Count number of new tokens | 2 |
pos=torch.arange(cp,cp+newlen,...) | Absolute cache position for each token | 2 |
with torch.no_grad(): | Disable gradient tracking for inference | 4 |
out=dumoda() | Generate logits using past context and new slice | 4 |
cp+=newlen | Advance cursor by number of input tokens | 2 |
outstr="" | Start building model reply string | 3 |
for i in range(60): | Limit response to 60 tokens max | 4 |
t=out.logits[:,-1,:] | Select last token logits from output | 4 |
nxt=torch.argmax(t,dim=-1) | Choose most likely token | 4 |
tokstr=tok.decode(nxt) | Convert token to readable string | 3 |
if tokstr=="<end_of_turn>": break | Stop generation on end marker | 4 |
outstr+=tokstr | Append decoded token to string | 3 |
pos=torch.tensor([cp],device=dev) | Update single-token cache position | 2 |
with torch.no_grad(): | Safe inference block | 4 |
out=dumodb() | Single-token model forward | 4 |
cp+=1 | Advance cursor by one token | 2 |
chat.append({"role":"model","content":outstr}) | Append model reply to chat | 3 |
print(outstr) | Print model response string | — |
print("Response time:",round(time.time()-t0,3),"sec") | Show time spent on generating the turn | — |