Implemented matrix multiplication, added tests for matrix mul and iter
This commit is contained in:
parent
7cad1b8dfc
commit
a998ec4339
2 changed files with 124 additions and 12 deletions
|
@ -4,6 +4,24 @@ use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssi
|
||||||
|
|
||||||
use num::{traits::NumAssign, Num, Signed};
|
use num::{traits::NumAssign, Num, Signed};
|
||||||
|
|
||||||
|
impl<T: Num + Clone> Matrix<T> {
|
||||||
|
pub fn multiply(&self, rhs: &Self) -> Self {
|
||||||
|
if !self.can_mul(&rhs) {
|
||||||
|
panic!("Unable to multiply matrices with sizes {}x{} and {}x{}",
|
||||||
|
self.width, self.height(), rhs.width, rhs.height());
|
||||||
|
}
|
||||||
|
let mut new = Matrix::new_zeroes(rhs.width, self.height());
|
||||||
|
for (i, j) in new.indices() {
|
||||||
|
let new_elem = self.row(i).into_iter()
|
||||||
|
.zip(rhs.column(j).into_iter())
|
||||||
|
.map(|(e1, e2)| e1 * e2)
|
||||||
|
.reduce(|acc, e| acc + e);
|
||||||
|
new[i][j] = new_elem.unwrap();
|
||||||
|
}
|
||||||
|
return new;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<T: Num + Clone> Add for Matrix<T> {
|
impl<T: Num + Clone> Add for Matrix<T> {
|
||||||
type Output = Self;
|
type Output = Self;
|
||||||
fn add(self, rhs: Self) -> Self::Output {
|
fn add(self, rhs: Self) -> Self::Output {
|
||||||
|
@ -109,15 +127,7 @@ impl<T: NumAssign + Clone> DivAssign<T> for Matrix<T> {
|
||||||
impl<T: Num + Clone> Mul for Matrix<T> {
|
impl<T: Num + Clone> Mul for Matrix<T> {
|
||||||
type Output = Self;
|
type Output = Self;
|
||||||
fn mul(self, rhs: Self) -> Self::Output {
|
fn mul(self, rhs: Self) -> Self::Output {
|
||||||
if !self.can_mul(&rhs) {
|
self.multiply(&rhs)
|
||||||
panic!("Unable to multiply matrices with sizes {}x{} and {}x{}",
|
|
||||||
self.width, self.height(), rhs.width, rhs.height());
|
|
||||||
}
|
|
||||||
let mut new = Matrix::new_zeroes(rhs.width, self.height());
|
|
||||||
for (i, j) in new.indices() {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
return new;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -190,8 +200,50 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[should_panic(expected = "Unable to multiply")]
|
#[should_panic(expected = "Unable to multiply")]
|
||||||
fn mul() {
|
fn mul_wrong_size() {
|
||||||
|
let _ = matrix![2; 1, 2, 3, 4] * matrix![2; 1, 2, 3, 4, 5, 6];
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn mul() {
|
||||||
|
let first = matrix![3;
|
||||||
|
1, 2, 3,
|
||||||
|
4, 5, 6,
|
||||||
|
7, 8, 9
|
||||||
|
];
|
||||||
|
let second = matrix![3;
|
||||||
|
2, 3, 4,
|
||||||
|
5, 6, 7,
|
||||||
|
8, 9, 10
|
||||||
|
];
|
||||||
|
assert_eq!(first.multiply(&second), matrix![3;
|
||||||
|
36, 42, 48,
|
||||||
|
81, 96, 111,
|
||||||
|
126, 150, 174
|
||||||
|
]);
|
||||||
|
assert_eq!(second * first, matrix![3;
|
||||||
|
42, 51, 60,
|
||||||
|
78, 96, 114,
|
||||||
|
114, 141, 168
|
||||||
|
]);
|
||||||
|
let first = matrix![3;
|
||||||
|
1, 2, 3,
|
||||||
|
4, 5, 6
|
||||||
|
];
|
||||||
|
let second = matrix![2;
|
||||||
|
2, 3,
|
||||||
|
4, 5,
|
||||||
|
6, 7
|
||||||
|
];
|
||||||
|
assert_eq!(first.multiply(&second), matrix![2;
|
||||||
|
28, 34,
|
||||||
|
64, 79
|
||||||
|
]);
|
||||||
|
assert_eq!(second * first, matrix![3;
|
||||||
|
14, 19, 24,
|
||||||
|
24, 33, 42,
|
||||||
|
34, 47, 60
|
||||||
|
])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -120,7 +120,6 @@ impl<'a, T: Num> Iterator for ColumnsIter<'a, T> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
impl Iterator for IterIndices {
|
impl Iterator for IterIndices {
|
||||||
type Item = (usize, usize);
|
type Item = (usize, usize);
|
||||||
fn next(&mut self) -> Option<Self::Item> {
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
@ -139,7 +138,7 @@ mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn iter() {
|
fn iter_elem() {
|
||||||
let data = vec![1,2,3,4,5,6,7,8,9];
|
let data = vec![1,2,3,4,5,6,7,8,9];
|
||||||
let mut matrix = Matrix::new(data.clone(), 3);
|
let mut matrix = Matrix::new(data.clone(), 3);
|
||||||
for (i, e) in matrix.iter().enumerate() {
|
for (i, e) in matrix.iter().enumerate() {
|
||||||
|
@ -153,5 +152,66 @@ mod tests {
|
||||||
assert_eq!(data[i] + 2, e);
|
assert_eq!(data[i] + 2, e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn iter_indexed() {
|
||||||
|
let data = vec![1,2,3,4,5,6,7,8,9];
|
||||||
|
let mut matrix = Matrix::new(data.clone(), 3);
|
||||||
|
let (width, height) = (matrix.width(), matrix.height());
|
||||||
|
let mut matrix_iter = matrix.iter_indexed();
|
||||||
|
for i in 0..height {
|
||||||
|
for j in 0..width {
|
||||||
|
let (mi, mj, e) = matrix_iter.next().unwrap();
|
||||||
|
assert_eq!((mi, mj), (i, j));
|
||||||
|
assert_eq!(*e, data[i * width + j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let mut matrix_iter = matrix.iter_indexed_mut();
|
||||||
|
for i in 0..height {
|
||||||
|
for j in 0..width {
|
||||||
|
let (mi, mj, e) = matrix_iter.next().unwrap();
|
||||||
|
assert_eq!((mi, mj), (i, j));
|
||||||
|
assert_eq!(*e, data[i * width + j]);
|
||||||
|
*e += 2;
|
||||||
|
assert_eq!(*e, data[i * width + j] + 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn iter_rows() {
|
||||||
|
let mut matrix = matrix![3;
|
||||||
|
1, 2, 3,
|
||||||
|
4, 5, 6,
|
||||||
|
7, 8, 9
|
||||||
|
];
|
||||||
|
let mut matrix_iter = matrix.iter_rows();
|
||||||
|
assert_eq!(matrix_iter.next().unwrap(), &[1, 2, 3]);
|
||||||
|
assert_eq!(matrix_iter.next().unwrap(), &[4, 5, 6]);
|
||||||
|
assert_eq!(matrix_iter.next().unwrap(), &[7, 8, 9]);
|
||||||
|
assert_eq!(matrix_iter.next(), None);
|
||||||
|
for row in matrix.iter_rows_mut() {
|
||||||
|
row[0] += 9;
|
||||||
|
}
|
||||||
|
let mut matrix_iter = matrix.iter_rows_mut();
|
||||||
|
assert_eq!(matrix_iter.next().unwrap(), &[10, 2, 3]);
|
||||||
|
assert_eq!(matrix_iter.next().unwrap(), &[13, 5, 6]);
|
||||||
|
assert_eq!(matrix_iter.next().unwrap(), &[16, 8, 9]);
|
||||||
|
assert_eq!(matrix_iter.next(), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn iter_columns() {
|
||||||
|
let matrix = matrix![3;
|
||||||
|
1, 2, 3,
|
||||||
|
4, 5, 6,
|
||||||
|
7, 8, 9
|
||||||
|
];
|
||||||
|
for (j, column) in matrix.iter_columns().enumerate() {
|
||||||
|
for (i, e) in column.enumerate() {
|
||||||
|
assert_eq!(*e, matrix[i][j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue