Code for implementation:
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Data.Entity;
using System.Data.Entity.Infrastructure;
using System.Linq;
using System.Linq.Expressions;
using System.Threading;
using System.Threading.Tasks;
namespace TestHelpers
{
/// <summary>
/// Create fake db set for testing db sets.
/// </summary>
/// <typeparam name="T">Pass ojbect to cereate DbContext fake.</typeparam>
public class FakeDbSet<T> : IDbSet<T> where T : class
{
private HashSet<T> _data;
public FakeDbSet()
{
_data = new HashSet<T>();
}
public virtual T Find(params object[] keyValues)
{
throw new NotImplementedException();
}
public Task<T> FindAsync(CancellationToken cancellationToken, params object[] keyValues)
{
throw new NotImplementedException();
}
public T Add(T item)
{
_data.Add(item);
return item;
}
public T Remove(T item)
{
_data.Remove(item);
return item;
}
public T Attach(T item)
{
_data.Add(item);
return item;
}
public void Detach(T item)
{
_data.Remove(item);
}
Type IQueryable.ElementType
{
get { return _data.AsQueryable().ElementType; }
}
Expression IQueryable.Expression
{
get { return _data.AsQueryable().Expression; }
}
IQueryProvider IQueryable.Provider
{
get { return _data.AsQueryable().Provider; }
}
System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
{
return _data.GetEnumerator();
}
IEnumerator<T> IEnumerable<T>.GetEnumerator()
{
return _data.GetEnumerator();
}
public T Create()
{
return Activator.CreateInstance<T>();
}
public ObservableCollection<T> Local
{
get { return new ObservableCollection<T>(_data); }
}
public TDerivedEntity Create<TDerivedEntity>() where TDerivedEntity : class, T
{
return Activator.CreateInstance<TDerivedEntity>();
}
DbLocalView<T> IDbSet<T>.Local
{
get { throw new NotImplementedException(); }
}
}
}
Now we have to modify our Context:
public partial class DataContext : DbContext, DataContext
{
static DataContext()
{
Database.SetInitializer<DataContext>(null);
}
public DataContext()
: base("Name=DataContext")
{
}
public IDbSet<PageHit> PageHits { get; set; }
protected override void OnModelCreating(DbModelBuilder modelBuilder)
{
modelBuilder.Configurations.Add(new PageHitMap());
}
}
No comments:
Post a Comment