`timescale 1ps/1ps module softmax #( parameter integer D_W = 8, // output data width (int8) parameter integer D_W_ACC = 32, // input/coeff data width (int32) parameter integer N = 32, // length of the input vector parameter integer FP_BITS = 30, // # fractional bits for exp, Sreq parameter integer MAX_BITS = 30, // used for final divident and shift parameter integer OUT_BITS = 6 // used for final shift ) ( input logic clk, input logic rst, input logic enable, // Streaming input data (1 element per clock) input logic in_valid, input logic signed [D_W_ACC-1:0] qin, // softmax input (int32) // Coefficients for exp input logic signed [D_W_ACC-1:0] qb, input logic signed [D_W_ACC-1:0] qc, input logic signed [D_W_ACC-1:0] qln2, input logic signed [D_W_ACC-1:0] qln2_inv, // Requantization coefficient for exponent output input logic [D_W_ACC-1:0] Sreq, // Streaming output data (1 element per clock) output logic out_valid, output logic signed [D_W-1:0] qout // softmax output (int8) ); // citation: gpt told me to initialize constants as localparam instead of computing it localparam integer SHIFT = MAX_BITS - OUT_BITS; localparam integer DIVIDENT = (1 << MAX_BITS); logic signed [D_W_ACC-1:0] qin_buf; logic [N-1:0] qin_buf_valid; logic signed [D_W_ACC-1:0] qin_counter; // max module logic max_init; logic signed [D_W_ACC-1:0] max_out_reg; logic signed [D_W_ACC-1:0] qmax; logic signed [D_W_ACC-1:0] qhat; logic qhat_valid; // exp module logic qexp_out_valid; logic signed [D_W_ACC-1:0] qexp; logic signed [2*D_W_ACC-1:0] qexp64; logic signed qexp64_valid; logic signed [D_W_ACC-1:0] qreq; logic qreq_valid; logic signed [D_W_ACC-1:0] qreq_buf; logic [2*N:0] qreq_valid_buf; // acc module logic signed [D_W_ACC-1:0] qreq_counter; logic acc_init; logic signed [D_W_ACC-1:0] acc_out; logic signed [D_W_ACC-1:0] qsum; logic signed qsum_valid; // div module logic div_in_valid; logic signed [D_W_ACC-1:0] div_out; logic div_out_valid; logic signed [D_W_ACC-1:0] factor; logic signed [D_W_ACC-1:0] factor_buf; logic [N-1:0] qsum_valid_buf; logic signed [D_W_ACC-1:0] qmul; logic qmul_valid; always_ff @(posedge clk) begin if(rst) begin // counters qin_counter <= 0; qreq_counter <= 0; // numbers max_init <= 1; acc_init <= 1; qmax <= 0; qhat <= 0; qexp64 <= 0; qreq <= 0; qsum <= 0; factor <= 0; factor_buf <= 0; qmul <= 0; qout <= 0; // valids qin_buf_valid <= 0; qhat_valid <= 0; div_in_valid <= 0; qexp64_valid <= 0; qreq_valid <= 0; qreq_valid_buf <= 0; qsum_valid <= 0; qsum_valid_buf <= 0; qmul_valid <= 0; out_valid <= 0; qout <= 0; end else begin if(in_valid) begin if(qin_counter == N-1) begin qin_counter <= 0; max_init <= 1; qmax <= max_out_reg; end else begin qin_counter <= qin_counter + 1; max_init <= 0; end end qin_buf_valid <= {qin_buf_valid[N-1:0],in_valid}; qhat <= qin_buf - qmax; qhat_valid <= qin_buf_valid[N-1]; qexp64 <= qexp * Sreq; qexp64_valid <= qexp_out_valid; // qreq >> FP_BITS does truncated quantization, + (1 << (FP_BITS-1) does the rounding qreq <= (qexp64 + (1 << (FP_BITS-1))) >> FP_BITS; qreq_valid <= qexp64_valid; if(qreq_valid) begin if(qreq_counter == N-1) begin acc_init <= 1; qreq_counter <= 0; qsum_valid <= 1; end else begin qreq_counter <= qreq_counter + 1; acc_init <= 0; qsum_valid <= 0; end end else begin acc_init <= 1; qsum_valid <= 0; end qreq_valid_buf <= {qreq_valid_buf[2*N:0], qreq_valid}; if(qsum_valid) qsum <= acc_out; qsum_valid_buf <= {qsum_valid_buf[N-1:0], qsum_valid}; if(qsum_valid_buf[N-1]) factor_buf <= factor; div_in_valid <= qsum_valid; if(div_out_valid) factor <= div_out; qmul_valid <= qreq_valid_buf[2*N]; qmul <= factor_buf * qreq_buf; qout <= qmul >> SHIFT; out_valid <= qmul_valid; end end // submodule instantiation max #( .D_W(D_W_ACC) ) max ( .clk(clk), .rst(rst), .enable(enable), .initialize(max_init), .in_data(qin), .result(max_out_reg) ); exp #( .D_W(D_W_ACC), .FP_BITS(FP_BITS) ) exp ( .clk(clk), .rst(rst), .enable(enable), .in_valid(qhat_valid), .qin(qhat), .qb(qb), .qc(qc), .qln2(qln2), .qln2_inv(qln2_inv), .out_valid(qexp_out_valid), .qout(qexp) ); acc #( .D_W(D_W_ACC) ) acc ( .clk(clk), .rst(rst), .enable(enable), .initialize(acc_init), .in_data(qreq), .result(acc_out) ); div #( .D_W(D_W_ACC) ) div ( .clk(clk), .rst(rst), .enable(enable), .in_valid(div_in_valid), .divisor(qsum), .dividend(DIVIDENT), .quotient(div_out), .out_valid(div_out_valid) ); sreg #( .D_W(D_W_ACC), .DEPTH(N) ) sreg_qin ( .clk(clk), .rst(rst), .shift_en(1), .data_in(qin), .data_out(qin_buf) ); sreg #( .D_W(D_W_ACC), .DEPTH(2*N+1) ) sreg_qreq ( .clk(clk), .rst(rst), .shift_en(1), .data_in(qreq), .data_out(qreq_buf) ); endmodule