DLM: Whale-4B-Base

Published:

去了一家初创大模型公司面试。他们发了一个开源的小模型,看看他们使用的架构是什么样子的。(本来想继续面的可惜已经找到了满意的地方就决定不继续了。。。

基本架构

首先架构核心采用的是经典的transformer DIT做序列建模。整个架构的流程大概是:

Main

flowchart TD

    %% ===== Input =====
    P["Prompt text"]
    T["Timesteps t"]

    %% ===== Token pipeline =====
    P --> TOK["Tokenizer"]

    TOK --> IDS["input_ids<br/>prompt tokens + MASK tokens"]

    IDS --> EMB["TokenEmbedding"]

    EMB --> X["x: token hidden states"]

    %% ===== Time conditioning =====
    T --> TEMB["TimestepEmbedding"]

    TEMB --> C["c: conditioning vector"]

    %% ===== Transformer backbone =====
    X --> BLOCKS["DiT Blocks × depth"]

    C --> BLOCKS

    %% ===== Final conditioning =====
    BLOCKS --> FINAL["FinalAdaLN"]

    C --> FINAL

    %% ===== LM head =====
    FINAL --> H["hidden states"]

    H --> HEAD["lm_head Linear"]

    HEAD --> LOGITS["logits<br/>B × T × vocab_size"]

    %% ===== Sampling =====
    LOGITS --> SAMPLER["Diffusion Sampler<br/>standard / jump / GIDD"]

    IDS --> SAMPLER
    T --> SAMPLER

    %% ===== Token update =====
    SAMPLER --> UPDATE["Update editable tokens only<br/>keep prompt prefix fixed"]

    UPDATE --> OUT["Generated text"]

    %% ===== Styles =====

    %% Raw input
    classDef input fill:#ECECEC,stroke:#666,color:#000;

    %% Token processing
    classDef token fill:#E8F0FE,stroke:#4A90E2,color:#000;

    %% Conditioning
    classDef cond fill:#F3E5F5,stroke:#8E44AD,color:#000;

    %% Core transformer compute
    classDef core fill:#D5F5E3,stroke:#27AE60,color:#000;

    %% Sampling process
    classDef sample fill:#FFF3CD,stroke:#F39C12,color:#000;

    %% Output
    classDef output fill:#FDEDEC,stroke:#C0392B,color:#000;

    class P,T input;

    class TOK,IDS,EMB,X token;

    class TEMB,C cond;

    class BLOCKS,FINAL,H,HEAD,LOGITS core;

    class SAMPLER,UPDATE sample;

    class OUT output;

DIT block

flowchart TD

    %% ===== Inputs =====
    X[x input]
    C[c conditioning]
    ROPE[RoPE]

    %% ===== AdaLN condition branch =====
    C --> ADA["adaLN MLP<br/>SiLU + Linear"]
    ADA --> SPLIT["Split into 6 vectors"]

    SPLIT --> SMSA["shift_msa"]
    SPLIT --> SCSA["scale_msa"]
    SPLIT --> GMSA["gate_msa"]

    SPLIT --> SMLP["shift_mlp"]
    SPLIT --> SCMLP["scale_mlp"]
    SPLIT --> GMLP["gate_mlp"]

    %% ===== Attention branch =====
    X --> N1["RMSNorm"]

    N1 --> MOD1["AdaLN Modulate<br/>x * (1+scale) + shift"]

    SMSA --> MOD1
    SCSA --> MOD1

    MOD1 --> ATTN["Self Attention"]

    ROPE --> ATTN

    ATTN --> GM1["Multiply gate_msa"]

    GMSA --> GM1

    X --> RES1["Residual Add"]

    GM1 --> RES1

    %% ===== MLP branch =====
    RES1 --> N2["RMSNorm"]

    N2 --> MOD2["AdaLN Modulate<br/>x * (1+scale) + shift"]

    SMLP --> MOD2
    SCMLP --> MOD2

    MOD2 --> FFN["SwiGLU FFN"]

    FFN --> GM2["Multiply gate_mlp"]

    GMLP --> GM2

    RES1 --> RES2["Residual Add"]

    GM2 --> RES2

    RES2 --> OUT["Block Output"]

    %% ===== Styles =====

    %% Norm
    classDef norm fill:#E8F0FE,stroke:#4A90E2,color:#000;

    %% Conditioning
    classDef cond fill:#F3E5F5,stroke:#8E44AD,color:#000;

    %% AdaLN modulation
    classDef mod fill:#FFF3CD,stroke:#F39C12,color:#000;

    %% Core transformer compute
    classDef core fill:#D5F5E3,stroke:#27AE60,color:#000;

    %% Residual / gating
    classDef residual fill:#FDEDEC,stroke:#C0392B,color:#000;

    %% Output
    classDef output fill:#ECECEC,stroke:#666,color:#000;

    class N1,N2 norm;
    class ADA,SPLIT,SMSA,SCSA,GMSA,SMLP,SCMLP,GMLP,C cond;
    class MOD1,MOD2 mod;
    class ATTN,FFN,ROPE core;
    class GM1,GM2,RES1,RES2 residual;
    class OUT output;

config

model:
  vocab_size: 64512
  hidden_size: 2048
  attn_dim: 3072
  ffn_dim: 7168
  depth: 48
  num_heads: 24
  head_dim: 128
  max_seq_len: 4096
  timestep_freq_dim: 256
  rope_theta: 10000.0
  cond_dim: 256
  dropout: 0.0
  attn_dropout: 0.0

diffusion:
  mask_token_id: 14

采样

模型使用了两种方式进行生成,一种是启发式的,每次从$x_t$进入模型得到$x_0$,固定最高置信度的部分token,迭代生成;另一种就是使用GIDD,真正的离散扩散采样。

Jump

flowchart TD

    X0["x_t<br/>prompt + MASK tokens"]

    T["timestep t"] --> MODEL
    X0 --> MODEL

    MODEL["LangDiT model"] --> LOGITS["token logits"]

    LOGITS --> SAMPLE["Sample candidate tokens"]

    SAMPLE --> CONF["Compute confidence"]

    CONF --> PICK["Pick highest-confidence MASK positions"]

    PICK --> FILL["Fill selected MASK tokens"]

    %% Jump phase
    FILL --> JUDGE{"Late jump phase?"}

    JUDGE -- yes --> LOWCONF["Find low-confidence generated tokens"]

    LOWCONF --> ALT["Resample alternative tokens"]

    ALT --> UPDATE["Update uncertain tokens"]

    JUDGE -- no --> UPDATE

    UPDATE --> FIX["Keep prompt prefix fixed"]

    FIX --> NEXT["x_s<br/>more refined tokens"]

    NEXT --> LOOP{"More timesteps?"}

    LOOP -- yes --> X0

    LOOP -- no --> OUT["Final generated text"]

    %% ===== Styles =====

    classDef state fill:#E8F0FE,stroke:#4A90E2,color:#000;
    classDef model fill:#D5F5E3,stroke:#27AE60,color:#000;
    classDef heuristic fill:#FFF3CD,stroke:#F39C12,color:#000;
    classDef output fill:#FDEDEC,stroke:#C0392B,color:#000;

    class X0,NEXT state;
    class MODEL,LOGITS model;
    class SAMPLE,CONF,PICK,FILL,LOWCONF,ALT,UPDATE,FIX heuristic;
    class OUT output;

GIDD

flowchart TD

    XT["x_t<br/>current noisy tokens"]

    T["current timestep t"]
    S["next timestep s"]

    XT --> MODEL
    T --> MODEL

    MODEL["LangDiT model"] --> P0["Predict p_theta(x0 | x_t)"]

    %% Forward distributions
    P0 --> QT["Construct q_t<br/>alpha_t * p0 + noise"]

    P0 --> QS["Construct q_s<br/>alpha_s * p0 + noise"]

    T --> QT
    S --> QS

    %% Posterior
    QT --> POST["Compute reverse posterior<br/>q(x_s | x_t, x0)"]

    QS --> POST
    XT --> POST

    %% Sampling
    POST --> SAMPLE["Sample x_s from posterior"]

    SAMPLE --> FIX["Keep prompt prefix fixed"]

    FIX --> XS["x_s<br/>cleaner token state"]

    XS --> LOOP{"More timesteps?"}

    LOOP -- yes --> XT

    LOOP -- no --> OUT["Final generated text"]

    %% ===== Styles =====

    classDef state fill:#E8F0FE,stroke:#4A90E2,color:#000;
    classDef model fill:#D5F5E3,stroke:#27AE60,color:#000;
    classDef diffusion fill:#F3E5F5,stroke:#8E44AD,color:#000;
    classDef posterior fill:#FFF3CD,stroke:#F39C12,color:#000;
    classDef output fill:#FDEDEC,stroke:#C0392B,color:#000;

    class XT,XS state;

    class MODEL,P0 model;

    class QT,QS diffusion;

    class POST,SAMPLE,FIX posterior;

    class OUT output;