SSE Vector Math Library Speed

Flix
Posts: 456
Joined: Tue Dec 25, 2007 1:06 pm

SSE Vector Math Library Speed

Post by Flix »

Hi everybody.

I've made some speed testing on some common functions in the SSE Vector Math Library (sin4f(...),acos4f(...) and sincosf4(...)) and I found out that, while acosf4(...) is two times faster than the std acos(...) function, the sin4f(...) and sincosf4(...) are much slowler (10x or so) than the std sin(...) and cos(...) functions.

I don't know if it depends on some compilation options I've used or not; anyway I've rewritten the functions using a different approach: the main problem is that I'm not able to gain a decent approximation.
Since I'm a very newbie of SSE, I post here some code in the hope that somebody can modify it and improve it, so that maybe it can become useful ( I added an atan(...) implementation too, but it's still too slow, so don't take it into consideration).

PS. Of course in the testing I did not consider the time spent to convert the __m128 into something printable on the screen.

Code: Select all

#define asinf4(x) _mm_sub_ps( _mm_set1_ps(3.1415926535898f*0.5f) , acosf4(x) )

// TESTED: Discrepancies appear in 3rd or 4th decimal. These errors happen in every range ([-1,1] as well).
// TIME: It's much slowler (3x-4x) than the normal atan() function; maybe because there are a lot of
// stack allocations (too many variables declared).
static inline __m128 atanf4(vec_float4 x)
{
// ALGORITHM: (Hope it works)
// It's a known fact that for X in [-1,1]: atan(X) = X - X3/3 + X5/5 - X7/7 + X9/9 - ... (it's like sin(X) without the factorial)
// Now I've found this trick that seems to work (according to my empirical results):
// X > 1 : atan(X) =  PI*0.5f - atan(1/X)
// X < -1: atan(X) = -(PI*0.5f + atan(1/X))
// PS: Here I make calculations with X>0 so:
// True X > 1 : atan(X) =  PI*0.5f - atan(1/X)
// True X < -1: atan(X) = -PI*0.5f + atan(1/X)
// In both cases 1/X is in [-1,1] interval 
//---------------------------------------------------------------------------------------------------
__m128 isXGreaterThanOne = _mm_cmpgt_ps( x,_mm_set1_ps(1.0f) );
__m128 isXLessThanMinusOne = _mm_cmplt_ps( x,_mm_set1_ps(-1.0f) );
//__m128 isPositive = _mm_cmpgt0_ps(x);
__m128 absX = fabsf4(x);	
__m128 isModuleLessThanOne = _mm_cmplt_ps( absX,_mm_set1_ps(1.0f) ); // !(isXGreaterThanOne || isXLessThanMinusOne)
__m128 xabs = vec_sel(recipf4(absX),absX,isModuleLessThanOne); 
//----------------------------------------------------------------------------------------------------
/// TODO: use xabs to calculate: atan(X) = X - X3/3 + X5/5 - X7/7 + X9/9 - ...
/// Of course it can't be calculated straight away: x5, x7 and x9 are too low floating points;
/// so they must be premultiplied in some way so that multiplications and divisions are balanced:
/// probably a multiplying constant is needed at the beginning (it must be taken away at the end).
/// The code of acosf4(...) can be taken as an example: it splits the operation in two branches
/// (but is VERY compicated...).

// Here I try to copy the acosf4(...) without additional things (and t1=0):

#define _ATAN3  _mm_set1_ps(-1.0f/3.0f)
#define _ATAN5  _mm_set1_ps(1.0f/5.0f)
#define _ATAN7  _mm_set1_ps(-1.0f/7.0f)
#define _ATAN9  _mm_set1_ps(1.0f/9.0f)
#define _ATAN11 _mm_set1_ps(-1.0f/11.0f)

__m128 xabs2 = _mm_mul_ps(xabs,  xabs);
__m128 xabs3 = _mm_mul_ps(xabs2, xabs);
__m128 xabs4 = _mm_mul_ps(xabs2, xabs2);
					
__m128 hi = vec_mul(
					vec_madd(
							vec_madd(
									_ATAN11,
									xabs2, 
									_ATAN9
									),
							xabs2, 
							_ATAN7
							),
					xabs3 
					);
// hi = x3 * (  _ATAN7 + x2 * (_ATAN9 + x2 * _ATAN11) )  =  x +  _ATAN3 x3 + _ATAN5 x5
// hi=hi*x4

__m128 lo = vec_add(
					vec_mul(
							vec_madd(
									_ATAN5,
									xabs2, 
									_ATAN3
									),
							xabs3 
							),
					xabs
					);
// lo = x + x3 * (  _ATAN3 + x2 * _ATAN5 )  =  x +  _ATAN3 x3 + _ATAN5 x5

__m128 result = vec_madd(hi, xabs4, lo);

// Inside comments below there are a few tries to balance things (in the discrepancies of decimals);
// If you try these, remember that the zero case (recipf4(0)) might happen (it does not happen if you leave things as they are):
/*
lo = vec_mul( lo, recipf4(xabs4) );
__m128 result = vec_add(hi,lo);
result=vec_mul(result,xabs4);
*/
/*
lo = vec_mul( lo, recipf4(xabs2) );
hi = vec_mul( hi, xabs2 );
__m128 result = vec_add(hi,lo);
result=vec_mul(result,xabs2);
*/

//-----------------------------------------------------------------------------------------------------
result=vec_sel(result,_mm_add_ps(_mm_set1_ps(3.1415926535898f*0.5f),negatef4(result)),isXGreaterThanOne);
return vec_sel(result,_mm_add_ps(_mm_set1_ps(-3.1415926535898f*0.5f),result),isXLessThanMinusOne);
}

// TESTED: Discrepancies appear in 5th decimal compared with the std function (probably they can be adjusted with a different factorization).
// TIME: It's faster (of 0.4x i.e. SSELength=0.6 * normalLength) than the normal sin() function and much more faster (20x or so) than the prevoius sin4f function
static inline __m128 NEW_sinf4(vec_float4 x)
{
// ALGORITHM:
// sin(X) = X - X3/3! + X5/5! - X7/7! + X9/9! - ... 
//---------------------------------------------------------------------------------------------------
#define _SIN3  _mm_set1_ps(-1.0f/(3.0f*2.0f))
#define _SIN5  _mm_set1_ps(1.0f/(5.0f*4.0f*3.0f*2.0f))
#define _SIN7  _mm_set1_ps(-1.0f/(7.0f*6.0f*5.0f*4.0f*3.0f*2.0f))
#define _SIN9  _mm_set1_ps(1.0f/(9.0f*8.0f*7.0f*6.0f*5.0f*4.0f*3.0f*2.0f))
#define _SIN11 _mm_set1_ps(-1.0f/(11.0f*10.0f*9.0f*8.0f*7.0f*6.0f*5.0f*4.0f*3.0f*2.0f))

__m128 x2 = _mm_mul_ps(x,  x);
__m128 x3 = _mm_mul_ps(x2, x);
					
__m128 hi = vec_madd(
					vec_madd(
							_SIN11,
							x2, 
							_SIN9
							),
					x2, 
					_SIN7
					);
					
// hi =  _SIN7 + x2 * (_SIN9 + x2 * _SIN11)   
// hi=hi*x7

__m128 lo = vec_madd(
					vec_madd(
							_SIN5,
							x2, 
							_SIN3
							),
					x3,
					x
					);

// lo = x + x3 * (  _SIN3 + x2 * _SIN5 ) 

//__m128 x4 = _mm_mul_ps(x2, x2);
x=_mm_mul_ps(x2, x2);	// Reuse x to calculate x4

return vec_madd(hi, _mm_mul_ps(x,x3) , lo);
}

// TESTED: Discrepancies appear in 5th decimal compared with the std function (probably they can be adjusted with a different factorization).
// TIME: It's only a bit faster (of 0.1x i.e. SSELength=0.9 * normalLength) than the normal cos() function and much more faster (10x or so) than the prevoius cos4f function
static inline __m128 NEW_cosf4(vec_float4 x)
{
// ALGORITHM:
// cos(X) = 1 - X2/2! + X4/4! - X6/6! + X8/8! - ... 
//---------------------------------------------------------------------------------------------------
#define _COS2  _mm_set1_ps(-1.0f/(2.0f))
#define _COS4  _mm_set1_ps(1.0f/(4.0f*3.0f*2.0f))
#define _COS6  _mm_set1_ps(-1.0f/(6.0f*5.0f*4.0f*3.0f*2.0f))
#define _COS8  _mm_set1_ps(1.0f/(8.0f*7.0f*6.0f*5.0f*4.0f*3.0f*2.0f))
#define _COS10 _mm_set1_ps(-1.0f/(10.0f*9.0f*8.0f*7.0f*6.0f*5.0f*4.0f*3.0f*2.0f))
#define _COS12 _mm_set1_ps(1.0f/(12.0f*11.0f*10.0f*9.0f*8.0f*7.0f*6.0f*5.0f*4.0f*3.0f*2.0f))
#define _COS14 _mm_set1_ps(1.0f/(14.0f*13.0f*12.0f*11.0f*10.0f*9.0f*8.0f*7.0f*6.0f*5.0f*4.0f*3.0f*2.0f))

__m128 x2 = _mm_mul_ps(x,  x);

__m128 hi = vec_madd(
					vec_madd(
							_COS12,		
							x2, 
							_COS10
							),
					x2,
					_COS8 
					);
// hi =  _COS8 + x2 * (_COS10 + x2 * _COS12) ) 
// hi=hi*x8

__m128 lo = vec_madd(
					vec_madd(
							vec_madd(
									_COS6,
									x2, 
									_COS4
									),
							x2,
							_COS2 
							),
					x2,
					_mm_set1_ps(1.0f)
					);					
// lo = 1 + x2 * (  _COS2 + x2  ( _COS4 + x2 * _COS6))  

/*
//REAL CODE (for clarity)
__m128 x4 = _mm_mul_ps(x2, x2);	//Temporary Only
__m128 x8 = _mm_mul_ps(x4, x4);					
*/
// OPTIMISATION (x is not needed anymore):
x=_mm_mul_ps(x2, x2);

return vec_madd(hi, _mm_mul_ps(x, x), lo);
}