-
Notifications
You must be signed in to change notification settings - Fork 42
Expand file tree
/
Copy pathdebug_deltanet.py
More file actions
179 lines (155 loc) · 6.91 KB
/
debug_deltanet.py
File metadata and controls
179 lines (155 loc) · 6.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
#!/usr/bin/env python3
"""
debug_deltanet.py — Capture per-layer activations from Qwen3.5-0.8B in PyTorch.
Saves numpy files to /tmp/tq_ref/ for comparison with the C engine.
Each file contains the hidden state after that layer's full processing
(attention/deltanet + FFN + residual).
"""
import torch
import numpy as np
import os
import warnings
import contextlib
import io
import sys
# Suppress noisy warnings from transformers
warnings.filterwarnings('ignore')
TOKEN_ID = 9419 # "Hello" — single token for easy debugging
print(f"Loading Qwen3.5-0.8B (token_id={TOKEN_ID})...")
with contextlib.redirect_stderr(io.StringIO()):
from transformers import AutoModelForCausalLM, AutoConfig
config = AutoConfig.from_pretrained('Qwen/Qwen3.5-0.8B', trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
'Qwen/Qwen3.5-0.8B',
trust_remote_code=True,
torch_dtype=torch.float32
)
model.eval()
# Print model structure summary
print(f"Model type: {type(model).__name__}")
lm = model.model
print(f"Number of layers: {len(lm.layers)}")
# Print config details relevant to our C implementation
# Qwen3.5 uses nested config: config.text_config has the details
tc = config.text_config if hasattr(config, 'text_config') else config
print(f"\nModel config:")
print(f" hidden_size: {tc.hidden_size}")
print(f" num_attention_heads: {tc.num_attention_heads}")
print(f" num_key_value_heads: {tc.num_key_value_heads}")
print(f" head_dim: {tc.head_dim}")
print(f" intermediate_size: {tc.intermediate_size}")
print(f" vocab_size: {tc.vocab_size}")
print(f" rms_norm_eps: {tc.rms_norm_eps}")
print(f" layer_types: {tc.layer_types}")
print(f" linear_num_key_heads: {tc.linear_num_key_heads}")
print(f" linear_key_head_dim: {tc.linear_key_head_dim}")
print(f" linear_value_head_dim: {tc.linear_value_head_dim}")
print(f" linear_conv_kernel_dim: {tc.linear_conv_kernel_dim}")
print(f" attn_output_gate: {tc.attn_output_gate}")
if hasattr(tc, 'rope_parameters'):
print(f" rope_parameters: {tc.rope_parameters}")
if hasattr(tc, 'partial_rotary_factor'):
print(f" partial_rotary_factor: {tc.partial_rotary_factor}")
# Inspect layer types
for i, layer in enumerate(lm.layers):
layer_type = "unknown"
if hasattr(layer, 'self_attn') and layer.self_attn is not None:
attn_cls = type(layer.self_attn).__name__
layer_type = f"self_attn({attn_cls})"
if hasattr(layer, 'linear_attn') and layer.linear_attn is not None:
la_cls = type(layer.linear_attn).__name__
layer_type = f"linear_attn({la_cls})"
if i < 3 or i >= len(lm.layers) - 1:
print(f" layer {i:2d}: {layer_type}")
elif i == 3:
print(f" ...")
# Register hooks to capture activations at each stage
activations = {}
def make_hook(name):
def hook_fn(module, input, output):
if isinstance(output, tuple):
activations[name] = output[0].detach().clone()
elif isinstance(output, torch.Tensor):
activations[name] = output.detach().clone()
else:
# BaseModelOutputWithPast or similar
if hasattr(output, 'last_hidden_state'):
activations[name] = output.last_hidden_state.detach().clone()
return hook_fn
# Hook embedding
lm.embed_tokens.register_forward_hook(make_hook('embed'))
# Hook each transformer layer (captures output AFTER residual connections)
for i in range(len(lm.layers)):
lm.layers[i].register_forward_hook(make_hook(f'layer{i:02d}'))
# Hook final norm
lm.norm.register_forward_hook(make_hook('final_norm'))
# Also capture intermediate states within layer 0 for detailed debugging
# Hook the attention/linear_attn sub-module of first few layers
for i in range(min(4, len(lm.layers))):
layer = lm.layers[i]
if hasattr(layer, 'linear_attn') and layer.linear_attn is not None:
layer.linear_attn.register_forward_hook(make_hook(f'layer{i:02d}_linear_attn'))
if hasattr(layer, 'self_attn') and layer.self_attn is not None:
layer.self_attn.register_forward_hook(make_hook(f'layer{i:02d}_self_attn'))
# Hook the MLP
if hasattr(layer, 'mlp') and layer.mlp is not None:
layer.mlp.register_forward_hook(make_hook(f'layer{i:02d}_mlp'))
# Hook input_layernorm (pre-attention norm)
if hasattr(layer, 'input_layernorm'):
layer.input_layernorm.register_forward_hook(make_hook(f'layer{i:02d}_attn_norm'))
# Hook post_attention_layernorm (pre-FFN norm)
if hasattr(layer, 'post_attention_layernorm'):
layer.post_attention_layernorm.register_forward_hook(make_hook(f'layer{i:02d}_ffn_norm'))
# Run forward pass
print(f"\nRunning forward pass with token_id={TOKEN_ID}...")
input_ids = torch.tensor([[TOKEN_ID]])
with torch.no_grad():
out = model(input_ids, use_cache=False)
# Save activations
os.makedirs('/tmp/tq_ref', exist_ok=True)
print(f"\nActivations captured ({len(activations)} entries):")
print(f"{'Name':<30} {'Shape':<20} {'Mean':>12} {'Std':>12} {'[0:5]'}")
print("-" * 100)
for name in sorted(activations.keys()):
t = activations[name]
d = t.squeeze().float().numpy()
np.save(f'/tmp/tq_ref/{name}.npy', d)
vals = d.flatten()[:5]
vals_str = ', '.join(f'{v:.4f}' for v in vals)
print(f'{name:<30} {str(d.shape):<20} {d.mean():>12.6f} {d.std():>12.6f} [{vals_str}]')
# Save logits
logits = out.logits[0, -1, :].float().numpy()
np.save('/tmp/tq_ref/logits.npy', logits)
top_id = logits.argmax()
print(f'\nLogits: shape={logits.shape}, top_id={top_id}, top_val={logits[top_id]:.4f}')
print(f'logits[0:5] = [{", ".join(f"{v:.4f}" for v in logits[:5])}]')
# Also save embedding weights for first token to verify loading
embed_weight = lm.embed_tokens.weight[TOKEN_ID].detach().float().numpy()
np.save('/tmp/tq_ref/embed_weight_token.npy', embed_weight)
print(f'\nEmbed weight for token {TOKEN_ID}: [0:5]=[{", ".join(f"{v:.6f}" for v in embed_weight[:5])}]')
# Save some key DeltaNet weights for layer 0 to verify weight loading
layer0 = lm.layers[0]
if hasattr(layer0, 'linear_attn') and layer0.linear_attn is not None:
la = layer0.linear_attn
print(f"\nLayer 0 DeltaNet weight shapes:")
for wname in ['A_log', 'dt_bias', 'in_proj_qkv', 'in_proj_z', 'in_proj_a', 'in_proj_b',
'conv1d', 'norm', 'out_proj']:
parts = wname.split('.')
obj = la
for p in parts:
obj = getattr(obj, p, None)
if obj is None:
break
if obj is not None:
if hasattr(obj, 'weight'):
w = obj.weight
elif isinstance(obj, (torch.Tensor, torch.nn.Parameter)):
w = obj
else:
w = None
if w is not None:
wd = w.detach().clone().float().numpy()
np.save(f'/tmp/tq_ref/l0_{wname.replace(".", "_")}.npy', wd)
print(f" {wname}: shape={wd.shape} [0:3]={wd.flatten()[:3]}")
print(f"\nAll files saved to /tmp/tq_ref/")
print("Run the C comparison tool next to identify divergence point.")