-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Activation functions #85
base: main
Are you sure you want to change the base?
Conversation
3ef411e
to
9ebdeaf
Compare
9ebdeaf
to
6c70fbc
Compare
@yonatankarni could you add a short description of what the contents of this PR are (i.e. an overview of sorts, as there are many changes)? @SkBlaz Done (not sure why I can't simply reply, had to edit the comment :| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yonatankarni the deep branch was already merged to main, so better to close this PR and open one that merges straight to main.
@adischw thanks for the heads up, (not sure why I can't simply reply, had to edit the comment :| )
no need to open a new PR - I simply updated this one by changing the base and force-pushing the updated revision
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yonatankarni the deep branch was already merged to main, so please open a PR that merges straight to main.
@adischw thanks for the heads up, (not sure why I can't simply reply, had to edit the comment :| )
no need to open a new PR - I simply updated this one by changing the base and force-pushing the updated revision
c1d115a
to
769ff63
Compare
done (replying again, this time hopefully the right way) |
@yonatankarni it seems there is an issue with |
@SkBlaz yes, this is due to a merge with new incoming changes from main, I will fix it shortly. |
10e7d52
to
8cc0ff4
Compare
input: graph::BlockPtrOutput, | ||
) -> Result<graph::BlockPtrOutput, Box<dyn Error>> { | ||
let num_inputs = bg.get_num_output_values(vec![&input]); | ||
assert!(num_inputs != 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be debug_assert
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
frankly I just copy-pasted this from block_relu, that's why I said we might want to consider eliminating code repetition.
as to whether this should be a assert of debug_assert - since we run FW from release builds, and it seems @andraztori intended for this to fail in case of bad wiring, I think we should keep it a regular assert for now.
unsafe { | ||
for i in 0..self.num_inputs as usize { | ||
let x = *pb.tape.get_unchecked_mut(self.input_offset + i); | ||
if x < 0.0 { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The <
and <=
are intentional right? (one strict)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the first section is for the value of leakyRELU,
the second is for the derivative.
RELU/leakyRELU(x) is defined to equal 0 at x=0, so I use "<" to avoid a redundant multiplication.
however, the derivative of RELU and leakyRELU isn't defined at 0, and I read somewhere (see link below) that the convention is to set it to 1 for x=0, that's why there is a difference (https://stats.stackexchange.com/questions/333394/what-is-the-derivative-of-the-relu-activation-function)
} | ||
|
||
fn get_num_output_slots(&self) -> usize { | ||
1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method's name does not reflect the contents -> why is this not a constant of the object?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this method is an implementation of the abstract method from the BlockTrait trait
input: graph::BlockPtrOutput, | ||
) -> Result<graph::BlockPtrOutput, Box<dyn Error>> { | ||
let num_inputs = bg.get_num_output_values(vec![&input]); | ||
assert!(num_inputs != 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
debug_assert!
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see my reply to same comment for block_leaky_relu
fn allocate_and_init_weights(&mut self, _mi: &model_instance::ModelInstance) {} | ||
|
||
fn get_num_output_slots(&self) -> usize { | ||
1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The constant function once more (same comment as in the previous example)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see reply for previous comment on this function
} | ||
|
||
fn set_input_offset(&mut self, input: graph::InputSlot, offset: usize) { | ||
assert!(input.get_input_index() == 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assert here and above - are they necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see first reply (copy-pasted from block_relu, and my bet is that it's best to have them there to catch issues early)
for i in 0..self.num_inputs as usize { | ||
let x = *pb.tape.get_unchecked_mut(self.input_offset + i); | ||
|
||
// for now using libm tanh computation. once we establish a baseline, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good idea - having full dependency just for computing tanh seems like a lot? (might have missed other uses)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, I added it just for that. the package seems pretty small though (~43K), but I can check if fast approximations perform good enough (https://math.stackexchange.com/questions/107292/rapid-approximation-of-tanhx)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. In case we'd want to roll our own, This one should do a pretty decent trick perhaps
float fast_tanh(float x){
float x2 = x * x;
float a = x * (135135.0f + x2 * (17325.0f + x2 * (378.0f + x2)));
float b = 135135.0f + x2 * (62370.0f + x2 * (3150.0f + x2 * 28.0f));
return a / b;
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool! I hope I will get to test that as well.
I speculate @andraztori's approach is to replace those with lookup tables (which don't involve any computation...), but I think we'll get there only after we see the value - so I will proceed testing
} | ||
if layernorm == NNLayerNorm::AfterRelu { | ||
|
||
if layernorm == NNLayerNorm::AfterActivation { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work! Btw, worth seeing speed comparison with just relu in this case, probably not critical but would be interesting to see
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure I will test if for a regular sequential training scenario
assert_epsilon!(slearn2(&mut bg, &fb, &mut pb, true), 2.0); // leaky_relu doesn't learn | ||
} | ||
|
||
fn test_simple_negative() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here and in other files, I believe you meant to add a #[test]
for this test case
use block_helpers::slearn2; | ||
use block_misc::Observe; | ||
|
||
fn fb_vec() -> feature_buffer::FeatureBuffer { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here and in all other files where this applies - if you have a function that's being used only in the unit tests, please add a #[cfg(test)]
above it so cargo doesn't render it unused
as implemented in "deep" branch, the deep layers can use either RELU activation, or none (no activation function...).
conveniently, the activation function type ("relu"/"none") is already governed by a command line argument,
for instance for a 3rd layer with width 25 and RELU activation we add the command line args:
--nn 2:width:025 --nn 2:activation:relu
in this PR I add additional activation functions for the deep layers, which can be controlled in the same manner:
"leaky_relu", "tanh", "sigmoid".
now that we have 4 activation functions, it seems to me we can do better in terms of code re-use / eliminating repetitions between them, but not sure which approach to take so if you have concrete suggestions this is a good time and place to bring them up.