`timescale 1ps / 1ps module layer_norm #( parameter integer D_W = 8, parameter integer D_W_ACC = 32, parameter integer N = 768, parameter signed [D_W_ACC-1:0] N_INV = 1398101, parameter integer FP_BITS = 30, parameter integer MAX_BITS = 31 ) ( input logic clk, input logic rst, input logic enable, input logic in_valid, input logic signed [D_W_ACC-1:0] qin, input logic signed [D_W_ACC-1:0] bias, input logic [$clog2(D_W_ACC)-1:0] shift, output logic out_valid, output logic signed [D_W_ACC-1:0] qout ); // Your code here localparam integer DIVIDENT = (1 << MAX_BITS); // 50 = 1+16+1+32 localparam integer R_N = 50; localparam integer BIAS_N = N+32+16+2+6; logic signed [D_W_ACC-1:0] qin_reg; // goes into acc and the buffer logic qin_valid; logic signed [D_W_ACC-1:0] qin_counter; logic signed [D_W_ACC-1:0] qshift; // multiply accumulate logic signed [2*D_W_ACC-1:0] qsum_sq; logic signed [3:0][2*D_W_ACC-1:0] qsum_sq_d; // accmulate logic mac_acc_init; // logic signed [2*D_W_ACC-1:0] qsum; logic signed [D_W_ACC-1:0] qsum; // logic signed [2*D_W_ACC-1:0] qsum_d1; logic signed [D_W_ACC-1:0] qsum_d1; // logic signed [2*D_W_ACC-1:0] qsum_d2; logic signed [D_W_ACC-1:0] qsum_d2; logic qsum_valid; // for both qsm and qsum_sq logic [4:0] qsum_valid_d; // sqrt logic var_sqrt_valid; logic [D_W_ACC/2-1:0] var_sqrt; logic signed [2*D_W_ACC-1:0] qmul; logic signed [D_W_ACC-1:0] qmean; logic signed [D_W_ACC-1:0] qmean_reg; logic signed [2*D_W_ACC-1:0] qmean_mul; logic signed [D_W_ACC-1:0] qmean_sq; logic signed [D_W_ACC-1:0] r; logic r_valid; logic signed [D_W_ACC-1:0] r_counter; logic [R_N-1:0] r_valid_d; logic signed [D_W_ACC-1:0] r_out; // logic signed r_out_valid; // TODO: for waveform purposes, get rid of // assign r_out_valid = r_valid_d[R_N-1]; logic [D_W_ACC-1:0] varr; logic [D_W_ACC-1:0] std; logic std_valid; logic [D_W_ACC-1:0] factor; logic [D_W_ACC-1:0] factor_r; logic factor_valid; logic signed [D_W_ACC-1:0] qout_mul; logic signed qout_mul_valid; logic signed [D_W_ACC-1:0] bias_out; // sreg_qin logic signed [D_W_ACC-1:0] qin_buf; logic signed [3:0][D_W_ACC-1:0] qin_buf_d; always_ff @(posedge clk) begin if(rst) begin qin_reg <= 0; qin_counter <= 0; qshift <= 0; mac_acc_init <= 1; qsum_valid_d <= 0; varr <= 0; std <= 0; r <= 0; end else begin qin_valid <= in_valid; if(in_valid) begin qin_reg <= qin; qshift <= qin >>> shift; end else begin qshift <= 0; end if(qin_valid) begin if(qin_counter == N-1) begin qin_counter <= 0; mac_acc_init <= 1; qsum_valid <= 1; end else begin qin_counter <= qin_counter + 1; mac_acc_init <= 0; qsum_valid <= 0; end end else qsum_valid <= 0; qmul <= qsum * N_INV; qmean <= qmul >>> FP_BITS; if(qsum_valid_d[1]) qmean_reg <= qmean; r <= qin_buf_d[3] - qmean_reg; qsum_d1 <= qsum; qsum_d2 <= qsum_d1; qmean_mul <= qsum_d2 * qmean; // TODO: for timing make this constant qmean_sq <= qmean_mul >>> (2*shift); // turn this into BRAM qin_buf_d <= {qin_buf_d[3:0],qin_buf}; // valid qsum_valid_d <= {qsum_valid_d[4:0],qsum_valid}; qsum_sq_d <= {qsum_sq_d[3:0], qsum_sq}; r_valid_d <= {r_valid_d[R_N-1:0], r_valid}; // start counter (and enable) for R if(qsum_valid_d[2]) begin r_valid <= 1; r_counter <= 0; end if(r_valid) begin if(r_counter == N-1) begin if(qsum_valid_d[2] == 0) begin r_counter <= 0; r_valid <= 0; end end else begin r_counter <= r_counter + 1; end end if(r_counter == 49) factor_r <= factor; varr <= qsum_sq_d[3] - qmean_sq; std <= var_sqrt <<< shift; std_valid <= var_sqrt_valid; qout_mul_valid <= r_valid_d[R_N-1]; qout_mul <= r_out * factor_r; qout <= (qout_mul >>> 1) + bias_out; out_valid <= qout_mul_valid; end end mac #( .D_W(D_W_ACC), .D_W_ACC(2*D_W_ACC) ) mac ( .clk(clk), .rst(rst), .enable(enable), .initialize(mac_acc_init), .a(qshift), .b(qshift), .result(qsum_sq) ); acc #( .D_W(D_W_ACC), .D_W_ACC(D_W_ACC) ) acc ( .clk(clk), .rst(rst), .enable(enable), .initialize(mac_acc_init), .in_data(qin_reg), .result(qsum) ); sqrt #( .D_W(D_W_ACC) ) sqrt ( .clk (clk), .rst (rst), .enable (enable), .in_valid (qsum_valid_d[4]), .qin (varr), .out_valid (var_sqrt_valid), .qout (var_sqrt) ); div #( .D_W(D_W_ACC) ) div ( .clk(clk), .rst(rst), .enable(enable), .in_valid(std_valid), .divisor(std), .dividend(DIVIDENT), .quotient(factor), .out_valid(factor_valid) ); sreg #( .D_W(D_W_ACC), .DEPTH(N) ) sreg_qin ( .clk(clk), .rst(rst), .shift_en(1), .data_in(qin), .data_out(qin_buf) ); // 50 = 1+16+1+32 sreg #( .D_W(D_W_ACC), .DEPTH(R_N) ) sreg_r ( .clk(clk), .rst(rst), .shift_en(1), .data_in(r), .data_out(r_out) ); sreg #( .D_W(D_W_ACC), .DEPTH(BIAS_N) ) sreg_bias ( .clk(clk), .rst(rst), .shift_en(1), .data_in(bias), .data_out(bias_out) ); endmodule