half.cpp 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. /*************************************************************************
  2. * Copyright (C) [2020] by Cambricon, Inc. All rights reserved
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * The above copyright notice and this permission notice shall be included in
  11. * all copies or substantial portions of the Software.
  12. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
  13. * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  14. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
  15. * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  16. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  17. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
  18. * THE SOFTWARE.
  19. *************************************************************************/
  20. #include "half.h"
  21. inline bool diffNotMuch(half a, half b) {
  22. return true;
  23. float c = static_cast<float>(a);
  24. float d = static_cast<float>(b);
  25. if (c == 0 || d == 0) {
  26. return true;
  27. }
  28. c = c > 0 ? c : 0;
  29. d = d > 0 ? d : 0;
  30. if ((c / d) > DIFF_SCALE && (c / d) > DIFF_SCALE) {
  31. return false;
  32. }
  33. return true;
  34. }
  35. half::half() {}
  36. half::~half() {}
  37. half::half(const float a) { data_ = float2half(a); }
  38. // Data Cast
  39. half::operator int() {
  40. float a = half::half2float(data_);
  41. return static_cast<int>(a);
  42. }
  43. half::operator float() {
  44. float a = half::half2float(data_);
  45. return a;
  46. }
  47. half::operator double() {
  48. float a = half::half2float(data_);
  49. return static_cast<double>(a);
  50. }
  51. std::ostream& operator<<(std::ostream& output, const half& c) {
  52. float data_f = half::half2float(c.data_);
  53. output << data_f;
  54. return output;
  55. }
  56. std::istream& operator>>(std::istream& input, half& c) {
  57. float data_f;
  58. input >> data_f;
  59. c.data_ = half::float2half(data_f);
  60. return input;
  61. }
  62. half operator+(const int& a, const half& b) {
  63. float c = a;
  64. float b_f = half::half2float(b.data_);
  65. float result = c + b_f;
  66. result = result > HALF_MAX ? HALF_MAX : result;
  67. result = result < HALF_MIN ? HALF_MIN : result;
  68. return result;
  69. }
  70. half& half::operator=(const half& a) {
  71. data_ = a.data_;
  72. return *this;
  73. }
  74. half half::operator-(void) { return (half(0) - *this); }
  75. half half::operator+(const half& a) {
  76. assert(diffNotMuch(*this, a));
  77. float data_f = half2float(data_);
  78. float a_f = half2float(a.data_);
  79. float result = data_f + a_f;
  80. result = result > HALF_MAX ? HALF_MAX : result;
  81. result = result < HALF_MIN ? HALF_MIN : result;
  82. return result;
  83. }
  84. half half::operator-(const half& a) {
  85. assert(diffNotMuch(*this, a));
  86. float data_f = half2float(data_);
  87. float a_f = half2float(a.data_);
  88. float result = data_f - a_f;
  89. result = result > HALF_MAX ? HALF_MAX : result;
  90. result = result < HALF_MIN ? HALF_MIN : result;
  91. return result;
  92. }
  93. half half::operator*(const half& a) {
  94. assert(diffNotMuch(*this, a));
  95. float data_f = half2float(data_);
  96. float a_f = half2float(a.data_);
  97. float result = data_f * a_f;
  98. result = result > HALF_MAX ? HALF_MAX : result;
  99. result = result < HALF_MIN ? HALF_MIN : result;
  100. return result;
  101. }
  102. half half::operator/(const half& a) {
  103. assert(diffNotMuch(*this, a));
  104. float data_f = half2float(data_);
  105. float a_f = half2float(a.data_);
  106. float result = data_f / a_f;
  107. result = result > HALF_MAX ? HALF_MAX : result;
  108. result = result < HALF_MIN ? HALF_MIN : result;
  109. return result;
  110. }
  111. half& half::operator+=(const half& a) {
  112. assert(diffNotMuch(*this, a));
  113. half result = *this + a;
  114. data_ = result.data_;
  115. return *this;
  116. }
  117. half& half::operator-=(const half& a) {
  118. assert(diffNotMuch(*this, a));
  119. half result = *this - a;
  120. data_ = result.data_;
  121. return *this;
  122. }
  123. half& half::operator*=(const half& a) {
  124. assert(diffNotMuch(*this, a));
  125. half result = *this * a;
  126. data_ = result.data_;
  127. return *this;
  128. }
  129. half& half::operator/=(const half& a) {
  130. assert(diffNotMuch(*this, a));
  131. half result = *this / a;
  132. data_ = result.data_;
  133. return *this;
  134. }
  135. bool half::operator<(const half& a) {
  136. float data_f = half2float(this->data_);
  137. float a_f = half2float(a.data_);
  138. return data_f < a_f ? true : false;
  139. }
  140. bool half::operator<=(const half& a) {
  141. float data_f = half2float(this->data_);
  142. float a_f = half2float(a.data_);
  143. return data_f <= a_f ? true : false;
  144. }
  145. bool half::operator>(const half& a) {
  146. float data_f = half2float(this->data_);
  147. float a_f = half2float(a.data_);
  148. return data_f > a_f ? true : false;
  149. }
  150. bool half::operator>=(const half& a) {
  151. float data_f = half2float(this->data_);
  152. float a_f = half2float(a.data_);
  153. return data_f >= a_f ? true : false;
  154. }
  155. bool half::operator==(const half& a) { return data_ == a.data_ ? true : false; }
  156. bool half::operator!=(const half& a) { return data_ != a.data_ ? true : false; }
  157. uint16_t half::float2half(const float f) {
  158. // assert((f > HALF_MIN) && (f < HALF_MAX));
  159. // assert((f == 0) || (f > HALF_PRECISION) || (f < -HALF_PRECISION));
  160. _bit32_u u;
  161. u.f = f;
  162. unsigned int bytes = u.i;
  163. unsigned char sign = (bytes >> 31) & 0x00000001;
  164. unsigned char exp = (bytes >> 23) & 0x000000FF;
  165. unsigned int eff = ((bytes >> 13) & 0x000003FF); // + ((bytes >> 12) & 0x00000001);
  166. if (exp == 0xFF) {
  167. // inf or nan
  168. exp = 0x1F;
  169. if (eff) {
  170. // nan -NaN +NaN
  171. return sign ? 0xFFFF : 0x7FFF;
  172. } else {
  173. // inf -inf +inf
  174. return sign ? 0xFC00 : 0x7C00;
  175. }
  176. } else if (exp == 0x00) {
  177. // zero or denormal
  178. if (eff) {
  179. // denormal
  180. return sign ? 0x8000 : 0x0000;
  181. } else {
  182. return sign ? 0x8000 : 0x0000;
  183. }
  184. } else if (exp - 0x7F >= 0x1F - 0x0F) {
  185. // +/- inf
  186. // inf -inf +inf
  187. return sign ? 0xFC00 : 0x7C00;
  188. } else if (exp - 0x7F <= 0x00 - 0x0F) {
  189. // denormal
  190. int shift = (0x7F - exp - 0x0E);
  191. shift = shift > 11 ? 11 : shift;
  192. return ((sign << 15) | ((0x0400 | eff) >> shift));
  193. } else {
  194. // normal number
  195. exp = ((exp - 0x7F) + 0x0F) & 0x1F;
  196. return (sign << 15) | (exp << 10) | eff;
  197. }
  198. }
  199. float half::half2float(const uint16_t f) {
  200. unsigned char sign = (f >> 15) & 0x01;
  201. unsigned char exp = (f >> 10) & 0x1F;
  202. unsigned int eff = f & 0x03FF;
  203. unsigned int result;
  204. if (exp == 0x1F) {
  205. // handle inf of nan
  206. if (eff) {
  207. // NaN
  208. result = sign ? 0xFFFFFFFF : 0x7FFFFFFF;
  209. } else {
  210. // +/- inf
  211. result = sign ? 0xFF800000 : 0x7F800000;
  212. }
  213. } else if (exp == 0x00) {
  214. if (eff) {
  215. // denormal
  216. unsigned int result_base;
  217. result = (sign << 31) | ((0x7F - 0x0E) << 23) | (eff << (23 - 10));
  218. // substruct the 1.xxxxxx in eff
  219. result_base = (sign << 31) | ((0x7F - 0x0E) << 23) | (0x00000000 << 13);
  220. _bit32_u u1, u2;
  221. u1.i = result;
  222. u2.i = result_base;
  223. return u1.f - u2.f;
  224. } else {
  225. // zero
  226. result = (sign << 31) | 0x00000000;
  227. }
  228. } else {
  229. // normal number
  230. exp = (exp - 0x0F) + 0x7F;
  231. result = (sign << 31) | (exp << 23) | (eff << (23 - 10));
  232. }
  233. _bit32_u u;
  234. u.i = result;
  235. return u.f;
  236. }