mirror of
https://github.com/matrix-construct/construct
synced 2024-06-10 22:18:54 +02:00
ircd::simt: Abstract the three-piece softmax, mean state related.
This commit is contained in:
parent
5e91d51e6a
commit
3e9c2d1b56
|
@ -11,6 +11,16 @@
|
|||
#pragma once
|
||||
#define HAVE_IRCD_SIMT_MEAN_H
|
||||
|
||||
/// Averaging state; this is for computing running averages
|
||||
/// XXX eventually
|
||||
struct ircd_math_mean
|
||||
{
|
||||
float
|
||||
last, ///< Last addend.
|
||||
mean, ///< Computed mean.
|
||||
sum[4]; ///< Summand spread. TODO XXX
|
||||
};
|
||||
|
||||
#ifdef __OPENCL_C_VERSION__
|
||||
/// Compute average of all elements in the input. The result is broadcast
|
||||
/// to all elements of the output.
|
||||
|
|
25
include/ircd/simt/samax.h
Normal file
25
include/ircd/simt/samax.h
Normal file
|
@ -0,0 +1,25 @@
|
|||
// Matrix Construct
|
||||
//
|
||||
// Copyright (C) Matrix Construct Developers, Authors & Contributors
|
||||
// Copyright (C) 2016-2021 Jason Volk <jason@zemos.net>
|
||||
//
|
||||
// Permission to use, copy, modify, and/or distribute this software for any
|
||||
// purpose with or without fee is hereby granted, provided that the above
|
||||
// copyright notice and this permission notice is present in all copies. The
|
||||
// full license for this software is available in the LICENSE file.
|
||||
|
||||
#pragma once
|
||||
#define HAVE_IRCD_SIMT_SAMAX_H
|
||||
|
||||
/// Softargmax state.
|
||||
///
|
||||
/// In FP32 environments, naively implementing softmax as expressed in
|
||||
/// literature will overflow. Researchers have mitigated this at the cost
|
||||
/// of some extra state and passes.
|
||||
struct ircd_math_samax
|
||||
{
|
||||
float
|
||||
mu,
|
||||
sum,
|
||||
lambda;
|
||||
};
|
|
@ -18,3 +18,4 @@
|
|||
#include "mean.h"
|
||||
#include "norm.h"
|
||||
#include "rand.h"
|
||||
#include "samax.h"
|
||||
|
|
|
@ -153,6 +153,39 @@ ircd_gpt_ffnn(__global const struct ircd_gpt_task *const ctrl,
|
|||
ircd_gpt_sgemv(token->word, buf->fcon, proj_bias, proj_weight, height, width, 1);
|
||||
}
|
||||
|
||||
inline void
|
||||
__attribute__((always_inline))
|
||||
ircd_gpt_attn_self_samax(__global const struct ircd_gpt_task *const ctrl,
|
||||
__constant const struct ircd_gpt_opts *const opts,
|
||||
__local float self[][12])
|
||||
{
|
||||
const uint
|
||||
li = get_local_id(0),
|
||||
wn = get_num_groups(0);
|
||||
|
||||
struct ircd_math_samax samax =
|
||||
{
|
||||
.mu = -10000.0f,
|
||||
.sum = 0.0f,
|
||||
};
|
||||
|
||||
for(uint i = 0; i < wn; ++i)
|
||||
samax.mu = max(samax.mu, self[i][li]);
|
||||
|
||||
for(uint i = 0; i < wn; ++i)
|
||||
self[i][li] = exp(self[i][li] - samax.mu);
|
||||
|
||||
__attribute__((opencl_unroll_hint))
|
||||
for(uint i = 0; i < wn; ++i)
|
||||
samax.sum += self[i][li];
|
||||
|
||||
samax.lambda = 1.0f / samax.sum;
|
||||
|
||||
__attribute__((opencl_unroll_hint))
|
||||
for(uint i = 0; i < wn; ++i)
|
||||
self[i][li] *= samax.lambda;
|
||||
}
|
||||
|
||||
inline void
|
||||
__attribute__((always_inline))
|
||||
ircd_gpt_attn_self(__global const struct ircd_gpt_task *const ctrl,
|
||||
|
@ -201,24 +234,7 @@ ircd_gpt_attn_self(__global const struct ircd_gpt_task *const ctrl,
|
|||
}
|
||||
|
||||
// Three-piece softmax
|
||||
float mu = -10000.0f;
|
||||
for(uint i = 0; i < wn; ++i)
|
||||
mu = max(mu, self[i][li]);
|
||||
|
||||
for(uint i = 0; i < wn; ++i)
|
||||
self[i][li] = exp(self[i][li] - mu);
|
||||
|
||||
float sum = 0.0f;
|
||||
__attribute__((opencl_unroll_hint))
|
||||
for(uint i = 0; i < wn; ++i)
|
||||
sum += self[i][li];
|
||||
|
||||
const float
|
||||
lambda = 1.0f / sum;
|
||||
|
||||
__attribute__((opencl_unroll_hint))
|
||||
for(uint i = 0; i < wn; ++i)
|
||||
self[i][li] *= lambda;
|
||||
ircd_gpt_attn_self_samax(ctrl, opts, self);
|
||||
}
|
||||
|
||||
// Propagate to full width for value dot prod.
|
||||
|
@ -529,7 +545,7 @@ ircd_gpt_lm_logsm(__global struct ircd_gpt_task *const ctrl,
|
|||
ircd_simt_reduce_max_flldr(share);
|
||||
|
||||
if(li == 0)
|
||||
share4[li] = ctrl->samax_mu = share[li];
|
||||
share4[li] = ctrl->samax.mu = share[li];
|
||||
|
||||
ircd_simt_broadcast_f4lldr(share4);
|
||||
|
||||
|
@ -560,8 +576,8 @@ ircd_gpt_lm_logsm(__global struct ircd_gpt_task *const ctrl,
|
|||
const float
|
||||
sum = ircd_simt_reduce_add_f4(share4[li]);
|
||||
|
||||
share4[li][0] = ctrl->samax_sum = sum;
|
||||
share4[li][1] = ctrl->samax_lambda = 1.0f / sum;
|
||||
share4[li][0] = ctrl->samax.sum = sum;
|
||||
share4[li][1] = ctrl->samax.lambda = 1.0f / sum;
|
||||
}
|
||||
|
||||
ircd_simt_broadcast_f4lldr(share4);
|
||||
|
@ -658,25 +674,25 @@ ircd_gpt_lm_result(__global struct ircd_gpt_task *const ctrl,
|
|||
|
||||
const float
|
||||
test_lsm = logexp[opts->label],
|
||||
loss = 0.0f - log(test_lsm * ctrl->samax_lambda),
|
||||
loss = 0.0f - log(test_lsm * ctrl->samax.lambda),
|
||||
perp = (1.0f - logsm[token]) * native_log2(opts->logits),
|
||||
cert = (logsm[token] - logsm[next_token]) / logsm[token],
|
||||
loss_sum = ctrl->loss_sum[0] + ctrl->loss_sum[1] + ctrl->loss_sum[2] + loss,
|
||||
perp_sum = ctrl->perp_sum[0] + ctrl->perp_sum[1] + ctrl->perp_sum[2] + perp,
|
||||
cert_sum = ctrl->cert_sum[0] + ctrl->cert_sum[1] + ctrl->cert_sum[2] + cert,
|
||||
loss_mean = loss_sum / (ctrl->epoch + 1.0f),
|
||||
perp_mean = perp_sum / (ctrl->epoch + 1.0f),
|
||||
cert_mean = cert_sum / (ctrl->epoch + 1.0f);
|
||||
loss_sum = ctrl->loss.sum[0] + ctrl->loss.sum[1] + ctrl->loss.sum[2] + loss,
|
||||
perp_sum = ctrl->perp.sum[0] + ctrl->perp.sum[1] + ctrl->perp.sum[2] + perp,
|
||||
cert_sum = ctrl->cert.sum[0] + ctrl->cert.sum[1] + ctrl->cert.sum[2] + cert,
|
||||
loss_mean = loss_sum / (ctrl->epic.epoch + 1.0f),
|
||||
perp_mean = perp_sum / (ctrl->epic.epoch + 1.0f),
|
||||
cert_mean = cert_sum / (ctrl->epic.epoch + 1.0f);
|
||||
|
||||
ctrl->loss = loss;
|
||||
ctrl->loss_sum[sum_sel] += loss;
|
||||
ctrl->loss_mean = loss_mean;
|
||||
ctrl->perp = perp;
|
||||
ctrl->perp_sum[sum_sel] += perp;
|
||||
ctrl->perp_mean = perp_mean;
|
||||
ctrl->cert = cert;
|
||||
ctrl->cert_sum[sum_sel] += cert;
|
||||
ctrl->cert_mean = cert_mean;
|
||||
ctrl->loss.last = loss;
|
||||
ctrl->loss.sum[sum_sel] += loss;
|
||||
ctrl->loss.mean = loss_mean;
|
||||
ctrl->perp.last = perp;
|
||||
ctrl->perp.sum[sum_sel] += perp;
|
||||
ctrl->perp.mean = perp_mean;
|
||||
ctrl->cert.last = cert;
|
||||
ctrl->cert.sum[sum_sel] += cert;
|
||||
ctrl->cert.mean = cert_mean;
|
||||
}
|
||||
|
||||
__kernel void
|
||||
|
@ -725,7 +741,7 @@ ircd_gpt_prop_elem(__global const struct ircd_gpt_task *const ctrl,
|
|||
|
||||
const float4
|
||||
param = param_[li],
|
||||
grad = ctrl->loss_mean,
|
||||
grad = ctrl->loss.mean,
|
||||
alpha[2] = { 1.0f - opts->beta[0], 1.0f - opts->beta[1], },
|
||||
exp_avg = step? exp_avg_[li]: 0.0f,
|
||||
exp_avg_sqr = step? exp_avg_sqr_[li]: 0.0f,
|
||||
|
|
Loading…
Reference in a new issue