0
0
Fork 0
mirror of https://github.com/matrix-construct/construct synced 2024-06-10 22:18:54 +02:00

ircd::simt: Abstract the three-piece softmax, mean state related.

This commit is contained in:
Jason Volk 2021-05-02 14:51:49 -07:00
parent 5e91d51e6a
commit 3e9c2d1b56
4 changed files with 90 additions and 38 deletions

View file

@ -11,6 +11,16 @@
#pragma once
#define HAVE_IRCD_SIMT_MEAN_H
/// Averaging state; this is for computing running averages
/// XXX eventually
struct ircd_math_mean
{
float
last, ///< Last addend.
mean, ///< Computed mean.
sum[4]; ///< Summand spread. TODO XXX
};
#ifdef __OPENCL_C_VERSION__
/// Compute average of all elements in the input. The result is broadcast
/// to all elements of the output.

25
include/ircd/simt/samax.h Normal file
View file

@ -0,0 +1,25 @@
// Matrix Construct
//
// Copyright (C) Matrix Construct Developers, Authors & Contributors
// Copyright (C) 2016-2021 Jason Volk <jason@zemos.net>
//
// Permission to use, copy, modify, and/or distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice is present in all copies. The
// full license for this software is available in the LICENSE file.
#pragma once
#define HAVE_IRCD_SIMT_SAMAX_H
/// Softargmax state.
///
/// In FP32 environments, naively implementing softmax as expressed in
/// literature will overflow. Researchers have mitigated this at the cost
/// of some extra state and passes.
struct ircd_math_samax
{
float
mu,
sum,
lambda;
};

View file

@ -18,3 +18,4 @@
#include "mean.h"
#include "norm.h"
#include "rand.h"
#include "samax.h"

View file

@ -153,6 +153,39 @@ ircd_gpt_ffnn(__global const struct ircd_gpt_task *const ctrl,
ircd_gpt_sgemv(token->word, buf->fcon, proj_bias, proj_weight, height, width, 1);
}
inline void
__attribute__((always_inline))
ircd_gpt_attn_self_samax(__global const struct ircd_gpt_task *const ctrl,
__constant const struct ircd_gpt_opts *const opts,
__local float self[][12])
{
const uint
li = get_local_id(0),
wn = get_num_groups(0);
struct ircd_math_samax samax =
{
.mu = -10000.0f,
.sum = 0.0f,
};
for(uint i = 0; i < wn; ++i)
samax.mu = max(samax.mu, self[i][li]);
for(uint i = 0; i < wn; ++i)
self[i][li] = exp(self[i][li] - samax.mu);
__attribute__((opencl_unroll_hint))
for(uint i = 0; i < wn; ++i)
samax.sum += self[i][li];
samax.lambda = 1.0f / samax.sum;
__attribute__((opencl_unroll_hint))
for(uint i = 0; i < wn; ++i)
self[i][li] *= samax.lambda;
}
inline void
__attribute__((always_inline))
ircd_gpt_attn_self(__global const struct ircd_gpt_task *const ctrl,
@ -201,24 +234,7 @@ ircd_gpt_attn_self(__global const struct ircd_gpt_task *const ctrl,
}
// Three-piece softmax
float mu = -10000.0f;
for(uint i = 0; i < wn; ++i)
mu = max(mu, self[i][li]);
for(uint i = 0; i < wn; ++i)
self[i][li] = exp(self[i][li] - mu);
float sum = 0.0f;
__attribute__((opencl_unroll_hint))
for(uint i = 0; i < wn; ++i)
sum += self[i][li];
const float
lambda = 1.0f / sum;
__attribute__((opencl_unroll_hint))
for(uint i = 0; i < wn; ++i)
self[i][li] *= lambda;
ircd_gpt_attn_self_samax(ctrl, opts, self);
}
// Propagate to full width for value dot prod.
@ -529,7 +545,7 @@ ircd_gpt_lm_logsm(__global struct ircd_gpt_task *const ctrl,
ircd_simt_reduce_max_flldr(share);
if(li == 0)
share4[li] = ctrl->samax_mu = share[li];
share4[li] = ctrl->samax.mu = share[li];
ircd_simt_broadcast_f4lldr(share4);
@ -560,8 +576,8 @@ ircd_gpt_lm_logsm(__global struct ircd_gpt_task *const ctrl,
const float
sum = ircd_simt_reduce_add_f4(share4[li]);
share4[li][0] = ctrl->samax_sum = sum;
share4[li][1] = ctrl->samax_lambda = 1.0f / sum;
share4[li][0] = ctrl->samax.sum = sum;
share4[li][1] = ctrl->samax.lambda = 1.0f / sum;
}
ircd_simt_broadcast_f4lldr(share4);
@ -658,25 +674,25 @@ ircd_gpt_lm_result(__global struct ircd_gpt_task *const ctrl,
const float
test_lsm = logexp[opts->label],
loss = 0.0f - log(test_lsm * ctrl->samax_lambda),
loss = 0.0f - log(test_lsm * ctrl->samax.lambda),
perp = (1.0f - logsm[token]) * native_log2(opts->logits),
cert = (logsm[token] - logsm[next_token]) / logsm[token],
loss_sum = ctrl->loss_sum[0] + ctrl->loss_sum[1] + ctrl->loss_sum[2] + loss,
perp_sum = ctrl->perp_sum[0] + ctrl->perp_sum[1] + ctrl->perp_sum[2] + perp,
cert_sum = ctrl->cert_sum[0] + ctrl->cert_sum[1] + ctrl->cert_sum[2] + cert,
loss_mean = loss_sum / (ctrl->epoch + 1.0f),
perp_mean = perp_sum / (ctrl->epoch + 1.0f),
cert_mean = cert_sum / (ctrl->epoch + 1.0f);
loss_sum = ctrl->loss.sum[0] + ctrl->loss.sum[1] + ctrl->loss.sum[2] + loss,
perp_sum = ctrl->perp.sum[0] + ctrl->perp.sum[1] + ctrl->perp.sum[2] + perp,
cert_sum = ctrl->cert.sum[0] + ctrl->cert.sum[1] + ctrl->cert.sum[2] + cert,
loss_mean = loss_sum / (ctrl->epic.epoch + 1.0f),
perp_mean = perp_sum / (ctrl->epic.epoch + 1.0f),
cert_mean = cert_sum / (ctrl->epic.epoch + 1.0f);
ctrl->loss = loss;
ctrl->loss_sum[sum_sel] += loss;
ctrl->loss_mean = loss_mean;
ctrl->perp = perp;
ctrl->perp_sum[sum_sel] += perp;
ctrl->perp_mean = perp_mean;
ctrl->cert = cert;
ctrl->cert_sum[sum_sel] += cert;
ctrl->cert_mean = cert_mean;
ctrl->loss.last = loss;
ctrl->loss.sum[sum_sel] += loss;
ctrl->loss.mean = loss_mean;
ctrl->perp.last = perp;
ctrl->perp.sum[sum_sel] += perp;
ctrl->perp.mean = perp_mean;
ctrl->cert.last = cert;
ctrl->cert.sum[sum_sel] += cert;
ctrl->cert.mean = cert_mean;
}
__kernel void
@ -725,7 +741,7 @@ ircd_gpt_prop_elem(__global const struct ircd_gpt_task *const ctrl,
const float4
param = param_[li],
grad = ctrl->loss_mean,
grad = ctrl->loss.mean,
alpha[2] = { 1.0f - opts->beta[0], 1.0f - opts->beta[1], },
exp_avg = step? exp_avg_[li]: 0.0f,
exp_avg_sqr = step? exp_avg_sqr_[li]: 0.0f,